• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "nnacl/fp32/matmul_fp32.h"
18 #include "nnacl/fp32/pack_fp32.h"
19 #ifdef ENABLE_SSE
20 #ifdef _MSC_VER
21 #include <immintrin.h>
22 #else
23 #include <x86intrin.h>
24 #endif
25 #endif
26 #ifdef ENABLE_ARM64
27 #include <arm_neon.h>
28 #endif
RowMajor2ColMajor(const float * src_ptr,float * dst_ptr,int row,int col)29 void RowMajor2ColMajor(const float *src_ptr, float *dst_ptr, int row, int col) {
30   for (int r = 0; r < row; ++r) {
31     for (int c = 0; c < col; ++c) {
32       dst_ptr[c * row + r] = src_ptr[r * col + c];
33     }
34   }
35 }
36 
RowMajor2Row4Major(const float * src_ptr,float * dst_ptr,int row,int col)37 void RowMajor2Row4Major(const float *src_ptr, float *dst_ptr, int row, int col) {
38   for (int r = 0; r < row; r++) {
39     const float *src = src_ptr + r * col;
40     int c = 0;
41     for (; c < col; c++) {
42       int cd4 = c / C4NUM;
43       int cm4 = c % C4NUM;
44       dst_ptr[cd4 * C4NUM * row + r * C4NUM + cm4] = src[c];
45     }
46     for (; c < UP_ROUND(col, C4NUM); c++) {
47       int cd4 = c / C4NUM;
48       int cm4 = c % C4NUM;
49       dst_ptr[cd4 * C4NUM * row + r * C4NUM + cm4] = 0;
50     }
51   }
52   return;
53 }
54 
RowMajor2Row6Major(const float * src_ptr,float * dst_ptr,int row,int col)55 void RowMajor2Row6Major(const float *src_ptr, float *dst_ptr, int row, int col) {
56   for (int r = 0; r < row; r++) {
57     const float *src = src_ptr + r * col;
58     int c = 0;
59     for (; c < col; c++) {
60       int cd6 = c / C6NUM;
61       int cm6 = c % C6NUM;
62       dst_ptr[cd6 * C6NUM * row + r * C6NUM + cm6] = src[c];
63     }
64     for (; c < UP_ROUND(col, C6NUM); c++) {
65       int cd6 = c / C6NUM;
66       int cm6 = c % C6NUM;
67       dst_ptr[cd6 * C6NUM * row + r * C6NUM + cm6] = 0;
68     }
69   }
70   return;
71 }
72 
RowMajor2Row8Major(const float * src_ptr,float * dst_ptr,int row,int col)73 void RowMajor2Row8Major(const float *src_ptr, float *dst_ptr, int row, int col) {
74   for (int r = 0; r < row; r++) {
75     const float *src = src_ptr + r * col;
76     int c = 0;
77     for (; c < col; c++) {
78       int cd8 = c / C8NUM;
79       int cm8 = c % C8NUM;
80       dst_ptr[cd8 * C8NUM * row + r * C8NUM + cm8] = src[c];
81     }
82     for (; c < UP_ROUND(col, C8NUM); c++) {
83       int cd8 = c / C8NUM;
84       int cm8 = c % C8NUM;
85       dst_ptr[cd8 * C8NUM * row + r * C8NUM + cm8] = 0;
86     }
87   }
88   return;
89 }
90 
RowMajor2Row12Major(const float * src_ptr,float * dst_ptr,int row,int col)91 void RowMajor2Row12Major(const float *src_ptr, float *dst_ptr, int row, int col) {
92   for (int r = 0; r < row; r++) {
93     const float *src = src_ptr + r * col;
94     int c = 0;
95     for (; c < col; c++) {
96       int cd12 = c / C12NUM;
97       int cm12 = c % C12NUM;
98       dst_ptr[cd12 * C12NUM * row + r * C12NUM + cm12] = src[c];
99     }
100     for (; c < UP_ROUND(col, C12NUM); c++) {
101       int cd12 = c / C12NUM;
102       int cm12 = c % C12NUM;
103       dst_ptr[cd12 * C12NUM * row + r * C12NUM + cm12] = 0;
104     }
105   }
106   return;
107 }
108 
RowMajor2Row16Major(const float * src_ptr,float * dst_ptr,int row,int col)109 void RowMajor2Row16Major(const float *src_ptr, float *dst_ptr, int row, int col) {
110   for (int r = 0; r < row; r++) {
111     const float *src = src_ptr + r * col;
112     int c = 0;
113     for (; c < col; c++) {
114       int cd16 = c / C16NUM;
115       int cm16 = c % C16NUM;
116       dst_ptr[cd16 * C16NUM * row + r * C16NUM + cm16] = src[c];
117     }
118     for (; c < UP_ROUND(col, C16NUM); c++) {
119       int cd16 = c / C16NUM;
120       int cm16 = c % C16NUM;
121       dst_ptr[cd16 * C16NUM * row + r * C16NUM + cm16] = 0;
122     }
123   }
124   return;
125 }
126 
RowMajor2Row32Major(const float * src_ptr,float * dst_ptr,int row,int col)127 void RowMajor2Row32Major(const float *src_ptr, float *dst_ptr, int row, int col) {
128   // Not exactly aligned to 32, but aligned to 24 or 16 or 8 If 32 is not met.
129   int row_block_num = UP_DIV(row, C8NUM);
130   int row_block = C4NUM;
131   for (int i = 0; i < row_block_num; i += row_block) {
132     row_block = MSMIN(C4NUM, row_block_num - i);  // max_tile = 4
133     int row_remainder = MSMIN(row_block * C8NUM, row - i * C8NUM);
134     for (int oc = 0; oc < col; ++oc) {
135       memcpy(dst_ptr, src_ptr + oc * row + i * C8NUM, row_remainder * sizeof(float));
136       dst_ptr += row_block * C8NUM;
137     }
138   }
139 }
140 
141 #ifdef ENABLE_ARM64
RowMajor2Col12Major_arm64(const float * src_c,float * dst_c,size_t col)142 void RowMajor2Col12Major_arm64(const float *src_c, float *dst_c, size_t col) {
143   size_t stride = col * sizeof(float);
144   asm volatile(
145     "mov x10, %[src_c]\n"
146     "mov x11, %[dst_c]\n"
147 
148     "ld1 {v0.4s}, [x10], %[stride]\n"
149     "ld1 {v1.4s}, [x10], %[stride]\n"
150     "ld1 {v2.4s}, [x10], %[stride]\n"
151     "ld1 {v3.4s}, [x10], %[stride]\n"
152 
153     "ld1 {v4.4s}, [x10], %[stride]\n"
154     "ld1 {v5.4s}, [x10], %[stride]\n"
155     "ld1 {v6.4s}, [x10], %[stride]\n"
156     "ld1 {v7.4s}, [x10], %[stride]\n"
157 
158     "zip1 v12.4s, v0.4s, v1.4s\n"
159     "zip2 v13.4s, v0.4s, v1.4s\n"
160     "zip1 v14.4s, v2.4s, v3.4s\n"
161     "zip2 v15.4s, v2.4s, v3.4s\n"
162 
163     "ld1 {v8.4s}, [x10], %[stride]\n"
164     "ld1 {v9.4s}, [x10], %[stride]\n"
165     "ld1 {v10.4s}, [x10], %[stride]\n"
166     "ld1 {v11.4s}, [x10], %[stride]\n"
167 
168     "zip1 v16.4s, v4.4s, v5.4s\n"
169     "zip2 v17.4s, v4.4s, v5.4s\n"
170     "zip1 v18.4s, v6.4s, v7.4s\n"
171     "zip2 v19.4s, v6.4s, v7.4s\n"
172 
173     "trn1 v20.2d, v12.2d, v14.2d\n"
174     "trn2 v23.2d, v12.2d, v14.2d\n"
175     "trn1 v26.2d, v13.2d, v15.2d\n"
176     "trn2 v29.2d, v13.2d, v15.2d\n"
177 
178     "trn1 v21.2d, v16.2d, v18.2d\n"
179     "trn2 v24.2d, v16.2d, v18.2d\n"
180     "trn1 v27.2d, v17.2d, v19.2d\n"
181     "trn2 v30.2d, v17.2d, v19.2d\n"
182 
183     "zip1 v12.4s, v8.4s, v9.4s\n"
184     "zip2 v13.4s, v8.4s, v9.4s\n"
185     "zip1 v14.4s, v10.4s, v11.4s\n"
186     "zip2 v15.4s, v10.4s, v11.4s\n"
187 
188     "trn1 v22.2d, v12.2d, v14.2d\n"
189     "trn2 v25.2d, v12.2d, v14.2d\n"
190     "trn1 v28.2d, v13.2d, v15.2d\n"
191     "trn2 v31.2d, v13.2d, v15.2d\n"
192 
193     "st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x11], #64\n"
194     "st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x11], #64\n"
195     "st1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x11], #64\n"
196 
197     :
198     : [ dst_c ] "r"(dst_c), [ src_c ] "r"(src_c), [ stride ] "r"(stride)
199     : "x10", "x11", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14",
200       "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30",
201       "v31");
202   return;
203 }
204 #endif
205 #ifdef ENABLE_ARM32
RowMajor2Col12Major_arm32(const float * src_c,float * dst_c,size_t col)206 void RowMajor2Col12Major_arm32(const float *src_c, float *dst_c, size_t col) {
207   size_t stride = col * sizeof(float);
208   asm volatile(
209     "mov r10, %[src_c]\n"
210     "mov r12, %[dst_c]\n"
211 
212     "vld1.32 {q0}, [r10], %[stride]\n"
213     "vld1.32 {q3}, [r10], %[stride]\n"
214     "vld1.32 {q10}, [r10], %[stride]\n"
215     "vld1.32 {q13}, [r10], %[stride]\n"
216 
217     "vtrn.32 d0, d6\n"
218     "vtrn.32 d1, d7\n"
219     "vtrn.32 d20, d26\n"
220     "vtrn.32 d21, d27\n"
221 
222     "vld1.32 {q1}, [r10], %[stride]\n"
223     "vld1.32 {q8}, [r10], %[stride]\n"
224     "vld1.32 {q11}, [r10], %[stride]\n"
225     "vld1.32 {q14}, [r10], %[stride]\n"
226 
227     "vswp d1, d20\n"
228     "vswp d7, d26\n"
229 
230     "vld1.32 {q2}, [r10], %[stride]\n"
231     "vld1.32 {q9}, [r10], %[stride]\n"
232     "vld1.32 {q12}, [r10], %[stride]\n"
233     "vld1.32 {q15}, [r10], %[stride]\n"
234 
235     "vtrn.32 d2, d16\n"
236     "vtrn.32 d3, d17\n"
237     "vtrn.32 d22, d28\n"
238     "vtrn.32 d23, d29\n"
239 
240     "vswp d3, d22\n"
241     "vswp d17, d28\n"
242 
243     "vtrn.32 d4, d18\n"
244     "vtrn.32 d5, d19\n"
245     "vtrn.32 d24, d30\n"
246     "vtrn.32 d25, d31\n"
247 
248     "vswp d5, d24\n"
249     "vswp d19, d30\n"
250 
251     "vst1.32 {q0, q1}, [r12]!\n"
252     "vst1.32 {q2, q3}, [r12]!\n"
253     "vst1.32 {q8, q9}, [r12]!\n"
254     "vst1.32 {q10, q11}, [r12]!\n"
255     "vst1.32 {q12, q13}, [r12]!\n"
256     "vst1.32 {q14, q15}, [r12]!\n"
257 
258     :
259     : [ dst_c ] "r"(dst_c), [ src_c ] "r"(src_c), [ stride ] "r"(stride)
260     : "r10", "r12", "q0", "q1", "q2", "q3", "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15");
261   return;
262 }
263 #endif
RowMajor2Col12Major(const float * src_ptr,float * dst_ptr,int row,int col)264 void RowMajor2Col12Major(const float *src_ptr, float *dst_ptr, int row, int col) {
265   const float *src_r = src_ptr;
266   float *dst_r = dst_ptr;
267   int ri = 0;
268   for (; ri < (row / C12NUM * C12NUM); ri += C12NUM) {
269     int ci = 0;
270     for (; ci < (col / C4NUM * C4NUM); ci += C4NUM) {
271       const float *src_c = src_r + ci;
272       float *dst_c = dst_r + ci * C12NUM;
273 #ifdef ENABLE_ARM64
274       RowMajor2Col12Major_arm64(src_c, dst_c, col);
275 #elif ENABLE_ARM32
276       RowMajor2Col12Major_arm32(src_c, dst_c, col);
277 #elif ENABLE_SSE
278       __m128 src1 = _mm_loadu_ps(src_c);
279       __m128 src2 = _mm_loadu_ps(src_c + col);
280       __m128 src3 = _mm_loadu_ps(src_c + 2 * col);
281       __m128 src4 = _mm_loadu_ps(src_c + 3 * col);
282       src_c += 4 * col;
283       __m128 src12L = _mm_unpacklo_ps(src1, src2);
284       __m128 src12H = _mm_unpackhi_ps(src1, src2);
285       __m128 src34L = _mm_unpacklo_ps(src3, src4);
286       __m128 src34H = _mm_unpackhi_ps(src3, src4);
287 
288       __m128 dst0 = _mm_movelh_ps(src12L, src34L);
289       __m128 dst3 = _mm_movehl_ps(src34L, src12L);
290       __m128 dst6 = _mm_movelh_ps(src12H, src34H);
291       __m128 dst9 = _mm_movehl_ps(src34H, src12H);
292 
293       __m128 src5 = _mm_loadu_ps(src_c);
294       __m128 src6 = _mm_loadu_ps(src_c + col);
295       __m128 src7 = _mm_loadu_ps(src_c + 2 * col);
296       __m128 src8 = _mm_loadu_ps(src_c + 3 * col);
297       src_c += 4 * col;
298       __m128 src56L = _mm_unpacklo_ps(src5, src6);
299       __m128 src56H = _mm_unpackhi_ps(src5, src6);
300       __m128 src78L = _mm_unpacklo_ps(src7, src8);
301       __m128 src78H = _mm_unpackhi_ps(src7, src8);
302       __m128 dst1 = _mm_movelh_ps(src56L, src78L);
303       __m128 dst4 = _mm_movehl_ps(src78L, src56L);
304       __m128 dst7 = _mm_movelh_ps(src56H, src78H);
305       __m128 dst10 = _mm_movehl_ps(src78H, src56H);
306 
307       __m128 src9 = _mm_loadu_ps(src_c);
308       __m128 src10 = _mm_loadu_ps(src_c + col);
309       __m128 src11 = _mm_loadu_ps(src_c + 2 * col);
310       __m128 src12 = _mm_loadu_ps(src_c + 3 * col);
311       src_c += 4 * col;
312       __m128 src910L = _mm_unpacklo_ps(src9, src10);
313       __m128 src910H = _mm_unpackhi_ps(src9, src10);
314       __m128 src1112L = _mm_unpacklo_ps(src11, src12);
315       __m128 src1112H = _mm_unpackhi_ps(src11, src12);
316       __m128 dst2 = _mm_movelh_ps(src910L, src1112L);
317       __m128 dst5 = _mm_movehl_ps(src1112L, src910L);
318       __m128 dst8 = _mm_movelh_ps(src910H, src1112H);
319       __m128 dst11 = _mm_movehl_ps(src1112H, src910H);
320 
321       _mm_storeu_ps(dst_c, dst0);
322       _mm_storeu_ps(dst_c + 4, dst1);
323       _mm_storeu_ps(dst_c + 8, dst2);
324       _mm_storeu_ps(dst_c + 12, dst3);
325       _mm_storeu_ps(dst_c + 16, dst4);
326       _mm_storeu_ps(dst_c + 20, dst5);
327       _mm_storeu_ps(dst_c + 24, dst6);
328       _mm_storeu_ps(dst_c + 28, dst7);
329       _mm_storeu_ps(dst_c + 32, dst8);
330       _mm_storeu_ps(dst_c + 36, dst9);
331       _mm_storeu_ps(dst_c + 40, dst10);
332       _mm_storeu_ps(dst_c + 44, dst11);
333 #else
334       for (int tr = 0; tr < C12NUM; tr++) {
335         for (int tc = 0; tc < C4NUM; tc++) {
336           dst_c[tc * C12NUM + tr] = src_c[tr * col + tc];
337         }
338       }
339 #endif
340     }
341     for (; ci < col; ci++) {
342       const float *src_c = src_r + ci;
343       float *dst_c = dst_r + ci * C12NUM;
344       for (int i = 0; i < C12NUM; i++) {
345         dst_c[i] = src_c[i * col];
346       }
347     }
348     src_r += C12NUM * col;
349     dst_r += C12NUM * col;
350   }
351   for (; ri < row; ri++, dst_r++, src_r += col) {
352     for (int i = 0; i < col; i++) {
353       dst_r[i * C12NUM] = src_r[i];
354     }
355   }
356   for (; ri < UP_ROUND(row, C12NUM); ri++, dst_r++) {
357     for (int i = 0; i < col; i++) {
358       dst_r[i * C12NUM] = 0;
359     }
360   }
361 }
362 
363 #ifdef ENABLE_ARM64
RowMajor2Col8Major_arm64(const float * src_c,float * dst_c,size_t col)364 void RowMajor2Col8Major_arm64(const float *src_c, float *dst_c, size_t col) {
365   size_t stride = col * sizeof(float);
366   asm volatile(
367     "mov x10, %[src_c]\n"
368     "mov x11, %[dst_c]\n"
369 
370     "ld1 {v0.4s, v1.4s}, [x10], %[stride]\n"
371     "ld1 {v2.4s, v3.4s}, [x10], %[stride]\n"
372     "ld1 {v4.4s, v5.4s}, [x10], %[stride]\n"
373     "ld1 {v6.4s, v7.4s}, [x10], %[stride]\n"
374 
375     "zip1 v8.4s, v0.4s, v2.4s\n"
376     "zip2 v9.4s, v0.4s, v2.4s\n"
377     "zip1 v10.4s, v4.4s, v6.4s\n"
378     "zip2 v11.4s, v4.4s, v6.4s\n"
379 
380     "ld1 {v16.4s, v17.4s}, [x10], %[stride]\n"
381     "ld1 {v18.4s, v19.4s}, [x10], %[stride]\n"
382     "ld1 {v20.4s, v21.4s}, [x10], %[stride]\n"
383     "ld1 {v22.4s, v23.4s}, [x10], %[stride]\n"
384 
385     "zip1 v12.4s, v1.4s, v3.4s\n"
386     "zip2 v13.4s, v1.4s, v3.4s\n"
387     "zip1 v14.4s, v5.4s, v7.4s\n"
388     "zip2 v15.4s, v5.4s, v7.4s\n"
389 
390     "trn1 v0.2d, v8.2d, v10.2d\n"
391     "trn2 v1.2d, v8.2d, v10.2d\n"
392     "trn1 v2.2d, v9.2d, v11.2d\n"
393     "trn2 v3.2d, v9.2d, v11.2d\n"
394 
395     "zip1 v24.4s, v16.4s, v18.4s\n"
396     "zip2 v25.4s, v16.4s, v18.4s\n"
397     "zip1 v26.4s, v20.4s, v22.4s\n"
398     "zip2 v27.4s, v20.4s, v22.4s\n"
399 
400     "trn1 v4.2d, v12.2d, v14.2d\n"
401     "trn2 v5.2d, v12.2d, v14.2d\n"
402     "trn1 v6.2d, v13.2d, v15.2d\n"
403     "trn2 v7.2d, v13.2d, v15.2d\n"
404 
405     "zip1 v28.4s, v17.4s, v19.4s\n"
406     "zip2 v29.4s, v17.4s, v19.4s\n"
407     "zip1 v30.4s, v21.4s, v23.4s\n"
408     "zip2 v31.4s, v21.4s, v23.4s\n"
409 
410     "trn1 v16.2d, v24.2d, v26.2d\n"
411     "trn2 v17.2d, v24.2d, v26.2d\n"
412     "trn1 v18.2d, v25.2d, v27.2d\n"
413     "trn2 v19.2d, v25.2d, v27.2d\n"
414 
415     "trn1 v20.2d, v28.2d, v30.2d\n"
416     "trn2 v21.2d, v28.2d, v30.2d\n"
417     "trn1 v22.2d, v29.2d, v31.2d\n"
418     "trn2 v23.2d, v29.2d, v31.2d\n"
419 
420     "st1 {v0.4s}, [x11], #16\n"
421     "st1 {v16.4s}, [x11], #16\n"
422     "st1 {v1.4s}, [x11], #16\n"
423     "st1 {v17.4s}, [x11], #16\n"
424     "st1 {v2.4s}, [x11], #16\n"
425     "st1 {v18.4s}, [x11], #16\n"
426     "st1 {v3.4s}, [x11], #16\n"
427     "st1 {v19.4s}, [x11], #16\n"
428     "st1 {v4.4s}, [x11], #16\n"
429     "st1 {v20.4s}, [x11], #16\n"
430     "st1 {v5.4s}, [x11], #16\n"
431     "st1 {v21.4s}, [x11], #16\n"
432     "st1 {v6.4s}, [x11], #16\n"
433     "st1 {v22.4s}, [x11], #16\n"
434     "st1 {v7.4s}, [x11], #16\n"
435     "st1 {v23.4s}, [x11], #16\n"
436 
437     :
438     : [ dst_c ] "r"(dst_c), [ src_c ] "r"(src_c), [ stride ] "r"(stride)
439     : "x10", "x11", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14",
440       "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30",
441       "v31");
442   return;
443 }
444 #endif
445 #ifdef ENABLE_ARM32
446 #ifndef SUPPORT_NNIE
RowMajor2Col8Major_arm32(const float * src_c,float * dst_c,size_t col)447 void RowMajor2Col8Major_arm32(const float *src_c, float *dst_c, size_t col) {
448   size_t stride = col * sizeof(float);
449   asm volatile(
450     "mov r10, %[src_c]\n"
451     "mov r11, %[dst_c]\n"
452 
453     "vld1.32 {q0}, [r10], %[stride]\n"
454     "vld1.32 {q2}, [r10], %[stride]\n"
455     "vld1.32 {q4}, [r10], %[stride]\n"
456     "vld1.32 {q6}, [r10], %[stride]\n"
457 
458     "vtrn.32 d0, d4\n"
459     "vtrn.32 d1, d5\n"
460     "vtrn.32 d8, d12\n"
461     "vtrn.32 d9, d13\n"
462 
463     "vld1.32 {q1}, [r10], %[stride]\n"
464     "vld1.32 {q3}, [r10], %[stride]\n"
465     "vld1.32 {q5}, [r10], %[stride]\n"
466     "vld1.32 {q7}, [r10], %[stride]\n"
467 
468     "vswp d1, d8\n"
469     "vswp d5, d12\n"
470 
471     "vtrn.32 d2, d6\n"
472     "vtrn.32 d3, d7\n"
473     "vtrn.32 d10, d14\n"
474     "vtrn.32 d11, d15\n"
475 
476     "vswp d3, d10\n"
477     "vswp d7, d14\n"
478 
479     "vst1.32 {q0, q1}, [r11]!\n"
480     "vst1.32 {q2, q3}, [r11]!\n"
481     "vst1.32 {q4, q5}, [r11]!\n"
482     "vst1.32 {q6, q7}, [r11]!\n"
483 
484     :
485     : [ dst_c ] "r"(dst_c), [ src_c ] "r"(src_c), [ stride ] "r"(stride)
486     : "r10", "r11", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7");
487   return;
488 }
489 #else
RowMajor2Col8Major_arm32(const float * src_c,float * dst_c,size_t col)490 void RowMajor2Col8Major_arm32(const float *src_c, float *dst_c, size_t col) {
491   size_t stride = col * sizeof(float);
492   asm volatile(
493     "mov r10, %[src_c]\n"
494     "mov r7, %[dst_c]\n"
495 
496     "vld1.32 {q0}, [r10], %[stride]\n"
497     "vld1.32 {q2}, [r10], %[stride]\n"
498     "vld1.32 {q4}, [r10], %[stride]\n"
499     "vld1.32 {q6}, [r10], %[stride]\n"
500 
501     "vtrn.32 d0, d4\n"
502     "vtrn.32 d1, d5\n"
503     "vtrn.32 d8, d12\n"
504     "vtrn.32 d9, d13\n"
505 
506     "vld1.32 {q1}, [r10], %[stride]\n"
507     "vld1.32 {q3}, [r10], %[stride]\n"
508     "vld1.32 {q5}, [r10], %[stride]\n"
509     "vld1.32 {q7}, [r10], %[stride]\n"
510 
511     "vswp d1, d8\n"
512     "vswp d5, d12\n"
513 
514     "vtrn.32 d2, d6\n"
515     "vtrn.32 d3, d7\n"
516     "vtrn.32 d10, d14\n"
517     "vtrn.32 d11, d15\n"
518 
519     "vswp d3, d10\n"
520     "vswp d7, d14\n"
521 
522     "vst1.32 {q0, q1}, [r7]!\n"
523     "vst1.32 {q2, q3}, [r7]!\n"
524     "vst1.32 {q4, q5}, [r7]!\n"
525     "vst1.32 {q6, q7}, [r7]!\n"
526 
527     :
528     : [ dst_c ] "r"(dst_c), [ src_c ] "r"(src_c), [ stride ] "r"(stride)
529     : "r10", "r7", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7");
530   return;
531 }
532 #endif
533 #endif
RowMajor2Col8Major(const float * src_ptr,float * dst_ptr,int row,int col)534 void RowMajor2Col8Major(const float *src_ptr, float *dst_ptr, int row, int col) {
535   int row8 = row / C8NUM * C8NUM;
536 #ifdef ENABLE_ARM64
537   int col_skip = col / C8NUM * C8NUM;
538   int skip_size = C8NUM;
539 #else
540   int col_skip = col / C4NUM * C4NUM;
541   int skip_size = C4NUM;
542 #endif
543   const float *src_r = src_ptr;
544   float *dst_r = dst_ptr;
545 
546   int ri = 0;
547   for (; ri < row8; ri += C8NUM) {
548     int ci = 0;
549     for (; ci < col_skip; ci += skip_size) {
550       const float *src_c = src_r + ci;
551       float *dst_c = dst_r + ci * C8NUM;
552 
553 #ifdef ENABLE_ARM64
554       RowMajor2Col8Major_arm64(src_c, dst_c, col);
555 #elif ENABLE_ARM32
556       RowMajor2Col8Major_arm32(src_c, dst_c, col);
557 #elif ENABLE_SSE
558       __m128 src1 = _mm_loadu_ps(src_c);
559       __m128 src2 = _mm_loadu_ps(src_c + col);
560       __m128 src3 = _mm_loadu_ps(src_c + 2 * col);
561       __m128 src4 = _mm_loadu_ps(src_c + 3 * col);
562       src_c += 4 * col;
563       __m128 src12L = _mm_unpacklo_ps(src1, src2);  // x5
564       __m128 src12H = _mm_unpackhi_ps(src1, src2);  // x1
565       __m128 src34L = _mm_unpacklo_ps(src3, src4);  // x
566       __m128 src34H = _mm_unpackhi_ps(src3, src4);
567       _mm_storeu_ps(dst_c, _mm_movelh_ps(src12L, src34L));
568       _mm_storeu_ps(dst_c + 8, _mm_movehl_ps(src34L, src12L));
569       _mm_storeu_ps(dst_c + 16, _mm_movelh_ps(src12H, src34H));
570       _mm_storeu_ps(dst_c + 24, _mm_movehl_ps(src34H, src12H));
571 
572       __m128 src5 = _mm_loadu_ps(src_c);
573       __m128 src6 = _mm_loadu_ps(src_c + col);
574       __m128 src7 = _mm_loadu_ps(src_c + 2 * col);
575       __m128 src8 = _mm_loadu_ps(src_c + 3 * col);
576       src_c += 4 * col;
577       __m128 src56L = _mm_unpacklo_ps(src5, src6);
578       __m128 src56H = _mm_unpackhi_ps(src5, src6);
579       __m128 src78L = _mm_unpacklo_ps(src7, src8);
580       __m128 src78H = _mm_unpackhi_ps(src7, src8);
581       _mm_storeu_ps(dst_c + 4, _mm_movelh_ps(src56L, src78L));
582       _mm_storeu_ps(dst_c + 12, _mm_movehl_ps(src78L, src56L));
583       _mm_storeu_ps(dst_c + 20, _mm_movelh_ps(src56H, src78H));
584       _mm_storeu_ps(dst_c + 28, _mm_movehl_ps(src78H, src56H));
585 #else
586       for (int tr = 0; tr < 8; tr++) {
587         for (int tc = 0; tc < 4; tc++) {
588           dst_c[tc * 8 + tr] = src_c[tr * col + tc];
589         }
590       }
591 #endif
592     }
593     for (; ci < col; ci++) {
594       const float *src_c = src_r + ci;
595       float *dst_c = dst_r + ci * C8NUM;
596       for (int i = 0; i < C8NUM; i++) {
597         dst_c[i] = src_c[i * col];
598       }
599     }
600     src_r += C8NUM * col;
601     dst_r += C8NUM * col;
602   }
603   for (; ri < row; ri++, src_r += col, dst_r++) {
604     for (int i = 0; i < col; i++) {
605       dst_r[i * C8NUM] = src_r[i];
606     }
607   }
608 
609   for (; ri < UP_ROUND(row, C8NUM); ri++, dst_r++) {
610     for (int i = 0; i < col; i++) {
611       dst_r[i * C8NUM] = 0;
612     }
613   }
614 }
615 
RowMajor2Col16Major(const float * src_ptr,float * dst_ptr,int row,int col)616 void RowMajor2Col16Major(const float *src_ptr, float *dst_ptr, int row, int col) {
617   int row16 = row / C16NUM * C16NUM;
618   int col8 = col / C8NUM * C8NUM;
619   const float *src_r = src_ptr;
620   float *dst_r = dst_ptr;
621 
622   int ri = 0;
623   for (; ri < row16; ri += C16NUM) {
624     int ci = 0;
625     for (; ci < col8; ci += C8NUM) {
626       const float *src_c = src_r + ci;
627       float *dst_c = dst_r + ci * C16NUM;
628 #ifdef ENABLE_AVX
629       Transpose8X8Fp32Avx(src_c, dst_c, col, C16NUM);
630       Transpose8X8Fp32Avx(src_c + C8NUM * col, dst_c + C8NUM, col, C16NUM);
631 #endif
632       for (int tr = 0; tr < C16NUM; tr++) {
633         for (int tc = 0; tc < C8NUM; tc++) {
634           dst_c[tc * C16NUM + tr] = src_c[tr * col + tc];
635         }
636       }
637     }
638     for (; ci < col; ci++) {
639       const float *src_c = src_r + ci;
640       float *dst_c = dst_r + ci * C16NUM;
641       for (int i = 0; i < C16NUM; i++) {
642         dst_c[i] = src_c[i * col];
643       }
644     }
645     src_r += C16NUM * col;
646     dst_r += C16NUM * col;
647   }
648   for (; ri < row; ri++) {
649     for (int i = 0; i < col; i++) {
650       dst_r[i * C16NUM] = src_r[i];
651     }
652     src_r += col;
653     dst_r += 1;
654   }
655 
656   int total_row = UP_ROUND(row, C16NUM);
657   for (; ri < total_row; ri++) {
658     for (int i = 0; i < col; i++) {
659       dst_r[i * C16NUM] = 0;
660     }
661     dst_r += 1;
662   }
663 }
664 
RowMajor2Col32Major(const float * src_ptr,float * dst_ptr,int row,int col)665 void RowMajor2Col32Major(const float *src_ptr, float *dst_ptr, int row, int col) {
666   // Not exactly aligned to 32, but aligned to 24 or 16 or 8 If 32 is not met.
667   int col_block_num = UP_DIV(col, C8NUM);
668   int col_block = C4NUM;
669   for (int i = 0; i < col_block_num; i += col_block) {
670     col_block = MSMIN(C4NUM, col_block_num - i);  // max_tile = 4
671     int index = i * row * C8NUM;
672     int col_remainder = MSMIN(C8NUM * col_block, col - i * C8NUM);
673     for (int ir = 0; ir < row; ++ir) {
674       for (int oc = 0; oc < col_remainder; ++oc) {
675         int oc_index = oc * row + ir + index;
676         dst_ptr[oc] = src_ptr[oc_index];
677       }
678       dst_ptr += col_block * C8NUM;
679     }
680   }
681 }
682 
RowMajor2Col6Major(const float * src_ptr,float * dst_ptr,int row,int col)683 void RowMajor2Col6Major(const float *src_ptr, float *dst_ptr, int row, int col) {
684   int totalRow = UP_ROUND(row, C6NUM);
685   int row6 = row / C6NUM * C6NUM;
686   int col8 = col / C8NUM * C8NUM;
687   const float *src_r = src_ptr;
688   float *dst_r = dst_ptr;
689 
690   int ri = 0;
691   for (; ri < row6; ri += C6NUM) {
692     int ci = 0;
693     for (; ci < col8; ci += C8NUM) {
694       const float *src_c = src_r + ci;
695       float *dst_c = dst_r + ci * C6NUM;
696 
697 #ifdef ENABLE_AVX
698       __m256 src0 = _mm256_loadu_ps(src_c);
699       __m256 src1 = _mm256_loadu_ps(src_c + col);
700       __m256 src2 = _mm256_loadu_ps(src_c + 2 * col);
701       __m256 src3 = _mm256_loadu_ps(src_c + 3 * col);
702       __m256 src4 = _mm256_loadu_ps(src_c + 4 * col);
703       __m256 src5 = _mm256_loadu_ps(src_c + 5 * col);
704       __m256 trans0 = _mm256_unpacklo_ps(src0, src1);
705       __m256 trans1 = _mm256_unpacklo_ps(src2, src3);
706       __m256 trans2 = _mm256_unpacklo_ps(src4, src5);
707       __m256 trans3 = _mm256_unpackhi_ps(src0, src1);
708       __m256 trans4 = _mm256_unpackhi_ps(src2, src3);
709       __m256 trans5 = _mm256_unpackhi_ps(src4, src5);
710       __m128 lo0 = _mm256_castps256_ps128(trans0);
711       __m128 lo1 = _mm256_castps256_ps128(trans1);
712       __m128 lo2 = _mm256_castps256_ps128(trans2);
713       __m128 lo3 = _mm256_castps256_ps128(trans3);
714       __m128 lo4 = _mm256_castps256_ps128(trans4);
715       __m128 lo5 = _mm256_castps256_ps128(trans5);
716       __m128 hi0 = _mm256_extractf128_ps(trans0, 1);
717       __m128 hi1 = _mm256_extractf128_ps(trans1, 1);
718       __m128 hi2 = _mm256_extractf128_ps(trans2, 1);
719       __m128 hi3 = _mm256_extractf128_ps(trans3, 1);
720       __m128 hi4 = _mm256_extractf128_ps(trans4, 1);
721       __m128 hi5 = _mm256_extractf128_ps(trans5, 1);
722       __m128 res0 = _mm_shuffle_ps(lo0, lo1, _MM_SHUFFLE(1, 0, 1, 0));
723       __m128 res1 = _mm_shuffle_ps(lo2, lo0, _MM_SHUFFLE(3, 2, 1, 0));
724       __m128 res2 = _mm_shuffle_ps(lo1, lo2, _MM_SHUFFLE(3, 2, 3, 2));
725       __m128 res3 = _mm_shuffle_ps(lo3, lo4, _MM_SHUFFLE(1, 0, 1, 0));
726       __m128 res4 = _mm_shuffle_ps(lo5, lo3, _MM_SHUFFLE(3, 2, 1, 0));
727       __m128 res5 = _mm_shuffle_ps(lo4, lo5, _MM_SHUFFLE(3, 2, 3, 2));
728       __m128 res6 = _mm_shuffle_ps(hi0, hi1, _MM_SHUFFLE(1, 0, 1, 0));
729       __m128 res7 = _mm_shuffle_ps(hi2, hi0, _MM_SHUFFLE(3, 2, 1, 0));
730       __m128 res8 = _mm_shuffle_ps(hi1, hi2, _MM_SHUFFLE(3, 2, 3, 2));
731       __m128 res9 = _mm_shuffle_ps(hi3, hi4, _MM_SHUFFLE(1, 0, 1, 0));
732       __m128 res10 = _mm_shuffle_ps(hi5, hi3, _MM_SHUFFLE(3, 2, 1, 0));
733       __m128 res11 = _mm_shuffle_ps(hi4, hi5, _MM_SHUFFLE(3, 2, 3, 2));
734       _mm_storeu_ps(dst_c, res0);
735       _mm_storeu_ps(dst_c + 4, res1);
736       _mm_storeu_ps(dst_c + 8, res2);
737       _mm_storeu_ps(dst_c + 12, res3);
738       _mm_storeu_ps(dst_c + 16, res4);
739       _mm_storeu_ps(dst_c + 20, res5);
740       _mm_storeu_ps(dst_c + 24, res6);
741       _mm_storeu_ps(dst_c + 28, res7);
742       _mm_storeu_ps(dst_c + 32, res8);
743       _mm_storeu_ps(dst_c + 36, res9);
744       _mm_storeu_ps(dst_c + 40, res10);
745       _mm_storeu_ps(dst_c + 44, res11);
746 #else
747       for (int tr = 0; tr < C6NUM; tr++) {
748         for (int tc = 0; tc < C8NUM; tc++) {
749           dst_c[tc * C6NUM + tr] = src_c[tr * col + tc];
750         }
751       }
752 #endif
753     }
754     for (; ci < col; ci++) {
755       const float *src_c = src_r + ci;
756       float *dst_c = dst_r + ci * C6NUM;
757       for (int i = 0; i < C6NUM; i++) {
758         dst_c[i] = src_c[i * col];
759       }
760     }
761     src_r += C6NUM * col;
762     dst_r += C6NUM * col;
763   }
764 
765   for (; ri < row; ri++) {
766     for (int i = 0; i < col; i++) {
767       dst_r[i * C6NUM] = src_r[i];
768     }
769     src_r += col;
770     dst_r += 1;
771   }
772 
773   for (; ri < totalRow; ri++) {
774     for (int i = 0; i < col; i++) {
775       dst_r[i * C6NUM] = 0;
776     }
777     dst_r += 1;
778   }
779 }
780 
RowMajor2Col4Major(const float * src_ptr,float * dst_ptr,int row,int col)781 void RowMajor2Col4Major(const float *src_ptr, float *dst_ptr, int row, int col) {
782   int total_row = UP_ROUND(row, C4NUM);
783   int row4 = row / C4NUM * C4NUM;
784   int col4 = col / C4NUM * C4NUM;
785   const float *src_r = src_ptr;
786   float *dst_r = dst_ptr;
787 
788   int ri = 0;
789   for (; ri < row4; ri += C4NUM) {
790     int ci = 0;
791     for (; ci < col4; ci += C4NUM) {
792       const float *src_c = src_r + ci;
793       float *dst_c = dst_r + ci * C4NUM;
794 
795 #ifdef ENABLE_ARM32
796       int stride = col * 4;
797       asm volatile(
798         "mov r10, %[src_c]\n"
799         "mov r12, %[dst_c]\n"
800 
801         "vld1.32 {q0}, [r10], %[stride]\n"
802         "vld1.32 {q1}, [r10], %[stride]\n"
803         "vld1.32 {q2}, [r10], %[stride]\n"
804         "vld1.32 {q3}, [r10], %[stride]\n"
805 
806         "vtrn.32 d0, d2\n"
807         "vtrn.32 d1, d3\n"
808         "vtrn.32 d4, d6\n"
809         "vtrn.32 d5, d7\n"
810 
811         "vswp d1, d4\n"
812         "vswp d3, d6\n"
813 
814         "vst1.32 {q0}, [r12]!\n"
815         "vst1.32 {q1}, [r12]!\n"
816         "vst1.32 {q2}, [r12]!\n"
817         "vst1.32 {q3}, [r12]!\n"
818 
819         :
820         : [ dst_c ] "r"(dst_c), [ src_c ] "r"(src_c), [ stride ] "r"(stride)
821         : "r10", "r12", "q0", "q1", "q2", "q3");
822 #elif ENABLE_SSE
823       __m128 src1 = _mm_loadu_ps(src_c);
824       __m128 src2 = _mm_loadu_ps(src_c + col);
825       __m128 src3 = _mm_loadu_ps(src_c + 2 * col);
826       __m128 src4 = _mm_loadu_ps(src_c + 3 * col);
827       src_c += 4 * col;
828       __m128 src12L = _mm_unpacklo_ps(src1, src2);
829       __m128 src12H = _mm_unpackhi_ps(src1, src2);
830       __m128 src34L = _mm_unpacklo_ps(src3, src4);
831       __m128 src34H = _mm_unpackhi_ps(src3, src4);
832 
833       __m128 dst0 = _mm_movelh_ps(src12L, src34L);
834       __m128 dst1 = _mm_movehl_ps(src34L, src12L);
835       __m128 dst2 = _mm_movelh_ps(src12H, src34H);
836       __m128 dst3 = _mm_movehl_ps(src34H, src12H);
837 
838       _mm_storeu_ps(dst_c, dst0);
839       _mm_storeu_ps(dst_c + 4, dst1);
840       _mm_storeu_ps(dst_c + 8, dst2);
841       _mm_storeu_ps(dst_c + 12, dst3);
842 #else
843       for (size_t tr = 0; tr < C4NUM; tr++) {
844         for (size_t tc = 0; tc < C4NUM; tc++) {
845           dst_c[tc * C4NUM + tr] = src_c[tr * col + tc];
846         }
847       }
848 #endif
849     }
850     for (; ci < col; ci++) {
851       const float *src_c = src_r + ci;
852       float *dst_c = dst_r + ci * C4NUM;
853       for (int i = 0; i < C4NUM; i++) {
854         dst_c[i] = src_c[i * col];
855       }
856     }
857     src_r += C4NUM * col;
858     dst_r += C4NUM * col;
859   }
860   for (; ri < row; ri++) {
861     for (int i = 0; i < col; i++) {
862       dst_r[i * C4NUM] = src_r[i];
863     }
864     src_r += col;
865     dst_r += 1;
866   }
867 
868   for (; ri < total_row; ri++) {
869     for (int i = 0; i < col; i++) {
870       dst_r[i * C4NUM] = 0;
871     }
872     dst_r += 1;
873   }
874 }
875 
876 #ifndef ENABLE_ARM
MatVecMulFp32(const float * a,const float * b,float * c,const float * bias,int act_type,int depth,int col)877 void MatVecMulFp32(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int col) {
878   for (int ci = 0; ci < col; ci++) {
879     float value = 0;
880     for (int di = 0; di < depth; di++) {
881       value += a[di] * b[ci * depth + di];
882     }
883     if (bias != NULL) value += bias[ci];
884     if (act_type == ActType_Relu6) value = MSMIN(6.0f, value);
885     if (act_type == ActType_Relu || act_type == ActType_Relu6) value = MSMAX(0.0f, value);
886     c[ci] = value;
887   }
888 }
889 #endif
890 
891 #ifdef ENABLE_ARM64
892 // 4x8
MatVecMulFp32Neon64(const float * a,const float * b,float * c,const float * bias,int act_type,int depth,int col,int align_col)893 void MatVecMulFp32Neon64(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int col,
894                          int align_col) {
895   int ci = 0;
896   for (; ci < align_col - C8NUM + 1; ci += C8NUM) {
897     float32x4_t acc_0;
898     float32x4_t acc_1;
899     if (bias != NULL) {
900       acc_0 = vld1q_f32(bias + ci);
901       acc_1 = vld1q_f32(bias + ci + C4NUM);
902     } else {
903       acc_0 = vdupq_n_f32(0.0f);
904       acc_1 = vdupq_n_f32(0.0f);
905     }
906     const float *bv_base = b + ci * depth;
907     int di = 0;
908     for (; di < depth - C4NUM + 1; di += C4NUM) {
909       float32x4_t av = vld1q_f32(a + di);
910       float32x4_t bv_00 = vld1q_f32(bv_base);
911       float32x4_t bv_10 = vld1q_f32(bv_base + C4NUM);
912       bv_base += C8NUM;
913       float32x4_t bv_01 = vld1q_f32(bv_base);
914       float32x4_t bv_11 = vld1q_f32(bv_base + C4NUM);
915       bv_base += C8NUM;
916       float32x4_t bv_02 = vld1q_f32(bv_base);
917       float32x4_t bv_12 = vld1q_f32(bv_base + C4NUM);
918       bv_base += C8NUM;
919       float32x4_t bv_03 = vld1q_f32(bv_base);
920       float32x4_t bv_13 = vld1q_f32(bv_base + C4NUM);
921       bv_base += C8NUM;
922       acc_0 = vmlaq_n_f32(acc_0, bv_00, av[0]);
923       acc_1 = vmlaq_n_f32(acc_1, bv_10, av[0]);
924       acc_0 = vmlaq_n_f32(acc_0, bv_01, av[1]);
925       acc_1 = vmlaq_n_f32(acc_1, bv_11, av[1]);
926       acc_0 = vmlaq_n_f32(acc_0, bv_02, av[2]);
927       acc_1 = vmlaq_n_f32(acc_1, bv_12, av[2]);
928       acc_0 = vmlaq_n_f32(acc_0, bv_03, av[3]);
929       acc_1 = vmlaq_n_f32(acc_1, bv_13, av[3]);
930     }
931     if (di < depth) {
932       for (; di < depth; ++di) {
933         float ai = a[di];
934         float32x4_t bv0 = vld1q_f32(bv_base);
935         float32x4_t bv1 = vld1q_f32(bv_base + C4NUM);
936         acc_0 = vmlaq_n_f32(acc_0, bv0, ai);
937         acc_1 = vmlaq_n_f32(acc_1, bv1, ai);
938         bv_base += C8NUM;
939       }
940     }  // only save actual col num data
941     if (ci + C4NUM - 1 >= col) {
942       int c_remain = col - ci;
943       for (int i = 0; i < c_remain; ++i) {
944         if (act_type == ActType_Relu) {
945           c[i] = MSMAX(acc_0[i], 0.0f);
946         } else if (act_type == ActType_Relu6) {
947           c[i] = MSMIN(MSMAX(acc_0[i], 0.0f), 6.0f);
948         } else {
949           c[i] = acc_0[i];
950         }
951       }
952       return;
953     }
954     if (act_type == ActType_Relu) {
955       acc_0 = vmaxq_f32(acc_0, vdupq_n_f32(0.0f));
956     } else if (act_type == ActType_Relu6) {
957       acc_0 = vminq_f32(vmaxq_f32(acc_0, vdupq_n_f32(0.0f)), vdupq_n_f32(6.0f));
958     }
959     vst1q_f32(c, acc_0);
960     if (ci + C8NUM - 1 >= col) {
961       int c_remain = col - ci - C4NUM;
962       for (int i = 0; i < c_remain; ++i) {
963         if (act_type == ActType_Relu) {
964           c[C4NUM + i] = MSMAX(acc_1[i], 0.0f);
965         } else if (act_type == ActType_Relu6) {
966           c[C4NUM + i] = MSMIN(MSMAX(acc_1[i], 0.0f), 6.0f);
967         } else {
968           c[C4NUM + i] = acc_1[i];
969         }
970       }
971       return;
972     }
973     if (act_type == ActType_Relu) {
974       acc_1 = vmaxq_f32(acc_1, vdupq_n_f32(0.0f));
975     } else if (act_type == ActType_Relu6) {
976       acc_1 = vminq_f32(vmaxq_f32(acc_1, vdupq_n_f32(0.0f)), vdupq_n_f32(6.0f));
977     }
978     vst1q_f32(c + C4NUM, acc_1);
979     c += C8NUM;
980   }
981 }
982 #endif
983 
MatMul12x8(const float * a,const float * b,float * dst,const float * bias,ActType act_type,int deep,int row,int col,int stride,int out_type)984 void MatMul12x8(const float *a, const float *b, float *dst, const float *bias, ActType act_type, int deep, int row,
985                 int col, int stride, int out_type) {
986   if (out_type == OutType_Nhwc) {
987     for (int r = 0; r < row; r++) {
988       for (int c = 0; c < col; c++) {
989         int r12div = r / 12, r12mod = r % 12;
990         int c8div = c / 8, c8mod = c % 8;
991         size_t ci = r * stride + c;
992         float value = 0;
993         for (int d = 0; d < deep; d++) {
994           size_t ai = r12div * deep * 12 + d * 12 + r12mod;
995           size_t bi = c8div * deep * 8 + d * 8 + c8mod;
996           value = value + a[ai] * b[bi];
997         }
998         ADD_BIAS(value, bias, c)
999         DO_RELU(value, act_type)
1000         DO_RELU6(value, act_type)
1001         dst[ci] = value;
1002       }
1003     }
1004   } else if (out_type == OutType_C8) {
1005     int col_8 = UP_ROUND(col, C8NUM);
1006     int row_12 = UP_ROUND(row, C12NUM);
1007     for (int r = 0; r < row_12; r++) {
1008       for (int c = 0; c < col_8; c++) {
1009         int r12div = r / C12NUM, r12mod = r % C12NUM;
1010         int c8div = c / C8NUM, c8mod = c % C8NUM;
1011         size_t ci = (c8div * C8NUM * row_12 + r * C8NUM + c8mod);
1012         float value = 0;
1013         for (int d = 0; d < deep; d++) {
1014           size_t ai = r12div * deep * C12NUM + d * C12NUM + r12mod;
1015           size_t bi = c8div * deep * C8NUM + d * C8NUM + c8mod;
1016           value = value + a[ai] * b[bi];
1017         }
1018         ADD_BIAS(value, bias, c)
1019         DO_RELU(value, act_type)
1020         DO_RELU6(value, act_type)
1021         dst[ci] = value;
1022       }
1023     }
1024   } else {
1025     for (int i = 0; i < row; ++i) {
1026       int src_r_offset = i;
1027       int dst_r_offset = i * col * stride;
1028       for (int j = 0; j < col; ++j) {
1029         int c8div = j / 8, c8mod = j % 8;
1030         size_t ci = dst_r_offset + c8div * 8 * stride + c8mod;
1031         float value = 0;
1032         for (int d = 0; d < deep; ++d) {
1033           size_t ai = src_r_offset + d * C12NUM;
1034           size_t bi = c8div * deep * 8 + d * 8 + c8mod;
1035           value = value + a[ai] * b[bi];
1036         }
1037         ADD_BIAS(value, bias, j)
1038         DO_RELU(value, act_type)
1039         DO_RELU6(value, act_type)
1040         dst[ci] = value;
1041       }
1042     }
1043   }
1044 }
1045 
MatMulOpt(const float * a,const float * b,float * c,const float * bias,ActType act_type,int deep,int row,int col,size_t stride,int out_type)1046 void MatMulOpt(const float *a, const float *b, float *c, const float *bias, ActType act_type, int deep, int row,
1047                int col, size_t stride, int out_type) {
1048 #ifdef ENABLE_ARM64
1049   if (out_type == OutType_C8) {
1050     MatmulFloatNeon64(a, b, c, bias, (int)act_type, deep, row, col, stride, 0, 0);
1051   } else {
1052     MatmulFloatNeon64Opt(a, b, c, bias, (int)act_type, deep, row, col, stride, (int)(out_type));
1053   }
1054 #elif ENABLE_ARM32
1055   if (out_type == OutType_C8) {
1056     MatmulFloatNeon32(a, b, c, bias, (int)act_type, deep, row, col, stride, 0, 0);
1057   } else if (out_type == OutType_Nhwc) {
1058     MatmulFloatNeon32Opt12x4(a, b, c, bias, (int)act_type, deep, row, col, stride, 1);
1059   } else {
1060     MatmulFloatNeon32Opt(a, b, c, bias, (int)act_type, deep, row, col, stride, (int)(out_type));
1061   }
1062 #elif ENABLE_AVX
1063   MatmulFloatAvxOpt(a, b, c, bias, (size_t)act_type, deep, row, col, stride, (size_t)(out_type));
1064 #elif ENABLE_SSE
1065   MatmulFloatSse64Opt(a, b, c, bias, (int)act_type, deep, row, col, stride, (int)(out_type));
1066 #else
1067   MatMul12x8(a, b, c, bias, act_type, deep, row, col, stride, out_type);
1068 #endif
1069 }
1070 
1071 #ifdef ENABLE_AVX
MatVecMulAvxFp32(const float * a,const float * b,float * c,const float * bias,int act_type,int depth,int cur_col,int col_align)1072 void MatVecMulAvxFp32(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int cur_col,
1073                       int col_align) {
1074   // one time process 32 out_channel
1075   int col_block = C32NUM;
1076   int act_flag = 0;
1077   if (act_type == ActType_Relu6) {
1078     act_flag += 1;
1079   }
1080   if (act_type == ActType_Relu || act_type == ActType_Relu6) {
1081     act_flag += 2;
1082   }
1083   MatVecMulKernel kernel[4] = {MatVecMul1x8Kernel, MatVecMul1x16Kernel, MatVecMul1x24Kernel, MatVecMul1x32Kernel};
1084   const float *bias_data = bias;
1085   for (int col_index = 0; col_index < cur_col; col_index += col_block) {
1086     col_block = cur_col - col_index < col_block ? cur_col - col_index : col_block;
1087     kernel[(col_block >> 3) - 1](c + col_index, a, b + col_index * depth, bias_data, act_flag, 1, col_block >> 3,
1088                                  col_align, depth);
1089     if (bias_data != NULL) {
1090       bias_data += col_block;
1091     }
1092   }
1093 }
1094 
MatVecMul1x32Kernel(float * dst,const float * src,const float * weight,const float * bias,size_t act_flag,size_t row_block,size_t col_block,size_t col_algin,size_t deep)1095 void MatVecMul1x32Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t act_flag,
1096                          size_t row_block, size_t col_block, size_t col_algin, size_t deep) {
1097   asm volatile(
1098     "cmpq $0, %2\n"
1099     "je 0f\n"
1100     "vmovups (%2), %%ymm0\n"
1101     "vmovups 0x20(%2), %%ymm1\n"
1102     "vmovups 0x40(%2), %%ymm2\n"
1103     "vmovups 0x60(%2), %%ymm3\n"
1104     "jmp 1f\n"
1105     "0:\n"
1106     "vxorps %%ymm0, %%ymm0, %%ymm0\n"
1107     "vxorps %%ymm1, %%ymm1, %%ymm1\n"
1108     "vxorps %%ymm2, %%ymm2, %%ymm2\n"
1109     "vxorps %%ymm3, %%ymm3, %%ymm3\n"
1110     "1:\n"  // deep_c8
1111     "movq %3, %%rcx\n"
1112     "shr $3, %%ecx\n"
1113     "je 3f\n"
1114     "2:\n"
1115     "vbroadcastss (%0), %%ymm4\n"
1116     "vfmadd231ps (%1), %%ymm4, %%ymm0\n"
1117     "vfmadd231ps 0x20(%1), %%ymm4, %%ymm1\n"
1118     "vfmadd231ps 0x40(%1), %%ymm4, %%ymm2\n"
1119     "vfmadd231ps 0x60(%1), %%ymm4, %%ymm3\n"
1120 
1121     "vbroadcastss 4(%0), %%ymm4\n"
1122     "vfmadd231ps 128(%1), %%ymm4, %%ymm0\n"
1123     "vfmadd231ps 160(%1), %%ymm4, %%ymm1\n"
1124     "vfmadd231ps 192(%1), %%ymm4, %%ymm2\n"
1125     "vfmadd231ps 224(%1), %%ymm4, %%ymm3\n"
1126 
1127     "vbroadcastss 8(%0), %%ymm4\n"
1128     "vfmadd231ps 256(%1), %%ymm4, %%ymm0\n"
1129     "vfmadd231ps 288(%1), %%ymm4, %%ymm1\n"
1130     "vfmadd231ps 320(%1), %%ymm4, %%ymm2\n"
1131     "vfmadd231ps 352(%1), %%ymm4, %%ymm3\n"
1132 
1133     "vbroadcastss 12(%0), %%ymm4\n"
1134     "vfmadd231ps 384(%1), %%ymm4, %%ymm0\n"
1135     "vfmadd231ps 416(%1), %%ymm4, %%ymm1\n"
1136     "vfmadd231ps 448(%1), %%ymm4, %%ymm2\n"
1137     "vfmadd231ps 480(%1), %%ymm4, %%ymm3\n"
1138 
1139     "vbroadcastss 16(%0), %%ymm4\n"
1140     "vfmadd231ps 512(%1), %%ymm4, %%ymm0\n"
1141     "vfmadd231ps 544(%1), %%ymm4, %%ymm1\n"
1142     "vfmadd231ps 576(%1), %%ymm4, %%ymm2\n"
1143     "vfmadd231ps 608(%1), %%ymm4, %%ymm3\n"
1144 
1145     "vbroadcastss 20(%0), %%ymm4\n"
1146     "vfmadd231ps 640(%1), %%ymm4, %%ymm0\n"
1147     "vfmadd231ps 672(%1), %%ymm4, %%ymm1\n"
1148     "vfmadd231ps 704(%1), %%ymm4, %%ymm2\n"
1149     "vfmadd231ps 736(%1), %%ymm4, %%ymm3\n"
1150 
1151     "vbroadcastss 24(%0), %%ymm4\n"
1152     "vfmadd231ps 768(%1), %%ymm4, %%ymm0\n"
1153     "vfmadd231ps 800(%1), %%ymm4, %%ymm1\n"
1154     "vfmadd231ps 832(%1), %%ymm4, %%ymm2\n"
1155     "vfmadd231ps 864(%1), %%ymm4, %%ymm3\n"
1156 
1157     "vbroadcastss 28(%0), %%ymm4\n"
1158     "vfmadd231ps 896(%1), %%ymm4, %%ymm0\n"
1159     "vfmadd231ps 928(%1), %%ymm4, %%ymm1\n"
1160     "vfmadd231ps 960(%1), %%ymm4, %%ymm2\n"
1161     "vfmadd231ps 992(%1), %%ymm4, %%ymm3\n"
1162     "addq $1024, %1\n"
1163     "addq $32, %0\n"
1164     "dec %%ecx\n"
1165     "jg 2b\n"
1166 
1167     "3:\n"
1168     "and $7, %3\n"  // deep_remainder
1169     "je 5f\n"
1170     "4:\n"
1171     "vbroadcastss (%0), %%ymm4\n"
1172     "vfmadd231ps (%1), %%ymm4, %%ymm0\n"
1173     "vfmadd231ps 0x20(%1), %%ymm4, %%ymm1\n"
1174     "vfmadd231ps 0x40(%1), %%ymm4, %%ymm2\n"
1175     "vfmadd231ps 0x60(%1), %%ymm4, %%ymm3\n"
1176     "addq $128, %1\n"
1177     "addq $4, %0\n"
1178     "dec %3\n"
1179     "jg 4b\n"
1180 
1181     "5:\n"
1182     "and $0x3, %%eax\n"  // act_type
1183     "je 6f\n"
1184     // Relu
1185     "vxorps %%ymm12, %%ymm12, %%ymm12\n"
1186     "vmaxps %%ymm12, %%ymm0, %%ymm0\n"
1187     "vmaxps %%ymm12, %%ymm1, %%ymm1\n"
1188     "vmaxps %%ymm12, %%ymm2, %%ymm2\n"
1189     "vmaxps %%ymm12, %%ymm3, %%ymm3\n"
1190     "and $0x1, %%eax\n"
1191     "je 6f\n"
1192     // relu6
1193     "mov $0x40C00000, %%ecx\n"
1194     "vmovd %%ecx, %%xmm14\n"
1195     "vpermps %%ymm14, %%ymm12, %%ymm14\n"
1196     "vminps %%ymm14, %%ymm0, %%ymm0\n"
1197     "vminps %%ymm14, %%ymm1, %%ymm1\n"
1198     "vminps %%ymm14, %%ymm2, %%ymm2\n"
1199     "vminps %%ymm14, %%ymm3, %%ymm3\n"
1200     "6:\n"
1201     "vmovups %%ymm0, (%5)\n"  // dst_0
1202     "vmovups %%ymm1, 0x20(%5)\n"
1203     "vmovups %%ymm2, 0x40(%5)\n"
1204     "vmovups %%ymm3, 0x60(%5)\n"
1205     :
1206     : "r"(src), "r"(weight), "r"(bias), "r"(deep), "a"(act_flag), "r"(dst)  // 5
1207     : "%rcx", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm12", "%ymm4", "%ymm14");
1208 }
1209 
MatVecMul1x24Kernel(float * dst,const float * src,const float * weight,const float * bias,size_t act_flag,size_t row_block,size_t col_block,size_t col_algin,size_t deep)1210 void MatVecMul1x24Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t act_flag,
1211                          size_t row_block, size_t col_block, size_t col_algin, size_t deep) {
1212   asm volatile(
1213     "cmpq $0, %2\n"
1214     "je 0f\n"
1215     "vmovups (%2), %%ymm0\n"
1216     "vmovups 0x20(%2), %%ymm1\n"
1217     "vmovups 0x40(%2), %%ymm2\n"
1218     "jmp 1f\n"
1219     "0:\n"
1220     "vxorps %%ymm0, %%ymm0, %%ymm0\n"
1221     "vxorps %%ymm1, %%ymm1, %%ymm1\n"
1222     "vxorps %%ymm2, %%ymm2, %%ymm2\n"
1223 
1224     "1:\n"  // deep
1225     "movq %3, %%rcx\n"
1226     "shr $3, %%ecx\n"
1227     "je 3f\n"
1228     "2:\n"
1229     "vbroadcastss (%0), %%ymm4\n"
1230     "vfmadd231ps (%1), %%ymm4, %%ymm0\n"
1231     "vfmadd231ps 0x20(%1), %%ymm4, %%ymm1\n"
1232     "vfmadd231ps 0x40(%1), %%ymm4, %%ymm2\n"
1233 
1234     "vbroadcastss 4(%0), %%ymm4\n"
1235     "vfmadd231ps 96(%1), %%ymm4, %%ymm0\n"
1236     "vfmadd231ps 128(%1), %%ymm4, %%ymm1\n"
1237     "vfmadd231ps 160(%1), %%ymm4, %%ymm2\n"
1238 
1239     "vbroadcastss 8(%0), %%ymm4\n"
1240     "vfmadd231ps 192(%1), %%ymm4, %%ymm0\n"
1241     "vfmadd231ps 224(%1), %%ymm4, %%ymm1\n"
1242     "vfmadd231ps 256(%1), %%ymm4, %%ymm2\n"
1243 
1244     "vbroadcastss 12(%0), %%ymm4\n"
1245     "vfmadd231ps 288(%1), %%ymm4, %%ymm0\n"
1246     "vfmadd231ps 320(%1), %%ymm4, %%ymm1\n"
1247     "vfmadd231ps 352(%1), %%ymm4, %%ymm2\n"
1248 
1249     "vbroadcastss 16(%0), %%ymm4\n"
1250     "vfmadd231ps 384(%1), %%ymm4, %%ymm0\n"
1251     "vfmadd231ps 416(%1), %%ymm4, %%ymm1\n"
1252     "vfmadd231ps 448(%1), %%ymm4, %%ymm2\n"
1253 
1254     "vbroadcastss 20(%0), %%ymm4\n"
1255     "vfmadd231ps 480(%1), %%ymm4, %%ymm0\n"
1256     "vfmadd231ps 512(%1), %%ymm4, %%ymm1\n"
1257     "vfmadd231ps 544(%1), %%ymm4, %%ymm2\n"
1258 
1259     "vbroadcastss 24(%0), %%ymm4\n"
1260     "vfmadd231ps 576(%1), %%ymm4, %%ymm0\n"
1261     "vfmadd231ps 608(%1), %%ymm4, %%ymm1\n"
1262     "vfmadd231ps 640(%1), %%ymm4, %%ymm2\n"
1263 
1264     "vbroadcastss 28(%0), %%ymm4\n"
1265     "vfmadd231ps 672(%1), %%ymm4, %%ymm0\n"
1266     "vfmadd231ps 704(%1), %%ymm4, %%ymm1\n"
1267     "vfmadd231ps 736(%1), %%ymm4, %%ymm2\n"
1268     "addq $768, %1\n"
1269     "addq $32, %0\n"
1270     "dec %%ecx\n"
1271     "jg 2b\n"
1272 
1273     "3:\n"
1274     "and $7, %3\n"  // deep_remainder
1275     "je 5f\n"
1276     "4:\n"
1277     "vbroadcastss (%0), %%ymm4\n"
1278     "vfmadd231ps (%1), %%ymm4, %%ymm0\n"
1279     "vfmadd231ps 0x20(%1), %%ymm4, %%ymm1\n"
1280     "vfmadd231ps 0x40(%1), %%ymm4, %%ymm2\n"
1281     "addq $96, %1\n"
1282     "addq $4, %0\n"
1283     "dec %3\n"
1284     "jg 4b\n"
1285 
1286     "5:\n"
1287     "and $0x3, %%eax\n"  // act_type
1288     "je 6f\n"
1289     // Relu
1290     "vxorps %%ymm12, %%ymm12, %%ymm12\n"
1291     "vmaxps %%ymm12, %%ymm0, %%ymm0\n"
1292     "vmaxps %%ymm12, %%ymm1, %%ymm1\n"
1293     "vmaxps %%ymm12, %%ymm2, %%ymm2\n"
1294 
1295     "and $0x1, %%eax\n"
1296     "je 6f\n"
1297     // relu6
1298     "mov $0x40C00000, %%ecx\n"
1299     "vmovd %%ecx, %%xmm14\n"
1300     "vpermps %%ymm14, %%ymm12, %%ymm14\n"
1301     "vminps %%ymm14, %%ymm0, %%ymm0\n"
1302     "vminps %%ymm14, %%ymm1, %%ymm1\n"
1303     "vminps %%ymm14, %%ymm2, %%ymm2\n"
1304 
1305     "6:\n"
1306     "vmovups %%ymm0, (%5)\n"  // dst_0
1307     "vmovups %%ymm1, 0x20(%5)\n"
1308     "vmovups %%ymm2, 0x40(%5)\n"
1309 
1310     :
1311     : "r"(src), "r"(weight), "r"(bias), "r"(deep), "a"(act_flag), "r"(dst)  // 5
1312     : "%rcx", "%ymm0", "%ymm1", "%ymm2", "%ymm12", "%ymm4", "%ymm14");
1313 }
1314 
MatVecMul1x16Kernel(float * dst,const float * src,const float * weight,const float * bias,size_t act_flag,size_t row_block,size_t col_block,size_t col_algin,size_t deep)1315 void MatVecMul1x16Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t act_flag,
1316                          size_t row_block, size_t col_block, size_t col_algin, size_t deep) {
1317   asm volatile(
1318     "cmpq $0, %2\n"
1319     "je 0f\n"
1320     "vmovups (%2), %%ymm0\n"
1321     "vmovups 0x20(%2), %%ymm1\n"
1322     "jmp 1f\n"
1323     "0:\n"
1324     "vxorps %%ymm0, %%ymm0, %%ymm0\n"
1325     "vxorps %%ymm1, %%ymm1, %%ymm1\n"
1326     "1:\n"
1327     "movq %3, %%rcx\n"
1328     "shr $3, %%ecx\n"
1329     "je 3f\n"
1330     "2:\n"  // deep_c8
1331     "vbroadcastss (%0), %%ymm4\n"
1332     "vfmadd231ps (%1), %%ymm4, %%ymm0\n"
1333     "vfmadd231ps 0x20(%1), %%ymm4, %%ymm1\n"
1334 
1335     "vbroadcastss 4(%0), %%ymm4\n"
1336     "vfmadd231ps 64(%1), %%ymm4, %%ymm0\n"
1337     "vfmadd231ps 96(%1), %%ymm4, %%ymm1\n"
1338 
1339     "vbroadcastss 8(%0), %%ymm4\n"
1340     "vfmadd231ps 128(%1), %%ymm4, %%ymm0\n"
1341     "vfmadd231ps 160(%1), %%ymm4, %%ymm1\n"
1342 
1343     "vbroadcastss 12(%0), %%ymm4\n"
1344     "vfmadd231ps 192(%1), %%ymm4, %%ymm0\n"
1345     "vfmadd231ps 224(%1), %%ymm4, %%ymm1\n"
1346 
1347     "vbroadcastss 16(%0), %%ymm4\n"
1348     "vfmadd231ps 256(%1), %%ymm4, %%ymm0\n"
1349     "vfmadd231ps 288(%1), %%ymm4, %%ymm1\n"
1350 
1351     "vbroadcastss 20(%0), %%ymm4\n"
1352     "vfmadd231ps 320(%1), %%ymm4, %%ymm0\n"
1353     "vfmadd231ps 352(%1), %%ymm4, %%ymm1\n"
1354 
1355     "vbroadcastss 24(%0), %%ymm4\n"
1356     "vfmadd231ps 384(%1), %%ymm4, %%ymm0\n"
1357     "vfmadd231ps 416(%1), %%ymm4, %%ymm1\n"
1358 
1359     "vbroadcastss 28(%0), %%ymm4\n"
1360     "vfmadd231ps 448(%1), %%ymm4, %%ymm0\n"
1361     "vfmadd231ps 480(%1), %%ymm4, %%ymm1\n"
1362     "addq $512, %1\n"
1363     "addq $32, %0\n"
1364     "dec %%ecx\n"
1365     "jg 2b\n"
1366 
1367     "3:\n"
1368     "and $7, %3\n"
1369     "je 5f\n"
1370     "4:\n"
1371     "vbroadcastss (%0), %%ymm4\n"
1372     "vfmadd231ps (%1), %%ymm4, %%ymm0\n"
1373     "vfmadd231ps 0x20(%1), %%ymm4, %%ymm1\n"
1374     "addq $64, %1\n"
1375     "addq $4, %0\n"
1376     "dec %3\n"
1377     "jg 4b\n"
1378 
1379     "5:\n"
1380     "and $0x3, %%eax\n"  // act_type
1381     "je 6f\n"
1382     // Relu
1383     "vxorps %%ymm12, %%ymm12, %%ymm12\n"
1384     "vmaxps %%ymm12, %%ymm0, %%ymm0\n"
1385     "vmaxps %%ymm12, %%ymm1, %%ymm1\n"
1386 
1387     "and $0x1, %%eax\n"
1388     "je 6f\n"
1389     // relu6
1390     "mov $0x40C00000, %%ecx\n"
1391     "vmovd %%ecx, %%xmm14\n"
1392     "vpermps %%ymm14, %%ymm12, %%ymm14\n"
1393     "vminps %%ymm14, %%ymm0, %%ymm0\n"
1394     "vminps %%ymm14, %%ymm1, %%ymm1\n"
1395 
1396     "6:\n"
1397     "vmovups %%ymm0, (%5)\n"  // dst_0
1398     "vmovups %%ymm1, 0x20(%5)\n"
1399 
1400     :
1401     : "r"(src), "r"(weight), "r"(bias), "r"(deep), "a"(act_flag), "r"(dst)  // 5
1402     : "%rcx", "%ymm0", "%ymm1", "%ymm12", "%ymm4", "%ymm14");
1403 }
1404 
MatVecMul1x8Kernel(float * dst,const float * src,const float * weight,const float * bias,size_t act_flag,size_t row_block,size_t col_block,size_t col_algin,size_t deep)1405 void MatVecMul1x8Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t act_flag,
1406                         size_t row_block, size_t col_block, size_t col_algin, size_t deep) {
1407   asm volatile(
1408     "cmpq $0, %2\n"
1409     "je 0f\n"
1410     "vmovups (%2), %%ymm0\n"
1411     "jmp 1f\n"
1412     "0:\n"
1413     "vxorps %%ymm0, %%ymm0, %%ymm0\n"
1414     "1:\n"
1415     "movq %3, %%rcx\n"
1416     "shr $3, %%ecx\n"
1417     "je 3f\n"
1418     "2:\n"  // deep_c8
1419     "vbroadcastss (%0), %%ymm4\n"
1420     "vfmadd231ps (%1), %%ymm4, %%ymm0\n"
1421     "vbroadcastss 4(%0), %%ymm4\n"
1422     "vfmadd231ps 32(%1), %%ymm4, %%ymm0\n"
1423     "vbroadcastss 8(%0), %%ymm4\n"
1424     "vfmadd231ps 64(%1), %%ymm4, %%ymm0\n"
1425     "vbroadcastss 12(%0), %%ymm4\n"
1426     "vfmadd231ps 96(%1), %%ymm4, %%ymm0\n"
1427     "vbroadcastss 16(%0), %%ymm4\n"
1428     "vfmadd231ps 128(%1), %%ymm4, %%ymm0\n"
1429     "vbroadcastss 20(%0), %%ymm4\n"
1430     "vfmadd231ps 160(%1), %%ymm4, %%ymm0\n"
1431     "vbroadcastss 24(%0), %%ymm4\n"
1432     "vfmadd231ps 192(%1), %%ymm4, %%ymm0\n"
1433     "vbroadcastss 28(%0), %%ymm4\n"
1434     "vfmadd231ps 224(%1), %%ymm4, %%ymm0\n"
1435     "addq $256, %1\n"
1436     "addq $32, %0\n"
1437     "dec %%ecx\n"
1438     "jg 2b\n"
1439 
1440     "3:\n"
1441     "and $7, %3\n"
1442     "je 5f\n"
1443     "4:\n"
1444     "vbroadcastss (%0), %%ymm4\n"
1445     "vfmadd231ps (%1), %%ymm4, %%ymm0\n"
1446     "addq $32, %1\n"
1447     "addq $4, %0\n"
1448     "dec %3\n"
1449     "jg 4b\n"
1450 
1451     "5:\n"
1452     "and $0x3, %%eax\n"  // act_type
1453     "je 6f\n"
1454     // Relu
1455     "vxorps %%ymm12, %%ymm12, %%ymm12\n"
1456     "vmaxps %%ymm12, %%ymm0, %%ymm0\n"
1457 
1458     "and $0x1, %%eax\n"
1459     "je 6f\n"
1460     // relu6
1461     "mov $0x40C00000, %%ecx\n"
1462     "vmovd %%ecx, %%xmm14\n"
1463     "vpermps %%ymm14, %%ymm12, %%ymm14\n"
1464     "vminps %%ymm14, %%ymm0, %%ymm0\n"
1465 
1466     "6:\n"
1467     "vmovups %%ymm0, (%5)\n"  // dst_0
1468 
1469     :
1470     : "r"(src), "r"(weight), "r"(bias), "r"(deep), "a"(act_flag), "r"(dst)  // 5
1471     : "%rcx", "%ymm0", "%ymm1", "%ymm12", "%ymm4", "%ymm14");
1472 }
1473 
1474 #ifdef ENABLE_DEBUG
MatVecMulRowxColKernel(float * dst,const float * src,const float * weight,const float * bias,size_t act_flag,size_t row_block,size_t col_block,size_t col_algin,size_t deep)1475 void MatVecMulRowxColKernel(float *dst, const float *src, const float *weight, const float *bias, size_t act_flag,
1476                             size_t row_block, size_t col_block, size_t col_algin, size_t deep) {
1477   __m256 dst_data[12];
1478   const float *src_sw[12];
1479   __m256 weight_data[4];
1480   for (int i = 0; i < 4; ++i) {
1481     weight_data[i] = _mm256_set1_ps(0.0f);
1482   }
1483   for (int i = 0; i < row_block; ++i) {
1484     if (bias != NULL) {
1485       for (int j = 0; j < col_block; ++j) {
1486         dst_data[i * col_block + j] = _mm256_loadu_ps(bias + j * 8);
1487       }
1488     } else {
1489       for (int j = 0; j < col_block; ++j) {
1490         dst_data[i * col_block + j] = _mm256_set1_ps(0.0f);
1491       }
1492     }
1493     src_sw[i] = src + i * deep;
1494   }
1495   const float *weight_kernel = weight;
1496   for (int ic = 0; ic < deep; ++ic) {
1497     for (int j = 0; j < col_block; ++j) {
1498       weight_data[j] = _mm256_loadu_ps(weight_kernel + j * C8NUM);
1499     }
1500     for (int i = 0; i < row_block; ++i) {
1501       for (int j = 0; j < col_block; ++j) {
1502         dst_data[i * col_block + j] =
1503           _mm256_fmadd_ps(_mm256_set1_ps(src_sw[i][ic]), weight_data[j], dst_data[i * col_block + j]);
1504       }
1505     }
1506     weight_kernel += C8NUM * col_block;
1507   }  // ic loop
1508   // add bias and relu
1509   for (int i = 0; i < row_block; ++i) {
1510     for (int j = 0; j < col_block; ++j) {
1511       if (0x1 & act_flag) {  // relu6
1512         dst_data[i * col_block + j] = _mm256_min_ps(dst_data[i * col_block + j], _mm256_set1_ps(6.0f));
1513       }
1514       if (0x2 & act_flag) {  // relu
1515         dst_data[i * col_block + j] = _mm256_max_ps(dst_data[i * col_block + j], _mm256_set1_ps(0.0f));
1516       }
1517       _mm256_storeu_ps(dst + i * col_algin + j * C8NUM, dst_data[i * col_block + j]);
1518     }
1519   }
1520 }
1521 #endif
1522 #endif
1523