1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "nnacl/fp32/matmul_fp32.h"
18 #include "nnacl/fp32/pack_fp32.h"
19 #ifdef ENABLE_SSE
20 #ifdef _MSC_VER
21 #include <immintrin.h>
22 #else
23 #include <x86intrin.h>
24 #endif
25 #endif
26 #ifdef ENABLE_ARM64
27 #include <arm_neon.h>
28 #endif
RowMajor2ColMajor(const float * src_ptr,float * dst_ptr,int row,int col)29 void RowMajor2ColMajor(const float *src_ptr, float *dst_ptr, int row, int col) {
30 for (int r = 0; r < row; ++r) {
31 for (int c = 0; c < col; ++c) {
32 dst_ptr[c * row + r] = src_ptr[r * col + c];
33 }
34 }
35 }
36
RowMajor2Row4Major(const float * src_ptr,float * dst_ptr,int row,int col)37 void RowMajor2Row4Major(const float *src_ptr, float *dst_ptr, int row, int col) {
38 for (int r = 0; r < row; r++) {
39 const float *src = src_ptr + r * col;
40 int c = 0;
41 for (; c < col; c++) {
42 int cd4 = c / C4NUM;
43 int cm4 = c % C4NUM;
44 dst_ptr[cd4 * C4NUM * row + r * C4NUM + cm4] = src[c];
45 }
46 for (; c < UP_ROUND(col, C4NUM); c++) {
47 int cd4 = c / C4NUM;
48 int cm4 = c % C4NUM;
49 dst_ptr[cd4 * C4NUM * row + r * C4NUM + cm4] = 0;
50 }
51 }
52 return;
53 }
54
RowMajor2Row6Major(const float * src_ptr,float * dst_ptr,int row,int col)55 void RowMajor2Row6Major(const float *src_ptr, float *dst_ptr, int row, int col) {
56 for (int r = 0; r < row; r++) {
57 const float *src = src_ptr + r * col;
58 int c = 0;
59 for (; c < col; c++) {
60 int cd6 = c / C6NUM;
61 int cm6 = c % C6NUM;
62 dst_ptr[cd6 * C6NUM * row + r * C6NUM + cm6] = src[c];
63 }
64 for (; c < UP_ROUND(col, C6NUM); c++) {
65 int cd6 = c / C6NUM;
66 int cm6 = c % C6NUM;
67 dst_ptr[cd6 * C6NUM * row + r * C6NUM + cm6] = 0;
68 }
69 }
70 return;
71 }
72
RowMajor2Row8Major(const float * src_ptr,float * dst_ptr,int row,int col)73 void RowMajor2Row8Major(const float *src_ptr, float *dst_ptr, int row, int col) {
74 for (int r = 0; r < row; r++) {
75 const float *src = src_ptr + r * col;
76 int c = 0;
77 for (; c < col; c++) {
78 int cd8 = c / C8NUM;
79 int cm8 = c % C8NUM;
80 dst_ptr[cd8 * C8NUM * row + r * C8NUM + cm8] = src[c];
81 }
82 for (; c < UP_ROUND(col, C8NUM); c++) {
83 int cd8 = c / C8NUM;
84 int cm8 = c % C8NUM;
85 dst_ptr[cd8 * C8NUM * row + r * C8NUM + cm8] = 0;
86 }
87 }
88 return;
89 }
90
RowMajor2Row12Major(const float * src_ptr,float * dst_ptr,int row,int col)91 void RowMajor2Row12Major(const float *src_ptr, float *dst_ptr, int row, int col) {
92 for (int r = 0; r < row; r++) {
93 const float *src = src_ptr + r * col;
94 int c = 0;
95 for (; c < col; c++) {
96 int cd12 = c / C12NUM;
97 int cm12 = c % C12NUM;
98 dst_ptr[cd12 * C12NUM * row + r * C12NUM + cm12] = src[c];
99 }
100 for (; c < UP_ROUND(col, C12NUM); c++) {
101 int cd12 = c / C12NUM;
102 int cm12 = c % C12NUM;
103 dst_ptr[cd12 * C12NUM * row + r * C12NUM + cm12] = 0;
104 }
105 }
106 return;
107 }
108
RowMajor2Row16Major(const float * src_ptr,float * dst_ptr,int row,int col)109 void RowMajor2Row16Major(const float *src_ptr, float *dst_ptr, int row, int col) {
110 for (int r = 0; r < row; r++) {
111 const float *src = src_ptr + r * col;
112 int c = 0;
113 for (; c < col; c++) {
114 int cd16 = c / C16NUM;
115 int cm16 = c % C16NUM;
116 dst_ptr[cd16 * C16NUM * row + r * C16NUM + cm16] = src[c];
117 }
118 for (; c < UP_ROUND(col, C16NUM); c++) {
119 int cd16 = c / C16NUM;
120 int cm16 = c % C16NUM;
121 dst_ptr[cd16 * C16NUM * row + r * C16NUM + cm16] = 0;
122 }
123 }
124 return;
125 }
126
RowMajor2Row32Major(const float * src_ptr,float * dst_ptr,int row,int col)127 void RowMajor2Row32Major(const float *src_ptr, float *dst_ptr, int row, int col) {
128 // Not exactly aligned to 32, but aligned to 24 or 16 or 8 If 32 is not met.
129 int row_block_num = UP_DIV(row, C8NUM);
130 int row_block = C4NUM;
131 for (int i = 0; i < row_block_num; i += row_block) {
132 row_block = MSMIN(C4NUM, row_block_num - i); // max_tile = 4
133 int row_remainder = MSMIN(row_block * C8NUM, row - i * C8NUM);
134 for (int oc = 0; oc < col; ++oc) {
135 memcpy(dst_ptr, src_ptr + oc * row + i * C8NUM, row_remainder * sizeof(float));
136 dst_ptr += row_block * C8NUM;
137 }
138 }
139 }
140
141 #ifdef ENABLE_ARM64
RowMajor2Col12Major_arm64(const float * src_c,float * dst_c,size_t col)142 void RowMajor2Col12Major_arm64(const float *src_c, float *dst_c, size_t col) {
143 size_t stride = col * sizeof(float);
144 asm volatile(
145 "mov x10, %[src_c]\n"
146 "mov x11, %[dst_c]\n"
147
148 "ld1 {v0.4s}, [x10], %[stride]\n"
149 "ld1 {v1.4s}, [x10], %[stride]\n"
150 "ld1 {v2.4s}, [x10], %[stride]\n"
151 "ld1 {v3.4s}, [x10], %[stride]\n"
152
153 "ld1 {v4.4s}, [x10], %[stride]\n"
154 "ld1 {v5.4s}, [x10], %[stride]\n"
155 "ld1 {v6.4s}, [x10], %[stride]\n"
156 "ld1 {v7.4s}, [x10], %[stride]\n"
157
158 "zip1 v12.4s, v0.4s, v1.4s\n"
159 "zip2 v13.4s, v0.4s, v1.4s\n"
160 "zip1 v14.4s, v2.4s, v3.4s\n"
161 "zip2 v15.4s, v2.4s, v3.4s\n"
162
163 "ld1 {v8.4s}, [x10], %[stride]\n"
164 "ld1 {v9.4s}, [x10], %[stride]\n"
165 "ld1 {v10.4s}, [x10], %[stride]\n"
166 "ld1 {v11.4s}, [x10], %[stride]\n"
167
168 "zip1 v16.4s, v4.4s, v5.4s\n"
169 "zip2 v17.4s, v4.4s, v5.4s\n"
170 "zip1 v18.4s, v6.4s, v7.4s\n"
171 "zip2 v19.4s, v6.4s, v7.4s\n"
172
173 "trn1 v20.2d, v12.2d, v14.2d\n"
174 "trn2 v23.2d, v12.2d, v14.2d\n"
175 "trn1 v26.2d, v13.2d, v15.2d\n"
176 "trn2 v29.2d, v13.2d, v15.2d\n"
177
178 "trn1 v21.2d, v16.2d, v18.2d\n"
179 "trn2 v24.2d, v16.2d, v18.2d\n"
180 "trn1 v27.2d, v17.2d, v19.2d\n"
181 "trn2 v30.2d, v17.2d, v19.2d\n"
182
183 "zip1 v12.4s, v8.4s, v9.4s\n"
184 "zip2 v13.4s, v8.4s, v9.4s\n"
185 "zip1 v14.4s, v10.4s, v11.4s\n"
186 "zip2 v15.4s, v10.4s, v11.4s\n"
187
188 "trn1 v22.2d, v12.2d, v14.2d\n"
189 "trn2 v25.2d, v12.2d, v14.2d\n"
190 "trn1 v28.2d, v13.2d, v15.2d\n"
191 "trn2 v31.2d, v13.2d, v15.2d\n"
192
193 "st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x11], #64\n"
194 "st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x11], #64\n"
195 "st1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x11], #64\n"
196
197 :
198 : [ dst_c ] "r"(dst_c), [ src_c ] "r"(src_c), [ stride ] "r"(stride)
199 : "x10", "x11", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14",
200 "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30",
201 "v31");
202 return;
203 }
204 #endif
205 #ifdef ENABLE_ARM32
RowMajor2Col12Major_arm32(const float * src_c,float * dst_c,size_t col)206 void RowMajor2Col12Major_arm32(const float *src_c, float *dst_c, size_t col) {
207 size_t stride = col * sizeof(float);
208 asm volatile(
209 "mov r10, %[src_c]\n"
210 "mov r12, %[dst_c]\n"
211
212 "vld1.32 {q0}, [r10], %[stride]\n"
213 "vld1.32 {q3}, [r10], %[stride]\n"
214 "vld1.32 {q10}, [r10], %[stride]\n"
215 "vld1.32 {q13}, [r10], %[stride]\n"
216
217 "vtrn.32 d0, d6\n"
218 "vtrn.32 d1, d7\n"
219 "vtrn.32 d20, d26\n"
220 "vtrn.32 d21, d27\n"
221
222 "vld1.32 {q1}, [r10], %[stride]\n"
223 "vld1.32 {q8}, [r10], %[stride]\n"
224 "vld1.32 {q11}, [r10], %[stride]\n"
225 "vld1.32 {q14}, [r10], %[stride]\n"
226
227 "vswp d1, d20\n"
228 "vswp d7, d26\n"
229
230 "vld1.32 {q2}, [r10], %[stride]\n"
231 "vld1.32 {q9}, [r10], %[stride]\n"
232 "vld1.32 {q12}, [r10], %[stride]\n"
233 "vld1.32 {q15}, [r10], %[stride]\n"
234
235 "vtrn.32 d2, d16\n"
236 "vtrn.32 d3, d17\n"
237 "vtrn.32 d22, d28\n"
238 "vtrn.32 d23, d29\n"
239
240 "vswp d3, d22\n"
241 "vswp d17, d28\n"
242
243 "vtrn.32 d4, d18\n"
244 "vtrn.32 d5, d19\n"
245 "vtrn.32 d24, d30\n"
246 "vtrn.32 d25, d31\n"
247
248 "vswp d5, d24\n"
249 "vswp d19, d30\n"
250
251 "vst1.32 {q0, q1}, [r12]!\n"
252 "vst1.32 {q2, q3}, [r12]!\n"
253 "vst1.32 {q8, q9}, [r12]!\n"
254 "vst1.32 {q10, q11}, [r12]!\n"
255 "vst1.32 {q12, q13}, [r12]!\n"
256 "vst1.32 {q14, q15}, [r12]!\n"
257
258 :
259 : [ dst_c ] "r"(dst_c), [ src_c ] "r"(src_c), [ stride ] "r"(stride)
260 : "r10", "r12", "q0", "q1", "q2", "q3", "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15");
261 return;
262 }
263 #endif
RowMajor2Col12Major(const float * src_ptr,float * dst_ptr,int row,int col)264 void RowMajor2Col12Major(const float *src_ptr, float *dst_ptr, int row, int col) {
265 const float *src_r = src_ptr;
266 float *dst_r = dst_ptr;
267 int ri = 0;
268 for (; ri < (row / C12NUM * C12NUM); ri += C12NUM) {
269 int ci = 0;
270 for (; ci < (col / C4NUM * C4NUM); ci += C4NUM) {
271 const float *src_c = src_r + ci;
272 float *dst_c = dst_r + ci * C12NUM;
273 #ifdef ENABLE_ARM64
274 RowMajor2Col12Major_arm64(src_c, dst_c, col);
275 #elif ENABLE_ARM32
276 RowMajor2Col12Major_arm32(src_c, dst_c, col);
277 #elif ENABLE_SSE
278 __m128 src1 = _mm_loadu_ps(src_c);
279 __m128 src2 = _mm_loadu_ps(src_c + col);
280 __m128 src3 = _mm_loadu_ps(src_c + 2 * col);
281 __m128 src4 = _mm_loadu_ps(src_c + 3 * col);
282 src_c += 4 * col;
283 __m128 src12L = _mm_unpacklo_ps(src1, src2);
284 __m128 src12H = _mm_unpackhi_ps(src1, src2);
285 __m128 src34L = _mm_unpacklo_ps(src3, src4);
286 __m128 src34H = _mm_unpackhi_ps(src3, src4);
287
288 __m128 dst0 = _mm_movelh_ps(src12L, src34L);
289 __m128 dst3 = _mm_movehl_ps(src34L, src12L);
290 __m128 dst6 = _mm_movelh_ps(src12H, src34H);
291 __m128 dst9 = _mm_movehl_ps(src34H, src12H);
292
293 __m128 src5 = _mm_loadu_ps(src_c);
294 __m128 src6 = _mm_loadu_ps(src_c + col);
295 __m128 src7 = _mm_loadu_ps(src_c + 2 * col);
296 __m128 src8 = _mm_loadu_ps(src_c + 3 * col);
297 src_c += 4 * col;
298 __m128 src56L = _mm_unpacklo_ps(src5, src6);
299 __m128 src56H = _mm_unpackhi_ps(src5, src6);
300 __m128 src78L = _mm_unpacklo_ps(src7, src8);
301 __m128 src78H = _mm_unpackhi_ps(src7, src8);
302 __m128 dst1 = _mm_movelh_ps(src56L, src78L);
303 __m128 dst4 = _mm_movehl_ps(src78L, src56L);
304 __m128 dst7 = _mm_movelh_ps(src56H, src78H);
305 __m128 dst10 = _mm_movehl_ps(src78H, src56H);
306
307 __m128 src9 = _mm_loadu_ps(src_c);
308 __m128 src10 = _mm_loadu_ps(src_c + col);
309 __m128 src11 = _mm_loadu_ps(src_c + 2 * col);
310 __m128 src12 = _mm_loadu_ps(src_c + 3 * col);
311 src_c += 4 * col;
312 __m128 src910L = _mm_unpacklo_ps(src9, src10);
313 __m128 src910H = _mm_unpackhi_ps(src9, src10);
314 __m128 src1112L = _mm_unpacklo_ps(src11, src12);
315 __m128 src1112H = _mm_unpackhi_ps(src11, src12);
316 __m128 dst2 = _mm_movelh_ps(src910L, src1112L);
317 __m128 dst5 = _mm_movehl_ps(src1112L, src910L);
318 __m128 dst8 = _mm_movelh_ps(src910H, src1112H);
319 __m128 dst11 = _mm_movehl_ps(src1112H, src910H);
320
321 _mm_storeu_ps(dst_c, dst0);
322 _mm_storeu_ps(dst_c + 4, dst1);
323 _mm_storeu_ps(dst_c + 8, dst2);
324 _mm_storeu_ps(dst_c + 12, dst3);
325 _mm_storeu_ps(dst_c + 16, dst4);
326 _mm_storeu_ps(dst_c + 20, dst5);
327 _mm_storeu_ps(dst_c + 24, dst6);
328 _mm_storeu_ps(dst_c + 28, dst7);
329 _mm_storeu_ps(dst_c + 32, dst8);
330 _mm_storeu_ps(dst_c + 36, dst9);
331 _mm_storeu_ps(dst_c + 40, dst10);
332 _mm_storeu_ps(dst_c + 44, dst11);
333 #else
334 for (int tr = 0; tr < C12NUM; tr++) {
335 for (int tc = 0; tc < C4NUM; tc++) {
336 dst_c[tc * C12NUM + tr] = src_c[tr * col + tc];
337 }
338 }
339 #endif
340 }
341 for (; ci < col; ci++) {
342 const float *src_c = src_r + ci;
343 float *dst_c = dst_r + ci * C12NUM;
344 for (int i = 0; i < C12NUM; i++) {
345 dst_c[i] = src_c[i * col];
346 }
347 }
348 src_r += C12NUM * col;
349 dst_r += C12NUM * col;
350 }
351 for (; ri < row; ri++, dst_r++, src_r += col) {
352 for (int i = 0; i < col; i++) {
353 dst_r[i * C12NUM] = src_r[i];
354 }
355 }
356 for (; ri < UP_ROUND(row, C12NUM); ri++, dst_r++) {
357 for (int i = 0; i < col; i++) {
358 dst_r[i * C12NUM] = 0;
359 }
360 }
361 }
362
363 #ifdef ENABLE_ARM64
RowMajor2Col8Major_arm64(const float * src_c,float * dst_c,size_t col)364 void RowMajor2Col8Major_arm64(const float *src_c, float *dst_c, size_t col) {
365 size_t stride = col * sizeof(float);
366 asm volatile(
367 "mov x10, %[src_c]\n"
368 "mov x11, %[dst_c]\n"
369
370 "ld1 {v0.4s, v1.4s}, [x10], %[stride]\n"
371 "ld1 {v2.4s, v3.4s}, [x10], %[stride]\n"
372 "ld1 {v4.4s, v5.4s}, [x10], %[stride]\n"
373 "ld1 {v6.4s, v7.4s}, [x10], %[stride]\n"
374
375 "zip1 v8.4s, v0.4s, v2.4s\n"
376 "zip2 v9.4s, v0.4s, v2.4s\n"
377 "zip1 v10.4s, v4.4s, v6.4s\n"
378 "zip2 v11.4s, v4.4s, v6.4s\n"
379
380 "ld1 {v16.4s, v17.4s}, [x10], %[stride]\n"
381 "ld1 {v18.4s, v19.4s}, [x10], %[stride]\n"
382 "ld1 {v20.4s, v21.4s}, [x10], %[stride]\n"
383 "ld1 {v22.4s, v23.4s}, [x10], %[stride]\n"
384
385 "zip1 v12.4s, v1.4s, v3.4s\n"
386 "zip2 v13.4s, v1.4s, v3.4s\n"
387 "zip1 v14.4s, v5.4s, v7.4s\n"
388 "zip2 v15.4s, v5.4s, v7.4s\n"
389
390 "trn1 v0.2d, v8.2d, v10.2d\n"
391 "trn2 v1.2d, v8.2d, v10.2d\n"
392 "trn1 v2.2d, v9.2d, v11.2d\n"
393 "trn2 v3.2d, v9.2d, v11.2d\n"
394
395 "zip1 v24.4s, v16.4s, v18.4s\n"
396 "zip2 v25.4s, v16.4s, v18.4s\n"
397 "zip1 v26.4s, v20.4s, v22.4s\n"
398 "zip2 v27.4s, v20.4s, v22.4s\n"
399
400 "trn1 v4.2d, v12.2d, v14.2d\n"
401 "trn2 v5.2d, v12.2d, v14.2d\n"
402 "trn1 v6.2d, v13.2d, v15.2d\n"
403 "trn2 v7.2d, v13.2d, v15.2d\n"
404
405 "zip1 v28.4s, v17.4s, v19.4s\n"
406 "zip2 v29.4s, v17.4s, v19.4s\n"
407 "zip1 v30.4s, v21.4s, v23.4s\n"
408 "zip2 v31.4s, v21.4s, v23.4s\n"
409
410 "trn1 v16.2d, v24.2d, v26.2d\n"
411 "trn2 v17.2d, v24.2d, v26.2d\n"
412 "trn1 v18.2d, v25.2d, v27.2d\n"
413 "trn2 v19.2d, v25.2d, v27.2d\n"
414
415 "trn1 v20.2d, v28.2d, v30.2d\n"
416 "trn2 v21.2d, v28.2d, v30.2d\n"
417 "trn1 v22.2d, v29.2d, v31.2d\n"
418 "trn2 v23.2d, v29.2d, v31.2d\n"
419
420 "st1 {v0.4s}, [x11], #16\n"
421 "st1 {v16.4s}, [x11], #16\n"
422 "st1 {v1.4s}, [x11], #16\n"
423 "st1 {v17.4s}, [x11], #16\n"
424 "st1 {v2.4s}, [x11], #16\n"
425 "st1 {v18.4s}, [x11], #16\n"
426 "st1 {v3.4s}, [x11], #16\n"
427 "st1 {v19.4s}, [x11], #16\n"
428 "st1 {v4.4s}, [x11], #16\n"
429 "st1 {v20.4s}, [x11], #16\n"
430 "st1 {v5.4s}, [x11], #16\n"
431 "st1 {v21.4s}, [x11], #16\n"
432 "st1 {v6.4s}, [x11], #16\n"
433 "st1 {v22.4s}, [x11], #16\n"
434 "st1 {v7.4s}, [x11], #16\n"
435 "st1 {v23.4s}, [x11], #16\n"
436
437 :
438 : [ dst_c ] "r"(dst_c), [ src_c ] "r"(src_c), [ stride ] "r"(stride)
439 : "x10", "x11", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14",
440 "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30",
441 "v31");
442 return;
443 }
444 #endif
445 #ifdef ENABLE_ARM32
446 #ifndef SUPPORT_NNIE
RowMajor2Col8Major_arm32(const float * src_c,float * dst_c,size_t col)447 void RowMajor2Col8Major_arm32(const float *src_c, float *dst_c, size_t col) {
448 size_t stride = col * sizeof(float);
449 asm volatile(
450 "mov r10, %[src_c]\n"
451 "mov r11, %[dst_c]\n"
452
453 "vld1.32 {q0}, [r10], %[stride]\n"
454 "vld1.32 {q2}, [r10], %[stride]\n"
455 "vld1.32 {q4}, [r10], %[stride]\n"
456 "vld1.32 {q6}, [r10], %[stride]\n"
457
458 "vtrn.32 d0, d4\n"
459 "vtrn.32 d1, d5\n"
460 "vtrn.32 d8, d12\n"
461 "vtrn.32 d9, d13\n"
462
463 "vld1.32 {q1}, [r10], %[stride]\n"
464 "vld1.32 {q3}, [r10], %[stride]\n"
465 "vld1.32 {q5}, [r10], %[stride]\n"
466 "vld1.32 {q7}, [r10], %[stride]\n"
467
468 "vswp d1, d8\n"
469 "vswp d5, d12\n"
470
471 "vtrn.32 d2, d6\n"
472 "vtrn.32 d3, d7\n"
473 "vtrn.32 d10, d14\n"
474 "vtrn.32 d11, d15\n"
475
476 "vswp d3, d10\n"
477 "vswp d7, d14\n"
478
479 "vst1.32 {q0, q1}, [r11]!\n"
480 "vst1.32 {q2, q3}, [r11]!\n"
481 "vst1.32 {q4, q5}, [r11]!\n"
482 "vst1.32 {q6, q7}, [r11]!\n"
483
484 :
485 : [ dst_c ] "r"(dst_c), [ src_c ] "r"(src_c), [ stride ] "r"(stride)
486 : "r10", "r11", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7");
487 return;
488 }
489 #else
RowMajor2Col8Major_arm32(const float * src_c,float * dst_c,size_t col)490 void RowMajor2Col8Major_arm32(const float *src_c, float *dst_c, size_t col) {
491 size_t stride = col * sizeof(float);
492 asm volatile(
493 "mov r10, %[src_c]\n"
494 "mov r7, %[dst_c]\n"
495
496 "vld1.32 {q0}, [r10], %[stride]\n"
497 "vld1.32 {q2}, [r10], %[stride]\n"
498 "vld1.32 {q4}, [r10], %[stride]\n"
499 "vld1.32 {q6}, [r10], %[stride]\n"
500
501 "vtrn.32 d0, d4\n"
502 "vtrn.32 d1, d5\n"
503 "vtrn.32 d8, d12\n"
504 "vtrn.32 d9, d13\n"
505
506 "vld1.32 {q1}, [r10], %[stride]\n"
507 "vld1.32 {q3}, [r10], %[stride]\n"
508 "vld1.32 {q5}, [r10], %[stride]\n"
509 "vld1.32 {q7}, [r10], %[stride]\n"
510
511 "vswp d1, d8\n"
512 "vswp d5, d12\n"
513
514 "vtrn.32 d2, d6\n"
515 "vtrn.32 d3, d7\n"
516 "vtrn.32 d10, d14\n"
517 "vtrn.32 d11, d15\n"
518
519 "vswp d3, d10\n"
520 "vswp d7, d14\n"
521
522 "vst1.32 {q0, q1}, [r7]!\n"
523 "vst1.32 {q2, q3}, [r7]!\n"
524 "vst1.32 {q4, q5}, [r7]!\n"
525 "vst1.32 {q6, q7}, [r7]!\n"
526
527 :
528 : [ dst_c ] "r"(dst_c), [ src_c ] "r"(src_c), [ stride ] "r"(stride)
529 : "r10", "r7", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7");
530 return;
531 }
532 #endif
533 #endif
RowMajor2Col8Major(const float * src_ptr,float * dst_ptr,int row,int col)534 void RowMajor2Col8Major(const float *src_ptr, float *dst_ptr, int row, int col) {
535 int row8 = row / C8NUM * C8NUM;
536 #ifdef ENABLE_ARM64
537 int col_skip = col / C8NUM * C8NUM;
538 int skip_size = C8NUM;
539 #else
540 int col_skip = col / C4NUM * C4NUM;
541 int skip_size = C4NUM;
542 #endif
543 const float *src_r = src_ptr;
544 float *dst_r = dst_ptr;
545
546 int ri = 0;
547 for (; ri < row8; ri += C8NUM) {
548 int ci = 0;
549 for (; ci < col_skip; ci += skip_size) {
550 const float *src_c = src_r + ci;
551 float *dst_c = dst_r + ci * C8NUM;
552
553 #ifdef ENABLE_ARM64
554 RowMajor2Col8Major_arm64(src_c, dst_c, col);
555 #elif ENABLE_ARM32
556 RowMajor2Col8Major_arm32(src_c, dst_c, col);
557 #elif ENABLE_SSE
558 __m128 src1 = _mm_loadu_ps(src_c);
559 __m128 src2 = _mm_loadu_ps(src_c + col);
560 __m128 src3 = _mm_loadu_ps(src_c + 2 * col);
561 __m128 src4 = _mm_loadu_ps(src_c + 3 * col);
562 src_c += 4 * col;
563 __m128 src12L = _mm_unpacklo_ps(src1, src2); // x5
564 __m128 src12H = _mm_unpackhi_ps(src1, src2); // x1
565 __m128 src34L = _mm_unpacklo_ps(src3, src4); // x
566 __m128 src34H = _mm_unpackhi_ps(src3, src4);
567 _mm_storeu_ps(dst_c, _mm_movelh_ps(src12L, src34L));
568 _mm_storeu_ps(dst_c + 8, _mm_movehl_ps(src34L, src12L));
569 _mm_storeu_ps(dst_c + 16, _mm_movelh_ps(src12H, src34H));
570 _mm_storeu_ps(dst_c + 24, _mm_movehl_ps(src34H, src12H));
571
572 __m128 src5 = _mm_loadu_ps(src_c);
573 __m128 src6 = _mm_loadu_ps(src_c + col);
574 __m128 src7 = _mm_loadu_ps(src_c + 2 * col);
575 __m128 src8 = _mm_loadu_ps(src_c + 3 * col);
576 src_c += 4 * col;
577 __m128 src56L = _mm_unpacklo_ps(src5, src6);
578 __m128 src56H = _mm_unpackhi_ps(src5, src6);
579 __m128 src78L = _mm_unpacklo_ps(src7, src8);
580 __m128 src78H = _mm_unpackhi_ps(src7, src8);
581 _mm_storeu_ps(dst_c + 4, _mm_movelh_ps(src56L, src78L));
582 _mm_storeu_ps(dst_c + 12, _mm_movehl_ps(src78L, src56L));
583 _mm_storeu_ps(dst_c + 20, _mm_movelh_ps(src56H, src78H));
584 _mm_storeu_ps(dst_c + 28, _mm_movehl_ps(src78H, src56H));
585 #else
586 for (int tr = 0; tr < 8; tr++) {
587 for (int tc = 0; tc < 4; tc++) {
588 dst_c[tc * 8 + tr] = src_c[tr * col + tc];
589 }
590 }
591 #endif
592 }
593 for (; ci < col; ci++) {
594 const float *src_c = src_r + ci;
595 float *dst_c = dst_r + ci * C8NUM;
596 for (int i = 0; i < C8NUM; i++) {
597 dst_c[i] = src_c[i * col];
598 }
599 }
600 src_r += C8NUM * col;
601 dst_r += C8NUM * col;
602 }
603 for (; ri < row; ri++, src_r += col, dst_r++) {
604 for (int i = 0; i < col; i++) {
605 dst_r[i * C8NUM] = src_r[i];
606 }
607 }
608
609 for (; ri < UP_ROUND(row, C8NUM); ri++, dst_r++) {
610 for (int i = 0; i < col; i++) {
611 dst_r[i * C8NUM] = 0;
612 }
613 }
614 }
615
RowMajor2Col16Major(const float * src_ptr,float * dst_ptr,int row,int col)616 void RowMajor2Col16Major(const float *src_ptr, float *dst_ptr, int row, int col) {
617 int row16 = row / C16NUM * C16NUM;
618 int col8 = col / C8NUM * C8NUM;
619 const float *src_r = src_ptr;
620 float *dst_r = dst_ptr;
621
622 int ri = 0;
623 for (; ri < row16; ri += C16NUM) {
624 int ci = 0;
625 for (; ci < col8; ci += C8NUM) {
626 const float *src_c = src_r + ci;
627 float *dst_c = dst_r + ci * C16NUM;
628 #ifdef ENABLE_AVX
629 Transpose8X8Fp32Avx(src_c, dst_c, col, C16NUM);
630 Transpose8X8Fp32Avx(src_c + C8NUM * col, dst_c + C8NUM, col, C16NUM);
631 #endif
632 for (int tr = 0; tr < C16NUM; tr++) {
633 for (int tc = 0; tc < C8NUM; tc++) {
634 dst_c[tc * C16NUM + tr] = src_c[tr * col + tc];
635 }
636 }
637 }
638 for (; ci < col; ci++) {
639 const float *src_c = src_r + ci;
640 float *dst_c = dst_r + ci * C16NUM;
641 for (int i = 0; i < C16NUM; i++) {
642 dst_c[i] = src_c[i * col];
643 }
644 }
645 src_r += C16NUM * col;
646 dst_r += C16NUM * col;
647 }
648 for (; ri < row; ri++) {
649 for (int i = 0; i < col; i++) {
650 dst_r[i * C16NUM] = src_r[i];
651 }
652 src_r += col;
653 dst_r += 1;
654 }
655
656 int total_row = UP_ROUND(row, C16NUM);
657 for (; ri < total_row; ri++) {
658 for (int i = 0; i < col; i++) {
659 dst_r[i * C16NUM] = 0;
660 }
661 dst_r += 1;
662 }
663 }
664
RowMajor2Col32Major(const float * src_ptr,float * dst_ptr,int row,int col)665 void RowMajor2Col32Major(const float *src_ptr, float *dst_ptr, int row, int col) {
666 // Not exactly aligned to 32, but aligned to 24 or 16 or 8 If 32 is not met.
667 int col_block_num = UP_DIV(col, C8NUM);
668 int col_block = C4NUM;
669 for (int i = 0; i < col_block_num; i += col_block) {
670 col_block = MSMIN(C4NUM, col_block_num - i); // max_tile = 4
671 int index = i * row * C8NUM;
672 int col_remainder = MSMIN(C8NUM * col_block, col - i * C8NUM);
673 for (int ir = 0; ir < row; ++ir) {
674 for (int oc = 0; oc < col_remainder; ++oc) {
675 int oc_index = oc * row + ir + index;
676 dst_ptr[oc] = src_ptr[oc_index];
677 }
678 dst_ptr += col_block * C8NUM;
679 }
680 }
681 }
682
RowMajor2Col6Major(const float * src_ptr,float * dst_ptr,int row,int col)683 void RowMajor2Col6Major(const float *src_ptr, float *dst_ptr, int row, int col) {
684 int totalRow = UP_ROUND(row, C6NUM);
685 int row6 = row / C6NUM * C6NUM;
686 int col8 = col / C8NUM * C8NUM;
687 const float *src_r = src_ptr;
688 float *dst_r = dst_ptr;
689
690 int ri = 0;
691 for (; ri < row6; ri += C6NUM) {
692 int ci = 0;
693 for (; ci < col8; ci += C8NUM) {
694 const float *src_c = src_r + ci;
695 float *dst_c = dst_r + ci * C6NUM;
696
697 #ifdef ENABLE_AVX
698 __m256 src0 = _mm256_loadu_ps(src_c);
699 __m256 src1 = _mm256_loadu_ps(src_c + col);
700 __m256 src2 = _mm256_loadu_ps(src_c + 2 * col);
701 __m256 src3 = _mm256_loadu_ps(src_c + 3 * col);
702 __m256 src4 = _mm256_loadu_ps(src_c + 4 * col);
703 __m256 src5 = _mm256_loadu_ps(src_c + 5 * col);
704 __m256 trans0 = _mm256_unpacklo_ps(src0, src1);
705 __m256 trans1 = _mm256_unpacklo_ps(src2, src3);
706 __m256 trans2 = _mm256_unpacklo_ps(src4, src5);
707 __m256 trans3 = _mm256_unpackhi_ps(src0, src1);
708 __m256 trans4 = _mm256_unpackhi_ps(src2, src3);
709 __m256 trans5 = _mm256_unpackhi_ps(src4, src5);
710 __m128 lo0 = _mm256_castps256_ps128(trans0);
711 __m128 lo1 = _mm256_castps256_ps128(trans1);
712 __m128 lo2 = _mm256_castps256_ps128(trans2);
713 __m128 lo3 = _mm256_castps256_ps128(trans3);
714 __m128 lo4 = _mm256_castps256_ps128(trans4);
715 __m128 lo5 = _mm256_castps256_ps128(trans5);
716 __m128 hi0 = _mm256_extractf128_ps(trans0, 1);
717 __m128 hi1 = _mm256_extractf128_ps(trans1, 1);
718 __m128 hi2 = _mm256_extractf128_ps(trans2, 1);
719 __m128 hi3 = _mm256_extractf128_ps(trans3, 1);
720 __m128 hi4 = _mm256_extractf128_ps(trans4, 1);
721 __m128 hi5 = _mm256_extractf128_ps(trans5, 1);
722 __m128 res0 = _mm_shuffle_ps(lo0, lo1, _MM_SHUFFLE(1, 0, 1, 0));
723 __m128 res1 = _mm_shuffle_ps(lo2, lo0, _MM_SHUFFLE(3, 2, 1, 0));
724 __m128 res2 = _mm_shuffle_ps(lo1, lo2, _MM_SHUFFLE(3, 2, 3, 2));
725 __m128 res3 = _mm_shuffle_ps(lo3, lo4, _MM_SHUFFLE(1, 0, 1, 0));
726 __m128 res4 = _mm_shuffle_ps(lo5, lo3, _MM_SHUFFLE(3, 2, 1, 0));
727 __m128 res5 = _mm_shuffle_ps(lo4, lo5, _MM_SHUFFLE(3, 2, 3, 2));
728 __m128 res6 = _mm_shuffle_ps(hi0, hi1, _MM_SHUFFLE(1, 0, 1, 0));
729 __m128 res7 = _mm_shuffle_ps(hi2, hi0, _MM_SHUFFLE(3, 2, 1, 0));
730 __m128 res8 = _mm_shuffle_ps(hi1, hi2, _MM_SHUFFLE(3, 2, 3, 2));
731 __m128 res9 = _mm_shuffle_ps(hi3, hi4, _MM_SHUFFLE(1, 0, 1, 0));
732 __m128 res10 = _mm_shuffle_ps(hi5, hi3, _MM_SHUFFLE(3, 2, 1, 0));
733 __m128 res11 = _mm_shuffle_ps(hi4, hi5, _MM_SHUFFLE(3, 2, 3, 2));
734 _mm_storeu_ps(dst_c, res0);
735 _mm_storeu_ps(dst_c + 4, res1);
736 _mm_storeu_ps(dst_c + 8, res2);
737 _mm_storeu_ps(dst_c + 12, res3);
738 _mm_storeu_ps(dst_c + 16, res4);
739 _mm_storeu_ps(dst_c + 20, res5);
740 _mm_storeu_ps(dst_c + 24, res6);
741 _mm_storeu_ps(dst_c + 28, res7);
742 _mm_storeu_ps(dst_c + 32, res8);
743 _mm_storeu_ps(dst_c + 36, res9);
744 _mm_storeu_ps(dst_c + 40, res10);
745 _mm_storeu_ps(dst_c + 44, res11);
746 #else
747 for (int tr = 0; tr < C6NUM; tr++) {
748 for (int tc = 0; tc < C8NUM; tc++) {
749 dst_c[tc * C6NUM + tr] = src_c[tr * col + tc];
750 }
751 }
752 #endif
753 }
754 for (; ci < col; ci++) {
755 const float *src_c = src_r + ci;
756 float *dst_c = dst_r + ci * C6NUM;
757 for (int i = 0; i < C6NUM; i++) {
758 dst_c[i] = src_c[i * col];
759 }
760 }
761 src_r += C6NUM * col;
762 dst_r += C6NUM * col;
763 }
764
765 for (; ri < row; ri++) {
766 for (int i = 0; i < col; i++) {
767 dst_r[i * C6NUM] = src_r[i];
768 }
769 src_r += col;
770 dst_r += 1;
771 }
772
773 for (; ri < totalRow; ri++) {
774 for (int i = 0; i < col; i++) {
775 dst_r[i * C6NUM] = 0;
776 }
777 dst_r += 1;
778 }
779 }
780
RowMajor2Col4Major(const float * src_ptr,float * dst_ptr,int row,int col)781 void RowMajor2Col4Major(const float *src_ptr, float *dst_ptr, int row, int col) {
782 int total_row = UP_ROUND(row, C4NUM);
783 int row4 = row / C4NUM * C4NUM;
784 int col4 = col / C4NUM * C4NUM;
785 const float *src_r = src_ptr;
786 float *dst_r = dst_ptr;
787
788 int ri = 0;
789 for (; ri < row4; ri += C4NUM) {
790 int ci = 0;
791 for (; ci < col4; ci += C4NUM) {
792 const float *src_c = src_r + ci;
793 float *dst_c = dst_r + ci * C4NUM;
794
795 #ifdef ENABLE_ARM32
796 int stride = col * 4;
797 asm volatile(
798 "mov r10, %[src_c]\n"
799 "mov r12, %[dst_c]\n"
800
801 "vld1.32 {q0}, [r10], %[stride]\n"
802 "vld1.32 {q1}, [r10], %[stride]\n"
803 "vld1.32 {q2}, [r10], %[stride]\n"
804 "vld1.32 {q3}, [r10], %[stride]\n"
805
806 "vtrn.32 d0, d2\n"
807 "vtrn.32 d1, d3\n"
808 "vtrn.32 d4, d6\n"
809 "vtrn.32 d5, d7\n"
810
811 "vswp d1, d4\n"
812 "vswp d3, d6\n"
813
814 "vst1.32 {q0}, [r12]!\n"
815 "vst1.32 {q1}, [r12]!\n"
816 "vst1.32 {q2}, [r12]!\n"
817 "vst1.32 {q3}, [r12]!\n"
818
819 :
820 : [ dst_c ] "r"(dst_c), [ src_c ] "r"(src_c), [ stride ] "r"(stride)
821 : "r10", "r12", "q0", "q1", "q2", "q3");
822 #elif ENABLE_SSE
823 __m128 src1 = _mm_loadu_ps(src_c);
824 __m128 src2 = _mm_loadu_ps(src_c + col);
825 __m128 src3 = _mm_loadu_ps(src_c + 2 * col);
826 __m128 src4 = _mm_loadu_ps(src_c + 3 * col);
827 src_c += 4 * col;
828 __m128 src12L = _mm_unpacklo_ps(src1, src2);
829 __m128 src12H = _mm_unpackhi_ps(src1, src2);
830 __m128 src34L = _mm_unpacklo_ps(src3, src4);
831 __m128 src34H = _mm_unpackhi_ps(src3, src4);
832
833 __m128 dst0 = _mm_movelh_ps(src12L, src34L);
834 __m128 dst1 = _mm_movehl_ps(src34L, src12L);
835 __m128 dst2 = _mm_movelh_ps(src12H, src34H);
836 __m128 dst3 = _mm_movehl_ps(src34H, src12H);
837
838 _mm_storeu_ps(dst_c, dst0);
839 _mm_storeu_ps(dst_c + 4, dst1);
840 _mm_storeu_ps(dst_c + 8, dst2);
841 _mm_storeu_ps(dst_c + 12, dst3);
842 #else
843 for (size_t tr = 0; tr < C4NUM; tr++) {
844 for (size_t tc = 0; tc < C4NUM; tc++) {
845 dst_c[tc * C4NUM + tr] = src_c[tr * col + tc];
846 }
847 }
848 #endif
849 }
850 for (; ci < col; ci++) {
851 const float *src_c = src_r + ci;
852 float *dst_c = dst_r + ci * C4NUM;
853 for (int i = 0; i < C4NUM; i++) {
854 dst_c[i] = src_c[i * col];
855 }
856 }
857 src_r += C4NUM * col;
858 dst_r += C4NUM * col;
859 }
860 for (; ri < row; ri++) {
861 for (int i = 0; i < col; i++) {
862 dst_r[i * C4NUM] = src_r[i];
863 }
864 src_r += col;
865 dst_r += 1;
866 }
867
868 for (; ri < total_row; ri++) {
869 for (int i = 0; i < col; i++) {
870 dst_r[i * C4NUM] = 0;
871 }
872 dst_r += 1;
873 }
874 }
875
876 #ifndef ENABLE_ARM
MatVecMulFp32(const float * a,const float * b,float * c,const float * bias,int act_type,int depth,int col)877 void MatVecMulFp32(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int col) {
878 for (int ci = 0; ci < col; ci++) {
879 float value = 0;
880 for (int di = 0; di < depth; di++) {
881 value += a[di] * b[ci * depth + di];
882 }
883 if (bias != NULL) value += bias[ci];
884 if (act_type == ActType_Relu6) value = MSMIN(6.0f, value);
885 if (act_type == ActType_Relu || act_type == ActType_Relu6) value = MSMAX(0.0f, value);
886 c[ci] = value;
887 }
888 }
889 #endif
890
891 #ifdef ENABLE_ARM64
892 // 4x8
MatVecMulFp32Neon64(const float * a,const float * b,float * c,const float * bias,int act_type,int depth,int col,int align_col)893 void MatVecMulFp32Neon64(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int col,
894 int align_col) {
895 int ci = 0;
896 for (; ci < align_col - C8NUM + 1; ci += C8NUM) {
897 float32x4_t acc_0;
898 float32x4_t acc_1;
899 if (bias != NULL) {
900 acc_0 = vld1q_f32(bias + ci);
901 acc_1 = vld1q_f32(bias + ci + C4NUM);
902 } else {
903 acc_0 = vdupq_n_f32(0.0f);
904 acc_1 = vdupq_n_f32(0.0f);
905 }
906 const float *bv_base = b + ci * depth;
907 int di = 0;
908 for (; di < depth - C4NUM + 1; di += C4NUM) {
909 float32x4_t av = vld1q_f32(a + di);
910 float32x4_t bv_00 = vld1q_f32(bv_base);
911 float32x4_t bv_10 = vld1q_f32(bv_base + C4NUM);
912 bv_base += C8NUM;
913 float32x4_t bv_01 = vld1q_f32(bv_base);
914 float32x4_t bv_11 = vld1q_f32(bv_base + C4NUM);
915 bv_base += C8NUM;
916 float32x4_t bv_02 = vld1q_f32(bv_base);
917 float32x4_t bv_12 = vld1q_f32(bv_base + C4NUM);
918 bv_base += C8NUM;
919 float32x4_t bv_03 = vld1q_f32(bv_base);
920 float32x4_t bv_13 = vld1q_f32(bv_base + C4NUM);
921 bv_base += C8NUM;
922 acc_0 = vmlaq_n_f32(acc_0, bv_00, av[0]);
923 acc_1 = vmlaq_n_f32(acc_1, bv_10, av[0]);
924 acc_0 = vmlaq_n_f32(acc_0, bv_01, av[1]);
925 acc_1 = vmlaq_n_f32(acc_1, bv_11, av[1]);
926 acc_0 = vmlaq_n_f32(acc_0, bv_02, av[2]);
927 acc_1 = vmlaq_n_f32(acc_1, bv_12, av[2]);
928 acc_0 = vmlaq_n_f32(acc_0, bv_03, av[3]);
929 acc_1 = vmlaq_n_f32(acc_1, bv_13, av[3]);
930 }
931 if (di < depth) {
932 for (; di < depth; ++di) {
933 float ai = a[di];
934 float32x4_t bv0 = vld1q_f32(bv_base);
935 float32x4_t bv1 = vld1q_f32(bv_base + C4NUM);
936 acc_0 = vmlaq_n_f32(acc_0, bv0, ai);
937 acc_1 = vmlaq_n_f32(acc_1, bv1, ai);
938 bv_base += C8NUM;
939 }
940 } // only save actual col num data
941 if (ci + C4NUM - 1 >= col) {
942 int c_remain = col - ci;
943 for (int i = 0; i < c_remain; ++i) {
944 if (act_type == ActType_Relu) {
945 c[i] = MSMAX(acc_0[i], 0.0f);
946 } else if (act_type == ActType_Relu6) {
947 c[i] = MSMIN(MSMAX(acc_0[i], 0.0f), 6.0f);
948 } else {
949 c[i] = acc_0[i];
950 }
951 }
952 return;
953 }
954 if (act_type == ActType_Relu) {
955 acc_0 = vmaxq_f32(acc_0, vdupq_n_f32(0.0f));
956 } else if (act_type == ActType_Relu6) {
957 acc_0 = vminq_f32(vmaxq_f32(acc_0, vdupq_n_f32(0.0f)), vdupq_n_f32(6.0f));
958 }
959 vst1q_f32(c, acc_0);
960 if (ci + C8NUM - 1 >= col) {
961 int c_remain = col - ci - C4NUM;
962 for (int i = 0; i < c_remain; ++i) {
963 if (act_type == ActType_Relu) {
964 c[C4NUM + i] = MSMAX(acc_1[i], 0.0f);
965 } else if (act_type == ActType_Relu6) {
966 c[C4NUM + i] = MSMIN(MSMAX(acc_1[i], 0.0f), 6.0f);
967 } else {
968 c[C4NUM + i] = acc_1[i];
969 }
970 }
971 return;
972 }
973 if (act_type == ActType_Relu) {
974 acc_1 = vmaxq_f32(acc_1, vdupq_n_f32(0.0f));
975 } else if (act_type == ActType_Relu6) {
976 acc_1 = vminq_f32(vmaxq_f32(acc_1, vdupq_n_f32(0.0f)), vdupq_n_f32(6.0f));
977 }
978 vst1q_f32(c + C4NUM, acc_1);
979 c += C8NUM;
980 }
981 }
982 #endif
983
MatMul12x8(const float * a,const float * b,float * dst,const float * bias,ActType act_type,int deep,int row,int col,int stride,int out_type)984 void MatMul12x8(const float *a, const float *b, float *dst, const float *bias, ActType act_type, int deep, int row,
985 int col, int stride, int out_type) {
986 if (out_type == OutType_Nhwc) {
987 for (int r = 0; r < row; r++) {
988 for (int c = 0; c < col; c++) {
989 int r12div = r / 12, r12mod = r % 12;
990 int c8div = c / 8, c8mod = c % 8;
991 size_t ci = r * stride + c;
992 float value = 0;
993 for (int d = 0; d < deep; d++) {
994 size_t ai = r12div * deep * 12 + d * 12 + r12mod;
995 size_t bi = c8div * deep * 8 + d * 8 + c8mod;
996 value = value + a[ai] * b[bi];
997 }
998 ADD_BIAS(value, bias, c)
999 DO_RELU(value, act_type)
1000 DO_RELU6(value, act_type)
1001 dst[ci] = value;
1002 }
1003 }
1004 } else if (out_type == OutType_C8) {
1005 int col_8 = UP_ROUND(col, C8NUM);
1006 int row_12 = UP_ROUND(row, C12NUM);
1007 for (int r = 0; r < row_12; r++) {
1008 for (int c = 0; c < col_8; c++) {
1009 int r12div = r / C12NUM, r12mod = r % C12NUM;
1010 int c8div = c / C8NUM, c8mod = c % C8NUM;
1011 size_t ci = (c8div * C8NUM * row_12 + r * C8NUM + c8mod);
1012 float value = 0;
1013 for (int d = 0; d < deep; d++) {
1014 size_t ai = r12div * deep * C12NUM + d * C12NUM + r12mod;
1015 size_t bi = c8div * deep * C8NUM + d * C8NUM + c8mod;
1016 value = value + a[ai] * b[bi];
1017 }
1018 ADD_BIAS(value, bias, c)
1019 DO_RELU(value, act_type)
1020 DO_RELU6(value, act_type)
1021 dst[ci] = value;
1022 }
1023 }
1024 } else {
1025 for (int i = 0; i < row; ++i) {
1026 int src_r_offset = i;
1027 int dst_r_offset = i * col * stride;
1028 for (int j = 0; j < col; ++j) {
1029 int c8div = j / 8, c8mod = j % 8;
1030 size_t ci = dst_r_offset + c8div * 8 * stride + c8mod;
1031 float value = 0;
1032 for (int d = 0; d < deep; ++d) {
1033 size_t ai = src_r_offset + d * C12NUM;
1034 size_t bi = c8div * deep * 8 + d * 8 + c8mod;
1035 value = value + a[ai] * b[bi];
1036 }
1037 ADD_BIAS(value, bias, j)
1038 DO_RELU(value, act_type)
1039 DO_RELU6(value, act_type)
1040 dst[ci] = value;
1041 }
1042 }
1043 }
1044 }
1045
MatMulOpt(const float * a,const float * b,float * c,const float * bias,ActType act_type,int deep,int row,int col,size_t stride,int out_type)1046 void MatMulOpt(const float *a, const float *b, float *c, const float *bias, ActType act_type, int deep, int row,
1047 int col, size_t stride, int out_type) {
1048 #ifdef ENABLE_ARM64
1049 if (out_type == OutType_C8) {
1050 MatmulFloatNeon64(a, b, c, bias, (int)act_type, deep, row, col, stride, 0, 0);
1051 } else {
1052 MatmulFloatNeon64Opt(a, b, c, bias, (int)act_type, deep, row, col, stride, (int)(out_type));
1053 }
1054 #elif ENABLE_ARM32
1055 if (out_type == OutType_C8) {
1056 MatmulFloatNeon32(a, b, c, bias, (int)act_type, deep, row, col, stride, 0, 0);
1057 } else if (out_type == OutType_Nhwc) {
1058 MatmulFloatNeon32Opt12x4(a, b, c, bias, (int)act_type, deep, row, col, stride, 1);
1059 } else {
1060 MatmulFloatNeon32Opt(a, b, c, bias, (int)act_type, deep, row, col, stride, (int)(out_type));
1061 }
1062 #elif ENABLE_AVX
1063 MatmulFloatAvxOpt(a, b, c, bias, (size_t)act_type, deep, row, col, stride, (size_t)(out_type));
1064 #elif ENABLE_SSE
1065 MatmulFloatSse64Opt(a, b, c, bias, (int)act_type, deep, row, col, stride, (int)(out_type));
1066 #else
1067 MatMul12x8(a, b, c, bias, act_type, deep, row, col, stride, out_type);
1068 #endif
1069 }
1070
1071 #ifdef ENABLE_AVX
MatVecMulAvxFp32(const float * a,const float * b,float * c,const float * bias,int act_type,int depth,int cur_col,int col_align)1072 void MatVecMulAvxFp32(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int cur_col,
1073 int col_align) {
1074 // one time process 32 out_channel
1075 int col_block = C32NUM;
1076 int act_flag = 0;
1077 if (act_type == ActType_Relu6) {
1078 act_flag += 1;
1079 }
1080 if (act_type == ActType_Relu || act_type == ActType_Relu6) {
1081 act_flag += 2;
1082 }
1083 MatVecMulKernel kernel[4] = {MatVecMul1x8Kernel, MatVecMul1x16Kernel, MatVecMul1x24Kernel, MatVecMul1x32Kernel};
1084 const float *bias_data = bias;
1085 for (int col_index = 0; col_index < cur_col; col_index += col_block) {
1086 col_block = cur_col - col_index < col_block ? cur_col - col_index : col_block;
1087 kernel[(col_block >> 3) - 1](c + col_index, a, b + col_index * depth, bias_data, act_flag, 1, col_block >> 3,
1088 col_align, depth);
1089 if (bias_data != NULL) {
1090 bias_data += col_block;
1091 }
1092 }
1093 }
1094
MatVecMul1x32Kernel(float * dst,const float * src,const float * weight,const float * bias,size_t act_flag,size_t row_block,size_t col_block,size_t col_algin,size_t deep)1095 void MatVecMul1x32Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t act_flag,
1096 size_t row_block, size_t col_block, size_t col_algin, size_t deep) {
1097 asm volatile(
1098 "cmpq $0, %2\n"
1099 "je 0f\n"
1100 "vmovups (%2), %%ymm0\n"
1101 "vmovups 0x20(%2), %%ymm1\n"
1102 "vmovups 0x40(%2), %%ymm2\n"
1103 "vmovups 0x60(%2), %%ymm3\n"
1104 "jmp 1f\n"
1105 "0:\n"
1106 "vxorps %%ymm0, %%ymm0, %%ymm0\n"
1107 "vxorps %%ymm1, %%ymm1, %%ymm1\n"
1108 "vxorps %%ymm2, %%ymm2, %%ymm2\n"
1109 "vxorps %%ymm3, %%ymm3, %%ymm3\n"
1110 "1:\n" // deep_c8
1111 "movq %3, %%rcx\n"
1112 "shr $3, %%ecx\n"
1113 "je 3f\n"
1114 "2:\n"
1115 "vbroadcastss (%0), %%ymm4\n"
1116 "vfmadd231ps (%1), %%ymm4, %%ymm0\n"
1117 "vfmadd231ps 0x20(%1), %%ymm4, %%ymm1\n"
1118 "vfmadd231ps 0x40(%1), %%ymm4, %%ymm2\n"
1119 "vfmadd231ps 0x60(%1), %%ymm4, %%ymm3\n"
1120
1121 "vbroadcastss 4(%0), %%ymm4\n"
1122 "vfmadd231ps 128(%1), %%ymm4, %%ymm0\n"
1123 "vfmadd231ps 160(%1), %%ymm4, %%ymm1\n"
1124 "vfmadd231ps 192(%1), %%ymm4, %%ymm2\n"
1125 "vfmadd231ps 224(%1), %%ymm4, %%ymm3\n"
1126
1127 "vbroadcastss 8(%0), %%ymm4\n"
1128 "vfmadd231ps 256(%1), %%ymm4, %%ymm0\n"
1129 "vfmadd231ps 288(%1), %%ymm4, %%ymm1\n"
1130 "vfmadd231ps 320(%1), %%ymm4, %%ymm2\n"
1131 "vfmadd231ps 352(%1), %%ymm4, %%ymm3\n"
1132
1133 "vbroadcastss 12(%0), %%ymm4\n"
1134 "vfmadd231ps 384(%1), %%ymm4, %%ymm0\n"
1135 "vfmadd231ps 416(%1), %%ymm4, %%ymm1\n"
1136 "vfmadd231ps 448(%1), %%ymm4, %%ymm2\n"
1137 "vfmadd231ps 480(%1), %%ymm4, %%ymm3\n"
1138
1139 "vbroadcastss 16(%0), %%ymm4\n"
1140 "vfmadd231ps 512(%1), %%ymm4, %%ymm0\n"
1141 "vfmadd231ps 544(%1), %%ymm4, %%ymm1\n"
1142 "vfmadd231ps 576(%1), %%ymm4, %%ymm2\n"
1143 "vfmadd231ps 608(%1), %%ymm4, %%ymm3\n"
1144
1145 "vbroadcastss 20(%0), %%ymm4\n"
1146 "vfmadd231ps 640(%1), %%ymm4, %%ymm0\n"
1147 "vfmadd231ps 672(%1), %%ymm4, %%ymm1\n"
1148 "vfmadd231ps 704(%1), %%ymm4, %%ymm2\n"
1149 "vfmadd231ps 736(%1), %%ymm4, %%ymm3\n"
1150
1151 "vbroadcastss 24(%0), %%ymm4\n"
1152 "vfmadd231ps 768(%1), %%ymm4, %%ymm0\n"
1153 "vfmadd231ps 800(%1), %%ymm4, %%ymm1\n"
1154 "vfmadd231ps 832(%1), %%ymm4, %%ymm2\n"
1155 "vfmadd231ps 864(%1), %%ymm4, %%ymm3\n"
1156
1157 "vbroadcastss 28(%0), %%ymm4\n"
1158 "vfmadd231ps 896(%1), %%ymm4, %%ymm0\n"
1159 "vfmadd231ps 928(%1), %%ymm4, %%ymm1\n"
1160 "vfmadd231ps 960(%1), %%ymm4, %%ymm2\n"
1161 "vfmadd231ps 992(%1), %%ymm4, %%ymm3\n"
1162 "addq $1024, %1\n"
1163 "addq $32, %0\n"
1164 "dec %%ecx\n"
1165 "jg 2b\n"
1166
1167 "3:\n"
1168 "and $7, %3\n" // deep_remainder
1169 "je 5f\n"
1170 "4:\n"
1171 "vbroadcastss (%0), %%ymm4\n"
1172 "vfmadd231ps (%1), %%ymm4, %%ymm0\n"
1173 "vfmadd231ps 0x20(%1), %%ymm4, %%ymm1\n"
1174 "vfmadd231ps 0x40(%1), %%ymm4, %%ymm2\n"
1175 "vfmadd231ps 0x60(%1), %%ymm4, %%ymm3\n"
1176 "addq $128, %1\n"
1177 "addq $4, %0\n"
1178 "dec %3\n"
1179 "jg 4b\n"
1180
1181 "5:\n"
1182 "and $0x3, %%eax\n" // act_type
1183 "je 6f\n"
1184 // Relu
1185 "vxorps %%ymm12, %%ymm12, %%ymm12\n"
1186 "vmaxps %%ymm12, %%ymm0, %%ymm0\n"
1187 "vmaxps %%ymm12, %%ymm1, %%ymm1\n"
1188 "vmaxps %%ymm12, %%ymm2, %%ymm2\n"
1189 "vmaxps %%ymm12, %%ymm3, %%ymm3\n"
1190 "and $0x1, %%eax\n"
1191 "je 6f\n"
1192 // relu6
1193 "mov $0x40C00000, %%ecx\n"
1194 "vmovd %%ecx, %%xmm14\n"
1195 "vpermps %%ymm14, %%ymm12, %%ymm14\n"
1196 "vminps %%ymm14, %%ymm0, %%ymm0\n"
1197 "vminps %%ymm14, %%ymm1, %%ymm1\n"
1198 "vminps %%ymm14, %%ymm2, %%ymm2\n"
1199 "vminps %%ymm14, %%ymm3, %%ymm3\n"
1200 "6:\n"
1201 "vmovups %%ymm0, (%5)\n" // dst_0
1202 "vmovups %%ymm1, 0x20(%5)\n"
1203 "vmovups %%ymm2, 0x40(%5)\n"
1204 "vmovups %%ymm3, 0x60(%5)\n"
1205 :
1206 : "r"(src), "r"(weight), "r"(bias), "r"(deep), "a"(act_flag), "r"(dst) // 5
1207 : "%rcx", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm12", "%ymm4", "%ymm14");
1208 }
1209
MatVecMul1x24Kernel(float * dst,const float * src,const float * weight,const float * bias,size_t act_flag,size_t row_block,size_t col_block,size_t col_algin,size_t deep)1210 void MatVecMul1x24Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t act_flag,
1211 size_t row_block, size_t col_block, size_t col_algin, size_t deep) {
1212 asm volatile(
1213 "cmpq $0, %2\n"
1214 "je 0f\n"
1215 "vmovups (%2), %%ymm0\n"
1216 "vmovups 0x20(%2), %%ymm1\n"
1217 "vmovups 0x40(%2), %%ymm2\n"
1218 "jmp 1f\n"
1219 "0:\n"
1220 "vxorps %%ymm0, %%ymm0, %%ymm0\n"
1221 "vxorps %%ymm1, %%ymm1, %%ymm1\n"
1222 "vxorps %%ymm2, %%ymm2, %%ymm2\n"
1223
1224 "1:\n" // deep
1225 "movq %3, %%rcx\n"
1226 "shr $3, %%ecx\n"
1227 "je 3f\n"
1228 "2:\n"
1229 "vbroadcastss (%0), %%ymm4\n"
1230 "vfmadd231ps (%1), %%ymm4, %%ymm0\n"
1231 "vfmadd231ps 0x20(%1), %%ymm4, %%ymm1\n"
1232 "vfmadd231ps 0x40(%1), %%ymm4, %%ymm2\n"
1233
1234 "vbroadcastss 4(%0), %%ymm4\n"
1235 "vfmadd231ps 96(%1), %%ymm4, %%ymm0\n"
1236 "vfmadd231ps 128(%1), %%ymm4, %%ymm1\n"
1237 "vfmadd231ps 160(%1), %%ymm4, %%ymm2\n"
1238
1239 "vbroadcastss 8(%0), %%ymm4\n"
1240 "vfmadd231ps 192(%1), %%ymm4, %%ymm0\n"
1241 "vfmadd231ps 224(%1), %%ymm4, %%ymm1\n"
1242 "vfmadd231ps 256(%1), %%ymm4, %%ymm2\n"
1243
1244 "vbroadcastss 12(%0), %%ymm4\n"
1245 "vfmadd231ps 288(%1), %%ymm4, %%ymm0\n"
1246 "vfmadd231ps 320(%1), %%ymm4, %%ymm1\n"
1247 "vfmadd231ps 352(%1), %%ymm4, %%ymm2\n"
1248
1249 "vbroadcastss 16(%0), %%ymm4\n"
1250 "vfmadd231ps 384(%1), %%ymm4, %%ymm0\n"
1251 "vfmadd231ps 416(%1), %%ymm4, %%ymm1\n"
1252 "vfmadd231ps 448(%1), %%ymm4, %%ymm2\n"
1253
1254 "vbroadcastss 20(%0), %%ymm4\n"
1255 "vfmadd231ps 480(%1), %%ymm4, %%ymm0\n"
1256 "vfmadd231ps 512(%1), %%ymm4, %%ymm1\n"
1257 "vfmadd231ps 544(%1), %%ymm4, %%ymm2\n"
1258
1259 "vbroadcastss 24(%0), %%ymm4\n"
1260 "vfmadd231ps 576(%1), %%ymm4, %%ymm0\n"
1261 "vfmadd231ps 608(%1), %%ymm4, %%ymm1\n"
1262 "vfmadd231ps 640(%1), %%ymm4, %%ymm2\n"
1263
1264 "vbroadcastss 28(%0), %%ymm4\n"
1265 "vfmadd231ps 672(%1), %%ymm4, %%ymm0\n"
1266 "vfmadd231ps 704(%1), %%ymm4, %%ymm1\n"
1267 "vfmadd231ps 736(%1), %%ymm4, %%ymm2\n"
1268 "addq $768, %1\n"
1269 "addq $32, %0\n"
1270 "dec %%ecx\n"
1271 "jg 2b\n"
1272
1273 "3:\n"
1274 "and $7, %3\n" // deep_remainder
1275 "je 5f\n"
1276 "4:\n"
1277 "vbroadcastss (%0), %%ymm4\n"
1278 "vfmadd231ps (%1), %%ymm4, %%ymm0\n"
1279 "vfmadd231ps 0x20(%1), %%ymm4, %%ymm1\n"
1280 "vfmadd231ps 0x40(%1), %%ymm4, %%ymm2\n"
1281 "addq $96, %1\n"
1282 "addq $4, %0\n"
1283 "dec %3\n"
1284 "jg 4b\n"
1285
1286 "5:\n"
1287 "and $0x3, %%eax\n" // act_type
1288 "je 6f\n"
1289 // Relu
1290 "vxorps %%ymm12, %%ymm12, %%ymm12\n"
1291 "vmaxps %%ymm12, %%ymm0, %%ymm0\n"
1292 "vmaxps %%ymm12, %%ymm1, %%ymm1\n"
1293 "vmaxps %%ymm12, %%ymm2, %%ymm2\n"
1294
1295 "and $0x1, %%eax\n"
1296 "je 6f\n"
1297 // relu6
1298 "mov $0x40C00000, %%ecx\n"
1299 "vmovd %%ecx, %%xmm14\n"
1300 "vpermps %%ymm14, %%ymm12, %%ymm14\n"
1301 "vminps %%ymm14, %%ymm0, %%ymm0\n"
1302 "vminps %%ymm14, %%ymm1, %%ymm1\n"
1303 "vminps %%ymm14, %%ymm2, %%ymm2\n"
1304
1305 "6:\n"
1306 "vmovups %%ymm0, (%5)\n" // dst_0
1307 "vmovups %%ymm1, 0x20(%5)\n"
1308 "vmovups %%ymm2, 0x40(%5)\n"
1309
1310 :
1311 : "r"(src), "r"(weight), "r"(bias), "r"(deep), "a"(act_flag), "r"(dst) // 5
1312 : "%rcx", "%ymm0", "%ymm1", "%ymm2", "%ymm12", "%ymm4", "%ymm14");
1313 }
1314
MatVecMul1x16Kernel(float * dst,const float * src,const float * weight,const float * bias,size_t act_flag,size_t row_block,size_t col_block,size_t col_algin,size_t deep)1315 void MatVecMul1x16Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t act_flag,
1316 size_t row_block, size_t col_block, size_t col_algin, size_t deep) {
1317 asm volatile(
1318 "cmpq $0, %2\n"
1319 "je 0f\n"
1320 "vmovups (%2), %%ymm0\n"
1321 "vmovups 0x20(%2), %%ymm1\n"
1322 "jmp 1f\n"
1323 "0:\n"
1324 "vxorps %%ymm0, %%ymm0, %%ymm0\n"
1325 "vxorps %%ymm1, %%ymm1, %%ymm1\n"
1326 "1:\n"
1327 "movq %3, %%rcx\n"
1328 "shr $3, %%ecx\n"
1329 "je 3f\n"
1330 "2:\n" // deep_c8
1331 "vbroadcastss (%0), %%ymm4\n"
1332 "vfmadd231ps (%1), %%ymm4, %%ymm0\n"
1333 "vfmadd231ps 0x20(%1), %%ymm4, %%ymm1\n"
1334
1335 "vbroadcastss 4(%0), %%ymm4\n"
1336 "vfmadd231ps 64(%1), %%ymm4, %%ymm0\n"
1337 "vfmadd231ps 96(%1), %%ymm4, %%ymm1\n"
1338
1339 "vbroadcastss 8(%0), %%ymm4\n"
1340 "vfmadd231ps 128(%1), %%ymm4, %%ymm0\n"
1341 "vfmadd231ps 160(%1), %%ymm4, %%ymm1\n"
1342
1343 "vbroadcastss 12(%0), %%ymm4\n"
1344 "vfmadd231ps 192(%1), %%ymm4, %%ymm0\n"
1345 "vfmadd231ps 224(%1), %%ymm4, %%ymm1\n"
1346
1347 "vbroadcastss 16(%0), %%ymm4\n"
1348 "vfmadd231ps 256(%1), %%ymm4, %%ymm0\n"
1349 "vfmadd231ps 288(%1), %%ymm4, %%ymm1\n"
1350
1351 "vbroadcastss 20(%0), %%ymm4\n"
1352 "vfmadd231ps 320(%1), %%ymm4, %%ymm0\n"
1353 "vfmadd231ps 352(%1), %%ymm4, %%ymm1\n"
1354
1355 "vbroadcastss 24(%0), %%ymm4\n"
1356 "vfmadd231ps 384(%1), %%ymm4, %%ymm0\n"
1357 "vfmadd231ps 416(%1), %%ymm4, %%ymm1\n"
1358
1359 "vbroadcastss 28(%0), %%ymm4\n"
1360 "vfmadd231ps 448(%1), %%ymm4, %%ymm0\n"
1361 "vfmadd231ps 480(%1), %%ymm4, %%ymm1\n"
1362 "addq $512, %1\n"
1363 "addq $32, %0\n"
1364 "dec %%ecx\n"
1365 "jg 2b\n"
1366
1367 "3:\n"
1368 "and $7, %3\n"
1369 "je 5f\n"
1370 "4:\n"
1371 "vbroadcastss (%0), %%ymm4\n"
1372 "vfmadd231ps (%1), %%ymm4, %%ymm0\n"
1373 "vfmadd231ps 0x20(%1), %%ymm4, %%ymm1\n"
1374 "addq $64, %1\n"
1375 "addq $4, %0\n"
1376 "dec %3\n"
1377 "jg 4b\n"
1378
1379 "5:\n"
1380 "and $0x3, %%eax\n" // act_type
1381 "je 6f\n"
1382 // Relu
1383 "vxorps %%ymm12, %%ymm12, %%ymm12\n"
1384 "vmaxps %%ymm12, %%ymm0, %%ymm0\n"
1385 "vmaxps %%ymm12, %%ymm1, %%ymm1\n"
1386
1387 "and $0x1, %%eax\n"
1388 "je 6f\n"
1389 // relu6
1390 "mov $0x40C00000, %%ecx\n"
1391 "vmovd %%ecx, %%xmm14\n"
1392 "vpermps %%ymm14, %%ymm12, %%ymm14\n"
1393 "vminps %%ymm14, %%ymm0, %%ymm0\n"
1394 "vminps %%ymm14, %%ymm1, %%ymm1\n"
1395
1396 "6:\n"
1397 "vmovups %%ymm0, (%5)\n" // dst_0
1398 "vmovups %%ymm1, 0x20(%5)\n"
1399
1400 :
1401 : "r"(src), "r"(weight), "r"(bias), "r"(deep), "a"(act_flag), "r"(dst) // 5
1402 : "%rcx", "%ymm0", "%ymm1", "%ymm12", "%ymm4", "%ymm14");
1403 }
1404
MatVecMul1x8Kernel(float * dst,const float * src,const float * weight,const float * bias,size_t act_flag,size_t row_block,size_t col_block,size_t col_algin,size_t deep)1405 void MatVecMul1x8Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t act_flag,
1406 size_t row_block, size_t col_block, size_t col_algin, size_t deep) {
1407 asm volatile(
1408 "cmpq $0, %2\n"
1409 "je 0f\n"
1410 "vmovups (%2), %%ymm0\n"
1411 "jmp 1f\n"
1412 "0:\n"
1413 "vxorps %%ymm0, %%ymm0, %%ymm0\n"
1414 "1:\n"
1415 "movq %3, %%rcx\n"
1416 "shr $3, %%ecx\n"
1417 "je 3f\n"
1418 "2:\n" // deep_c8
1419 "vbroadcastss (%0), %%ymm4\n"
1420 "vfmadd231ps (%1), %%ymm4, %%ymm0\n"
1421 "vbroadcastss 4(%0), %%ymm4\n"
1422 "vfmadd231ps 32(%1), %%ymm4, %%ymm0\n"
1423 "vbroadcastss 8(%0), %%ymm4\n"
1424 "vfmadd231ps 64(%1), %%ymm4, %%ymm0\n"
1425 "vbroadcastss 12(%0), %%ymm4\n"
1426 "vfmadd231ps 96(%1), %%ymm4, %%ymm0\n"
1427 "vbroadcastss 16(%0), %%ymm4\n"
1428 "vfmadd231ps 128(%1), %%ymm4, %%ymm0\n"
1429 "vbroadcastss 20(%0), %%ymm4\n"
1430 "vfmadd231ps 160(%1), %%ymm4, %%ymm0\n"
1431 "vbroadcastss 24(%0), %%ymm4\n"
1432 "vfmadd231ps 192(%1), %%ymm4, %%ymm0\n"
1433 "vbroadcastss 28(%0), %%ymm4\n"
1434 "vfmadd231ps 224(%1), %%ymm4, %%ymm0\n"
1435 "addq $256, %1\n"
1436 "addq $32, %0\n"
1437 "dec %%ecx\n"
1438 "jg 2b\n"
1439
1440 "3:\n"
1441 "and $7, %3\n"
1442 "je 5f\n"
1443 "4:\n"
1444 "vbroadcastss (%0), %%ymm4\n"
1445 "vfmadd231ps (%1), %%ymm4, %%ymm0\n"
1446 "addq $32, %1\n"
1447 "addq $4, %0\n"
1448 "dec %3\n"
1449 "jg 4b\n"
1450
1451 "5:\n"
1452 "and $0x3, %%eax\n" // act_type
1453 "je 6f\n"
1454 // Relu
1455 "vxorps %%ymm12, %%ymm12, %%ymm12\n"
1456 "vmaxps %%ymm12, %%ymm0, %%ymm0\n"
1457
1458 "and $0x1, %%eax\n"
1459 "je 6f\n"
1460 // relu6
1461 "mov $0x40C00000, %%ecx\n"
1462 "vmovd %%ecx, %%xmm14\n"
1463 "vpermps %%ymm14, %%ymm12, %%ymm14\n"
1464 "vminps %%ymm14, %%ymm0, %%ymm0\n"
1465
1466 "6:\n"
1467 "vmovups %%ymm0, (%5)\n" // dst_0
1468
1469 :
1470 : "r"(src), "r"(weight), "r"(bias), "r"(deep), "a"(act_flag), "r"(dst) // 5
1471 : "%rcx", "%ymm0", "%ymm1", "%ymm12", "%ymm4", "%ymm14");
1472 }
1473
1474 #ifdef ENABLE_DEBUG
MatVecMulRowxColKernel(float * dst,const float * src,const float * weight,const float * bias,size_t act_flag,size_t row_block,size_t col_block,size_t col_algin,size_t deep)1475 void MatVecMulRowxColKernel(float *dst, const float *src, const float *weight, const float *bias, size_t act_flag,
1476 size_t row_block, size_t col_block, size_t col_algin, size_t deep) {
1477 __m256 dst_data[12];
1478 const float *src_sw[12];
1479 __m256 weight_data[4];
1480 for (int i = 0; i < 4; ++i) {
1481 weight_data[i] = _mm256_set1_ps(0.0f);
1482 }
1483 for (int i = 0; i < row_block; ++i) {
1484 if (bias != NULL) {
1485 for (int j = 0; j < col_block; ++j) {
1486 dst_data[i * col_block + j] = _mm256_loadu_ps(bias + j * 8);
1487 }
1488 } else {
1489 for (int j = 0; j < col_block; ++j) {
1490 dst_data[i * col_block + j] = _mm256_set1_ps(0.0f);
1491 }
1492 }
1493 src_sw[i] = src + i * deep;
1494 }
1495 const float *weight_kernel = weight;
1496 for (int ic = 0; ic < deep; ++ic) {
1497 for (int j = 0; j < col_block; ++j) {
1498 weight_data[j] = _mm256_loadu_ps(weight_kernel + j * C8NUM);
1499 }
1500 for (int i = 0; i < row_block; ++i) {
1501 for (int j = 0; j < col_block; ++j) {
1502 dst_data[i * col_block + j] =
1503 _mm256_fmadd_ps(_mm256_set1_ps(src_sw[i][ic]), weight_data[j], dst_data[i * col_block + j]);
1504 }
1505 }
1506 weight_kernel += C8NUM * col_block;
1507 } // ic loop
1508 // add bias and relu
1509 for (int i = 0; i < row_block; ++i) {
1510 for (int j = 0; j < col_block; ++j) {
1511 if (0x1 & act_flag) { // relu6
1512 dst_data[i * col_block + j] = _mm256_min_ps(dst_data[i * col_block + j], _mm256_set1_ps(6.0f));
1513 }
1514 if (0x2 & act_flag) { // relu
1515 dst_data[i * col_block + j] = _mm256_max_ps(dst_data[i * col_block + j], _mm256_set1_ps(0.0f));
1516 }
1517 _mm256_storeu_ps(dst + i * col_algin + j * C8NUM, dst_data[i * col_block + j]);
1518 }
1519 }
1520 }
1521 #endif
1522 #endif
1523