1 /**
2 * Copyright 2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #ifdef ENABLE_ARM32
18 #include "nnacl/kernel/matmul_arm32.h"
19 #include "nnacl/kernel/matmul_base.h"
20 #include "nnacl/fp32/pack_fp32.h"
21 #include "nnacl/fp32/matmul_fp32.h"
22
MatmulARM32InitGlobalVariable(MatmulStruct * matmul)23 void MatmulARM32InitGlobalVariable(MatmulStruct *matmul) {
24 MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_);
25 matmul->matrix_a_.need_pack_ = true;
26 matmul->matrix_b_.need_pack_ = true;
27 matmul->matrix_a_pack_fun_ = param->a_transpose_ ? RowMajor2Row12MajorParallel : RowMajor2Col12MajorParallel;
28 matmul->matrix_b_pack_fun_ = param->b_transpose_ ? RowMajor2Col4MajorParallel : RowMajor2Row4MajorParallel;
29 matmul->compute_.row_tile_ = C12NUM;
30 matmul->compute_.col_tile_ = C4NUM;
31 matmul->compute_.col_min_unit_ = C4NUM;
32 }
33
MatmulARM32ParallelRunByBatch(MatmulStruct * matmul,int task_id)34 int MatmulARM32ParallelRunByBatch(MatmulStruct *matmul, int task_id) {
35 MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_);
36 MatmulComputeParam *compute = &matmul->compute_;
37 ActType act = param->act_type_;
38
39 int start_batch = task_id * compute->batch_stride_;
40 int end_batch = MSMIN(matmul->batch_, start_batch + compute->batch_stride_);
41 int func_flag = 0;
42 if (compute->row_ == 1) {
43 func_flag += (!matmul->b_const_ && compute->col_ <= C128NUM) ? C2NUM : C1NUM;
44 }
45
46 for (int index = start_batch; index < end_batch; ++index) {
47 const float *a = matmul->matrix_a_.pack_ptr_ + matmul->a_offset_[index] * compute->row_align_ * compute->deep_;
48 const float *b = matmul->matrix_b_.pack_ptr_ + matmul->b_offset_[index] * compute->deep_ * compute->col_align_;
49 float *c = matmul->output_data_ + index * compute->row_ * compute->col_step_;
50
51 float *bias = (matmul->matrix_c_.pack_ptr_ == NULL) ? NULL : matmul->matrix_c_.pack_ptr_;
52 if (func_flag == 0) {
53 MatMulOpt(a, b, c, bias, act, compute->deep_, compute->row_, compute->col_step_, compute->col_, OutType_Nhwc);
54 } else if (func_flag == C1NUM) {
55 MatVecMulFp32Block4(a, b, c, bias, act, compute->deep_, compute->col_step_);
56 } else {
57 MatVecMulNoPackFp32(a, b, c, bias, act, compute->deep_, compute->col_step_, compute->col_step_);
58 }
59 }
60 return NNACL_OK;
61 }
62
MatmulARM32ParallelRunByOC(MatmulStruct * matmul,int task_id)63 int MatmulARM32ParallelRunByOC(MatmulStruct *matmul, int task_id) {
64 NNACL_CHECK_FALSE(task_id < 0 || task_id >= matmul->base_.thread_nr_, NNACL_ERR);
65 MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_);
66 MatmulComputeParam *compute = &matmul->compute_;
67 ActType act = param->act_type_;
68
69 int start_oc = matmul->split_points_[task_id];
70 int end_oc = compute->col_step_;
71 if (task_id < (matmul->base_.thread_nr_ - 1)) {
72 end_oc = matmul->split_points_[task_id + 1];
73 }
74 int compute_oc = end_oc - start_oc;
75 if (compute_oc <= 0) {
76 return NNACL_OK;
77 }
78 int func_flag = 0;
79 if (compute->row_ == 1) {
80 func_flag += (!matmul->b_const_ && compute->col_ <= C128NUM) ? C2NUM : C1NUM;
81 }
82 int b_stride = func_flag == C2NUM ? start_oc : start_oc * compute->deep_;
83
84 for (int i = 0; i < matmul->batch_; ++i) {
85 float *a = matmul->matrix_a_.pack_ptr_ + matmul->a_offset_[i] * compute->row_align_ * compute->deep_;
86 float *b = matmul->matrix_b_.pack_ptr_ + matmul->b_offset_[i] * compute->deep_ * compute->col_align_ + b_stride;
87 float *c = matmul->output_data_ + i * compute->row_ * compute->col_step_ + start_oc;
88 float *bias = (matmul->matrix_c_.pack_ptr_ == NULL) ? NULL : matmul->matrix_c_.pack_ptr_ + start_oc;
89
90 if (func_flag == 0) {
91 MatMulOpt(a, b, c, bias, act, compute->deep_, compute->row_, compute_oc, compute->col_, OutType_Nhwc);
92 } else if (func_flag == C1NUM) {
93 MatVecMulFp32Block4(a, b, c, bias, act, compute->deep_, compute_oc);
94 } else {
95 MatVecMulNoPackFp32(a, b, c, bias, act, compute->deep_, compute_oc, compute->col_step_);
96 }
97 }
98 return NNACL_OK;
99 }
100
CreateMatmulARM32()101 KernelBase *CreateMatmulARM32() {
102 MatmulStruct *matmul = (MatmulStruct *)CreateMatmulBase();
103 NNACL_MALLOC_CHECK_NULL_RETURN_NULL(matmul);
104 matmul->matmul_type_ = kNotImplemented;
105 matmul->init_global_varibale_ = MatmulARM32InitGlobalVariable;
106 matmul->parallel_run_by_batch_ = MatmulARM32ParallelRunByBatch;
107 matmul->parallel_run_by_oc_ = MatmulARM32ParallelRunByOC;
108 return (KernelBase *)matmul;
109 }
110 #endif
111