1 /**
2 * Copyright 2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #ifdef ENABLE_ARM64
18 #include "nnacl/kernel/matmul_arm64.h"
19 #include "nnacl/kernel/matmul_base.h"
20 #include "nnacl/fp32/matmul_fp32.h"
21 #include "nnacl/fp32/pack_fp32.h"
22 #include "nnacl/fp32/pack_fp32_opt.h"
23
24 typedef struct MatrixAPack {
25 int64_t points_[MAX_THREAD_NUM];
26 int64_t unit_num_;
27 int thread_;
28 int deep_;
29 int row_;
30 int col_;
31 MatrixInfo *matrix_a_;
32 float *src_ptr_;
33 bool a_transpose_;
34 } MatrixAPack;
35
MatmulARM64PackMatrixAImplOptPack(void * cdata,int task_id,float l,float r)36 int MatmulARM64PackMatrixAImplOptPack(void *cdata, int task_id, float l, float r) {
37 MatrixAPack *pack = (MatrixAPack *)cdata;
38 int64_t start = pack->points_[task_id];
39 int64_t end = pack->unit_num_;
40 if (task_id < pack->thread_ - 1) {
41 end = pack->points_[task_id + 1];
42 }
43
44 if (pack->a_transpose_) {
45 RowMajor2Row12MajorOpt(pack->src_ptr_, pack->matrix_a_->pack_ptr_, pack->deep_, pack->row_, start, end);
46 } else {
47 RowMajor2Col12MajorOpt(pack->src_ptr_, pack->matrix_a_->pack_ptr_, pack->row_, pack->deep_, start, end);
48 }
49 return NNACL_OK;
50 }
51
MatmulARM64PackMatrixAImplOpt(MatmulStruct * matmul)52 int MatmulARM64PackMatrixAImplOpt(MatmulStruct *matmul) {
53 int64_t kPackAMinUnitNum = 1 << 13;
54 MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_);
55 float *src_ptr = matmul->matrix_a_.origin_ptr_ != NULL ? matmul->matrix_a_.origin_ptr_
56 : (float *)(matmul->base_.in_[FIRST_INPUT]->data_);
57 NNACL_CHECK_TRUE_RET(src_ptr != NULL, NNACL_ERR);
58 NNACL_CHECK_TRUE_RET(matmul->matrix_a_.pack_ptr_ != NULL, NNACL_ERR);
59
60 MatrixAPack pack;
61 pack.src_ptr_ = src_ptr;
62 pack.matrix_a_ = &matmul->matrix_a_;
63 pack.deep_ = matmul->compute_.deep_;
64 pack.col_ = matmul->compute_.col_;
65 pack.row_ = matmul->compute_.row_;
66 pack.a_transpose_ = param->a_transpose_;
67 pack.unit_num_ = 0;
68 pack.unit_num_ = matmul->a_batch_ * UP_DIV(matmul->compute_.row_, C12NUM) * matmul->compute_.deep_;
69 pack.thread_ = MSMIN(matmul->base_.thread_nr_, UP_DIV(pack.unit_num_, kPackAMinUnitNum));
70 if (pack.thread_ < 1) {
71 pack.thread_ = 1;
72 }
73 int64_t block_size = pack.unit_num_ / pack.thread_;
74 int64_t remain_size = pack.unit_num_ - block_size * pack.thread_;
75 int64_t start = 0;
76 size_t count = 0;
77 while (start < pack.unit_num_) {
78 pack.points_[count++] = start;
79 start += block_size;
80 if (remain_size > 0) {
81 ++start;
82 --remain_size;
83 }
84 }
85 pack.thread_ = count;
86
87 if (pack.thread_ == 1) {
88 return MatmulARM64PackMatrixAImplOptPack(&pack, 0, 0, 1);
89 }
90 return matmul->base_.env_->ParallelLaunch(matmul->base_.env_->thread_pool_, MatmulARM64PackMatrixAImplOptPack, &pack,
91 pack.thread_);
92 }
93
MatmulARM64CheckThreadCuttingByRow(MatmulStruct * matmul)94 bool MatmulARM64CheckThreadCuttingByRow(MatmulStruct *matmul) {
95 if (matmul->b_batch_ != C1NUM) {
96 return false;
97 }
98 if (matmul->batch_ >= matmul->base_.thread_nr_ || matmul->compute_.col_ == 1) {
99 matmul->compute_.row_min_unit_ = C4NUM;
100 return true;
101 }
102 return false;
103 }
MatmulARM64InitGlobalVariable(MatmulStruct * matmul)104 void MatmulARM64InitGlobalVariable(MatmulStruct *matmul) {
105 MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_);
106 matmul->pack_opt_ = true;
107 matmul->compute_.row_tile_ = C12NUM;
108 matmul->compute_.col_tile_ = C8NUM;
109 matmul->compute_.col_min_unit_ = C8NUM;
110 matmul->matrix_a_.need_pack_ = true;
111 matmul->matrix_b_.need_pack_ = !matmul->weight_is_packed_;
112 matmul->matrix_a_pack_fun_ = param->a_transpose_ ? RowMajor2Row12MajorParallel : RowMajor2Col12MajorParallel;
113 matmul->matrix_b_pack_fun_ = param->b_transpose_ ? RowMajor2Col8MajorParallel : RowMajor2Row8MajorParallel;
114 }
115
MatmulARM64ParallelRunByBatch(MatmulStruct * matmul,int task_id)116 int MatmulARM64ParallelRunByBatch(MatmulStruct *matmul, int task_id) {
117 NNACL_CHECK_FALSE(task_id < 0 || task_id >= matmul->base_.thread_nr_, NNACL_ERR);
118 MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_);
119 MatmulComputeParam *compute = &matmul->compute_;
120 ActType act = param->act_type_;
121
122 int start_batch = task_id * compute->batch_stride_;
123 int end_batch = MSMIN(matmul->batch_, start_batch + compute->batch_stride_);
124 int func_flag = 0;
125 if (compute->row_ == 1) {
126 func_flag += (!matmul->b_const_ && compute->col_ <= C128NUM) ? C2NUM : C1NUM;
127 }
128
129 for (int index = start_batch; index < end_batch; ++index) {
130 const float *a = matmul->matrix_a_.pack_ptr_ + matmul->a_offset_[index] * compute->row_align_ * compute->deep_;
131 const float *b = matmul->matrix_b_.pack_ptr_ + matmul->b_offset_[index] * compute->deep_ * compute->col_align_;
132 float *c = matmul->output_data_ + index * compute->row_ * compute->col_step_;
133 float *bias = (matmul->matrix_c_.pack_ptr_ == NULL) ? NULL : matmul->matrix_c_.pack_ptr_;
134
135 if (func_flag == 0) {
136 MatMulOpt(a, b, c, bias, act, compute->deep_, compute->row_, compute->col_step_, compute->col_, OutType_Nhwc);
137 } else if (func_flag == C1NUM) {
138 MatVecMulPackFp32(a, b, c, bias, act, compute->deep_, compute->col_step_);
139 } else {
140 MatVecMulNoPackFp32(a, b, c, bias, act, compute->deep_, compute->col_step_, compute->col_step_);
141 }
142 }
143 return NNACL_OK;
144 }
145
MatmulARM64ParallelRunByRow(MatmulStruct * matmul,int task_id)146 int MatmulARM64ParallelRunByRow(MatmulStruct *matmul, int task_id) {
147 MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_);
148 NNACL_CHECK_FALSE(task_id < 0 || task_id >= matmul->base_.thread_nr_, NNACL_ERR);
149
150 int start_row = matmul->split_points_[task_id];
151 int end_row = matmul->compute_.row_num_;
152 if (task_id < (matmul->base_.thread_nr_ - 1)) {
153 end_row = matmul->split_points_[task_id + 1];
154 }
155 int row_num = end_row - start_row;
156 if (row_num <= 0) {
157 return NNACL_OK;
158 }
159 GemmIsNotPackByRow(matmul->matrix_a_.pack_ptr_, matmul->matrix_b_.pack_ptr_, matmul->output_data_,
160 matmul->matrix_c_.pack_ptr_, start_row, end_row, matmul->compute_.deep_, param->act_type_);
161 return NNACL_OK;
162 }
163
MatmulARM64ParallelRunByOC(MatmulStruct * matmul,int task_id)164 int MatmulARM64ParallelRunByOC(MatmulStruct *matmul, int task_id) {
165 NNACL_CHECK_FALSE(task_id < 0 || task_id >= matmul->base_.thread_nr_, NNACL_ERR);
166 MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_);
167 MatmulComputeParam *compute = &matmul->compute_;
168 ActType act = param->act_type_;
169
170 int start_oc = matmul->split_points_[task_id];
171 int end_oc = compute->col_step_;
172 if (task_id < (matmul->base_.thread_nr_ - 1)) {
173 end_oc = matmul->split_points_[task_id + 1];
174 }
175 int compute_oc = end_oc - start_oc;
176 if (compute_oc <= 0) {
177 return NNACL_OK;
178 }
179 int func_flag = 0;
180 if (compute->row_ == 1) {
181 func_flag += (!matmul->b_const_ && compute->col_ <= C128NUM) ? C2NUM : C1NUM;
182 }
183 int b_stride = func_flag == C2NUM ? start_oc : start_oc * compute->deep_;
184
185 for (int i = 0; i < matmul->batch_; ++i) {
186 float *a = matmul->matrix_a_.pack_ptr_ + matmul->a_offset_[i] * compute->row_align_ * compute->deep_;
187 float *b = matmul->matrix_b_.pack_ptr_ + matmul->b_offset_[i] * compute->deep_ * compute->col_align_ + b_stride;
188 float *c = matmul->output_data_ + i * compute->row_ * compute->col_step_ + start_oc;
189 float *bias = (matmul->matrix_c_.pack_ptr_ == NULL) ? NULL : matmul->matrix_c_.pack_ptr_ + start_oc;
190
191 if (func_flag == 0) {
192 MatMulOpt(a, b, c, bias, act, compute->deep_, compute->row_, compute_oc, compute->col_, OutType_Nhwc);
193 } else if (func_flag == C1NUM) {
194 MatVecMulPackFp32(a, b, c, bias, act, compute->deep_, compute_oc);
195 } else {
196 MatVecMulNoPackFp32(a, b, c, bias, act, compute->deep_, compute_oc, compute->col_step_);
197 }
198 }
199 return NNACL_OK;
200 }
201
CreateMatmulARM64()202 KernelBase *CreateMatmulARM64() {
203 MatmulStruct *matmul = (MatmulStruct *)CreateMatmulBase();
204 NNACL_MALLOC_CHECK_NULL_RETURN_NULL(matmul);
205 matmul->matmul_type_ = kMatmulFp32Arm64Cpu;
206 matmul->check_thread_cutting_by_row_ = MatmulARM64CheckThreadCuttingByRow;
207 matmul->init_global_varibale_ = MatmulARM64InitGlobalVariable;
208 matmul->parallel_run_by_oc_ = MatmulARM64ParallelRunByOC;
209 matmul->parallel_run_by_row_ = MatmulARM64ParallelRunByRow;
210 matmul->parallel_run_by_batch_ = MatmulARM64ParallelRunByBatch;
211 matmul->pack_matrix_a_impl_opt_ = MatmulARM64PackMatrixAImplOpt;
212 return (KernelBase *)matmul;
213 }
214 #endif
215