• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "nnacl/fp32/matmul_fp32.h"
18 #include "nnacl/fp32/pack_fp32.h"
19 #include "nnacl/fp32/matmul_avx512_fp32.h"
20 #include "nnacl/matmul_fp32_simd.h"
21 
22 #ifndef ENABLE_ARM
MatVecMulFp32(const float * a,const float * b,float * c,const float * bias,int act_type,int depth,int col)23 void MatVecMulFp32(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int col) {
24   for (int ci = 0; ci < col; ci++) {
25     float value = 0;
26     for (int di = 0; di < depth; di++) {
27       value += a[di] * b[ci * depth + di];
28     }
29     if (bias != NULL) value += bias[ci];
30     if (act_type == ActType_Relu6) value = MSMIN(6.0f, value);
31     if (act_type == ActType_Relu || act_type == ActType_Relu6) value = MSMAX(0.0f, value);
32     c[ci] = value;
33   }
34 }
35 #endif
36 
MatVecMulFp32Block8(const float * a,const float * b,float * c,const float * bias,int act_type,int depth,int col)37 void MatVecMulFp32Block8(const float *a, const float *b, float *c, const float *bias, int act_type, int depth,
38                          int col) {
39   int col8 = col / C8NUM * C8NUM;
40   int ci = 0;
41   for (; ci < col8; ci += C8NUM, c += C8NUM) {
42 #ifdef ENABLE_NEON
43     float32x4_t value0 = vdupq_n_f32(0.0f);
44     float32x4_t value1 = vdupq_n_f32(0.0f);
45     for (int di = 0; di < depth; ++di, b += C8NUM) {
46       value0 += vdupq_n_f32(a[di]) * vld1q_f32(b);
47       value1 += vdupq_n_f32(a[di]) * vld1q_f32(b + C4NUM);
48     }
49     if (bias != NULL) {
50       value0 += vld1q_f32(bias + ci);
51       value1 += vld1q_f32(bias + ci + C4NUM);
52     }
53     if (act_type == ActType_Relu || act_type == ActType_Relu6) {
54       value0 = vmaxq_f32(value0, vdupq_n_f32(0.0f));
55       value1 = vmaxq_f32(value1, vdupq_n_f32(0.0f));
56     }
57     if (act_type == ActType_Relu6) {
58       value0 = vminq_f32(value0, vdupq_n_f32(6.0f));
59       value1 = vminq_f32(value1, vdupq_n_f32(6.0f));
60     }
61     vst1q_f32(c, value0);
62     vst1q_f32(c + 4, value1);
63 #else
64     float value[C8NUM] = {0};
65     for (int di = 0; di < depth; ++di, b += C8NUM) {
66       for (int j = 0; j < C8NUM; ++j) {
67         value[j] += a[di] * b[j];
68       }
69     }
70     for (int j = 0; j < C8NUM; ++j) {
71       ADD_BIAS(value[j], bias, ci + j);
72       DO_RELU(value[j], act_type);
73       DO_RELU6(value[j], act_type);
74     }
75     memcpy(c, value, C8NUM * sizeof(float));
76 #endif
77   }
78   int res = col - col8;
79   float value[C8NUM] = {0};
80   for (int di = 0; di < depth; ++di, b += C8NUM) {
81     for (int j = 0; j < res; ++j) {
82       value[j] += a[di] * b[j];
83     }
84   }
85   for (int j = 0; j < res; ++j) {
86     ADD_BIAS(value[j], bias, ci + j);
87     DO_RELU(value[j], act_type);
88     DO_RELU6(value[j], act_type);
89   }
90   memcpy(c, value, res * sizeof(float));
91 }
92 
93 #ifdef ENABLE_ARM32
MatVecMulFp32Block4(const float * a,const float * b,float * c,const float * bias,int act_type,int depth,int col)94 void MatVecMulFp32Block4(const float *a, const float *b, float *c, const float *bias, int act_type, int depth,
95                          int col) {
96   int col4 = col / C4NUM * C4NUM;
97   int ci = 0;
98   for (; ci < col4; ci += C4NUM, c += C4NUM) {
99 #ifdef ENABLE_NEON
100     float32x4_t value = vdupq_n_f32(0.0f);
101     for (int di = 0; di < depth; ++di, b += C4NUM) {
102       value += vdupq_n_f32(a[di]) * vld1q_f32(b);
103     }
104     if (bias != NULL) {
105       value += vld1q_f32(&(bias[ci]));
106     }
107     if (act_type == ActType_Relu || act_type == ActType_Relu6) {
108       value = vmaxq_f32(value, vdupq_n_f32(0.0f));
109     }
110     if (act_type == ActType_Relu6) {
111       value = vminq_f32(value, vdupq_n_f32(6.0f));
112     }
113     vst1q_f32(c, value);
114 #else
115     float value[C4NUM] = {0};
116     for (int di = 0; di < depth; ++di, b += C4NUM) {
117       for (int j = 0; j < C4NUM; ++j) {
118         value[j] += a[di] * b[j];
119       }
120     }
121     for (int j = 0; j < C4NUM; ++j) {
122       ADD_BIAS(value[j], bias, ci + j);
123       DO_RELU(value[j], act_type);
124       DO_RELU6(value[j], act_type);
125     }
126     memcpy(c, value, C4NUM * sizeof(float));
127 #endif
128   }
129   int res = col - col4;
130   float value[C4NUM] = {0};
131   for (int di = 0; di < depth; ++di, b += C4NUM) {
132     for (int j = 0; j < res; ++j) {
133       value[j] += a[di] * b[j];
134     }
135   }
136   for (int j = 0; j < res; ++j) {
137     ADD_BIAS(value[j], bias, ci + j);
138     DO_RELU(value[j], act_type);
139     DO_RELU6(value[j], act_type);
140   }
141   memcpy(c, value, res * sizeof(float));
142 }
143 #endif
144 
145 #ifdef ENABLE_ARM64
146 // 4x8
MatVecMulFp32Neon64(const float * a,const float * b,float * c,const float * bias,int act_type,int depth,int col,int align_col)147 void MatVecMulFp32Neon64(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int col,
148                          int align_col) {
149   int ci = 0;
150   for (; ci < align_col - C8NUM + 1; ci += C8NUM) {
151     float32x4_t acc_0;
152     float32x4_t acc_1;
153     if (bias != NULL) {
154       acc_0 = vld1q_f32(bias + ci);
155       acc_1 = vld1q_f32(bias + ci + C4NUM);
156     } else {
157       acc_0 = vdupq_n_f32(0.0f);
158       acc_1 = vdupq_n_f32(0.0f);
159     }
160     const float *bv_base = b + ci * depth;
161     int di = 0;
162     for (; di < depth - C4NUM + 1; di += C4NUM) {
163       float32x4_t av = vld1q_f32(a + di);
164       float32x4_t bv_00 = vld1q_f32(bv_base);
165       float32x4_t bv_10 = vld1q_f32(bv_base + C4NUM);
166       bv_base += C8NUM;
167       float32x4_t bv_01 = vld1q_f32(bv_base);
168       float32x4_t bv_11 = vld1q_f32(bv_base + C4NUM);
169       bv_base += C8NUM;
170       float32x4_t bv_02 = vld1q_f32(bv_base);
171       float32x4_t bv_12 = vld1q_f32(bv_base + C4NUM);
172       bv_base += C8NUM;
173       float32x4_t bv_03 = vld1q_f32(bv_base);
174       float32x4_t bv_13 = vld1q_f32(bv_base + C4NUM);
175       bv_base += C8NUM;
176       acc_0 = vmlaq_n_f32(acc_0, bv_00, av[0]);
177       acc_1 = vmlaq_n_f32(acc_1, bv_10, av[0]);
178       acc_0 = vmlaq_n_f32(acc_0, bv_01, av[1]);
179       acc_1 = vmlaq_n_f32(acc_1, bv_11, av[1]);
180       acc_0 = vmlaq_n_f32(acc_0, bv_02, av[2]);
181       acc_1 = vmlaq_n_f32(acc_1, bv_12, av[2]);
182       acc_0 = vmlaq_n_f32(acc_0, bv_03, av[3]);
183       acc_1 = vmlaq_n_f32(acc_1, bv_13, av[3]);
184     }
185     if (di < depth) {
186       for (; di < depth; ++di) {
187         float ai = a[di];
188         float32x4_t bv0 = vld1q_f32(bv_base);
189         float32x4_t bv1 = vld1q_f32(bv_base + C4NUM);
190         acc_0 = vmlaq_n_f32(acc_0, bv0, ai);
191         acc_1 = vmlaq_n_f32(acc_1, bv1, ai);
192         bv_base += C8NUM;
193       }
194     }  // only save actual col num data
195     if (ci + C4NUM - 1 >= col) {
196       int c_remain = col - ci;
197       for (int i = 0; i < c_remain; ++i) {
198         if (act_type == ActType_Relu) {
199           c[i] = MSMAX(acc_0[i], 0.0f);
200         } else if (act_type == ActType_Relu6) {
201           c[i] = MSMIN(MSMAX(acc_0[i], 0.0f), 6.0f);
202         } else {
203           c[i] = acc_0[i];
204         }
205       }
206       return;
207     }
208     if (act_type == ActType_Relu) {
209       acc_0 = vmaxq_f32(acc_0, vdupq_n_f32(0.0f));
210     } else if (act_type == ActType_Relu6) {
211       acc_0 = vminq_f32(vmaxq_f32(acc_0, vdupq_n_f32(0.0f)), vdupq_n_f32(6.0f));
212     }
213     vst1q_f32(c, acc_0);
214     if (ci + C8NUM - 1 >= col) {
215       int c_remain = col - ci - C4NUM;
216       for (int i = 0; i < c_remain; ++i) {
217         if (act_type == ActType_Relu) {
218           c[C4NUM + i] = MSMAX(acc_1[i], 0.0f);
219         } else if (act_type == ActType_Relu6) {
220           c[C4NUM + i] = MSMIN(MSMAX(acc_1[i], 0.0f), 6.0f);
221         } else {
222           c[C4NUM + i] = acc_1[i];
223         }
224       }
225       return;
226     }
227     if (act_type == ActType_Relu) {
228       acc_1 = vmaxq_f32(acc_1, vdupq_n_f32(0.0f));
229     } else if (act_type == ActType_Relu6) {
230       acc_1 = vminq_f32(vmaxq_f32(acc_1, vdupq_n_f32(0.0f)), vdupq_n_f32(6.0f));
231     }
232     vst1q_f32(c + C4NUM, acc_1);
233     c += C8NUM;
234   }
235 }
236 #endif
237 
MatMul12x8(const float * a,const float * b,float * dst,const float * bias,ActType act_type,int deep,int row,int col,int stride,int out_type)238 void MatMul12x8(const float *a, const float *b, float *dst, const float *bias, ActType act_type, int deep, int row,
239                 int col, int stride, int out_type) {
240   if (out_type == OutType_Nhwc) {
241     for (int r = 0; r < row; r++) {
242       for (int c = 0; c < col; c++) {
243         int r12div = r / 12, r12mod = r % 12;
244         int c8div = c / 8, c8mod = c % 8;
245         size_t ci = r * stride + c;
246         float value = 0;
247         for (int d = 0; d < deep; d++) {
248           size_t ai = r12div * deep * 12 + d * 12 + r12mod;
249           size_t bi = c8div * deep * 8 + d * 8 + c8mod;
250           value = value + a[ai] * b[bi];
251         }
252         ADD_BIAS(value, bias, c)
253         DO_RELU(value, act_type)
254         DO_RELU6(value, act_type)
255         dst[ci] = value;
256       }
257     }
258   } else if (out_type == OutType_C8) {
259     int col_8 = UP_ROUND(col, C8NUM);
260     int row_12 = UP_ROUND(row, C12NUM);
261     for (int r = 0; r < row_12; r++) {
262       for (int c = 0; c < col_8; c++) {
263         int r12div = r / C12NUM, r12mod = r % C12NUM;
264         int c8div = c / C8NUM, c8mod = c % C8NUM;
265         size_t ci = (c8div * C8NUM * row_12 + r * C8NUM + c8mod);
266         float value = 0;
267         for (int d = 0; d < deep; d++) {
268           size_t ai = r12div * deep * C12NUM + d * C12NUM + r12mod;
269           size_t bi = c8div * deep * C8NUM + d * C8NUM + c8mod;
270           value = value + a[ai] * b[bi];
271         }
272         ADD_BIAS(value, bias, c)
273         DO_RELU(value, act_type)
274         DO_RELU6(value, act_type)
275         dst[ci] = value;
276       }
277     }
278   } else if (out_type == OutType_TileC8) {
279     for (int i = 0; i < row; ++i) {
280       int src_r_offset = i;
281       int dst_r_offset = i * col * stride;
282       for (int j = 0; j < col; ++j) {
283         int c8div = j / 8, c8mod = j % 8;
284         size_t ci = dst_r_offset + c8div * 8 * stride + c8mod;
285         float value = 0;
286         for (int d = 0; d < deep; ++d) {
287           size_t ai = src_r_offset + d * C12NUM;
288           size_t bi = c8div * deep * 8 + d * 8 + c8mod;
289           value = value + a[ai] * b[bi];
290         }
291         ADD_BIAS(value, bias, j)
292         DO_RELU(value, act_type)
293         DO_RELU6(value, act_type)
294         dst[ci] = value;
295       }
296     }
297   }
298 }
299 
MatMulOpt(const float * a,const float * b,float * c,const float * bias,ActType act_type,int deep,int row,int col,size_t stride,int out_type)300 void MatMulOpt(const float *a, const float *b, float *c, const float *bias, ActType act_type, int deep, int row,
301                int col, size_t stride, int out_type) {
302 #ifdef ENABLE_ARM64
303   if (out_type == OutType_C8) {
304     MatmulFloatNeon64(a, b, c, bias, (int)act_type, deep, row, col, stride, 0, 0);
305   } else if (out_type == OutType_Nhwc && deep > C512NUM) {
306     BigMatmulFloatNeon64Opt(a, b, c, bias, (int)act_type, deep, row, col, stride);
307   } else {
308     MatmulFloatNeon64Opt(a, b, c, bias, (int)act_type, deep, row, col, stride, (int)(out_type));
309   }
310 #elif ENABLE_ARM32
311   if (out_type == OutType_C8) {
312     MatmulFloatNeon32(a, b, c, bias, (int)act_type, deep, row, col, stride, 0, 0);
313   } else if (out_type == OutType_Nhwc) {
314     MatmulFloatNeon32Opt12x4(a, b, c, bias, (int)act_type, deep, row, col, stride, 1);
315   } else {
316     MatmulFloatNeon32Opt(a, b, c, bias, (int)act_type, deep, row, col, stride, (int)(out_type));
317   }
318 #elif ENABLE_AVX
319   MatmulFloatAvxOpt(a, b, c, bias, (size_t)act_type, deep, row, col, stride, (size_t)(out_type));
320 #elif ENABLE_SSE
321   MatmulFloatSse64Opt(a, b, c, bias, (int)act_type, deep, row, col, stride, (int)(out_type));
322 #else
323   MatMul12x8(a, b, c, bias, act_type, deep, row, col, stride, out_type);
324 #endif
325 }
326 
327 #define ActCompute(bit_num, down_threshold, up_threshold) \
328   if (act_type != 0) {                                    \
329     dst = MS_MAX##bit_num##_F32(dst, down_threshold);     \
330     if (act_type == 3) {                                  \
331       dst = MS_MIN##bit_num##_F32(dst, up_threshold);     \
332     }                                                     \
333   }
334 
335 // act_type must be 0, 1, 2. 0: no_act, 1: relu, 3: relu6.
GemmIsNotPack(const float * a,const float * b,float * c,const float * bias,int row,int deep,int act_type)336 void GemmIsNotPack(const float *a, const float *b, float *c, const float *bias, int row, int deep, int act_type) {
337   int index = 0;
338 
339   SIMD_RUN_NO_SCALAR(GemmIsNotPack, index, a, b, c, bias, row, deep, act_type);
340 
341   for (; index < row; ++index) {
342     float dst = a[index] * b[0] + bias[0];
343     ActCompute(32, 0, C6NUM);
344     c[index] = dst;
345   }
346 }
347 
348 // act_type must be 0, 1, 2. 0: no_act, 1: relu, 3: relu6.
Row1Deep1GemmIsNotPack(const float * a,const float * b,float * c,const float * bias,int col,int deep,int act_type)349 void Row1Deep1GemmIsNotPack(const float *a, const float *b, float *c, const float *bias, int col, int deep,
350                             int act_type) {
351   int index = 0;
352 
353   SIMD_RUN_NO_SCALAR(Row1Deep1GemmIsNotPack, index, a, b, c, bias, col, act_type);
354   for (; index < col; ++index) {
355     float dst = a[0] * b[index] + bias[index];
356     ActCompute(32, 0, C6NUM);
357     c[index] = dst;
358   }
359 }
360 
361 // act_type must be 0, 1, 2. 0: no_act, 1: relu, 3: relu6.
Row1Deep1NoBiasGemmIsNotPack(const float * a,const float * b,float * c,const float * bias,int col,int deep,int act_type)362 void Row1Deep1NoBiasGemmIsNotPack(const float *a, const float *b, float *c, const float *bias, int col, int deep,
363                                   int act_type) {
364   int index = 0;
365 
366   SIMD_RUN_NO_SCALAR(Row1Deep1NoBiasGemmIsNotPack, index, a, b, c, bias, col, act_type);
367   for (; index < col; ++index) {
368     float dst = a[0] * b[index];
369     ActCompute(32, 0, C6NUM);
370     c[index] = dst;
371   }
372 }
373 
374 // act_type must be 0, 1, 2. 0: no_act, 1: relu, 3: relu6.
GemmIsNotPackOptimize(const float * a,const float * b,float * c,const float * bias,int m,int k,int act_type)375 void GemmIsNotPackOptimize(const float *a, const float *b, float *c, const float *bias, int m, int k, int act_type) {
376   // gemm dot is [m, k] * [k, 1] ==>> [m, 1]
377   int m_index = 0;
378 
379   SIMD_RUN_AVX512(GemmIsNotPackOptimize, m_index, a, b, c, bias, m, k, act_type);
380 
381 #ifdef ENABLE_AVX
382   // block 4
383   MS_FLOAT32X4 down_threshold128 = MS_MOVQ_F32(0);
384   MS_FLOAT32X4 up_threshold128 = MS_MOVQ_F32(C6NUM);
385   for (; m_index <= m - C4NUM; m_index += C4NUM) {
386     int k_index = 0;
387     MS_FLOAT32X4 dst = MS_MOV128_F32(bias[0]);
388     MS_SET_ZERO256X4_F32(dst_)
389     for (; k_index <= k - C8NUM; k_index += C8NUM) {
390       MS_FLOAT32X8 weight = MS_LD256_F32(b + k_index);
391       MS_LOAD256X4_F32(src, a + m_index * k + k_index, k);
392       MS_FMADD256X4_F32(src, weight, dst_);
393     }
394     MS_F32X8_GETI(dst, 0) += MS_REDUCE_ADD256_F32(dst_1);
395     MS_F32X8_GETI(dst, 1) += MS_REDUCE_ADD256_F32(dst_2);
396     MS_F32X8_GETI(dst, C2NUM) += MS_REDUCE_ADD256_F32(dst_3);
397     MS_F32X8_GETI(dst, C3NUM) += MS_REDUCE_ADD256_F32(dst_4);
398     for (; k_index < k; ++k_index) {
399       MS_F32X8_GETI(dst, 0) += b[k_index] * a[m_index * k + k_index];
400       MS_F32X8_GETI(dst, 1) += b[k_index] * a[m_index * k + k_index + k];
401       MS_F32X8_GETI(dst, C2NUM) += b[k_index] * a[m_index * k + k_index + C2NUM * k];
402       MS_F32X8_GETI(dst, C3NUM) += b[k_index] * a[m_index * k + k_index + C3NUM * k];
403     }
404     ActCompute(128, down_threshold128, up_threshold128);
405     MS_ST128_F32(c + m_index, dst);
406   }
407 #endif
408 
409   // block 1
410   for (; m_index < m; m_index++) {
411     float dst = bias[0];
412     int k_index = 0;
413 
414     SIMD_RUN_AVX512(GemmIsNotPackOptimizeCore, k_index, a + m_index * k, b, k, &dst);
415     SIMD_RUN_AVX(GemmIsNotPackOptimizeCore, k_index, a + m_index * k, b, k, &dst);
416 
417     for (; k_index < k; k_index++) {
418       dst += b[k_index] * a[m_index * k + k_index];
419     }
420     ActCompute(32, 0, C6NUM);
421     c[m_index] = dst;
422   }
423 }
424 
MatVecMulNoPackFp32(const float * a,const float * b,float * c,const float * bias,int act_type,int64_t depth,int64_t cur_col,int64_t col)425 void MatVecMulNoPackFp32(const float *a, const float *b, float *c, const float *bias, int act_type, int64_t depth,
426                          int64_t cur_col, int64_t col) {
427   int inc_flag = 0;
428   int64_t k = 0;
429   for (; k <= depth - C1500NUM; k += C1500NUM) {
430     inc_flag = (k == 0) + (k + C1500NUM == depth ? C2NUM : 0);
431     int64_t oc_index = 0;
432     SIMD_RUN_NO_SCALAR(MatVecMulNoPackCore, oc_index, a, b, c, bias, act_type, C1500NUM, cur_col, col, inc_flag);
433     for (; oc_index < cur_col; ++oc_index) {
434       float dst = (inc_flag & 1) == 0 ? c[oc_index] : (bias == NULL ? 0 : bias[oc_index]);
435       for (int64_t k_index = 0; k_index < k; ++k_index) {
436         dst += a[k_index] * b[oc_index + k_index * col];
437       }
438       if ((inc_flag & 0x2) != 0) {
439         ActCompute(32, 0, C6NUM);
440       }
441       c[oc_index] = dst;
442     }
443     a += C1500NUM;
444     b += C1500NUM * col;
445   }
446   if (k == depth) {
447     return;
448   }
449   inc_flag = (k == 0) + C2NUM;
450   int64_t oc_index = 0;
451   SIMD_RUN_NO_SCALAR(MatVecMulNoPackCore, oc_index, a, b, c, bias, act_type, depth - k, cur_col, col, inc_flag);
452   for (; oc_index < cur_col; ++oc_index) {
453     float dst = (inc_flag & 1) == 0 ? c[oc_index] : (bias == NULL ? 0 : bias[oc_index]);
454     for (int64_t k_index = 0; k_index < depth; ++k_index) {
455       dst += a[k_index] * b[oc_index + k_index * col];
456     }
457     ActCompute(32, 0, C6NUM);
458     c[oc_index] = dst;
459   }
460 }
461 
462 #ifdef ENABLE_ARM64
463 // act_type must be 0, 1, 2. 0: no_act, 1: relu, 3: relu6.
MatMul4x1Kernel(const float * input,const float * weight,float * output,const float * bias,size_t deep,size_t act_type)464 void MatMul4x1Kernel(const float *input, const float *weight, float *output, const float *bias, size_t deep,
465                      size_t act_type) {
466   // 1: LoopD16, 2: LoopD12, 3: LoopD8, 4: LoopD4, 5: LoopD1, 6: LoopDEnd, 7: LoopDTail, 8: LoopDTailCompute
467   // 9: WriteBack
468   asm volatile(
469     "mov x8, %[input]\n"
470     "mov x9, %[weight]\n"
471     "mov x10, %[deep]\n"
472     "add x5, %[input], %[deep], LSL #2\n"
473     "add x6, %[input], %[deep], LSL #3\n"
474     "add x7, x5, %[deep], LSL #3\n"
475     "dup v0.2d, xzr\n"
476     "dup v1.2d, xzr\n"
477     "dup v2.2d, xzr\n"
478     "dup v3.2d, xzr\n"
479     "subs x10, x10, #16\n"
480     "blt 2f\n"
481     "1:\n"  // LoopD16
482     "ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x8], #64\n"
483     "ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x5], #64\n"
484     "ld1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x6], #64\n"
485     "ld1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x7], #64\n"
486     "ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x9], #64\n"
487     "fmla v0.4s, v4.4s, v28.4s\n"
488     "fmla v1.4s, v16.4s, v28.4s\n"
489     "fmla v2.4s, v20.4s, v28.4s\n"
490     "fmla v3.4s, v24.4s, v28.4s\n"
491     "fmla v0.4s, v5.4s, v29.4s\n"
492     "fmla v1.4s, v17.4s, v29.4s\n"
493     "fmla v2.4s, v21.4s, v29.4s\n"
494     "fmla v3.4s, v25.4s, v29.4s\n"
495     "fmla v0.4s, v6.4s, v30.4s\n"
496     "fmla v1.4s, v18.4s, v30.4s\n"
497     "fmla v2.4s, v22.4s, v30.4s\n"
498     "fmla v3.4s, v26.4s, v30.4s\n"
499     "fmla v0.4s, v7.4s, v31.4s\n"
500     "fmla v1.4s, v19.4s, v31.4s\n"
501     "fmla v2.4s, v23.4s, v31.4s\n"
502     "fmla v3.4s, v27.4s, v31.4s\n"
503     "subs x10, x10, #16\n"
504     "bge 1b\n"
505     "2:\n"  // LoopD12
506     "adds x10, x10, #16\n"
507     "cbz x10, 6f\n"
508     "cmp x10, #12\n"
509     "blt 3f\n"
510     "ld1 {v4.4s, v5.4s, v6.4s}, [x8], #48\n"
511     "ld1 {v16.4s, v17.4s, v18.4s}, [x5], #48\n"
512     "ld1 {v20.4s, v21.4s, v22.4s}, [x6], #48\n"
513     "ld1 {v24.4s, v25.4s, v26.4s}, [x7], #48\n"
514     "ld1 {v28.4s, v29.4s, v30.4s}, [x9], #48\n"
515     "fmla v0.4s, v4.4s, v28.4s\n"
516     "fmla v1.4s, v16.4s, v28.4s\n"
517     "fmla v2.4s, v20.4s, v28.4s\n"
518     "fmla v3.4s, v24.4s, v28.4s\n"
519     "fmla v0.4s, v5.4s, v29.4s\n"
520     "fmla v1.4s, v17.4s, v29.4s\n"
521     "fmla v2.4s, v21.4s, v29.4s\n"
522     "fmla v3.4s, v25.4s, v29.4s\n"
523     "fmla v0.4s, v6.4s, v30.4s\n"
524     "fmla v1.4s, v18.4s, v30.4s\n"
525     "fmla v2.4s, v22.4s, v30.4s\n"
526     "fmla v3.4s, v26.4s, v30.4s\n"
527     "sub x10, x10, #12\n"
528     "b 7f\n"
529     "3:\n"  // LoopD8
530     "cmp x10, #8\n"
531     "blt 4f\n"
532     "ld1 {v4.4s, v5.4s}, [x8], #32\n"
533     "ld1 {v16.4s, v17.4s}, [x5], #32\n"
534     "ld1 {v20.4s, v21.4s}, [x6], #32\n"
535     "ld1 {v24.4s, v25.4s}, [x7], #32\n"
536     "ld1 {v28.4s, v29.4s}, [x9], #32\n"
537     "fmla v0.4s, v4.4s, v28.4s\n"
538     "fmla v1.4s, v16.4s, v28.4s\n"
539     "fmla v2.4s, v20.4s, v28.4s\n"
540     "fmla v3.4s, v24.4s, v28.4s\n"
541     "fmla v0.4s, v5.4s, v29.4s\n"
542     "fmla v1.4s, v17.4s, v29.4s\n"
543     "fmla v2.4s, v21.4s, v29.4s\n"
544     "fmla v3.4s, v25.4s, v29.4s\n"
545     "sub x10, x10, #8\n"
546     "b 7f\n"
547     "4:\n"  // LoopD4
548     "cmp x10, #4\n"
549     "blt 7f\n"
550     "ld1 {v4.4s}, [x8], #16\n"
551     "ld1 {v16.4s}, [x5], #16\n"
552     "ld1 {v20.4s}, [x6], #16\n"
553     "ld1 {v24.4s}, [x7], #16\n"
554     "ld1 {v28.4s}, [x9], #16\n"
555     "fmla v0.4s, v4.4s, v28.4s\n"
556     "fmla v1.4s, v16.4s, v28.4s\n"
557     "fmla v2.4s, v20.4s, v28.4s\n"
558     "fmla v3.4s, v24.4s, v28.4s\n"
559     "sub x10, x10, #4\n"
560     "7:\n"
561     "cbz x10, 6f\n"
562     "dup v4.2d, xzr\n"
563     "dup v16.2d, xzr\n"
564     "dup v20.2d, xzr\n"
565     "dup v24.2d, xzr\n"
566     "dup v28.2d, xzr\n"
567     "subs x10, x10, #2\n"
568     "blt 5f\n"
569     "ld1 {v4.d}[0], [x8], #8\n"  // LoopD2
570     "ld1 {v16.d}[0], [x5], #8\n"
571     "ld1 {v20.d}[0], [x6], #8\n"
572     "ld1 {v24.d}[0], [x7], #8\n"
573     "ld1 {v28.d}[0], [x9], #8\n"
574     "cbz x10, 8f\n"
575     "5:\n"  // LoopD1
576     "ld1 {v4.s}[2], [x8]\n"
577     "ld1 {v16.s}[2], [x5]\n"
578     "ld1 {v20.s}[2], [x6]\n"
579     "ld1 {v24.s}[2], [x7]\n"
580     "ld1 {v28.s}[2], [x9]\n"
581     "8:\n"
582     "fmla v0.4s, v4.4s, v28.4s\n"
583     "fmla v1.4s, v16.4s, v28.4s\n"
584     "fmla v2.4s, v20.4s, v28.4s\n"
585     "fmla v3.4s, v24.4s, v28.4s\n"
586     "6:\n"
587     "faddp v4.4s, v0.4s, v1.4s\n"
588     "faddp v5.4s, v2.4s, v3.4s\n"
589     "faddp v0.4s, v4.4s, v5.4s\n"
590     "cbz %[bias], 9f\n"
591     "ld1r {v1.4s}, [%[bias]]\n"
592     "fadd v0.4s, v0.4s, v1.4s\n"
593     "9:\n"
594     "cbz %[act], 10f\n"
595     "dup v1.2d, xzr\n"
596     "fmax v0.4s, v0.4s, v1.4s\n"
597     "cmp %[act], #3\n"
598     "bne 10f\n"
599     "movi v1.4s, #6\n"
600     "scvtf v1.4s, v1.4s\n"
601     "fmin v0.4s, v0.4s, v1.4s\n"
602     "10:\n"
603     "st1 {v0.4s}, [%[output]]\n"
604 
605     :
606     : [ input ] "r"(input), [ weight ] "r"(weight), [ output ] "r"(output), [ bias ] "r"(bias), [ deep ] "r"(deep),
607       [ act ] "r"(act_type)
608     : "cc", "x5", "x6", "x7", "x8", "x9", "x10", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v16", "v17", "v18",
609       "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31");
610 }
611 
MatMul2x1Kernel(const float * input,const float * weight,float * output,const float * bias,size_t deep,size_t act_type)612 void MatMul2x1Kernel(const float *input, const float *weight, float *output, const float *bias, size_t deep,
613                      size_t act_type) {
614   // 1: LoopD16, 2: LoopD12, 3: LoopD8, 4: LoopD4, 5: LoopD1, 6: LoopDEnd, 7: LoopDTail, 8: LoopDTailCompute
615   // 9: WriteBack
616   asm volatile(
617     "mov x8, %[input]\n"
618     "mov x9, %[weight]\n"
619     "mov x10, %[deep]\n"
620     "add x5, %[input], %[deep], LSL #2\n"
621     "dup v0.2d, xzr\n"
622     "dup v1.2d, xzr\n"
623     "subs x10, x10, #16\n"
624     "blt 2f\n"
625     "1:\n"  // LoopD16
626     "ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x8], #64\n"
627     "ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x5], #64\n"
628     "ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x9], #64\n"
629     "fmla v0.4s, v4.4s, v28.4s\n"
630     "fmla v1.4s, v16.4s, v28.4s\n"
631     "fmla v0.4s, v5.4s, v29.4s\n"
632     "fmla v1.4s, v17.4s, v29.4s\n"
633     "fmla v0.4s, v6.4s, v30.4s\n"
634     "fmla v1.4s, v18.4s, v30.4s\n"
635     "fmla v0.4s, v7.4s, v31.4s\n"
636     "fmla v1.4s, v19.4s, v31.4s\n"
637     "subs x10, x10, #16\n"
638     "bge 1b\n"
639     "2:\n"  // LoopD12
640     "adds x10, x10, #16\n"
641     "cbz x10, 6f\n"
642     "cmp x10, #12\n"
643     "blt 3f\n"
644     "ld1 {v4.4s, v5.4s, v6.4s}, [x8], #48\n"
645     "ld1 {v16.4s, v17.4s, v18.4s}, [x5], #48\n"
646     "ld1 {v28.4s, v29.4s, v30.4s}, [x9], #48\n"
647     "fmla v0.4s, v4.4s, v28.4s\n"
648     "fmla v1.4s, v16.4s, v28.4s\n"
649     "fmla v0.4s, v5.4s, v29.4s\n"
650     "fmla v1.4s, v17.4s, v29.4s\n"
651     "fmla v0.4s, v6.4s, v30.4s\n"
652     "fmla v1.4s, v18.4s, v30.4s\n"
653     "sub x10, x10, #12\n"
654     "b 7f\n"
655     "3:\n"  // LoopD8
656     "cmp x10, #8\n"
657     "blt 4f\n"
658     "ld1 {v4.4s, v5.4s}, [x8], #32\n"
659     "ld1 {v16.4s, v17.4s}, [x5], #32\n"
660     "ld1 {v28.4s, v29.4s}, [x9], #32\n"
661     "fmla v0.4s, v4.4s, v28.4s\n"
662     "fmla v1.4s, v16.4s, v28.4s\n"
663     "fmla v0.4s, v5.4s, v29.4s\n"
664     "fmla v1.4s, v17.4s, v29.4s\n"
665     "sub x10, x10, #8\n"
666     "b 7f\n"
667     "4:\n"  // LoopD4
668     "cmp x10, #4\n"
669     "blt 7f\n"
670     "ld1 {v4.4s}, [x8], #16\n"
671     "ld1 {v16.4s}, [x5], #16\n"
672     "ld1 {v28.4s}, [x9], #16\n"
673     "fmla v0.4s, v4.4s, v28.4s\n"
674     "fmla v1.4s, v16.4s, v28.4s\n"
675     "sub x10, x10, #4\n"
676     "7:\n"
677     "cbz x10, 6f\n"
678     "dup v4.2d, xzr\n"
679     "dup v16.2d, xzr\n"
680     "subs x10, x10, #2\n"
681     "blt 5f\n"
682     "ld1 {v4.d}[0], [x8], #8\n"  // LoopD2
683     "ld1 {v16.d}[0], [x5], #8\n"
684     "ld1 {v28.d}[0], [x9], #8\n"
685     "cbz x10, 8f\n"
686     "5:\n"  // LoopD1
687     "ld1 {v4.s}[2], [x8]\n"
688     "ld1 {v16.s}[2], [x5]\n"
689     "ld1 {v28.s}[2], [x9]\n"
690     "8:\n"
691     "fmla v0.4s, v4.4s, v28.4s\n"
692     "fmla v1.4s, v16.4s, v28.4s\n"
693     "6:\n"
694     "faddp v4.4s, v0.4s, v1.4s\n"
695     "faddp v0.4s, v4.4s, v4.4s\n"
696     "cbz %[bias], 9f\n"
697     "ld1r {v1.4s}, [%[bias]]\n"
698     "fadd v0.2s, v0.2s, v1.2s\n"
699     "9:\n"
700     "cbz %[act], 10f\n"
701     "fmov d1, xzr\n"
702     "fmax v0.2s, v0.2s, v1.2s\n"
703     "cmp %[act], #3\n"
704     "bne 10f\n"
705     "movi v1.2s, #6\n"
706     "scvtf v1.2s, v1.2s\n"
707     "fmin v0.2s, v0.2s, v1.2s\n"
708     "10:\n"
709     "st1 {v0.2s}, [%[output]]\n"
710 
711     :
712     : [ input ] "r"(input), [ weight ] "r"(weight), [ output ] "r"(output), [ bias ] "r"(bias), [ deep ] "r"(deep),
713       [ act ] "r"(act_type)
714     : "cc", "x5", "x8", "x9", "x10", "v0", "v1", "v4", "v5", "v6", "v7", "v16", "v17", "v18", "v19", "v28", "v29",
715       "v30", "v31", "memory");
716 }
717 
MatMul1x1Kernel(const float * input,const float * weight,float * output,const float * bias,size_t deep,size_t act_type)718 void MatMul1x1Kernel(const float *input, const float *weight, float *output, const float *bias, size_t deep,
719                      size_t act_type) {
720   // 1: LoopD16, 2: LoopD12, 3: LoopD8, 4: LoopD4, 5: LoopD1, 6: LoopDEnd, 7: LoopDTail, 8: LoopDTailCompute
721   // 9: WriteBack
722   asm volatile(
723     "mov x8, %[input]\n"
724     "mov x9, %[weight]\n"
725     "mov x10, %[deep]\n"
726     "dup v0.2d, xzr\n"
727     "subs x10, x10, #16\n"
728     "blt 2f\n"
729     "1:\n"  // LoopD16
730     "ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x8], #64\n"
731     "ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x9], #64\n"
732     "fmla v0.4s, v4.4s, v28.4s\n"
733     "fmla v0.4s, v5.4s, v29.4s\n"
734     "fmla v0.4s, v6.4s, v30.4s\n"
735     "fmla v0.4s, v7.4s, v31.4s\n"
736     "subs x10, x10, #16\n"
737     "bge 1b\n"
738     "2:\n"  // LoopD12
739     "adds x10, x10, #16\n"
740     "cbz x10, 6f\n"
741     "cmp x10, #12\n"
742     "blt 3f\n"
743     "ld1 {v4.4s, v5.4s, v6.4s}, [x8], #48\n"
744     "ld1 {v28.4s, v29.4s, v30.4s}, [x9], #48\n"
745     "fmla v0.4s, v4.4s, v28.4s\n"
746     "fmla v0.4s, v5.4s, v29.4s\n"
747     "fmla v0.4s, v6.4s, v30.4s\n"
748     "sub x10, x10, #12\n"
749     "b 7f\n"
750     "3:\n"  // LoopD8
751     "cmp x10, #8\n"
752     "blt 4f\n"
753     "ld1 {v4.4s, v5.4s}, [x8], #32\n"
754     "ld1 {v28.4s, v29.4s}, [x9], #32\n"
755     "fmla v0.4s, v4.4s, v28.4s\n"
756     "fmla v0.4s, v5.4s, v29.4s\n"
757     "sub x10, x10, #8\n"
758     "b 7f\n"
759     "4:\n"  // LoopD4
760     "cmp x10, #4\n"
761     "blt 7f\n"
762     "ld1 {v4.4s}, [x8], #16\n"
763     "ld1 {v28.4s}, [x9], #16\n"
764     "fmla v0.4s, v4.4s, v28.4s\n"
765     "sub x10, x10, #4\n"
766     "7:\n"
767     "cbz x10, 6f\n"
768     "dup v4.2d, xzr\n"
769     "subs x10, x10, #2\n"
770     "blt 5f\n"
771     "ld1 {v4.d}[0], [x8], #8\n"  // LoopD2
772     "ld1 {v28.d}[0], [x9], #8\n"
773     "cbz x10, 8f\n"
774     "5:\n"  // LoopD1
775     "ld1 {v4.s}[3], [x8]\n"
776     "ld1 {v28.s}[3], [x9]\n"
777     "8:\n"
778     "fmla v0.4s, v4.4s, v28.4s\n"
779     "6:\n"
780     "faddp v4.4s, v0.4s, v0.4s\n"
781     "faddp v0.4s, v4.4s, v4.4s\n"
782     "cbz %[bias], 9f\n"
783     "ld1 {v1.s}[0], [%[bias]]\n"
784     "fadd s0, s0, s1\n"
785     "9:\n"
786     "cbz %[act], 10f\n"
787     "fmov s1, wzr\n"
788     "fmax s0, s0, s1\n"
789     "cmp %[act], #3\n"
790     "bne 10f\n"
791     "mov x10, #6\n"
792     "scvtf s1, x10\n"
793     "fmin s0, s0, s1\n"
794     "10:\n"
795     "str s0, [%[output]]\n"
796 
797     :
798     : [ input ] "r"(input), [ weight ] "r"(weight), [ output ] "r"(output), [ bias ] "r"(bias), [ deep ] "r"(deep),
799       [ act ] "r"(act_type)
800     : "cc", "x8", "x9", "x10", "v0", "v4", "v5", "v6", "v7", "v16", "v17", "v18", "v19", "v28", "v29", "v30", "v31");
801 }
802 
GemmIsNotPackByRow(const float * a,const float * b,float * c,const float * bias,int start_row,int end_row,int deep,int act_type)803 void GemmIsNotPackByRow(const float *a, const float *b, float *c, const float *bias, int start_row, int end_row,
804                         int deep, int act_type) {
805   const float *input = a + start_row * deep;
806   float *output = c + start_row;
807   const int step = C4NUM * deep;
808   for (; start_row <= end_row - C4NUM; start_row += C4NUM) {
809     MatMul4x1Kernel(input, b, output, bias, deep, act_type);
810     input += step;
811     output += C4NUM;
812   }
813   for (; start_row <= end_row - C2NUM; start_row += C2NUM) {
814     MatMul2x1Kernel(input, b, output, bias, deep, act_type);
815     input += C2NUM * deep;
816     output += C2NUM;
817   }
818   if (start_row == end_row - 1) {
819     MatMul1x1Kernel(input, b, output, bias, deep, act_type);
820   }
821 }
822 #endif
823