1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "nnacl/fp32/matmul_fp32.h"
18 #include "nnacl/fp32/pack_fp32.h"
19 #include "nnacl/fp32/matmul_avx512_fp32.h"
20 #include "nnacl/matmul_fp32_simd.h"
21
22 #ifndef ENABLE_ARM
MatVecMulFp32(const float * a,const float * b,float * c,const float * bias,int act_type,int depth,int col)23 void MatVecMulFp32(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int col) {
24 for (int ci = 0; ci < col; ci++) {
25 float value = 0;
26 for (int di = 0; di < depth; di++) {
27 value += a[di] * b[ci * depth + di];
28 }
29 if (bias != NULL) value += bias[ci];
30 if (act_type == ActType_Relu6) value = MSMIN(6.0f, value);
31 if (act_type == ActType_Relu || act_type == ActType_Relu6) value = MSMAX(0.0f, value);
32 c[ci] = value;
33 }
34 }
35 #endif
36
MatVecMulFp32Block8(const float * a,const float * b,float * c,const float * bias,int act_type,int depth,int col)37 void MatVecMulFp32Block8(const float *a, const float *b, float *c, const float *bias, int act_type, int depth,
38 int col) {
39 int col8 = col / C8NUM * C8NUM;
40 int ci = 0;
41 for (; ci < col8; ci += C8NUM, c += C8NUM) {
42 #ifdef ENABLE_NEON
43 float32x4_t value0 = vdupq_n_f32(0.0f);
44 float32x4_t value1 = vdupq_n_f32(0.0f);
45 for (int di = 0; di < depth; ++di, b += C8NUM) {
46 value0 += vdupq_n_f32(a[di]) * vld1q_f32(b);
47 value1 += vdupq_n_f32(a[di]) * vld1q_f32(b + C4NUM);
48 }
49 if (bias != NULL) {
50 value0 += vld1q_f32(bias + ci);
51 value1 += vld1q_f32(bias + ci + C4NUM);
52 }
53 if (act_type == ActType_Relu || act_type == ActType_Relu6) {
54 value0 = vmaxq_f32(value0, vdupq_n_f32(0.0f));
55 value1 = vmaxq_f32(value1, vdupq_n_f32(0.0f));
56 }
57 if (act_type == ActType_Relu6) {
58 value0 = vminq_f32(value0, vdupq_n_f32(6.0f));
59 value1 = vminq_f32(value1, vdupq_n_f32(6.0f));
60 }
61 vst1q_f32(c, value0);
62 vst1q_f32(c + 4, value1);
63 #else
64 float value[C8NUM] = {0};
65 for (int di = 0; di < depth; ++di, b += C8NUM) {
66 for (int j = 0; j < C8NUM; ++j) {
67 value[j] += a[di] * b[j];
68 }
69 }
70 for (int j = 0; j < C8NUM; ++j) {
71 ADD_BIAS(value[j], bias, ci + j);
72 DO_RELU(value[j], act_type);
73 DO_RELU6(value[j], act_type);
74 }
75 memcpy(c, value, C8NUM * sizeof(float));
76 #endif
77 }
78 int res = col - col8;
79 float value[C8NUM] = {0};
80 for (int di = 0; di < depth; ++di, b += C8NUM) {
81 for (int j = 0; j < res; ++j) {
82 value[j] += a[di] * b[j];
83 }
84 }
85 for (int j = 0; j < res; ++j) {
86 ADD_BIAS(value[j], bias, ci + j);
87 DO_RELU(value[j], act_type);
88 DO_RELU6(value[j], act_type);
89 }
90 memcpy(c, value, res * sizeof(float));
91 }
92
93 #ifdef ENABLE_ARM32
MatVecMulFp32Block4(const float * a,const float * b,float * c,const float * bias,int act_type,int depth,int col)94 void MatVecMulFp32Block4(const float *a, const float *b, float *c, const float *bias, int act_type, int depth,
95 int col) {
96 int col4 = col / C4NUM * C4NUM;
97 int ci = 0;
98 for (; ci < col4; ci += C4NUM, c += C4NUM) {
99 #ifdef ENABLE_NEON
100 float32x4_t value = vdupq_n_f32(0.0f);
101 for (int di = 0; di < depth; ++di, b += C4NUM) {
102 value += vdupq_n_f32(a[di]) * vld1q_f32(b);
103 }
104 if (bias != NULL) {
105 value += vld1q_f32(&(bias[ci]));
106 }
107 if (act_type == ActType_Relu || act_type == ActType_Relu6) {
108 value = vmaxq_f32(value, vdupq_n_f32(0.0f));
109 }
110 if (act_type == ActType_Relu6) {
111 value = vminq_f32(value, vdupq_n_f32(6.0f));
112 }
113 vst1q_f32(c, value);
114 #else
115 float value[C4NUM] = {0};
116 for (int di = 0; di < depth; ++di, b += C4NUM) {
117 for (int j = 0; j < C4NUM; ++j) {
118 value[j] += a[di] * b[j];
119 }
120 }
121 for (int j = 0; j < C4NUM; ++j) {
122 ADD_BIAS(value[j], bias, ci + j);
123 DO_RELU(value[j], act_type);
124 DO_RELU6(value[j], act_type);
125 }
126 memcpy(c, value, C4NUM * sizeof(float));
127 #endif
128 }
129 int res = col - col4;
130 float value[C4NUM] = {0};
131 for (int di = 0; di < depth; ++di, b += C4NUM) {
132 for (int j = 0; j < res; ++j) {
133 value[j] += a[di] * b[j];
134 }
135 }
136 for (int j = 0; j < res; ++j) {
137 ADD_BIAS(value[j], bias, ci + j);
138 DO_RELU(value[j], act_type);
139 DO_RELU6(value[j], act_type);
140 }
141 memcpy(c, value, res * sizeof(float));
142 }
143 #endif
144
145 #ifdef ENABLE_ARM64
146 // 4x8
MatVecMulFp32Neon64(const float * a,const float * b,float * c,const float * bias,int act_type,int depth,int col,int align_col)147 void MatVecMulFp32Neon64(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int col,
148 int align_col) {
149 int ci = 0;
150 for (; ci < align_col - C8NUM + 1; ci += C8NUM) {
151 float32x4_t acc_0;
152 float32x4_t acc_1;
153 if (bias != NULL) {
154 acc_0 = vld1q_f32(bias + ci);
155 acc_1 = vld1q_f32(bias + ci + C4NUM);
156 } else {
157 acc_0 = vdupq_n_f32(0.0f);
158 acc_1 = vdupq_n_f32(0.0f);
159 }
160 const float *bv_base = b + ci * depth;
161 int di = 0;
162 for (; di < depth - C4NUM + 1; di += C4NUM) {
163 float32x4_t av = vld1q_f32(a + di);
164 float32x4_t bv_00 = vld1q_f32(bv_base);
165 float32x4_t bv_10 = vld1q_f32(bv_base + C4NUM);
166 bv_base += C8NUM;
167 float32x4_t bv_01 = vld1q_f32(bv_base);
168 float32x4_t bv_11 = vld1q_f32(bv_base + C4NUM);
169 bv_base += C8NUM;
170 float32x4_t bv_02 = vld1q_f32(bv_base);
171 float32x4_t bv_12 = vld1q_f32(bv_base + C4NUM);
172 bv_base += C8NUM;
173 float32x4_t bv_03 = vld1q_f32(bv_base);
174 float32x4_t bv_13 = vld1q_f32(bv_base + C4NUM);
175 bv_base += C8NUM;
176 acc_0 = vmlaq_n_f32(acc_0, bv_00, av[0]);
177 acc_1 = vmlaq_n_f32(acc_1, bv_10, av[0]);
178 acc_0 = vmlaq_n_f32(acc_0, bv_01, av[1]);
179 acc_1 = vmlaq_n_f32(acc_1, bv_11, av[1]);
180 acc_0 = vmlaq_n_f32(acc_0, bv_02, av[2]);
181 acc_1 = vmlaq_n_f32(acc_1, bv_12, av[2]);
182 acc_0 = vmlaq_n_f32(acc_0, bv_03, av[3]);
183 acc_1 = vmlaq_n_f32(acc_1, bv_13, av[3]);
184 }
185 if (di < depth) {
186 for (; di < depth; ++di) {
187 float ai = a[di];
188 float32x4_t bv0 = vld1q_f32(bv_base);
189 float32x4_t bv1 = vld1q_f32(bv_base + C4NUM);
190 acc_0 = vmlaq_n_f32(acc_0, bv0, ai);
191 acc_1 = vmlaq_n_f32(acc_1, bv1, ai);
192 bv_base += C8NUM;
193 }
194 } // only save actual col num data
195 if (ci + C4NUM - 1 >= col) {
196 int c_remain = col - ci;
197 for (int i = 0; i < c_remain; ++i) {
198 if (act_type == ActType_Relu) {
199 c[i] = MSMAX(acc_0[i], 0.0f);
200 } else if (act_type == ActType_Relu6) {
201 c[i] = MSMIN(MSMAX(acc_0[i], 0.0f), 6.0f);
202 } else {
203 c[i] = acc_0[i];
204 }
205 }
206 return;
207 }
208 if (act_type == ActType_Relu) {
209 acc_0 = vmaxq_f32(acc_0, vdupq_n_f32(0.0f));
210 } else if (act_type == ActType_Relu6) {
211 acc_0 = vminq_f32(vmaxq_f32(acc_0, vdupq_n_f32(0.0f)), vdupq_n_f32(6.0f));
212 }
213 vst1q_f32(c, acc_0);
214 if (ci + C8NUM - 1 >= col) {
215 int c_remain = col - ci - C4NUM;
216 for (int i = 0; i < c_remain; ++i) {
217 if (act_type == ActType_Relu) {
218 c[C4NUM + i] = MSMAX(acc_1[i], 0.0f);
219 } else if (act_type == ActType_Relu6) {
220 c[C4NUM + i] = MSMIN(MSMAX(acc_1[i], 0.0f), 6.0f);
221 } else {
222 c[C4NUM + i] = acc_1[i];
223 }
224 }
225 return;
226 }
227 if (act_type == ActType_Relu) {
228 acc_1 = vmaxq_f32(acc_1, vdupq_n_f32(0.0f));
229 } else if (act_type == ActType_Relu6) {
230 acc_1 = vminq_f32(vmaxq_f32(acc_1, vdupq_n_f32(0.0f)), vdupq_n_f32(6.0f));
231 }
232 vst1q_f32(c + C4NUM, acc_1);
233 c += C8NUM;
234 }
235 }
236 #endif
237
MatMul12x8(const float * a,const float * b,float * dst,const float * bias,ActType act_type,int deep,int row,int col,int stride,int out_type)238 void MatMul12x8(const float *a, const float *b, float *dst, const float *bias, ActType act_type, int deep, int row,
239 int col, int stride, int out_type) {
240 if (out_type == OutType_Nhwc) {
241 for (int r = 0; r < row; r++) {
242 for (int c = 0; c < col; c++) {
243 int r12div = r / 12, r12mod = r % 12;
244 int c8div = c / 8, c8mod = c % 8;
245 size_t ci = r * stride + c;
246 float value = 0;
247 for (int d = 0; d < deep; d++) {
248 size_t ai = r12div * deep * 12 + d * 12 + r12mod;
249 size_t bi = c8div * deep * 8 + d * 8 + c8mod;
250 value = value + a[ai] * b[bi];
251 }
252 ADD_BIAS(value, bias, c)
253 DO_RELU(value, act_type)
254 DO_RELU6(value, act_type)
255 dst[ci] = value;
256 }
257 }
258 } else if (out_type == OutType_C8) {
259 int col_8 = UP_ROUND(col, C8NUM);
260 int row_12 = UP_ROUND(row, C12NUM);
261 for (int r = 0; r < row_12; r++) {
262 for (int c = 0; c < col_8; c++) {
263 int r12div = r / C12NUM, r12mod = r % C12NUM;
264 int c8div = c / C8NUM, c8mod = c % C8NUM;
265 size_t ci = (c8div * C8NUM * row_12 + r * C8NUM + c8mod);
266 float value = 0;
267 for (int d = 0; d < deep; d++) {
268 size_t ai = r12div * deep * C12NUM + d * C12NUM + r12mod;
269 size_t bi = c8div * deep * C8NUM + d * C8NUM + c8mod;
270 value = value + a[ai] * b[bi];
271 }
272 ADD_BIAS(value, bias, c)
273 DO_RELU(value, act_type)
274 DO_RELU6(value, act_type)
275 dst[ci] = value;
276 }
277 }
278 } else if (out_type == OutType_TileC8) {
279 for (int i = 0; i < row; ++i) {
280 int src_r_offset = i;
281 int dst_r_offset = i * col * stride;
282 for (int j = 0; j < col; ++j) {
283 int c8div = j / 8, c8mod = j % 8;
284 size_t ci = dst_r_offset + c8div * 8 * stride + c8mod;
285 float value = 0;
286 for (int d = 0; d < deep; ++d) {
287 size_t ai = src_r_offset + d * C12NUM;
288 size_t bi = c8div * deep * 8 + d * 8 + c8mod;
289 value = value + a[ai] * b[bi];
290 }
291 ADD_BIAS(value, bias, j)
292 DO_RELU(value, act_type)
293 DO_RELU6(value, act_type)
294 dst[ci] = value;
295 }
296 }
297 }
298 }
299
MatMulOpt(const float * a,const float * b,float * c,const float * bias,ActType act_type,int deep,int row,int col,size_t stride,int out_type)300 void MatMulOpt(const float *a, const float *b, float *c, const float *bias, ActType act_type, int deep, int row,
301 int col, size_t stride, int out_type) {
302 #ifdef ENABLE_ARM64
303 if (out_type == OutType_C8) {
304 MatmulFloatNeon64(a, b, c, bias, (int)act_type, deep, row, col, stride, 0, 0);
305 } else if (out_type == OutType_Nhwc && deep > C512NUM) {
306 BigMatmulFloatNeon64Opt(a, b, c, bias, (int)act_type, deep, row, col, stride);
307 } else {
308 MatmulFloatNeon64Opt(a, b, c, bias, (int)act_type, deep, row, col, stride, (int)(out_type));
309 }
310 #elif ENABLE_ARM32
311 if (out_type == OutType_C8) {
312 MatmulFloatNeon32(a, b, c, bias, (int)act_type, deep, row, col, stride, 0, 0);
313 } else if (out_type == OutType_Nhwc) {
314 MatmulFloatNeon32Opt12x4(a, b, c, bias, (int)act_type, deep, row, col, stride, 1);
315 } else {
316 MatmulFloatNeon32Opt(a, b, c, bias, (int)act_type, deep, row, col, stride, (int)(out_type));
317 }
318 #elif ENABLE_AVX
319 MatmulFloatAvxOpt(a, b, c, bias, (size_t)act_type, deep, row, col, stride, (size_t)(out_type));
320 #elif ENABLE_SSE
321 MatmulFloatSse64Opt(a, b, c, bias, (int)act_type, deep, row, col, stride, (int)(out_type));
322 #else
323 MatMul12x8(a, b, c, bias, act_type, deep, row, col, stride, out_type);
324 #endif
325 }
326
327 #define ActCompute(bit_num, down_threshold, up_threshold) \
328 if (act_type != 0) { \
329 dst = MS_MAX##bit_num##_F32(dst, down_threshold); \
330 if (act_type == 3) { \
331 dst = MS_MIN##bit_num##_F32(dst, up_threshold); \
332 } \
333 }
334
335 // act_type must be 0, 1, 2. 0: no_act, 1: relu, 3: relu6.
GemmIsNotPack(const float * a,const float * b,float * c,const float * bias,int row,int deep,int act_type)336 void GemmIsNotPack(const float *a, const float *b, float *c, const float *bias, int row, int deep, int act_type) {
337 int index = 0;
338
339 SIMD_RUN_NO_SCALAR(GemmIsNotPack, index, a, b, c, bias, row, deep, act_type);
340
341 for (; index < row; ++index) {
342 float dst = a[index] * b[0] + bias[0];
343 ActCompute(32, 0, C6NUM);
344 c[index] = dst;
345 }
346 }
347
348 // act_type must be 0, 1, 2. 0: no_act, 1: relu, 3: relu6.
Row1Deep1GemmIsNotPack(const float * a,const float * b,float * c,const float * bias,int col,int deep,int act_type)349 void Row1Deep1GemmIsNotPack(const float *a, const float *b, float *c, const float *bias, int col, int deep,
350 int act_type) {
351 int index = 0;
352
353 SIMD_RUN_NO_SCALAR(Row1Deep1GemmIsNotPack, index, a, b, c, bias, col, act_type);
354 for (; index < col; ++index) {
355 float dst = a[0] * b[index] + bias[index];
356 ActCompute(32, 0, C6NUM);
357 c[index] = dst;
358 }
359 }
360
361 // act_type must be 0, 1, 2. 0: no_act, 1: relu, 3: relu6.
Row1Deep1NoBiasGemmIsNotPack(const float * a,const float * b,float * c,const float * bias,int col,int deep,int act_type)362 void Row1Deep1NoBiasGemmIsNotPack(const float *a, const float *b, float *c, const float *bias, int col, int deep,
363 int act_type) {
364 int index = 0;
365
366 SIMD_RUN_NO_SCALAR(Row1Deep1NoBiasGemmIsNotPack, index, a, b, c, bias, col, act_type);
367 for (; index < col; ++index) {
368 float dst = a[0] * b[index];
369 ActCompute(32, 0, C6NUM);
370 c[index] = dst;
371 }
372 }
373
374 // act_type must be 0, 1, 2. 0: no_act, 1: relu, 3: relu6.
GemmIsNotPackOptimize(const float * a,const float * b,float * c,const float * bias,int m,int k,int act_type)375 void GemmIsNotPackOptimize(const float *a, const float *b, float *c, const float *bias, int m, int k, int act_type) {
376 // gemm dot is [m, k] * [k, 1] ==>> [m, 1]
377 int m_index = 0;
378
379 SIMD_RUN_AVX512(GemmIsNotPackOptimize, m_index, a, b, c, bias, m, k, act_type);
380
381 #ifdef ENABLE_AVX
382 // block 4
383 MS_FLOAT32X4 down_threshold128 = MS_MOVQ_F32(0);
384 MS_FLOAT32X4 up_threshold128 = MS_MOVQ_F32(C6NUM);
385 for (; m_index <= m - C4NUM; m_index += C4NUM) {
386 int k_index = 0;
387 MS_FLOAT32X4 dst = MS_MOV128_F32(bias[0]);
388 MS_SET_ZERO256X4_F32(dst_)
389 for (; k_index <= k - C8NUM; k_index += C8NUM) {
390 MS_FLOAT32X8 weight = MS_LD256_F32(b + k_index);
391 MS_LOAD256X4_F32(src, a + m_index * k + k_index, k);
392 MS_FMADD256X4_F32(src, weight, dst_);
393 }
394 MS_F32X8_GETI(dst, 0) += MS_REDUCE_ADD256_F32(dst_1);
395 MS_F32X8_GETI(dst, 1) += MS_REDUCE_ADD256_F32(dst_2);
396 MS_F32X8_GETI(dst, C2NUM) += MS_REDUCE_ADD256_F32(dst_3);
397 MS_F32X8_GETI(dst, C3NUM) += MS_REDUCE_ADD256_F32(dst_4);
398 for (; k_index < k; ++k_index) {
399 MS_F32X8_GETI(dst, 0) += b[k_index] * a[m_index * k + k_index];
400 MS_F32X8_GETI(dst, 1) += b[k_index] * a[m_index * k + k_index + k];
401 MS_F32X8_GETI(dst, C2NUM) += b[k_index] * a[m_index * k + k_index + C2NUM * k];
402 MS_F32X8_GETI(dst, C3NUM) += b[k_index] * a[m_index * k + k_index + C3NUM * k];
403 }
404 ActCompute(128, down_threshold128, up_threshold128);
405 MS_ST128_F32(c + m_index, dst);
406 }
407 #endif
408
409 // block 1
410 for (; m_index < m; m_index++) {
411 float dst = bias[0];
412 int k_index = 0;
413
414 SIMD_RUN_AVX512(GemmIsNotPackOptimizeCore, k_index, a + m_index * k, b, k, &dst);
415 SIMD_RUN_AVX(GemmIsNotPackOptimizeCore, k_index, a + m_index * k, b, k, &dst);
416
417 for (; k_index < k; k_index++) {
418 dst += b[k_index] * a[m_index * k + k_index];
419 }
420 ActCompute(32, 0, C6NUM);
421 c[m_index] = dst;
422 }
423 }
424
MatVecMulNoPackFp32(const float * a,const float * b,float * c,const float * bias,int act_type,int64_t depth,int64_t cur_col,int64_t col)425 void MatVecMulNoPackFp32(const float *a, const float *b, float *c, const float *bias, int act_type, int64_t depth,
426 int64_t cur_col, int64_t col) {
427 int inc_flag = 0;
428 int64_t k = 0;
429 for (; k <= depth - C1500NUM; k += C1500NUM) {
430 inc_flag = (k == 0) + (k + C1500NUM == depth ? C2NUM : 0);
431 int64_t oc_index = 0;
432 SIMD_RUN_NO_SCALAR(MatVecMulNoPackCore, oc_index, a, b, c, bias, act_type, C1500NUM, cur_col, col, inc_flag);
433 for (; oc_index < cur_col; ++oc_index) {
434 float dst = (inc_flag & 1) == 0 ? c[oc_index] : (bias == NULL ? 0 : bias[oc_index]);
435 for (int64_t k_index = 0; k_index < k; ++k_index) {
436 dst += a[k_index] * b[oc_index + k_index * col];
437 }
438 if ((inc_flag & 0x2) != 0) {
439 ActCompute(32, 0, C6NUM);
440 }
441 c[oc_index] = dst;
442 }
443 a += C1500NUM;
444 b += C1500NUM * col;
445 }
446 if (k == depth) {
447 return;
448 }
449 inc_flag = (k == 0) + C2NUM;
450 int64_t oc_index = 0;
451 SIMD_RUN_NO_SCALAR(MatVecMulNoPackCore, oc_index, a, b, c, bias, act_type, depth - k, cur_col, col, inc_flag);
452 for (; oc_index < cur_col; ++oc_index) {
453 float dst = (inc_flag & 1) == 0 ? c[oc_index] : (bias == NULL ? 0 : bias[oc_index]);
454 for (int64_t k_index = 0; k_index < depth; ++k_index) {
455 dst += a[k_index] * b[oc_index + k_index * col];
456 }
457 ActCompute(32, 0, C6NUM);
458 c[oc_index] = dst;
459 }
460 }
461
462 #ifdef ENABLE_ARM64
463 // act_type must be 0, 1, 2. 0: no_act, 1: relu, 3: relu6.
MatMul4x1Kernel(const float * input,const float * weight,float * output,const float * bias,size_t deep,size_t act_type)464 void MatMul4x1Kernel(const float *input, const float *weight, float *output, const float *bias, size_t deep,
465 size_t act_type) {
466 // 1: LoopD16, 2: LoopD12, 3: LoopD8, 4: LoopD4, 5: LoopD1, 6: LoopDEnd, 7: LoopDTail, 8: LoopDTailCompute
467 // 9: WriteBack
468 asm volatile(
469 "mov x8, %[input]\n"
470 "mov x9, %[weight]\n"
471 "mov x10, %[deep]\n"
472 "add x5, %[input], %[deep], LSL #2\n"
473 "add x6, %[input], %[deep], LSL #3\n"
474 "add x7, x5, %[deep], LSL #3\n"
475 "dup v0.2d, xzr\n"
476 "dup v1.2d, xzr\n"
477 "dup v2.2d, xzr\n"
478 "dup v3.2d, xzr\n"
479 "subs x10, x10, #16\n"
480 "blt 2f\n"
481 "1:\n" // LoopD16
482 "ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x8], #64\n"
483 "ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x5], #64\n"
484 "ld1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x6], #64\n"
485 "ld1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x7], #64\n"
486 "ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x9], #64\n"
487 "fmla v0.4s, v4.4s, v28.4s\n"
488 "fmla v1.4s, v16.4s, v28.4s\n"
489 "fmla v2.4s, v20.4s, v28.4s\n"
490 "fmla v3.4s, v24.4s, v28.4s\n"
491 "fmla v0.4s, v5.4s, v29.4s\n"
492 "fmla v1.4s, v17.4s, v29.4s\n"
493 "fmla v2.4s, v21.4s, v29.4s\n"
494 "fmla v3.4s, v25.4s, v29.4s\n"
495 "fmla v0.4s, v6.4s, v30.4s\n"
496 "fmla v1.4s, v18.4s, v30.4s\n"
497 "fmla v2.4s, v22.4s, v30.4s\n"
498 "fmla v3.4s, v26.4s, v30.4s\n"
499 "fmla v0.4s, v7.4s, v31.4s\n"
500 "fmla v1.4s, v19.4s, v31.4s\n"
501 "fmla v2.4s, v23.4s, v31.4s\n"
502 "fmla v3.4s, v27.4s, v31.4s\n"
503 "subs x10, x10, #16\n"
504 "bge 1b\n"
505 "2:\n" // LoopD12
506 "adds x10, x10, #16\n"
507 "cbz x10, 6f\n"
508 "cmp x10, #12\n"
509 "blt 3f\n"
510 "ld1 {v4.4s, v5.4s, v6.4s}, [x8], #48\n"
511 "ld1 {v16.4s, v17.4s, v18.4s}, [x5], #48\n"
512 "ld1 {v20.4s, v21.4s, v22.4s}, [x6], #48\n"
513 "ld1 {v24.4s, v25.4s, v26.4s}, [x7], #48\n"
514 "ld1 {v28.4s, v29.4s, v30.4s}, [x9], #48\n"
515 "fmla v0.4s, v4.4s, v28.4s\n"
516 "fmla v1.4s, v16.4s, v28.4s\n"
517 "fmla v2.4s, v20.4s, v28.4s\n"
518 "fmla v3.4s, v24.4s, v28.4s\n"
519 "fmla v0.4s, v5.4s, v29.4s\n"
520 "fmla v1.4s, v17.4s, v29.4s\n"
521 "fmla v2.4s, v21.4s, v29.4s\n"
522 "fmla v3.4s, v25.4s, v29.4s\n"
523 "fmla v0.4s, v6.4s, v30.4s\n"
524 "fmla v1.4s, v18.4s, v30.4s\n"
525 "fmla v2.4s, v22.4s, v30.4s\n"
526 "fmla v3.4s, v26.4s, v30.4s\n"
527 "sub x10, x10, #12\n"
528 "b 7f\n"
529 "3:\n" // LoopD8
530 "cmp x10, #8\n"
531 "blt 4f\n"
532 "ld1 {v4.4s, v5.4s}, [x8], #32\n"
533 "ld1 {v16.4s, v17.4s}, [x5], #32\n"
534 "ld1 {v20.4s, v21.4s}, [x6], #32\n"
535 "ld1 {v24.4s, v25.4s}, [x7], #32\n"
536 "ld1 {v28.4s, v29.4s}, [x9], #32\n"
537 "fmla v0.4s, v4.4s, v28.4s\n"
538 "fmla v1.4s, v16.4s, v28.4s\n"
539 "fmla v2.4s, v20.4s, v28.4s\n"
540 "fmla v3.4s, v24.4s, v28.4s\n"
541 "fmla v0.4s, v5.4s, v29.4s\n"
542 "fmla v1.4s, v17.4s, v29.4s\n"
543 "fmla v2.4s, v21.4s, v29.4s\n"
544 "fmla v3.4s, v25.4s, v29.4s\n"
545 "sub x10, x10, #8\n"
546 "b 7f\n"
547 "4:\n" // LoopD4
548 "cmp x10, #4\n"
549 "blt 7f\n"
550 "ld1 {v4.4s}, [x8], #16\n"
551 "ld1 {v16.4s}, [x5], #16\n"
552 "ld1 {v20.4s}, [x6], #16\n"
553 "ld1 {v24.4s}, [x7], #16\n"
554 "ld1 {v28.4s}, [x9], #16\n"
555 "fmla v0.4s, v4.4s, v28.4s\n"
556 "fmla v1.4s, v16.4s, v28.4s\n"
557 "fmla v2.4s, v20.4s, v28.4s\n"
558 "fmla v3.4s, v24.4s, v28.4s\n"
559 "sub x10, x10, #4\n"
560 "7:\n"
561 "cbz x10, 6f\n"
562 "dup v4.2d, xzr\n"
563 "dup v16.2d, xzr\n"
564 "dup v20.2d, xzr\n"
565 "dup v24.2d, xzr\n"
566 "dup v28.2d, xzr\n"
567 "subs x10, x10, #2\n"
568 "blt 5f\n"
569 "ld1 {v4.d}[0], [x8], #8\n" // LoopD2
570 "ld1 {v16.d}[0], [x5], #8\n"
571 "ld1 {v20.d}[0], [x6], #8\n"
572 "ld1 {v24.d}[0], [x7], #8\n"
573 "ld1 {v28.d}[0], [x9], #8\n"
574 "cbz x10, 8f\n"
575 "5:\n" // LoopD1
576 "ld1 {v4.s}[2], [x8]\n"
577 "ld1 {v16.s}[2], [x5]\n"
578 "ld1 {v20.s}[2], [x6]\n"
579 "ld1 {v24.s}[2], [x7]\n"
580 "ld1 {v28.s}[2], [x9]\n"
581 "8:\n"
582 "fmla v0.4s, v4.4s, v28.4s\n"
583 "fmla v1.4s, v16.4s, v28.4s\n"
584 "fmla v2.4s, v20.4s, v28.4s\n"
585 "fmla v3.4s, v24.4s, v28.4s\n"
586 "6:\n"
587 "faddp v4.4s, v0.4s, v1.4s\n"
588 "faddp v5.4s, v2.4s, v3.4s\n"
589 "faddp v0.4s, v4.4s, v5.4s\n"
590 "cbz %[bias], 9f\n"
591 "ld1r {v1.4s}, [%[bias]]\n"
592 "fadd v0.4s, v0.4s, v1.4s\n"
593 "9:\n"
594 "cbz %[act], 10f\n"
595 "dup v1.2d, xzr\n"
596 "fmax v0.4s, v0.4s, v1.4s\n"
597 "cmp %[act], #3\n"
598 "bne 10f\n"
599 "movi v1.4s, #6\n"
600 "scvtf v1.4s, v1.4s\n"
601 "fmin v0.4s, v0.4s, v1.4s\n"
602 "10:\n"
603 "st1 {v0.4s}, [%[output]]\n"
604
605 :
606 : [ input ] "r"(input), [ weight ] "r"(weight), [ output ] "r"(output), [ bias ] "r"(bias), [ deep ] "r"(deep),
607 [ act ] "r"(act_type)
608 : "cc", "x5", "x6", "x7", "x8", "x9", "x10", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v16", "v17", "v18",
609 "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31");
610 }
611
MatMul2x1Kernel(const float * input,const float * weight,float * output,const float * bias,size_t deep,size_t act_type)612 void MatMul2x1Kernel(const float *input, const float *weight, float *output, const float *bias, size_t deep,
613 size_t act_type) {
614 // 1: LoopD16, 2: LoopD12, 3: LoopD8, 4: LoopD4, 5: LoopD1, 6: LoopDEnd, 7: LoopDTail, 8: LoopDTailCompute
615 // 9: WriteBack
616 asm volatile(
617 "mov x8, %[input]\n"
618 "mov x9, %[weight]\n"
619 "mov x10, %[deep]\n"
620 "add x5, %[input], %[deep], LSL #2\n"
621 "dup v0.2d, xzr\n"
622 "dup v1.2d, xzr\n"
623 "subs x10, x10, #16\n"
624 "blt 2f\n"
625 "1:\n" // LoopD16
626 "ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x8], #64\n"
627 "ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x5], #64\n"
628 "ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x9], #64\n"
629 "fmla v0.4s, v4.4s, v28.4s\n"
630 "fmla v1.4s, v16.4s, v28.4s\n"
631 "fmla v0.4s, v5.4s, v29.4s\n"
632 "fmla v1.4s, v17.4s, v29.4s\n"
633 "fmla v0.4s, v6.4s, v30.4s\n"
634 "fmla v1.4s, v18.4s, v30.4s\n"
635 "fmla v0.4s, v7.4s, v31.4s\n"
636 "fmla v1.4s, v19.4s, v31.4s\n"
637 "subs x10, x10, #16\n"
638 "bge 1b\n"
639 "2:\n" // LoopD12
640 "adds x10, x10, #16\n"
641 "cbz x10, 6f\n"
642 "cmp x10, #12\n"
643 "blt 3f\n"
644 "ld1 {v4.4s, v5.4s, v6.4s}, [x8], #48\n"
645 "ld1 {v16.4s, v17.4s, v18.4s}, [x5], #48\n"
646 "ld1 {v28.4s, v29.4s, v30.4s}, [x9], #48\n"
647 "fmla v0.4s, v4.4s, v28.4s\n"
648 "fmla v1.4s, v16.4s, v28.4s\n"
649 "fmla v0.4s, v5.4s, v29.4s\n"
650 "fmla v1.4s, v17.4s, v29.4s\n"
651 "fmla v0.4s, v6.4s, v30.4s\n"
652 "fmla v1.4s, v18.4s, v30.4s\n"
653 "sub x10, x10, #12\n"
654 "b 7f\n"
655 "3:\n" // LoopD8
656 "cmp x10, #8\n"
657 "blt 4f\n"
658 "ld1 {v4.4s, v5.4s}, [x8], #32\n"
659 "ld1 {v16.4s, v17.4s}, [x5], #32\n"
660 "ld1 {v28.4s, v29.4s}, [x9], #32\n"
661 "fmla v0.4s, v4.4s, v28.4s\n"
662 "fmla v1.4s, v16.4s, v28.4s\n"
663 "fmla v0.4s, v5.4s, v29.4s\n"
664 "fmla v1.4s, v17.4s, v29.4s\n"
665 "sub x10, x10, #8\n"
666 "b 7f\n"
667 "4:\n" // LoopD4
668 "cmp x10, #4\n"
669 "blt 7f\n"
670 "ld1 {v4.4s}, [x8], #16\n"
671 "ld1 {v16.4s}, [x5], #16\n"
672 "ld1 {v28.4s}, [x9], #16\n"
673 "fmla v0.4s, v4.4s, v28.4s\n"
674 "fmla v1.4s, v16.4s, v28.4s\n"
675 "sub x10, x10, #4\n"
676 "7:\n"
677 "cbz x10, 6f\n"
678 "dup v4.2d, xzr\n"
679 "dup v16.2d, xzr\n"
680 "subs x10, x10, #2\n"
681 "blt 5f\n"
682 "ld1 {v4.d}[0], [x8], #8\n" // LoopD2
683 "ld1 {v16.d}[0], [x5], #8\n"
684 "ld1 {v28.d}[0], [x9], #8\n"
685 "cbz x10, 8f\n"
686 "5:\n" // LoopD1
687 "ld1 {v4.s}[2], [x8]\n"
688 "ld1 {v16.s}[2], [x5]\n"
689 "ld1 {v28.s}[2], [x9]\n"
690 "8:\n"
691 "fmla v0.4s, v4.4s, v28.4s\n"
692 "fmla v1.4s, v16.4s, v28.4s\n"
693 "6:\n"
694 "faddp v4.4s, v0.4s, v1.4s\n"
695 "faddp v0.4s, v4.4s, v4.4s\n"
696 "cbz %[bias], 9f\n"
697 "ld1r {v1.4s}, [%[bias]]\n"
698 "fadd v0.2s, v0.2s, v1.2s\n"
699 "9:\n"
700 "cbz %[act], 10f\n"
701 "fmov d1, xzr\n"
702 "fmax v0.2s, v0.2s, v1.2s\n"
703 "cmp %[act], #3\n"
704 "bne 10f\n"
705 "movi v1.2s, #6\n"
706 "scvtf v1.2s, v1.2s\n"
707 "fmin v0.2s, v0.2s, v1.2s\n"
708 "10:\n"
709 "st1 {v0.2s}, [%[output]]\n"
710
711 :
712 : [ input ] "r"(input), [ weight ] "r"(weight), [ output ] "r"(output), [ bias ] "r"(bias), [ deep ] "r"(deep),
713 [ act ] "r"(act_type)
714 : "cc", "x5", "x8", "x9", "x10", "v0", "v1", "v4", "v5", "v6", "v7", "v16", "v17", "v18", "v19", "v28", "v29",
715 "v30", "v31", "memory");
716 }
717
MatMul1x1Kernel(const float * input,const float * weight,float * output,const float * bias,size_t deep,size_t act_type)718 void MatMul1x1Kernel(const float *input, const float *weight, float *output, const float *bias, size_t deep,
719 size_t act_type) {
720 // 1: LoopD16, 2: LoopD12, 3: LoopD8, 4: LoopD4, 5: LoopD1, 6: LoopDEnd, 7: LoopDTail, 8: LoopDTailCompute
721 // 9: WriteBack
722 asm volatile(
723 "mov x8, %[input]\n"
724 "mov x9, %[weight]\n"
725 "mov x10, %[deep]\n"
726 "dup v0.2d, xzr\n"
727 "subs x10, x10, #16\n"
728 "blt 2f\n"
729 "1:\n" // LoopD16
730 "ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x8], #64\n"
731 "ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x9], #64\n"
732 "fmla v0.4s, v4.4s, v28.4s\n"
733 "fmla v0.4s, v5.4s, v29.4s\n"
734 "fmla v0.4s, v6.4s, v30.4s\n"
735 "fmla v0.4s, v7.4s, v31.4s\n"
736 "subs x10, x10, #16\n"
737 "bge 1b\n"
738 "2:\n" // LoopD12
739 "adds x10, x10, #16\n"
740 "cbz x10, 6f\n"
741 "cmp x10, #12\n"
742 "blt 3f\n"
743 "ld1 {v4.4s, v5.4s, v6.4s}, [x8], #48\n"
744 "ld1 {v28.4s, v29.4s, v30.4s}, [x9], #48\n"
745 "fmla v0.4s, v4.4s, v28.4s\n"
746 "fmla v0.4s, v5.4s, v29.4s\n"
747 "fmla v0.4s, v6.4s, v30.4s\n"
748 "sub x10, x10, #12\n"
749 "b 7f\n"
750 "3:\n" // LoopD8
751 "cmp x10, #8\n"
752 "blt 4f\n"
753 "ld1 {v4.4s, v5.4s}, [x8], #32\n"
754 "ld1 {v28.4s, v29.4s}, [x9], #32\n"
755 "fmla v0.4s, v4.4s, v28.4s\n"
756 "fmla v0.4s, v5.4s, v29.4s\n"
757 "sub x10, x10, #8\n"
758 "b 7f\n"
759 "4:\n" // LoopD4
760 "cmp x10, #4\n"
761 "blt 7f\n"
762 "ld1 {v4.4s}, [x8], #16\n"
763 "ld1 {v28.4s}, [x9], #16\n"
764 "fmla v0.4s, v4.4s, v28.4s\n"
765 "sub x10, x10, #4\n"
766 "7:\n"
767 "cbz x10, 6f\n"
768 "dup v4.2d, xzr\n"
769 "subs x10, x10, #2\n"
770 "blt 5f\n"
771 "ld1 {v4.d}[0], [x8], #8\n" // LoopD2
772 "ld1 {v28.d}[0], [x9], #8\n"
773 "cbz x10, 8f\n"
774 "5:\n" // LoopD1
775 "ld1 {v4.s}[3], [x8]\n"
776 "ld1 {v28.s}[3], [x9]\n"
777 "8:\n"
778 "fmla v0.4s, v4.4s, v28.4s\n"
779 "6:\n"
780 "faddp v4.4s, v0.4s, v0.4s\n"
781 "faddp v0.4s, v4.4s, v4.4s\n"
782 "cbz %[bias], 9f\n"
783 "ld1 {v1.s}[0], [%[bias]]\n"
784 "fadd s0, s0, s1\n"
785 "9:\n"
786 "cbz %[act], 10f\n"
787 "fmov s1, wzr\n"
788 "fmax s0, s0, s1\n"
789 "cmp %[act], #3\n"
790 "bne 10f\n"
791 "mov x10, #6\n"
792 "scvtf s1, x10\n"
793 "fmin s0, s0, s1\n"
794 "10:\n"
795 "str s0, [%[output]]\n"
796
797 :
798 : [ input ] "r"(input), [ weight ] "r"(weight), [ output ] "r"(output), [ bias ] "r"(bias), [ deep ] "r"(deep),
799 [ act ] "r"(act_type)
800 : "cc", "x8", "x9", "x10", "v0", "v4", "v5", "v6", "v7", "v16", "v17", "v18", "v19", "v28", "v29", "v30", "v31");
801 }
802
GemmIsNotPackByRow(const float * a,const float * b,float * c,const float * bias,int start_row,int end_row,int deep,int act_type)803 void GemmIsNotPackByRow(const float *a, const float *b, float *c, const float *bias, int start_row, int end_row,
804 int deep, int act_type) {
805 const float *input = a + start_row * deep;
806 float *output = c + start_row;
807 const int step = C4NUM * deep;
808 for (; start_row <= end_row - C4NUM; start_row += C4NUM) {
809 MatMul4x1Kernel(input, b, output, bias, deep, act_type);
810 input += step;
811 output += C4NUM;
812 }
813 for (; start_row <= end_row - C2NUM; start_row += C2NUM) {
814 MatMul2x1Kernel(input, b, output, bias, deep, act_type);
815 input += C2NUM * deep;
816 output += C2NUM;
817 }
818 if (start_row == end_row - 1) {
819 MatMul1x1Kernel(input, b, output, bias, deep, act_type);
820 }
821 }
822 #endif
823