1 /**
2 * Copyright 2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "nnacl/kernel/matmul_base.h"
18 #include "nnacl/fp32/pack_fp32.h"
19 #include "nnacl/fp32/matmul_fp32.h"
20 #include "nnacl/tensor_c_utils.h"
21 #include "nnacl/op_base.h"
22
23 #define kNumDeepThreshold 512
24
MatmulFp32Run(void * cdata,int task_id,float l,float r)25 int MatmulFp32Run(void *cdata, int task_id, float l, float r) {
26 NNACL_CHECK_NULL_RETURN_ERR(cdata);
27 MatmulStruct *matmul = (MatmulStruct *)cdata;
28 return matmul->parallel_run_(matmul, task_id);
29 }
30
MatmulBaseFreeBatchOffset(MatmulStruct * matmul)31 void MatmulBaseFreeBatchOffset(MatmulStruct *matmul) {
32 if (matmul->a_offset_ != NULL) {
33 free(matmul->a_offset_);
34 matmul->a_offset_ = NULL;
35 }
36 if (matmul->b_offset_ != NULL) {
37 free(matmul->b_offset_);
38 matmul->b_offset_ = NULL;
39 }
40 }
41
MatmulBaseMallocBatchOffset(MatmulStruct * matmul)42 int MatmulBaseMallocBatchOffset(MatmulStruct *matmul) {
43 matmul->a_offset_ = malloc(matmul->batch_ * sizeof(int));
44 NNACL_MALLOC_CHECK_NULL_RETURN_ERR(matmul->a_offset_);
45 memset(matmul->a_offset_, 0, matmul->batch_ * sizeof(int));
46
47 matmul->b_offset_ = malloc(matmul->batch_ * sizeof(int));
48 NNACL_MALLOC_CHECK_NULL_RETURN_ERR(matmul->b_offset_);
49 memset(matmul->b_offset_, 0, matmul->batch_ * sizeof(int));
50 return NNACL_OK;
51 }
52
MatmulBasePackMatrixBParallelRunByBatch(MatmulStruct * matmul,int task_id)53 int MatmulBasePackMatrixBParallelRunByBatch(MatmulStruct *matmul, int task_id) {
54 MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_);
55 MatmulComputeParam *compute = &matmul->compute_;
56
57 int start = task_id * compute->pack_b_stride_;
58 if (param->b_transpose_) {
59 int end = NNACL_MIN(matmul->compute_.col_, start + compute->pack_b_stride_);
60 matmul->matrix_b_pack_fun_(matmul->pack_b_src_, matmul->pack_b_dst_, compute->col_, compute->deep_, start, end);
61 } else {
62 int end = NNACL_MIN(matmul->compute_.deep_, start + compute->pack_b_stride_);
63 matmul->matrix_b_pack_fun_(matmul->pack_b_src_, matmul->pack_b_dst_, compute->deep_, compute->col_, start, end);
64 }
65 return NNACL_OK;
66 }
67
MatmulFp32PackMatrixBRun(void * cdata,int task_id,float l,float r)68 int MatmulFp32PackMatrixBRun(void *cdata, int task_id, float l, float r) {
69 NNACL_CHECK_NULL_RETURN_ERR(cdata);
70 MatmulStruct *matmul = (MatmulStruct *)cdata;
71 return MatmulBasePackMatrixBParallelRunByBatch(matmul, task_id);
72 }
73
MatmulBaseCheckRowOptimalConditions(MatmulStruct * matmul)74 bool MatmulBaseCheckRowOptimalConditions(MatmulStruct *matmul) {
75 return matmul->compute_.row_ == 1 &&
76 !(matmul->support_mul_batch_cut_by_row_ && (matmul->a_batch_ > 1 && matmul->b_batch_ == 1));
77 }
78
MatmulBaseInitParameter(MatmulStruct * matmul)79 int MatmulBaseInitParameter(MatmulStruct *matmul) {
80 MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_);
81 MatmulComputeParam *compute = &matmul->compute_;
82
83 matmul->init_global_varibale_(matmul);
84 if (MatmulBaseCheckRowOptimalConditions(matmul)) {
85 compute->row_tile_ = 1;
86 matmul->matrix_a_pack_fun_ = param->a_transpose_ ? RowMajor2ColMajorParallel : RowMajor2RowMajorParallel;
87 matmul->matrix_a_.need_pack_ = false;
88 matmul->pack_opt_ = false;
89 if (!matmul->b_const_ && compute->col_ <= C128NUM) {
90 compute->col_tile_ = 1;
91 matmul->out_need_aligned_ = false;
92 matmul->matrix_b_pack_fun_ = param->b_transpose_ ? RowMajor2ColMajorParallel : RowMajor2RowMajorParallel;
93 matmul->matrix_b_.need_pack_ = param->b_transpose_;
94 }
95 }
96 if (compute->col_ == 1 && !matmul->a_const_) {
97 matmul->out_need_aligned_ = false;
98 compute->row_tile_ = 1;
99 compute->col_tile_ = 1;
100 matmul->matrix_a_pack_fun_ = param->a_transpose_ ? RowMajor2ColMajorParallel : RowMajor2RowMajorParallel;
101 matmul->matrix_b_pack_fun_ = param->b_transpose_ ? RowMajor2ColMajorParallel : RowMajor2RowMajorParallel;
102 matmul->matrix_a_.need_pack_ = param->a_transpose_ && compute->row_ != 1;
103 matmul->matrix_b_.need_pack_ = false;
104 matmul->pack_opt_ = false;
105 }
106 compute->row_align_ = UP_ROUND(compute->row_, compute->row_tile_);
107 compute->col_align_ = UP_ROUND(compute->col_, compute->col_tile_);
108 NNACL_CHECK_INT_MUL_NOT_OVERFLOW(matmul->a_batch_, compute->row_align_, NNACL_ERR);
109 NNACL_CHECK_INT_MUL_NOT_OVERFLOW(matmul->a_batch_ * compute->row_align_, compute->deep_, NNACL_ERR);
110 NNACL_CHECK_INT_MUL_NOT_OVERFLOW(matmul->a_batch_, compute->col_align_, NNACL_ERR);
111 NNACL_CHECK_INT_MUL_NOT_OVERFLOW(matmul->a_batch_ * compute->col_align_, compute->deep_, NNACL_ERR);
112 int a_pack_size = matmul->a_batch_ * compute->row_align_ * compute->deep_;
113 int b_pack_size = matmul->b_batch_ * compute->col_align_ * compute->deep_;
114 if ((matmul->matrix_a_.has_packed_ && matmul->matrix_a_.pack_size_ != a_pack_size) ||
115 (matmul->matrix_b_.has_packed_ && matmul->matrix_b_.pack_size_ != b_pack_size)) {
116 return NNACL_ERR;
117 }
118 matmul->matrix_a_.pack_size_ = a_pack_size;
119 matmul->matrix_b_.pack_size_ = b_pack_size;
120 compute->row_align_ = UP_ROUND(compute->row_, compute->row_tile_);
121 matmul->out_need_aligned_ = (matmul->out_need_aligned_ && ((compute->col_ % compute->col_tile_) != 0));
122 compute->col_step_ = matmul->out_need_aligned_ ? compute->col_align_ : compute->col_;
123 NNACL_CHECK_FALSE(INT_MUL_OVERFLOW(matmul->a_batch_, compute->row_), NNACL_ERR);
124 compute->row_num_ = matmul->a_batch_ * compute->row_;
125 return NNACL_OK;
126 }
127
MatmulBasePackMatrixAImplOpt(MatmulStruct * matmul)128 int MatmulBasePackMatrixAImplOpt(MatmulStruct *matmul) { return NNACL_ERR; }
129
MatmulBasePackMatrixAImpl(MatmulStruct * matmul)130 int MatmulBasePackMatrixAImpl(MatmulStruct *matmul) {
131 MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_);
132 float *src_ptr = (matmul->matrix_a_.origin_ptr_ != NULL) ? (matmul->matrix_a_.origin_ptr_)
133 : (float *)(matmul->base_.in_[FIRST_INPUT]->data_);
134 NNACL_CHECK_TRUE_RET(src_ptr != NULL, NNACL_ERR);
135 NNACL_CHECK_TRUE_RET(matmul->matrix_a_.pack_ptr_ != NULL, NNACL_ERR);
136 NNACL_CHECK_TRUE_RET(matmul->matrix_a_pack_fun_ != NULL, NNACL_ERR);
137 for (int i = 0; i < matmul->a_batch_; i++) {
138 const float *src = src_ptr + i * matmul->compute_.deep_ * matmul->compute_.row_;
139 float *dst = matmul->matrix_a_.pack_ptr_ + i * matmul->compute_.deep_ * matmul->compute_.row_align_;
140 if (param->a_transpose_) {
141 matmul->matrix_a_pack_fun_(src, dst, matmul->compute_.deep_, matmul->compute_.row_, 0, matmul->compute_.deep_);
142 } else {
143 matmul->matrix_a_pack_fun_(src, dst, matmul->compute_.row_, matmul->compute_.deep_, 0, matmul->compute_.row_);
144 }
145 }
146 return NNACL_OK;
147 }
148
MatmulBasePackMatrixBImpl(MatmulStruct * matmul)149 int MatmulBasePackMatrixBImpl(MatmulStruct *matmul) {
150 MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_);
151
152 float *src_ptr = matmul->matrix_b_.origin_ptr_ != NULL ? matmul->matrix_b_.origin_ptr_
153 : (float *)matmul->base_.in_[SECOND_INPUT]->data_;
154 NNACL_CHECK_TRUE_RET(src_ptr != NULL, NNACL_ERR);
155 NNACL_CHECK_TRUE_RET(matmul->matrix_b_.pack_ptr_ != NULL, NNACL_ERR);
156 NNACL_CHECK_TRUE_RET(matmul->matrix_b_pack_fun_ != NULL, NNACL_ERR);
157
158 for (int i = 0; i < matmul->b_batch_; i++) {
159 if (param->b_transpose_) {
160 matmul->compute_.pack_b_stride_ = UP_DIV(matmul->compute_.col_, matmul->base_.thread_nr_);
161 } else {
162 matmul->compute_.pack_b_stride_ = UP_DIV(matmul->compute_.deep_, matmul->base_.thread_nr_);
163 }
164 matmul->pack_b_src_ = src_ptr + i * matmul->compute_.deep_ * matmul->compute_.col_;
165 matmul->pack_b_dst_ = matmul->matrix_b_.pack_ptr_ + i * matmul->compute_.deep_ * matmul->compute_.col_align_;
166 int ret = matmul->base_.env_->ParallelLaunch(matmul->base_.env_->thread_pool_, MatmulFp32PackMatrixBRun, matmul,
167 matmul->base_.thread_nr_);
168 NNACL_CHECK_FALSE(ret != NNACL_OK, ret);
169 }
170 return NNACL_OK;
171 }
172
MatmulBasePackMatrixA(MatmulStruct * matmul)173 int MatmulBasePackMatrixA(MatmulStruct *matmul) {
174 if (!matmul->a_const_) {
175 if (!matmul->matrix_a_.need_pack_) {
176 matmul->matrix_a_.pack_ptr_ = (float *)matmul->base_.in_[0]->data_;
177 return NNACL_OK;
178 }
179 if (matmul->base_.train_session_) {
180 matmul->matrix_a_.pack_ptr_ = (float *)(matmul->base_.workspace_);
181 } else {
182 matmul->matrix_a_.pack_ptr_ = (float *)(matmul->base_.env_->Alloc(matmul->base_.env_->allocator_,
183 matmul->matrix_a_.pack_size_ * sizeof(float)));
184 }
185 NNACL_MALLOC_CHECK_NULL_RETURN_ERR(matmul->matrix_a_.pack_ptr_);
186 } else {
187 bool is_packed = false;
188 void *data = NULL;
189 size_t data_size = (size_t)(matmul->matrix_a_.pack_size_) * sizeof(float);
190 if (matmul->is_sharing_pack_) {
191 TensorC *a_matrix = matmul->base_.in_[FIRST_INPUT];
192 data = matmul->get_sharing_weight_(matmul->shaing_manager_, a_matrix->data_, data_size, &is_packed);
193 } else {
194 data = malloc(data_size);
195 }
196 matmul->matrix_a_.pack_ptr_ = (float *)data;
197 NNACL_MALLOC_CHECK_NULL_RETURN_ERR(matmul->matrix_a_.pack_ptr_);
198 if (is_packed) {
199 return NNACL_OK;
200 }
201 }
202 if (matmul->pack_opt_) {
203 /* valid in arm64 */
204 return matmul->pack_matrix_a_impl_opt_(matmul);
205 }
206 return matmul->pack_matrix_a_impl_(matmul);
207 }
208
MatmulBasePackMatrixB(MatmulStruct * matmul)209 int MatmulBasePackMatrixB(MatmulStruct *matmul) {
210 if (!matmul->b_const_) {
211 if (!matmul->matrix_b_.need_pack_) {
212 matmul->matrix_b_.pack_ptr_ = (float *)matmul->base_.in_[SECOND_INPUT]->data_;
213 return NNACL_OK;
214 }
215 if (matmul->base_.train_session_) {
216 matmul->matrix_b_.pack_ptr_ = (float *)(matmul->base_.workspace_) + matmul->matrix_a_.pack_size_;
217 } else {
218 matmul->matrix_b_.pack_ptr_ = (float *)(matmul->base_.env_->Alloc(matmul->base_.env_->allocator_,
219 matmul->matrix_b_.pack_size_ * sizeof(float)));
220 }
221 NNACL_MALLOC_CHECK_NULL_RETURN_ERR(matmul->matrix_b_.pack_ptr_);
222 } else {
223 if (!matmul->matrix_b_.need_pack_ && matmul->weight_is_packed_) {
224 matmul->matrix_b_.pack_ptr_ = (float *)matmul->base_.in_[SECOND_INPUT]->data_;
225 return NNACL_OK;
226 }
227 bool is_packed = false;
228 void *data = NULL;
229 size_t data_size = (size_t)(matmul->matrix_b_.pack_size_) * sizeof(float);
230 if (matmul->is_sharing_pack_) {
231 TensorC *b_matrix = matmul->base_.in_[SECOND_INPUT];
232 data = matmul->get_sharing_weight_(matmul->shaing_manager_, b_matrix->data_, data_size, &is_packed);
233 } else {
234 data = malloc(data_size);
235 }
236 NNACL_MALLOC_CHECK_NULL_RETURN_ERR(data);
237 matmul->matrix_b_.pack_ptr_ = (float *)data;
238 if (is_packed) {
239 return NNACL_OK;
240 }
241 }
242 return matmul->pack_matrix_b_impl_(matmul);
243 }
244
MatmulBaseBackupConstMatrix(MatmulStruct * matmul,MatrixInfo * matrix_info,int index)245 int MatmulBaseBackupConstMatrix(MatmulStruct *matmul, MatrixInfo *matrix_info, int index) {
246 NNACL_CHECK_TRUE_RET(index < (int)matmul->base_.in_size_, NNACL_ERR);
247 size_t backup_size = (size_t)GetElementNum(matmul->base_.in_[index]) * sizeof(float);
248 NNACL_CHECK_TRUE_RET(backup_size > 0, NNACL_ERR);
249 matrix_info->origin_ptr_ = (float *)(matmul->base_.env_->Alloc(matmul->base_.env_->allocator_, backup_size));
250 NNACL_MALLOC_CHECK_NULL_RETURN_ERR(matrix_info->origin_ptr_);
251 void *src_ptr = matmul->base_.in_[index]->data_;
252 NNACL_CHECK_NULL_RETURN_ERR(src_ptr);
253 (void)memcpy(matrix_info->origin_ptr_, src_ptr, backup_size);
254 matrix_info->origin_need_free_ = true;
255 return NNACL_OK;
256 }
257
MatmulBaseParallelRunByRow(MatmulStruct * matmul,int task_id)258 int MatmulBaseParallelRunByRow(MatmulStruct *matmul, int task_id) { return NNACL_ERR; }
259
MatmulBaseParallelRunByBatch(MatmulStruct * matmul,int task_id)260 int MatmulBaseParallelRunByBatch(MatmulStruct *matmul, int task_id) {
261 MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_);
262 MatmulComputeParam *compute = &matmul->compute_;
263
264 int start_batch = task_id * compute->batch_stride_;
265 int end_batch = MSMIN(matmul->batch_, start_batch + compute->batch_stride_);
266 int func_flag = 0;
267 if (compute->row_ == 1) {
268 func_flag += (!matmul->b_const_ && compute->col_ <= C128NUM) ? C2NUM : C1NUM;
269 }
270
271 for (int index = start_batch; index < end_batch; ++index) {
272 const float *a = matmul->matrix_a_.pack_ptr_ + matmul->a_offset_[index] * compute->row_align_ * compute->deep_;
273 const float *b = matmul->matrix_b_.pack_ptr_ + matmul->b_offset_[index] * compute->deep_ * compute->col_align_;
274 float *c = matmul->output_data_ + index * compute->row_ * compute->col_step_;
275
276 float *bias = (matmul->matrix_c_.pack_ptr_ == NULL) ? NULL : matmul->matrix_c_.pack_ptr_;
277 if (func_flag == 0) {
278 MatMulOpt(a, b, c, bias, param->act_type_, compute->deep_, compute->row_, compute->col_step_, compute->col_,
279 OutType_Nhwc);
280 } else if (func_flag == C1NUM) {
281 MatVecMulFp32Block8(a, b, c, bias, param->act_type_, compute->deep_, compute->col_step_);
282 } else {
283 MatVecMulNoPackFp32(a, b, c, bias, param->act_type_, compute->deep_, compute->col_step_, compute->col_step_);
284 }
285 }
286 return NNACL_OK;
287 }
288
MatmulBaseParallelRunIsNotPackByBatch(MatmulStruct * matmul,int task_id)289 int MatmulBaseParallelRunIsNotPackByBatch(MatmulStruct *matmul, int task_id) {
290 MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_);
291 int start_batch = task_id * matmul->compute_.batch_stride_;
292 int end_batch = MSMIN(matmul->batch_, start_batch + matmul->compute_.batch_stride_);
293 float bias = 0;
294 if (matmul->matrix_c_.pack_ptr_ != NULL) {
295 bias = matmul->matrix_c_.pack_ptr_[0];
296 }
297 for (int index = start_batch; index < end_batch; ++index) {
298 float *a = matmul->matrix_a_.pack_ptr_ + matmul->a_offset_[index] * matmul->compute_.row_ * matmul->compute_.deep_;
299 float *b = matmul->matrix_b_.pack_ptr_ + matmul->b_offset_[index] * matmul->compute_.deep_ * matmul->compute_.col_;
300 float *c = matmul->output_data_ + index * matmul->compute_.row_ * matmul->compute_.col_;
301 matmul->gemm_not_pack_fun_(a, b, c, &bias, matmul->compute_.row_, matmul->compute_.deep_, param->act_type_);
302 }
303 return NNACL_OK;
304 }
305
MatmulBaseGetThreadCuttingInfoByRow(MatmulStruct * matmul)306 void MatmulBaseGetThreadCuttingInfoByRow(MatmulStruct *matmul) {
307 int row_step = NNACL_MAX(matmul->compute_.row_num_ / matmul->base_.thread_nr_, matmul->compute_.row_min_unit_);
308 int row_remaining = matmul->compute_.row_num_ - row_step * matmul->base_.thread_nr_;
309
310 int split_point = 0;
311 int count = 0;
312 while (split_point < matmul->compute_.row_num_) {
313 matmul->split_points_[count++] = split_point;
314 split_point += row_step;
315 if (row_remaining > 0) {
316 ++split_point;
317 --row_remaining;
318 }
319 }
320 matmul->base_.thread_nr_ = count;
321 }
322
MatmulBaseParallelRunByOC(MatmulStruct * matmul,int task_id)323 int MatmulBaseParallelRunByOC(MatmulStruct *matmul, int task_id) {
324 MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_);
325 NNACL_CHECK_FALSE(task_id < 0 || task_id >= matmul->base_.thread_nr_, NNACL_ERR);
326 MatmulComputeParam *compute = &matmul->compute_;
327 ActType act = param->act_type_;
328
329 int start_oc = matmul->split_points_[task_id];
330 int end_oc = compute->col_step_;
331 if (task_id < (matmul->base_.thread_nr_ - 1)) {
332 end_oc = matmul->split_points_[task_id + 1];
333 }
334 int compute_oc = end_oc - start_oc;
335 if (compute_oc <= 0) {
336 return NNACL_OK;
337 }
338
339 int func_flag = 0;
340 if (compute->row_ == 1) {
341 func_flag += (!matmul->b_const_ && compute->col_ <= C128NUM) ? C2NUM : C1NUM;
342 }
343 int b_stride = func_flag == C2NUM ? start_oc : start_oc * compute->deep_;
344
345 for (int i = 0; i < matmul->batch_; ++i) {
346 float *a = matmul->matrix_a_.pack_ptr_ + matmul->a_offset_[i] * compute->row_align_ * compute->deep_;
347 float *b = matmul->matrix_b_.pack_ptr_ + matmul->b_offset_[i] * compute->deep_ * compute->col_align_ + b_stride;
348 float *c = matmul->output_data_ + i * compute->row_ * compute->col_step_ + start_oc;
349 float *bias = (matmul->matrix_c_.pack_ptr_ == NULL) ? NULL : matmul->matrix_c_.pack_ptr_ + start_oc;
350
351 if (func_flag == 0) {
352 MatMulOpt(a, b, c, bias, act, compute->deep_, compute->row_, compute_oc, compute->col_, OutType_Nhwc);
353 } else if (func_flag == C1NUM) {
354 MatVecMulFp32Block8(a, b, c, bias, act, compute->deep_, compute_oc);
355 } else {
356 MatVecMulNoPackFp32(a, b, c, bias, act, compute->deep_, compute_oc, compute->col_step_);
357 }
358 }
359 return NNACL_OK;
360 }
361
MatmulBaseGetThreadCuttingPolicy(MatmulStruct * matmul)362 void MatmulBaseGetThreadCuttingPolicy(MatmulStruct *matmul) {
363 if (matmul->compute_.deep_ < kNumDeepThreshold) {
364 if (matmul->model_thread_nr_ != -1) {
365 matmul->base_.thread_nr_ = matmul->model_thread_nr_;
366 }
367 }
368
369 if ((matmul->a_batch_ >= matmul->base_.thread_nr_ &&
370 (matmul->b_batch_ == matmul->a_batch_ || !matmul->support_mul_batch_cut_by_row_)) ||
371 matmul->compute_.col_ == 1) {
372 matmul->compute_.batch_stride_ = UP_DIV(matmul->batch_, matmul->base_.thread_nr_);
373 matmul->parallel_run_ = matmul->parallel_run_by_batch_;
374 if (matmul->compute_.col_ != 1 || matmul->a_const_) {
375 return;
376 }
377
378 matmul->parallel_run_ = matmul->parallel_run_not_pack_by_batch_;
379 if (matmul->compute_.deep_ == 1) {
380 matmul->gemm_not_pack_fun_ = GemmIsNotPack;
381 } else {
382 matmul->gemm_not_pack_fun_ = GemmIsNotPackOptimize;
383 if (matmul->check_thread_cutting_by_row_(matmul)) {
384 matmul->parallel_run_ = matmul->parallel_run_by_row_;
385 matmul->get_thread_cutting_info_by_row_(matmul);
386 }
387 }
388 return;
389 } else if ((matmul->a_batch_ >= matmul->base_.thread_nr_ && matmul->b_batch_ == 1) ||
390 matmul->check_thread_cutting_by_row_(matmul)) {
391 matmul->parallel_run_ = matmul->parallel_run_by_row_;
392 matmul->get_thread_cutting_info_by_row_(matmul);
393 } else {
394 int total_col_unit = UP_DIV(matmul->compute_.col_align_, matmul->compute_.col_min_unit_);
395 matmul->base_.thread_nr_ = MSMIN(matmul->base_.thread_nr_, total_col_unit);
396 int block_col_unit = UP_DIV(total_col_unit, matmul->base_.thread_nr_);
397
398 int count = 0;
399 int split_point = 0;
400 while (split_point < total_col_unit) {
401 matmul->split_points_[count++] = (split_point * matmul->compute_.col_min_unit_);
402 split_point += block_col_unit;
403 }
404 matmul->base_.thread_nr_ = count;
405 matmul->parallel_run_ = matmul->parallel_run_by_oc_;
406 }
407 return;
408 }
409
MatmulBasePackBiasMatrix(MatmulStruct * matmul)410 int MatmulBasePackBiasMatrix(MatmulStruct *matmul) {
411 if (matmul->base_.in_size_ != FOURTH_INPUT) {
412 return NNACL_OK;
413 }
414 if (matmul->matrix_c_.has_packed_) {
415 NNACL_CHECK_FALSE(matmul->matrix_c_.pack_size_ < matmul->compute_.col_align_, NNACL_ERR);
416 return NNACL_OK;
417 }
418 TensorC *bias_tensor = matmul->base_.in_[THIRD_INPUT];
419 float *bias_src = matmul->matrix_c_.origin_ptr_ != NULL ? matmul->matrix_c_.origin_ptr_ : (float *)bias_tensor->data_;
420 NNACL_CHECK_NULL_RETURN_ERR(bias_src);
421
422 int bias_num = GetElementNum(bias_tensor);
423 NNACL_CHECK_TRUE_RET(bias_num > 0 && matmul->compute_.col_align_ >= bias_num, NNACL_ERR);
424
425 matmul->matrix_c_.pack_size_ = matmul->compute_.col_align_;
426 if (matmul->matrix_c_.pack_ptr_ == NULL) {
427 matmul->matrix_c_.pack_ptr_ = (float *)(malloc(matmul->matrix_c_.pack_size_ * sizeof(float)));
428 }
429 NNACL_CHECK_NULL_RETURN_ERR(matmul->matrix_c_.pack_ptr_);
430
431 if (bias_num == 1) {
432 for (int i = 0; i < matmul->matrix_c_.pack_size_; ++i) {
433 matmul->matrix_c_.pack_ptr_[i] = bias_src[0];
434 }
435 } else {
436 (void)memcpy(matmul->matrix_c_.pack_ptr_, bias_src, bias_num * sizeof(float));
437 (void)memset(matmul->matrix_c_.pack_ptr_ + bias_num, 0, (matmul->matrix_c_.pack_size_ - bias_num) * sizeof(float));
438 }
439 if (matmul->matrix_c_.origin_need_free_) {
440 matmul->base_.env_->Free(matmul->base_.env_->allocator_, matmul->matrix_c_.origin_ptr_);
441 matmul->matrix_c_.origin_ptr_ = NULL;
442 matmul->matrix_c_.origin_need_free_ = false;
443 }
444 return NNACL_OK;
445 }
446
MatmulBaseInitTmpOutBuffer(MatmulStruct * matmul)447 int MatmulBaseInitTmpOutBuffer(MatmulStruct *matmul) {
448 if (matmul->out_need_aligned_) {
449 if (matmul->output_data_ != NULL) {
450 matmul->base_.env_->Free(matmul->base_.env_->allocator_, matmul->output_data_);
451 }
452 // avx need to malloc dst aligned to C8NUM
453 // avx512 need to malloc dst aligned to C16NUM
454 int out_channel = matmul->compute_.col_;
455 NNACL_CHECK_ZERO_RETURN_ERR(matmul->compute_.col_tile_);
456 int oc_block_num = UP_DIV(out_channel, matmul->compute_.col_tile_);
457 int ele_num = matmul->batch_ * matmul->compute_.row_ * oc_block_num * matmul->compute_.col_tile_;
458 int data_size = ele_num * (int)sizeof(float);
459 matmul->output_data_ = (float *)(matmul->base_.env_->Alloc(matmul->base_.env_->allocator_, data_size));
460 NNACL_CHECK_NULL_RETURN_ERR(matmul->output_data_);
461 }
462 return NNACL_OK;
463 }
464
MatmulBaseInitGlobalVariable(MatmulStruct * matmul)465 void MatmulBaseInitGlobalVariable(MatmulStruct *matmul) {
466 MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_);
467 matmul->matrix_a_.need_pack_ = true;
468 matmul->matrix_b_.need_pack_ = !matmul->weight_is_packed_;
469 matmul->matrix_a_pack_fun_ = param->a_transpose_ ? RowMajor2Row12MajorParallel : RowMajor2Col12MajorParallel;
470 matmul->matrix_b_pack_fun_ = param->b_transpose_ ? RowMajor2Col8MajorParallel : RowMajor2Row8MajorParallel;
471 matmul->compute_.row_tile_ = C12NUM;
472 matmul->compute_.col_tile_ = C8NUM;
473 matmul->compute_.col_min_unit_ = C8NUM;
474 return;
475 }
476
MatmulBaseCheckThreadCuttingByRow()477 bool MatmulBaseCheckThreadCuttingByRow() { return false; }
478
MatmulBaseFreePackedMatrixA(KernelBase * self)479 void MatmulBaseFreePackedMatrixA(KernelBase *self) {
480 MatmulStruct *matmul = (MatmulStruct *)self;
481 if (matmul->matrix_a_.need_pack_ && !matmul->base_.train_session_ && matmul->matrix_a_.pack_ptr_ != NULL) {
482 self->env_->Free(self->env_->allocator_, matmul->matrix_a_.pack_ptr_);
483 }
484 matmul->matrix_a_.pack_ptr_ = NULL;
485 }
486
MatmulBaseFreePackedMatrixB(KernelBase * self)487 void MatmulBaseFreePackedMatrixB(KernelBase *self) {
488 MatmulStruct *matmul = (MatmulStruct *)self;
489 if (matmul->matrix_b_.need_pack_ && !matmul->base_.train_session_ && matmul->matrix_b_.pack_ptr_ != NULL) {
490 matmul->base_.env_->Free(matmul->base_.env_->allocator_, matmul->matrix_b_.pack_ptr_);
491 }
492 matmul->matrix_b_.pack_ptr_ = NULL;
493 }
494
MatmulBaseResize(KernelBase * self)495 int MatmulBaseResize(KernelBase *self) {
496 MatmulStruct *matmul = (MatmulStruct *)self;
497
498 int ret = matmul->init_parameter_(matmul);
499 NNACL_CHECK_FALSE(ret != NNACL_OK, ret);
500 if (self->train_session_) {
501 self->work_size_ = (matmul->matrix_a_.pack_size_ + matmul->matrix_b_.pack_size_) * (int)sizeof(float);
502 }
503
504 matmul->get_thread_cutting_policy_(matmul);
505 if (!matmul->matrix_c_.has_packed_) {
506 ret = MatmulBasePackBiasMatrix(matmul);
507 NNACL_CHECK_FALSE(ret != NNACL_OK, ret);
508 if (!matmul->bias_need_repack_) {
509 matmul->matrix_c_.has_packed_ = true;
510 }
511 }
512 ret = MatmulBaseInitTmpOutBuffer(matmul);
513 NNACL_CHECK_FALSE(ret != NNACL_OK, ret);
514
515 return NNACL_OK;
516 }
517
MatmulBaseRelease(struct KernelBase * self)518 int MatmulBaseRelease(struct KernelBase *self) {
519 MatmulStruct *matmul = (MatmulStruct *)self;
520 MatmulBaseFreeBatchOffset(matmul);
521
522 if (matmul->out_need_aligned_ && matmul->output_data_ != NULL) {
523 matmul->base_.env_->Free(matmul->base_.env_->allocator_, matmul->output_data_);
524 matmul->output_data_ = NULL;
525 }
526 if (matmul->matrix_c_.pack_ptr_ != NULL) {
527 free(matmul->matrix_c_.pack_ptr_);
528 matmul->matrix_c_.pack_ptr_ = NULL;
529 }
530 if (matmul->a_const_) {
531 if (matmul->is_sharing_pack_) {
532 matmul->free_sharing_weight_(matmul->shaing_manager_, matmul->matrix_a_.pack_ptr_);
533 } else {
534 free(matmul->matrix_a_.pack_ptr_);
535 }
536 }
537 if (matmul->b_const_) {
538 if (!matmul->matrix_b_.need_pack_ && matmul->weight_is_packed_) {
539 return NNACL_OK;
540 }
541 if (matmul->is_sharing_pack_) {
542 matmul->free_sharing_weight_(matmul->shaing_manager_, matmul->matrix_b_.pack_ptr_);
543 } else {
544 free(matmul->matrix_b_.pack_ptr_);
545 }
546 }
547 return NNACL_OK;
548 }
549
MatmulBasePrepare(struct KernelBase * self)550 int MatmulBasePrepare(struct KernelBase *self) {
551 MatmulStruct *matmul = (MatmulStruct *)self;
552
553 NNACL_CHECK_FALSE(matmul->base_.in_size_ < C2NUM, NNACL_INPUT_TENSOR_ERROR);
554 NNACL_CHECK_FALSE(matmul->base_.out_size_ < 1, NNACL_OUTPUT_TENSOR_ERROR);
555 NNACL_CHECK_FALSE(matmul->base_.in_[FIRST_INPUT]->data_type_ != kNumberTypeFloat32, NNACL_INPUT_TENSOR_ERROR);
556 NNACL_CHECK_FALSE(matmul->base_.in_[SECOND_INPUT]->data_type_ != kNumberTypeFloat32, NNACL_INPUT_TENSOR_ERROR);
557
558 if (matmul->base_.in_size_ == THREE_TENSOR) {
559 NNACL_CHECK_TRUE_RET(matmul->base_.in_[THIRD_INPUT]->data_type_ == kNumberTypeFloat32, NNACL_MATMUL_BIAS_INVALID);
560 }
561
562 MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_);
563 NNACL_CHECK_FALSE(
564 param->act_type_ != ActType_No && param->act_type_ != ActType_Relu && param->act_type_ != ActType_Relu6,
565 NNACL_MATMUL_ACT_TYPE_INVALID);
566
567 int ret = matmul->init_parameter_(matmul);
568 NNACL_CHECK_FALSE(ret != NNACL_OK, ret);
569
570 if (matmul->a_const_) {
571 ret = MatmulBasePackMatrixA(matmul);
572 NNACL_CHECK_FALSE(ret != NNACL_OK, ret);
573 matmul->matrix_a_.has_packed_ = true;
574 }
575 if (matmul->b_const_) {
576 ret = MatmulBasePackMatrixB(matmul);
577 NNACL_CHECK_FALSE(ret != NNACL_OK, ret);
578 matmul->matrix_b_.has_packed_ = true;
579 }
580
581 if (matmul->base_.in_size_ == THREE_TENSOR) {
582 /* deal with const bias */
583 bool bias_const = IsConst(self->in_[THIRD_INPUT]);
584 if (!matmul->infer_shape_ && bias_const && !matmul->base_.train_session_ && matmul->matrix_c_.origin_ptr_ == NULL) {
585 ret = MatmulBaseBackupConstMatrix(matmul, &matmul->matrix_c_, THIRD_INPUT);
586 NNACL_CHECK_FALSE(ret != NNACL_OK, ret);
587 }
588 }
589 return NNACL_OK;
590 }
591
MatmulBaseCompute(struct KernelBase * self)592 int MatmulBaseCompute(struct KernelBase *self) {
593 MatmulStruct *matmul = (MatmulStruct *)self;
594
595 float *out_data = (float *)(matmul->base_.out_[FIRST_INPUT]->data_);
596 NNACL_CHECK_FALSE(out_data == NULL, NNACL_ERR);
597 if (!matmul->out_need_aligned_) {
598 matmul->output_data_ = out_data;
599 }
600
601 if (!matmul->a_const_) {
602 int ret = MatmulBasePackMatrixA(matmul);
603 NNACL_CHECK_FALSE(ret != NNACL_OK, ret);
604 }
605 if (!matmul->b_const_) {
606 int ret = MatmulBasePackMatrixB(matmul);
607 NNACL_CHECK_FALSE(ret != NNACL_OK, ret);
608 }
609
610 NNACL_CHECK_NULL_RETURN_ERR(matmul->matrix_a_.pack_ptr_);
611 NNACL_CHECK_NULL_RETURN_ERR(matmul->matrix_b_.pack_ptr_);
612
613 int ret = self->env_->ParallelLaunch(self->env_->thread_pool_, MatmulFp32Run, self, self->thread_nr_);
614 NNACL_CHECK_FALSE(ret != NNACL_OK, ret);
615
616 if (matmul->out_need_aligned_) {
617 PackNHWCXToNHWCFp32(matmul->output_data_, out_data, matmul->batch_, matmul->compute_.row_, matmul->compute_.col_,
618 matmul->compute_.col_tile_);
619 } else {
620 matmul->output_data_ = NULL;
621 }
622 if (!matmul->a_const_) {
623 MatmulBaseFreePackedMatrixA(self);
624 }
625
626 if (!matmul->b_const_) {
627 MatmulBaseFreePackedMatrixB(self);
628 }
629 return NNACL_OK;
630 }
631
InitMatrixInfo(MatrixInfo * info)632 void InitMatrixInfo(MatrixInfo *info) {
633 info->need_pack_ = false;
634 info->has_packed_ = false;
635 info->origin_need_free_ = false;
636 info->pack_size_ = -1;
637 info->origin_ptr_ = NULL;
638 info->pack_ptr_ = NULL;
639 }
640
CreateMatmulBase()641 KernelBase *CreateMatmulBase() {
642 NNACL_LOG_INFO("Attempting to allocate memory for MatmulStruct, size: %zu bytes", sizeof(MatmulStruct));
643 MatmulStruct *matmul = (MatmulStruct *)malloc(sizeof(MatmulStruct));
644 NNACL_MALLOC_CHECK_NULL_RETURN_NULL(matmul);
645 memset(matmul, 0, sizeof(MatmulStruct));
646 matmul->base_.Prepare = MatmulBasePrepare;
647 matmul->base_.Resize = MatmulBaseResize;
648 matmul->base_.Release = MatmulBaseRelease;
649 matmul->base_.Compute = MatmulBaseCompute;
650 InitMatrixInfo(&(matmul->matrix_a_));
651 InitMatrixInfo(&(matmul->matrix_b_));
652 InitMatrixInfo(&(matmul->matrix_c_));
653 matmul->is_sharing_pack_ = false;
654 matmul->pack_opt_ = false;
655 matmul->a_const_ = false;
656 matmul->b_const_ = false;
657 matmul->bias_need_repack_ = false;
658 matmul->out_need_aligned_ = false;
659 matmul->a_offset_ = NULL;
660 matmul->b_offset_ = NULL;
661 matmul->model_thread_nr_ = -1;
662 matmul->support_mul_batch_cut_by_row_ = false;
663 matmul->matmul_type_ = kMatmulFp32BaseCpu;
664 matmul->get_thread_cutting_policy_ = MatmulBaseGetThreadCuttingPolicy;
665 matmul->check_thread_cutting_by_row_ = MatmulBaseCheckThreadCuttingByRow;
666 matmul->get_thread_cutting_info_by_row_ = MatmulBaseGetThreadCuttingInfoByRow;
667 matmul->init_parameter_ = MatmulBaseInitParameter;
668 matmul->init_global_varibale_ = MatmulBaseInitGlobalVariable;
669 matmul->pack_matrix_a_impl_opt_ = MatmulBasePackMatrixAImplOpt;
670 matmul->pack_matrix_a_impl_ = MatmulBasePackMatrixAImpl;
671 matmul->pack_matrix_b_impl_ = MatmulBasePackMatrixBImpl;
672 matmul->parallel_run_by_batch_ = MatmulBaseParallelRunByBatch;
673 matmul->parallel_run_not_pack_by_batch_ = MatmulBaseParallelRunIsNotPackByBatch;
674 matmul->parallel_run_by_oc_ = MatmulBaseParallelRunByOC;
675 matmul->parallel_run_by_row_ = MatmulBaseParallelRunByRow;
676 return (KernelBase *)matmul;
677 }
678