1 /**
2 * Copyright 2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #ifdef ENABLE_AVX512
18 #include "nnacl/kernel/matmul_avx512.h"
19 #include "nnacl/kernel/matmul_base.h"
20 #include "nnacl/fp32/pack_fp32.h"
21 #include "nnacl/fp32/matmul_fp32.h"
22 #include "nnacl/fp32/matmul_avx512_fp32.h"
23 #include "nnacl/fp32/matmul_avx512_mask_fp32.h"
24
25 #define MIN_CALC_COST 24576 /* 1 x 6 x 64x 64 */
26
MatmulAVX512BatchRowThreadCut(MatmulStruct * matmul)27 void MatmulAVX512BatchRowThreadCut(MatmulStruct *matmul) {
28 // BatchCut
29 matmul->compute_.batch_stride_ = DOWN_DIV(matmul->batch_, matmul->base_.thread_nr_);
30
31 // RowCut
32 int row_step = MSMAX(matmul->compute_.row_ / matmul->base_.thread_nr_, matmul->compute_.row_min_unit_);
33 int row_remaining = matmul->compute_.row_ - row_step * matmul->base_.thread_nr_;
34
35 matmul->row_split_points_size_ = 0;
36 int row_split_point = 0;
37 while (row_split_point < matmul->compute_.row_) {
38 matmul->row_split_points_[matmul->row_split_points_size_++] = row_split_point;
39 row_split_point += row_step;
40 if (row_remaining > 0) {
41 ++row_split_point;
42 --row_remaining;
43 }
44 }
45 matmul->row_split_points_[matmul->row_split_points_size_] = matmul->compute_.row_;
46 if (matmul->compute_.batch_stride_ == 0) {
47 matmul->base_.thread_nr_ = matmul->row_split_points_size_;
48 }
49 }
50
MatmulAVX512BatchColThreadCut(MatmulStruct * matmul)51 void MatmulAVX512BatchColThreadCut(MatmulStruct *matmul) {
52 // BatchCut
53 matmul->compute_.batch_stride_ = DOWN_DIV(matmul->batch_, matmul->base_.thread_nr_);
54
55 // ColCut
56 int total_col_unit = UP_DIV(matmul->compute_.col_align_, matmul->compute_.col_min_unit_);
57 int thread_num_tmp = NNACL_MIN(matmul->base_.thread_nr_, total_col_unit);
58 int block_col_unit = UP_DIV(total_col_unit, thread_num_tmp);
59 int split_point = 0;
60 matmul->col_split_points_size_ = 0;
61 while (split_point < total_col_unit) {
62 matmul->col_split_points_[matmul->col_split_points_size_++] = split_point * matmul->compute_.col_min_unit_;
63 split_point += block_col_unit;
64 }
65 if (matmul->compute_.batch_stride_ == 0) {
66 matmul->base_.thread_nr_ = matmul->col_split_points_size_;
67 }
68 }
69
MatmulAVX512BatchColRowSliceThreadCut(MatmulStruct * matmul)70 void MatmulAVX512BatchColRowSliceThreadCut(MatmulStruct *matmul) {
71 // BatchCut
72 matmul->compute_.batch_stride_ = DOWN_DIV(matmul->batch_, matmul->base_.thread_nr_);
73
74 int row_s = 0;
75 int row_e = matmul->compute_.row_;
76 int col_s = 0;
77 int col_e = matmul->compute_.col_;
78
79 // ColCut
80 int total_col_unit = UP_DIV(matmul->compute_.col_align_, matmul->compute_.col_min_unit_);
81 matmul->compute_.block_col_unit_ = DOWN_DIV(total_col_unit, matmul->base_.thread_nr_);
82 matmul->col_split_points_size_ = 0;
83 matmul->col_split_points_[matmul->col_split_points_size_++] = 0;
84 if (matmul->compute_.block_col_unit_ > 0) {
85 int col_split_point = 0;
86 for (int i = 0; i < matmul->base_.thread_nr_; i++) {
87 MatmulSlice matmul_slice;
88 matmul_slice.row_s_ = row_s;
89 matmul_slice.row_e_ = row_e;
90 matmul_slice.col_s_ = col_split_point * matmul->compute_.col_min_unit_;
91 col_split_point += matmul->compute_.block_col_unit_;
92 col_s = NNACL_MIN(col_split_point * matmul->compute_.col_min_unit_, matmul->compute_.col_step_);
93 matmul_slice.col_e_ = col_s;
94 matmul->matmul_slice_set_[i][matmul->matmul_slice_count_[i]++] = matmul_slice;
95 }
96 }
97 if (col_e - col_s <= 0) {
98 return;
99 }
100
101 // RowColCut
102 int row_thread = 0;
103 int less_col_align = UP_ROUND(col_e - col_s, C16NUM);
104 bool use_colrowcut_flag = ((less_col_align / C64NUM) * C64NUM) == less_col_align;
105 bool use_rowcut_flag = matmul->compute_.row_ >= C6NUM * matmul->base_.thread_nr_ || col_e - col_s <= C64NUM;
106 if (use_rowcut_flag && !use_colrowcut_flag) {
107 int row_step = MSMAX(matmul->compute_.row_ / matmul->base_.thread_nr_, matmul->compute_.row_min_unit_);
108 int row_remaining = matmul->compute_.row_ - row_step * matmul->base_.thread_nr_;
109 int row_split_point = 0;
110
111 for (row_thread = 0; row_thread < matmul->base_.thread_nr_ && row_split_point < matmul->compute_.row_;
112 row_thread++) {
113 MatmulSlice matmul_slice;
114 matmul_slice.row_s_ = row_split_point;
115
116 row_split_point += row_step;
117 if (row_remaining > 0) {
118 ++row_split_point;
119 --row_remaining;
120 }
121
122 matmul_slice.row_e_ = row_split_point;
123 matmul_slice.col_s_ = col_s;
124 matmul_slice.col_e_ = col_e;
125 matmul->matmul_slice_set_[row_thread][matmul->matmul_slice_count_[row_thread]++] = matmul_slice;
126 }
127 } else {
128 int col_num = UP_DIV(col_e - col_s, C64NUM);
129 int row_num = NNACL_MIN(UP_DIV(matmul->base_.thread_nr_, col_num), (row_e - row_s));
130 int tile_remaining = MSMAX(col_num * row_num - matmul->base_.thread_nr_, 0);
131
132 NNACL_CHECK_ZERO_RETURN(row_num);
133 int row_step = (row_e - row_s) / row_num;
134 int row_remaining_tmp = (row_e - row_s) - row_step * row_num;
135
136 int row_step_cut2 = (row_num == 1) ? row_step : (row_e - row_s) / (row_num - 1);
137 int row_remaining_cut2_tmp = (row_e - row_s) - row_step_cut2 * (row_num - 1);
138
139 MatmulSlice matmul_slice;
140 for (int c = 0; c < col_num; c++) {
141 matmul_slice.col_s_ = col_s + c * C64NUM;
142 matmul_slice.col_e_ = NNACL_MIN(col_s + (c + 1) * C64NUM, matmul->compute_.col_);
143 int row_split_point = 0;
144 int row_remaining = row_remaining_tmp;
145 int row_remaining_cut2 = row_remaining_cut2_tmp;
146 if (c < col_num - tile_remaining) {
147 for (int r = 0; r < row_num; r++) {
148 matmul_slice.row_s_ = row_split_point;
149 row_split_point += row_step;
150 if (row_remaining > 0) {
151 ++row_split_point;
152 --row_remaining;
153 }
154 matmul_slice.row_e_ = NNACL_MIN(row_split_point, matmul->compute_.row_);
155 matmul->matmul_slice_set_[row_thread][matmul->matmul_slice_count_[row_thread]++] = matmul_slice;
156 row_thread++;
157 }
158 } else {
159 for (int r = 0; r < row_num - 1; r++) {
160 matmul_slice.row_s_ = row_split_point;
161 row_split_point += row_step_cut2;
162 if (row_remaining_cut2 > 0) {
163 ++row_split_point;
164 --row_remaining_cut2;
165 }
166 matmul_slice.row_e_ = NNACL_MIN(row_split_point, matmul->compute_.row_);
167 matmul->matmul_slice_set_[row_thread][matmul->matmul_slice_count_[row_thread]++] = matmul_slice;
168 row_thread++;
169 }
170 }
171 }
172 }
173 if ((matmul->compute_.batch_stride_ == 0) && (matmul->compute_.block_col_unit_ == 0)) {
174 matmul->base_.thread_nr_ = row_thread;
175 }
176 }
177
MatmulAVX512GetThreadCuttingPolicy(MatmulStruct * matmul)178 void MatmulAVX512GetThreadCuttingPolicy(MatmulStruct *matmul) {
179 size_t total_cost = (size_t)(matmul->batch_) * (size_t)(matmul->compute_.row_) * (size_t)(matmul->compute_.col_) *
180 (size_t)(matmul->compute_.deep_);
181
182 // Thread Update
183 matmul->base_.thread_nr_ = MSMAX(NNACL_MIN((int)(total_cost / MIN_CALC_COST), matmul->base_.thread_nr_), C1NUM);
184
185 if (matmul->compute_.deep_ < C128NUM) {
186 return MatmulBaseGetThreadCuttingPolicy(matmul);
187 }
188
189 for (int i = 0; i < SPLIT_COUNT; i++) {
190 matmul->matmul_slice_count_[i] = 0;
191 }
192 if (matmul->compute_.col_ == 1 && !matmul->a_const_) {
193 MatmulAVX512BatchRowThreadCut(matmul);
194 if (matmul->compute_.deep_ == 1) {
195 matmul->gemm_not_pack_fun_ = GemmIsNotPack;
196 } else {
197 matmul->gemm_not_pack_fun_ = GemmIsNotPackOptimize;
198 }
199 matmul->parallel_run_ = matmul->parallel_run_by_gepdot_;
200 } else if (matmul->compute_.row_ == 1 && !matmul->b_const_ && matmul->compute_.col_ <= C128NUM) {
201 MatmulAVX512BatchColThreadCut(matmul);
202 if (matmul->compute_.deep_ == 1) {
203 matmul->parallel_run_ = matmul->parallel_run_by_row1_deep1_gepdot_;
204 if (matmul->matrix_c_.pack_ptr_ != NULL) {
205 matmul->gemm_not_pack_fun_ = Row1Deep1GemmIsNotPack;
206 } else {
207 matmul->gemm_not_pack_fun_ = Row1Deep1NoBiasGemmIsNotPack;
208 }
209 return;
210 }
211 matmul->parallel_run_ = matmul->parallel_run_by_gepm_;
212 } else {
213 MatmulAVX512BatchColRowSliceThreadCut(matmul);
214 matmul->parallel_run_ = matmul->parallel_run_by_batch_col_row_gemm_;
215 }
216 return;
217 }
218
MatmulAVX512CheckThreadCuttingByRow(MatmulStruct * matmul)219 bool MatmulAVX512CheckThreadCuttingByRow(MatmulStruct *matmul) {
220 if (matmul->b_batch_ != C1NUM) {
221 return false;
222 }
223 if (matmul->compute_.row_num_ < matmul->base_.thread_nr_) {
224 return false;
225 }
226 if (matmul->compute_.col_ == 1) {
227 matmul->compute_.row_min_unit_ = C8NUM;
228 return true;
229 }
230 if (matmul->compute_.row_ == 1 && !matmul->b_const_ && matmul->compute_.col_ <= C128NUM) {
231 return false;
232 }
233 matmul->compute_.row_min_unit_ = C6NUM;
234 if (matmul->compute_.col_step_ < C48NUM) {
235 matmul->compute_.row_min_unit_ = C12NUM;
236 } else if (matmul->compute_.col_step_ < C64NUM) {
237 matmul->compute_.row_min_unit_ = C8NUM;
238 }
239 return NNACL_MIN(matmul->compute_.row_num_ / matmul->compute_.row_min_unit_, matmul->base_.thread_nr_) >
240 NNACL_MIN(matmul->compute_.col_step_ / matmul->compute_.col_min_unit_, matmul->base_.thread_nr_);
241 }
MatmulAVX512InitGlobalVariable(MatmulStruct * matmul)242 void MatmulAVX512InitGlobalVariable(MatmulStruct *matmul) {
243 MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_);
244 matmul->matrix_a_pack_fun_ = param->a_transpose_ ? RowMajor2ColMajorParallel : RowMajor2RowMajorParallel;
245 matmul->matrix_b_pack_fun_ = param->b_transpose_ ? RowMajor2Col64MajorParallel : RowMajor2Row64MajorParallel;
246 matmul->matrix_a_.need_pack_ = param->a_transpose_;
247 matmul->matrix_b_.need_pack_ = true;
248 matmul->compute_.row_tile_ = C1NUM;
249 matmul->compute_.col_tile_ = C16NUM;
250 matmul->compute_.col_min_unit_ = C64NUM;
251
252 if (matmul->compute_.row_ == 1) {
253 if (!matmul->b_const_ && matmul->compute_.col_ <= C128NUM) {
254 matmul->out_need_aligned_ = true;
255 }
256 } else if (matmul->compute_.col_ == 1) {
257 matmul->out_need_aligned_ = true;
258 } else {
259 matmul->out_need_aligned_ = false;
260 }
261
262 if (matmul->compute_.deep_ >= C128NUM) {
263 matmul->out_need_aligned_ = false;
264 }
265 }
MatmulAVX512InitParameter(MatmulStruct * matmul)266 int MatmulAVX512InitParameter(MatmulStruct *matmul) {
267 MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_);
268 MatmulComputeParam *compute = &matmul->compute_;
269
270 if (compute->deep_ < C128NUM) {
271 return MatmulBaseInitParameter(matmul);
272 }
273
274 matmul->init_global_varibale_(matmul);
275 if (compute->col_ == 1 && !matmul->a_const_) {
276 matmul->out_need_aligned_ = false;
277 compute->row_tile_ = 1;
278 compute->col_tile_ = 1;
279 matmul->matrix_a_pack_fun_ = param->a_transpose_ ? RowMajor2ColMajorParallel : RowMajor2RowMajorParallel;
280 matmul->matrix_b_pack_fun_ = param->b_transpose_ ? RowMajor2ColMajorParallel : RowMajor2RowMajorParallel;
281 matmul->matrix_a_.need_pack_ = param->a_transpose_ && compute->row_ != 1;
282 matmul->matrix_b_.need_pack_ = false;
283 matmul->pack_opt_ = false;
284 } else if (compute->row_ == 1 && !matmul->b_const_ && compute->col_ <= C128NUM) {
285 matmul->out_need_aligned_ = false;
286 compute->row_tile_ = 1;
287 compute->col_tile_ = 1;
288 matmul->matrix_a_pack_fun_ = param->a_transpose_ ? RowMajor2ColMajorParallel : RowMajor2RowMajorParallel;
289 matmul->matrix_b_pack_fun_ = param->b_transpose_ ? RowMajor2ColMajorParallel : RowMajor2RowMajorParallel;
290 matmul->matrix_a_.need_pack_ = false;
291 matmul->matrix_b_.need_pack_ = param->b_transpose_;
292 matmul->pack_opt_ = false;
293 }
294 compute->row_align_ = UP_ROUND(compute->row_, compute->row_tile_);
295 compute->col_align_ = UP_ROUND(compute->col_, compute->col_tile_);
296 NNACL_CHECK_INT_MUL_NOT_OVERFLOW(matmul->a_batch_, compute->row_align_, NNACL_ERR);
297 NNACL_CHECK_INT_MUL_NOT_OVERFLOW(matmul->a_batch_ * compute->row_align_, compute->deep_, NNACL_ERR);
298 NNACL_CHECK_INT_MUL_NOT_OVERFLOW(matmul->a_batch_, compute->col_align_, NNACL_ERR);
299 NNACL_CHECK_INT_MUL_NOT_OVERFLOW(matmul->a_batch_ * compute->col_align_, compute->deep_, NNACL_ERR);
300 int a_pack_size = matmul->a_batch_ * compute->row_align_ * compute->deep_;
301 int b_pack_size = matmul->b_batch_ * compute->col_align_ * compute->deep_;
302 if ((matmul->matrix_a_.has_packed_ && matmul->matrix_a_.pack_size_ != a_pack_size) ||
303 (matmul->matrix_b_.has_packed_ && matmul->matrix_b_.pack_size_ != b_pack_size)) {
304 return NNACL_ERR;
305 }
306 matmul->matrix_a_.pack_size_ = a_pack_size;
307 matmul->matrix_b_.pack_size_ = b_pack_size;
308 compute->row_align_ = UP_ROUND(compute->row_, compute->row_tile_);
309 matmul->out_need_aligned_ = (matmul->out_need_aligned_ && ((compute->col_ % compute->col_tile_) != 0));
310 compute->col_step_ = matmul->out_need_aligned_ ? compute->col_align_ : compute->col_;
311 NNACL_CHECK_FALSE(INT_MUL_OVERFLOW(matmul->a_batch_, compute->row_), NNACL_ERR);
312 compute->row_num_ = matmul->a_batch_ * compute->row_;
313 return NNACL_OK;
314 }
315
MatmulAVX512ParallelRunByRow(MatmulStruct * matmul,int task_id)316 int MatmulAVX512ParallelRunByRow(MatmulStruct *matmul, int task_id) {
317 MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_);
318 NNACL_CHECK_FALSE(task_id < 0 || task_id >= matmul->base_.thread_nr_, NNACL_ERR);
319
320 int start_row = matmul->split_points_[task_id];
321 int end_row = matmul->compute_.row_num_;
322 if (task_id < (matmul->base_.thread_nr_ - 1)) {
323 end_row = matmul->split_points_[task_id + 1];
324 }
325 int row_num = end_row - start_row;
326 if (row_num <= 0) {
327 return NNACL_OK;
328 }
329 const float *input = matmul->matrix_a_.pack_ptr_ + start_row * matmul->compute_.deep_;
330 float *output = matmul->output_data_ + start_row * matmul->compute_.col_step_;
331 if (matmul->compute_.col_ == 1) {
332 float bias = 0;
333 if (matmul->matrix_c_.pack_ptr_ != NULL) {
334 bias = matmul->matrix_c_.pack_ptr_[0];
335 }
336 matmul->gemm_not_pack_fun_(input, matmul->matrix_b_.pack_ptr_, output, &bias, row_num, matmul->compute_.deep_,
337 param->act_type_);
338 } else {
339 if (matmul->out_need_aligned_) {
340 MatMulAvx512Fp32(input, matmul->matrix_b_.pack_ptr_, output, matmul->matrix_c_.pack_ptr_, param->act_type_,
341 matmul->compute_.deep_, matmul->compute_.col_align_, matmul->compute_.col_align_, row_num);
342 } else {
343 MatMulMaskAvx512Fp32(input, matmul->matrix_b_.pack_ptr_, output, matmul->matrix_c_.pack_ptr_, param->act_type_,
344 matmul->compute_.deep_, matmul->compute_.col_, matmul->compute_.col_, row_num);
345 }
346 }
347 return NNACL_OK;
348 }
349
MatmulAVX512ParallelRunByOC(MatmulStruct * matmul,int task_id)350 int MatmulAVX512ParallelRunByOC(MatmulStruct *matmul, int task_id) {
351 MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_);
352 NNACL_CHECK_FALSE(task_id < 0 || task_id >= matmul->base_.thread_nr_, NNACL_ERR);
353 MatmulComputeParam *compute = &matmul->compute_;
354 ActType act = param->act_type_;
355
356 int start_oc = matmul->split_points_[task_id];
357 int end_oc = compute->col_step_;
358 if (task_id < (matmul->base_.thread_nr_ - 1)) {
359 end_oc = matmul->split_points_[task_id + 1];
360 }
361 int compute_oc = end_oc - start_oc;
362 if (compute_oc <= 0) {
363 return NNACL_OK;
364 }
365 int func_flag = 0;
366 if (matmul->compute_.row_ == 1) {
367 func_flag += (!matmul->b_const_ && compute->col_ <= C128NUM) ? C2NUM : C1NUM;
368 }
369 int b_stride = func_flag == C2NUM ? start_oc : start_oc * compute->deep_;
370 for (int i = 0; i < matmul->batch_; ++i) {
371 float *a = matmul->matrix_a_.pack_ptr_ + matmul->a_offset_[i] * compute->row_align_ * compute->deep_;
372 float *b = matmul->matrix_b_.pack_ptr_ + matmul->b_offset_[i] * compute->deep_ * compute->col_align_ + b_stride;
373 float *c = matmul->output_data_ + i * compute->row_ * compute->col_step_ + start_oc;
374 float *bias = (matmul->matrix_c_.pack_ptr_ == NULL) ? NULL : matmul->matrix_c_.pack_ptr_ + start_oc;
375
376 if (func_flag == 0) {
377 if (matmul->out_need_aligned_) {
378 MatMulAvx512Fp32(a, b, c, bias, act, compute->deep_, compute_oc, compute->col_align_, compute->row_);
379 } else {
380 MatMulMaskAvx512Fp32(a, b, c, bias, act, compute->deep_, compute_oc, compute->col_, compute->row_);
381 }
382 } else if (func_flag == C1NUM) {
383 if (matmul->out_need_aligned_) {
384 MatVecMulAvx512Fp32(a, b, c, bias, act, compute->deep_, compute_oc, compute->col_align_);
385 } else {
386 MatVecMulMaskAvx512Fp32(a, b, c, bias, act, compute->deep_, compute_oc, compute->col_);
387 }
388 } else {
389 MatVecMulNoPackFp32(a, b, c, bias, act, compute->deep_, compute_oc, compute->col_step_);
390 }
391 }
392
393 return NNACL_OK;
394 }
395
MatmulAVX512ParallelRunByBatch(MatmulStruct * matmul,int task_id)396 int MatmulAVX512ParallelRunByBatch(MatmulStruct *matmul, int task_id) {
397 MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_);
398 MatmulComputeParam *compute = &matmul->compute_;
399 ActType act = param->act_type_;
400
401 int start_batch = task_id * compute->batch_stride_;
402 int end_batch = NNACL_MIN(matmul->batch_, start_batch + compute->batch_stride_);
403 int func_flag = 0;
404 if (compute->row_ == 1) {
405 func_flag += (!matmul->b_const_ && compute->col_ <= C128NUM) ? C2NUM : C1NUM;
406 }
407
408 for (int index = start_batch; index < end_batch; ++index) {
409 const float *a = matmul->matrix_a_.pack_ptr_ + matmul->a_offset_[index] * compute->row_align_ * compute->deep_;
410 const float *b = matmul->matrix_b_.pack_ptr_ + matmul->b_offset_[index] * compute->deep_ * compute->col_align_;
411 float *c = matmul->output_data_ + index * compute->row_ * compute->col_step_;
412 float *bias = (matmul->matrix_c_.pack_ptr_ == NULL) ? NULL : matmul->matrix_c_.pack_ptr_;
413
414 if (func_flag == 0) {
415 if (matmul->out_need_aligned_) {
416 MatMulAvx512Fp32(a, b, c, bias, act, compute->deep_, compute->col_step_, compute->col_align_, compute->row_);
417 } else {
418 MatMulMaskAvx512Fp32(a, b, c, bias, act, compute->deep_, compute->col_step_, compute->col_, compute->row_);
419 }
420 } else if (func_flag == C1NUM) {
421 if (matmul->out_need_aligned_) {
422 MatVecMulAvx512Fp32(a, b, c, bias, act, compute->deep_, compute->col_step_, compute->col_align_);
423 } else {
424 MatVecMulMaskAvx512Fp32(a, b, c, bias, act, compute->deep_, compute->col_step_, compute->col_);
425 }
426 } else {
427 MatVecMulNoPackFp32(a, b, c, bias, act, compute->deep_, compute->col_step_, compute->col_step_);
428 }
429 }
430 return NNACL_OK;
431 }
432
MatmulAVX512ParallelRunByGEPM(MatmulStruct * matmul,int task_id)433 int MatmulAVX512ParallelRunByGEPM(MatmulStruct *matmul, int task_id) {
434 MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_);
435 NNACL_CHECK_FALSE(task_id < 0 || task_id >= matmul->base_.thread_nr_, NNACL_ERR);
436
437 int a_plane_size = matmul->compute_.row_align_ * matmul->compute_.deep_;
438 int b_plane_size = matmul->compute_.deep_ * matmul->compute_.col_align_;
439 int c_plane_size = matmul->compute_.row_ * matmul->compute_.col_step_;
440 int matrix_col = matmul->compute_.col_step_;
441 int matrix_deep = matmul->compute_.deep_;
442
443 // by BatchCut
444 int start_batch = task_id * matmul->compute_.batch_stride_;
445 int end_batch = NNACL_MIN(matmul->batch_, start_batch + matmul->compute_.batch_stride_);
446
447 for (int index = start_batch; index < end_batch; ++index) {
448 const float *a = matmul->matrix_a_.pack_ptr_ + matmul->a_offset_[index] * a_plane_size;
449 const float *b = matmul->matrix_b_.pack_ptr_ + matmul->b_offset_[index] * b_plane_size;
450 float *c = matmul->output_data_ + index * c_plane_size;
451
452 float *bias = (matmul->matrix_c_.pack_ptr_ == NULL) ? NULL : matmul->matrix_c_.pack_ptr_;
453 MatVecMulNoPackFp32(a, b, c, bias, param->act_type_, matrix_deep, matrix_col, matrix_col);
454 }
455
456 // by ColCut
457 int col_split_points_size = matmul->col_split_points_size_;
458 if (task_id < col_split_points_size) {
459 int start_oc = matmul->col_split_points_[task_id];
460 int end_oc = matrix_col;
461 if (task_id < (col_split_points_size - 1)) {
462 end_oc = matmul->col_split_points_[task_id + 1];
463 }
464 int compute_oc = end_oc - start_oc;
465 if (compute_oc <= 0) {
466 return NNACL_OK;
467 }
468
469 float *bias = (matmul->matrix_c_.pack_ptr_ == NULL) ? NULL : matmul->matrix_c_.pack_ptr_ + start_oc;
470 for (int i = matmul->base_.thread_nr_ * matmul->compute_.batch_stride_; i < matmul->batch_; ++i) {
471 float *a = matmul->matrix_a_.pack_ptr_ + matmul->a_offset_[i] * a_plane_size;
472 float *b = matmul->matrix_b_.pack_ptr_ + matmul->b_offset_[i] * b_plane_size + start_oc;
473 float *c = matmul->output_data_ + i * c_plane_size + start_oc;
474 MatVecMulNoPackFp32(a, b, c, bias, param->act_type_, matrix_deep, compute_oc, matrix_col);
475 }
476 }
477 return NNACL_OK;
478 }
MatmulAVX512ParallelRunByGEMM(MatmulStruct * matmul,int task_id)479 int MatmulAVX512ParallelRunByGEMM(MatmulStruct *matmul, int task_id) {
480 MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_);
481 NNACL_CHECK_FALSE(task_id < 0 || task_id >= matmul->base_.thread_nr_, NNACL_ERR);
482
483 int a_plane_size = matmul->compute_.row_align_ * matmul->compute_.deep_;
484 int b_plane_size = matmul->compute_.deep_ * matmul->compute_.col_align_;
485 int c_plane_size = matmul->compute_.row_ * matmul->compute_.col_step_;
486 int matrix_row = matmul->compute_.row_;
487 int matrix_col = matmul->compute_.col_step_;
488 int matrix_deep = matmul->compute_.deep_;
489
490 // by BatchCut
491 int start_batch = task_id * matmul->compute_.batch_stride_;
492 int end_batch = start_batch + matmul->compute_.batch_stride_;
493 float *bias = (matmul->matrix_c_.pack_ptr_ == NULL) ? NULL : matmul->matrix_c_.pack_ptr_;
494 for (int index = start_batch; index < end_batch; ++index) {
495 const float *a = matmul->matrix_a_.pack_ptr_ + matmul->a_offset_[index] * a_plane_size;
496 const float *b = matmul->matrix_b_.pack_ptr_ + matmul->b_offset_[index] * b_plane_size;
497 float *c = matmul->output_data_ + index * c_plane_size;
498 MatMulMaskAvx512Fp32(a, b, c, bias, param->act_type_, matrix_deep, matrix_col, matrix_col, matrix_row);
499 }
500
501 // by ColCut
502 int col_split_points_size = matmul->col_split_points_size_;
503 if (task_id < col_split_points_size) {
504 int start_oc = matmul->col_split_points_[task_id];
505 int end_oc = matmul->col_split_points_[task_id + 1];
506 int compute_oc = end_oc - start_oc;
507
508 bias = (matmul->matrix_c_.pack_ptr_ == NULL) ? NULL : matmul->matrix_c_.pack_ptr_ + start_oc;
509 if (compute_oc > 0) {
510 for (int i = matmul->base_.thread_nr_ * matmul->compute_.batch_stride_; i < matmul->batch_; ++i) {
511 float *a = matmul->matrix_a_.pack_ptr_ + matmul->a_offset_[i] * a_plane_size;
512 float *b = matmul->matrix_b_.pack_ptr_ + matmul->b_offset_[i] * b_plane_size + start_oc * matrix_deep;
513 float *c = matmul->output_data_ + i * c_plane_size + start_oc;
514 MatMulMaskAvx512Fp32(a, b, c, bias, param->act_type_, matrix_deep, compute_oc, matrix_col, matrix_row);
515 }
516 }
517 }
518
519 // by RowCut
520 int start_oc = matmul->col_split_points_[col_split_points_size];
521 int end_oc = matrix_col;
522 int compute_oc = end_oc - start_oc;
523 if (compute_oc <= 0) {
524 return NNACL_OK;
525 }
526
527 int row_split_points_size = matmul->row_split_points_size_;
528 if (task_id >= row_split_points_size) {
529 return NNACL_OK;
530 }
531 int start_row = matmul->row_split_points_[task_id];
532 int end_row = matmul->row_split_points_[task_id + 1];
533 int row_num = end_row - start_row;
534 if (row_num <= 0) {
535 return NNACL_OK;
536 }
537
538 bias = (matmul->matrix_c_.pack_ptr_ == NULL) ? NULL : matmul->matrix_c_.pack_ptr_ + start_oc;
539 for (int i = matmul->base_.thread_nr_ * matmul->compute_.batch_stride_; i < matmul->batch_; ++i) {
540 float *a = matmul->matrix_a_.pack_ptr_ + matmul->a_offset_[i] * a_plane_size + start_row * matrix_deep;
541 float *b = matmul->matrix_b_.pack_ptr_ + matmul->b_offset_[i] * b_plane_size + start_oc * matrix_deep;
542 float *c = matmul->output_data_ + i * c_plane_size + start_row * matrix_col + start_oc;
543 MatMulMaskAvx512Fp32(a, b, c, bias, param->act_type_, matrix_deep, compute_oc, matrix_col, row_num);
544 }
545
546 return NNACL_OK;
547 }
548
MatmulAVX512ParallelRunByGEPDOT(MatmulStruct * matmul,int task_id)549 int MatmulAVX512ParallelRunByGEPDOT(MatmulStruct *matmul, int task_id) {
550 MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_);
551 NNACL_CHECK_FALSE(task_id < 0 || task_id >= matmul->base_.thread_nr_, NNACL_ERR);
552 MatmulComputeParam *compute = &matmul->compute_;
553
554 // by BatchCut
555 int start_batch = task_id * compute->batch_stride_;
556 int end_batch = start_batch + compute->batch_stride_;
557 float bias = 0;
558 if (matmul->matrix_c_.pack_ptr_ != NULL) {
559 bias = matmul->matrix_c_.pack_ptr_[0];
560 }
561 int a_stride = compute->row_ * compute->deep_;
562 int b_stride = compute->deep_ * compute->col_;
563
564 for (int index = start_batch; index < end_batch; ++index) {
565 const float *a = matmul->matrix_a_.pack_ptr_ + matmul->a_offset_[index] * a_stride;
566 const float *b = matmul->matrix_b_.pack_ptr_ + matmul->b_offset_[index] * b_stride;
567 float *c = matmul->output_data_ + index * compute->row_ * compute->col_;
568 matmul->gemm_not_pack_fun_(a, b, c, &bias, compute->row_, compute->deep_, param->act_type_);
569 }
570
571 // by RowCut
572 int split_points_size = matmul->row_split_points_size_;
573 if (task_id >= split_points_size) {
574 return NNACL_OK;
575 }
576 for (int index = matmul->base_.thread_nr_ * compute->batch_stride_; index < matmul->batch_; ++index) {
577 int start_row = matmul->row_split_points_[task_id];
578 int end_row = matmul->row_split_points_[task_id + 1];
579 int row_num = end_row - start_row;
580 if (row_num <= 0) {
581 continue;
582 }
583 const float *a = matmul->matrix_a_.pack_ptr_ + matmul->a_offset_[index] * a_stride + start_row * compute->deep_;
584 const float *b = matmul->matrix_b_.pack_ptr_ + matmul->b_offset_[index] * b_stride;
585 float *c = matmul->output_data_ + index * compute->row_ * compute->col_ + start_row * compute->col_step_;
586 matmul->gemm_not_pack_fun_(a, b, c, &bias, row_num, compute->deep_, param->act_type_);
587 }
588
589 return NNACL_OK;
590 }
591
MatmulAVX512ParallelRunByRow1Deep1GEPDOT(MatmulStruct * matmul,int task_id)592 int MatmulAVX512ParallelRunByRow1Deep1GEPDOT(MatmulStruct *matmul, int task_id) {
593 MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_);
594 NNACL_CHECK_FALSE(task_id < 0 || task_id >= matmul->base_.thread_nr_, NNACL_ERR);
595
596 int a_plane_size = matmul->compute_.row_align_ * matmul->compute_.deep_;
597 int b_plane_size = matmul->compute_.deep_ * matmul->compute_.col_align_;
598 int c_plane_size = matmul->compute_.row_ * matmul->compute_.col_step_;
599 int matrix_col = matmul->compute_.col_step_;
600 int matrix_deep = matmul->compute_.deep_;
601
602 // by BatchCut
603 int start_batch = task_id * matmul->compute_.batch_stride_;
604 int end_batch = NNACL_MIN(matmul->batch_, start_batch + matmul->compute_.batch_stride_);
605
606 for (int index = start_batch; index < end_batch; ++index) {
607 const float *a = matmul->matrix_a_.pack_ptr_ + matmul->a_offset_[index] * a_plane_size;
608 const float *b = matmul->matrix_b_.pack_ptr_ + matmul->b_offset_[index] * b_plane_size;
609 float *c = matmul->output_data_ + index * c_plane_size;
610 float *bias = (matmul->matrix_c_.pack_ptr_ == NULL) ? NULL : matmul->matrix_c_.pack_ptr_;
611 matmul->gemm_not_pack_fun_(a, b, c, bias, matrix_col, matrix_deep, param->act_type_);
612 }
613
614 // by ColCut
615 int col_split_points_size = matmul->col_split_points_size_;
616 if (task_id < col_split_points_size) {
617 int start_oc = matmul->col_split_points_[task_id];
618 int end_oc = matrix_col;
619 if (task_id < (col_split_points_size - 1)) {
620 end_oc = matmul->col_split_points_[task_id + 1];
621 }
622 int compute_oc = end_oc - start_oc;
623 if (compute_oc <= 0) {
624 return NNACL_OK;
625 }
626
627 float *bias = (matmul->matrix_c_.pack_ptr_ == NULL) ? NULL : matmul->matrix_c_.pack_ptr_ + start_oc;
628 for (int i = matmul->base_.thread_nr_ * matmul->compute_.batch_stride_; i < matmul->batch_; ++i) {
629 float *a = matmul->matrix_a_.pack_ptr_ + matmul->a_offset_[i] * a_plane_size;
630 float *b = matmul->matrix_b_.pack_ptr_ + matmul->b_offset_[i] * b_plane_size + start_oc;
631 float *c = matmul->output_data_ + i * c_plane_size + start_oc;
632 matmul->gemm_not_pack_fun_(a, b, c, bias, compute_oc, matrix_deep, param->act_type_);
633 }
634 }
635 return NNACL_OK;
636 }
637
MatmulAVX512ParallelRunByBatchColRowGEMM(MatmulStruct * matmul,int task_id)638 int MatmulAVX512ParallelRunByBatchColRowGEMM(MatmulStruct *matmul, int task_id) {
639 MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_);
640 NNACL_CHECK_FALSE(task_id < 0 || task_id >= matmul->base_.thread_nr_, NNACL_ERR);
641
642 int a_plane_size = matmul->compute_.row_align_ * matmul->compute_.deep_;
643 int b_plane_size = matmul->compute_.deep_ * matmul->compute_.col_align_;
644 int c_plane_size = matmul->compute_.row_ * matmul->compute_.col_step_;
645 int matrix_row = matmul->compute_.row_;
646 int matrix_col = matmul->compute_.col_step_;
647 int matrix_deep = matmul->compute_.deep_;
648
649 // by BatchCut
650 int start_batch = task_id * matmul->compute_.batch_stride_;
651 int end_batch = start_batch + matmul->compute_.batch_stride_;
652 float *bias = (matmul->matrix_c_.pack_ptr_ == NULL) ? NULL : matmul->matrix_c_.pack_ptr_;
653 for (int index = start_batch; index < end_batch; ++index) {
654 const float *a = matmul->matrix_a_.pack_ptr_ + matmul->a_offset_[index] * a_plane_size;
655 const float *b = matmul->matrix_b_.pack_ptr_ + matmul->b_offset_[index] * b_plane_size;
656 float *c = matmul->output_data_ + index * c_plane_size;
657 MatMulMaskAvx512Fp32(a, b, c, bias, param->act_type_, matrix_deep, matrix_col, matrix_col, matrix_row);
658 }
659
660 MatmulSlice *matmul_slices = matmul->matmul_slice_set_[task_id];
661 int slice_count = matmul->matmul_slice_count_[task_id];
662 for (int s = 0; s < slice_count; s++) {
663 MatmulSlice matmul_slice = matmul_slices[s];
664
665 int start_oc = matmul_slice.col_s_;
666 int end_oc = matmul_slice.col_e_;
667 int compute_oc = end_oc - start_oc;
668 if (compute_oc <= 0) {
669 return NNACL_OK;
670 }
671
672 int start_row = matmul_slice.row_s_;
673 int end_row = matmul_slice.row_e_;
674 int row_num = end_row - start_row;
675 if (row_num <= 0) {
676 return NNACL_OK;
677 }
678
679 bias = (matmul->matrix_c_.pack_ptr_ == NULL) ? NULL : matmul->matrix_c_.pack_ptr_ + start_oc;
680 for (int i = matmul->base_.thread_nr_ * matmul->compute_.batch_stride_; i < matmul->batch_; ++i) {
681 float *a = matmul->matrix_a_.pack_ptr_ + matmul->a_offset_[i] * a_plane_size + start_row * matrix_deep;
682 float *b = matmul->matrix_b_.pack_ptr_ + matmul->b_offset_[i] * b_plane_size + start_oc * matrix_deep;
683 float *c = matmul->output_data_ + i * c_plane_size + start_row * matrix_col + start_oc;
684 MatMulMaskAvx512Fp32(a, b, c, bias, param->act_type_, matrix_deep, compute_oc, matrix_col, row_num);
685 }
686 }
687 return NNACL_OK;
688 }
689
CreateMatmulAVX512()690 KernelBase *CreateMatmulAVX512() {
691 MatmulStruct *matmul = (MatmulStruct *)CreateMatmulBase();
692 NNACL_MALLOC_CHECK_NULL_RETURN_NULL(matmul);
693 matmul->matmul_type_ = kNotImplemented;
694 matmul->check_thread_cutting_by_row_ = MatmulAVX512CheckThreadCuttingByRow;
695 matmul->get_thread_cutting_policy_ = MatmulAVX512GetThreadCuttingPolicy;
696 matmul->init_parameter_ = MatmulAVX512InitParameter;
697 matmul->init_global_varibale_ = MatmulAVX512InitGlobalVariable;
698 matmul->parallel_run_by_oc_ = MatmulAVX512ParallelRunByOC;
699 matmul->parallel_run_by_row_ = MatmulAVX512ParallelRunByRow;
700 matmul->parallel_run_by_batch_ = MatmulAVX512ParallelRunByBatch;
701 matmul->parallel_run_by_gemm_ = MatmulAVX512ParallelRunByGEMM;
702 matmul->parallel_run_by_gepm_ = MatmulAVX512ParallelRunByGEPM;
703 matmul->parallel_run_by_gepdot_ = MatmulAVX512ParallelRunByGEPDOT;
704 matmul->parallel_run_by_batch_col_row_gemm_ = MatmulAVX512ParallelRunByBatchColRowGEMM;
705 matmul->parallel_run_by_row1_deep1_gepdot_ = MatmulAVX512ParallelRunByRow1Deep1GEPDOT;
706 return (KernelBase *)matmul;
707 }
708 #endif
709