• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifdef ENABLE_AVX512
18 #include "nnacl/kernel/matmul_avx512.h"
19 #include "nnacl/kernel/matmul_base.h"
20 #include "nnacl/fp32/pack_fp32.h"
21 #include "nnacl/fp32/matmul_fp32.h"
22 #include "nnacl/fp32/matmul_avx512_fp32.h"
23 #include "nnacl/fp32/matmul_avx512_mask_fp32.h"
24 
25 #define MIN_CALC_COST 24576 /* 1 x 6 x 64x 64 */
26 
MatmulAVX512BatchRowThreadCut(MatmulStruct * matmul)27 void MatmulAVX512BatchRowThreadCut(MatmulStruct *matmul) {
28   // BatchCut
29   matmul->compute_.batch_stride_ = DOWN_DIV(matmul->batch_, matmul->base_.thread_nr_);
30 
31   // RowCut
32   int row_step = MSMAX(matmul->compute_.row_ / matmul->base_.thread_nr_, matmul->compute_.row_min_unit_);
33   int row_remaining = matmul->compute_.row_ - row_step * matmul->base_.thread_nr_;
34 
35   matmul->row_split_points_size_ = 0;
36   int row_split_point = 0;
37   while (row_split_point < matmul->compute_.row_) {
38     matmul->row_split_points_[matmul->row_split_points_size_++] = row_split_point;
39     row_split_point += row_step;
40     if (row_remaining > 0) {
41       ++row_split_point;
42       --row_remaining;
43     }
44   }
45   matmul->row_split_points_[matmul->row_split_points_size_] = matmul->compute_.row_;
46   if (matmul->compute_.batch_stride_ == 0) {
47     matmul->base_.thread_nr_ = matmul->row_split_points_size_;
48   }
49 }
50 
MatmulAVX512BatchColThreadCut(MatmulStruct * matmul)51 void MatmulAVX512BatchColThreadCut(MatmulStruct *matmul) {
52   // BatchCut
53   matmul->compute_.batch_stride_ = DOWN_DIV(matmul->batch_, matmul->base_.thread_nr_);
54 
55   // ColCut
56   int total_col_unit = UP_DIV(matmul->compute_.col_align_, matmul->compute_.col_min_unit_);
57   int thread_num_tmp = NNACL_MIN(matmul->base_.thread_nr_, total_col_unit);
58   int block_col_unit = UP_DIV(total_col_unit, thread_num_tmp);
59   int split_point = 0;
60   matmul->col_split_points_size_ = 0;
61   while (split_point < total_col_unit) {
62     matmul->col_split_points_[matmul->col_split_points_size_++] = split_point * matmul->compute_.col_min_unit_;
63     split_point += block_col_unit;
64   }
65   if (matmul->compute_.batch_stride_ == 0) {
66     matmul->base_.thread_nr_ = matmul->col_split_points_size_;
67   }
68 }
69 
MatmulAVX512BatchColRowSliceThreadCut(MatmulStruct * matmul)70 void MatmulAVX512BatchColRowSliceThreadCut(MatmulStruct *matmul) {
71   // BatchCut
72   matmul->compute_.batch_stride_ = DOWN_DIV(matmul->batch_, matmul->base_.thread_nr_);
73 
74   int row_s = 0;
75   int row_e = matmul->compute_.row_;
76   int col_s = 0;
77   int col_e = matmul->compute_.col_;
78 
79   // ColCut
80   int total_col_unit = UP_DIV(matmul->compute_.col_align_, matmul->compute_.col_min_unit_);
81   matmul->compute_.block_col_unit_ = DOWN_DIV(total_col_unit, matmul->base_.thread_nr_);
82   matmul->col_split_points_size_ = 0;
83   matmul->col_split_points_[matmul->col_split_points_size_++] = 0;
84   if (matmul->compute_.block_col_unit_ > 0) {
85     int col_split_point = 0;
86     for (int i = 0; i < matmul->base_.thread_nr_; i++) {
87       MatmulSlice matmul_slice;
88       matmul_slice.row_s_ = row_s;
89       matmul_slice.row_e_ = row_e;
90       matmul_slice.col_s_ = col_split_point * matmul->compute_.col_min_unit_;
91       col_split_point += matmul->compute_.block_col_unit_;
92       col_s = NNACL_MIN(col_split_point * matmul->compute_.col_min_unit_, matmul->compute_.col_step_);
93       matmul_slice.col_e_ = col_s;
94       matmul->matmul_slice_set_[i][matmul->matmul_slice_count_[i]++] = matmul_slice;
95     }
96   }
97   if (col_e - col_s <= 0) {
98     return;
99   }
100 
101   // RowColCut
102   int row_thread = 0;
103   int less_col_align = UP_ROUND(col_e - col_s, C16NUM);
104   bool use_colrowcut_flag = ((less_col_align / C64NUM) * C64NUM) == less_col_align;
105   bool use_rowcut_flag = matmul->compute_.row_ >= C6NUM * matmul->base_.thread_nr_ || col_e - col_s <= C64NUM;
106   if (use_rowcut_flag && !use_colrowcut_flag) {
107     int row_step = MSMAX(matmul->compute_.row_ / matmul->base_.thread_nr_, matmul->compute_.row_min_unit_);
108     int row_remaining = matmul->compute_.row_ - row_step * matmul->base_.thread_nr_;
109     int row_split_point = 0;
110 
111     for (row_thread = 0; row_thread < matmul->base_.thread_nr_ && row_split_point < matmul->compute_.row_;
112          row_thread++) {
113       MatmulSlice matmul_slice;
114       matmul_slice.row_s_ = row_split_point;
115 
116       row_split_point += row_step;
117       if (row_remaining > 0) {
118         ++row_split_point;
119         --row_remaining;
120       }
121 
122       matmul_slice.row_e_ = row_split_point;
123       matmul_slice.col_s_ = col_s;
124       matmul_slice.col_e_ = col_e;
125       matmul->matmul_slice_set_[row_thread][matmul->matmul_slice_count_[row_thread]++] = matmul_slice;
126     }
127   } else {
128     int col_num = UP_DIV(col_e - col_s, C64NUM);
129     int row_num = NNACL_MIN(UP_DIV(matmul->base_.thread_nr_, col_num), (row_e - row_s));
130     int tile_remaining = MSMAX(col_num * row_num - matmul->base_.thread_nr_, 0);
131 
132     NNACL_CHECK_ZERO_RETURN(row_num);
133     int row_step = (row_e - row_s) / row_num;
134     int row_remaining_tmp = (row_e - row_s) - row_step * row_num;
135 
136     int row_step_cut2 = (row_num == 1) ? row_step : (row_e - row_s) / (row_num - 1);
137     int row_remaining_cut2_tmp = (row_e - row_s) - row_step_cut2 * (row_num - 1);
138 
139     MatmulSlice matmul_slice;
140     for (int c = 0; c < col_num; c++) {
141       matmul_slice.col_s_ = col_s + c * C64NUM;
142       matmul_slice.col_e_ = NNACL_MIN(col_s + (c + 1) * C64NUM, matmul->compute_.col_);
143       int row_split_point = 0;
144       int row_remaining = row_remaining_tmp;
145       int row_remaining_cut2 = row_remaining_cut2_tmp;
146       if (c < col_num - tile_remaining) {
147         for (int r = 0; r < row_num; r++) {
148           matmul_slice.row_s_ = row_split_point;
149           row_split_point += row_step;
150           if (row_remaining > 0) {
151             ++row_split_point;
152             --row_remaining;
153           }
154           matmul_slice.row_e_ = NNACL_MIN(row_split_point, matmul->compute_.row_);
155           matmul->matmul_slice_set_[row_thread][matmul->matmul_slice_count_[row_thread]++] = matmul_slice;
156           row_thread++;
157         }
158       } else {
159         for (int r = 0; r < row_num - 1; r++) {
160           matmul_slice.row_s_ = row_split_point;
161           row_split_point += row_step_cut2;
162           if (row_remaining_cut2 > 0) {
163             ++row_split_point;
164             --row_remaining_cut2;
165           }
166           matmul_slice.row_e_ = NNACL_MIN(row_split_point, matmul->compute_.row_);
167           matmul->matmul_slice_set_[row_thread][matmul->matmul_slice_count_[row_thread]++] = matmul_slice;
168           row_thread++;
169         }
170       }
171     }
172   }
173   if ((matmul->compute_.batch_stride_ == 0) && (matmul->compute_.block_col_unit_ == 0)) {
174     matmul->base_.thread_nr_ = row_thread;
175   }
176 }
177 
MatmulAVX512GetThreadCuttingPolicy(MatmulStruct * matmul)178 void MatmulAVX512GetThreadCuttingPolicy(MatmulStruct *matmul) {
179   size_t total_cost = (size_t)(matmul->batch_) * (size_t)(matmul->compute_.row_) * (size_t)(matmul->compute_.col_) *
180                       (size_t)(matmul->compute_.deep_);
181 
182   // Thread Update
183   matmul->base_.thread_nr_ = MSMAX(NNACL_MIN((int)(total_cost / MIN_CALC_COST), matmul->base_.thread_nr_), C1NUM);
184 
185   if (matmul->compute_.deep_ < C128NUM) {
186     return MatmulBaseGetThreadCuttingPolicy(matmul);
187   }
188 
189   for (int i = 0; i < SPLIT_COUNT; i++) {
190     matmul->matmul_slice_count_[i] = 0;
191   }
192   if (matmul->compute_.col_ == 1 && !matmul->a_const_) {
193     MatmulAVX512BatchRowThreadCut(matmul);
194     if (matmul->compute_.deep_ == 1) {
195       matmul->gemm_not_pack_fun_ = GemmIsNotPack;
196     } else {
197       matmul->gemm_not_pack_fun_ = GemmIsNotPackOptimize;
198     }
199     matmul->parallel_run_ = matmul->parallel_run_by_gepdot_;
200   } else if (matmul->compute_.row_ == 1 && !matmul->b_const_ && matmul->compute_.col_ <= C128NUM) {
201     MatmulAVX512BatchColThreadCut(matmul);
202     if (matmul->compute_.deep_ == 1) {
203       matmul->parallel_run_ = matmul->parallel_run_by_row1_deep1_gepdot_;
204       if (matmul->matrix_c_.pack_ptr_ != NULL) {
205         matmul->gemm_not_pack_fun_ = Row1Deep1GemmIsNotPack;
206       } else {
207         matmul->gemm_not_pack_fun_ = Row1Deep1NoBiasGemmIsNotPack;
208       }
209       return;
210     }
211     matmul->parallel_run_ = matmul->parallel_run_by_gepm_;
212   } else {
213     MatmulAVX512BatchColRowSliceThreadCut(matmul);
214     matmul->parallel_run_ = matmul->parallel_run_by_batch_col_row_gemm_;
215   }
216   return;
217 }
218 
MatmulAVX512CheckThreadCuttingByRow(MatmulStruct * matmul)219 bool MatmulAVX512CheckThreadCuttingByRow(MatmulStruct *matmul) {
220   if (matmul->b_batch_ != C1NUM) {
221     return false;
222   }
223   if (matmul->compute_.row_num_ < matmul->base_.thread_nr_) {
224     return false;
225   }
226   if (matmul->compute_.col_ == 1) {
227     matmul->compute_.row_min_unit_ = C8NUM;
228     return true;
229   }
230   if (matmul->compute_.row_ == 1 && !matmul->b_const_ && matmul->compute_.col_ <= C128NUM) {
231     return false;
232   }
233   matmul->compute_.row_min_unit_ = C6NUM;
234   if (matmul->compute_.col_step_ < C48NUM) {
235     matmul->compute_.row_min_unit_ = C12NUM;
236   } else if (matmul->compute_.col_step_ < C64NUM) {
237     matmul->compute_.row_min_unit_ = C8NUM;
238   }
239   return NNACL_MIN(matmul->compute_.row_num_ / matmul->compute_.row_min_unit_, matmul->base_.thread_nr_) >
240          NNACL_MIN(matmul->compute_.col_step_ / matmul->compute_.col_min_unit_, matmul->base_.thread_nr_);
241 }
MatmulAVX512InitGlobalVariable(MatmulStruct * matmul)242 void MatmulAVX512InitGlobalVariable(MatmulStruct *matmul) {
243   MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_);
244   matmul->matrix_a_pack_fun_ = param->a_transpose_ ? RowMajor2ColMajorParallel : RowMajor2RowMajorParallel;
245   matmul->matrix_b_pack_fun_ = param->b_transpose_ ? RowMajor2Col64MajorParallel : RowMajor2Row64MajorParallel;
246   matmul->matrix_a_.need_pack_ = param->a_transpose_;
247   matmul->matrix_b_.need_pack_ = true;
248   matmul->compute_.row_tile_ = C1NUM;
249   matmul->compute_.col_tile_ = C16NUM;
250   matmul->compute_.col_min_unit_ = C64NUM;
251 
252   if (matmul->compute_.row_ == 1) {
253     if (!matmul->b_const_ && matmul->compute_.col_ <= C128NUM) {
254       matmul->out_need_aligned_ = true;
255     }
256   } else if (matmul->compute_.col_ == 1) {
257     matmul->out_need_aligned_ = true;
258   } else {
259     matmul->out_need_aligned_ = false;
260   }
261 
262   if (matmul->compute_.deep_ >= C128NUM) {
263     matmul->out_need_aligned_ = false;
264   }
265 }
MatmulAVX512InitParameter(MatmulStruct * matmul)266 int MatmulAVX512InitParameter(MatmulStruct *matmul) {
267   MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_);
268   MatmulComputeParam *compute = &matmul->compute_;
269 
270   if (compute->deep_ < C128NUM) {
271     return MatmulBaseInitParameter(matmul);
272   }
273 
274   matmul->init_global_varibale_(matmul);
275   if (compute->col_ == 1 && !matmul->a_const_) {
276     matmul->out_need_aligned_ = false;
277     compute->row_tile_ = 1;
278     compute->col_tile_ = 1;
279     matmul->matrix_a_pack_fun_ = param->a_transpose_ ? RowMajor2ColMajorParallel : RowMajor2RowMajorParallel;
280     matmul->matrix_b_pack_fun_ = param->b_transpose_ ? RowMajor2ColMajorParallel : RowMajor2RowMajorParallel;
281     matmul->matrix_a_.need_pack_ = param->a_transpose_ && compute->row_ != 1;
282     matmul->matrix_b_.need_pack_ = false;
283     matmul->pack_opt_ = false;
284   } else if (compute->row_ == 1 && !matmul->b_const_ && compute->col_ <= C128NUM) {
285     matmul->out_need_aligned_ = false;
286     compute->row_tile_ = 1;
287     compute->col_tile_ = 1;
288     matmul->matrix_a_pack_fun_ = param->a_transpose_ ? RowMajor2ColMajorParallel : RowMajor2RowMajorParallel;
289     matmul->matrix_b_pack_fun_ = param->b_transpose_ ? RowMajor2ColMajorParallel : RowMajor2RowMajorParallel;
290     matmul->matrix_a_.need_pack_ = false;
291     matmul->matrix_b_.need_pack_ = param->b_transpose_;
292     matmul->pack_opt_ = false;
293   }
294   compute->row_align_ = UP_ROUND(compute->row_, compute->row_tile_);
295   compute->col_align_ = UP_ROUND(compute->col_, compute->col_tile_);
296   NNACL_CHECK_INT_MUL_NOT_OVERFLOW(matmul->a_batch_, compute->row_align_, NNACL_ERR);
297   NNACL_CHECK_INT_MUL_NOT_OVERFLOW(matmul->a_batch_ * compute->row_align_, compute->deep_, NNACL_ERR);
298   NNACL_CHECK_INT_MUL_NOT_OVERFLOW(matmul->a_batch_, compute->col_align_, NNACL_ERR);
299   NNACL_CHECK_INT_MUL_NOT_OVERFLOW(matmul->a_batch_ * compute->col_align_, compute->deep_, NNACL_ERR);
300   int a_pack_size = matmul->a_batch_ * compute->row_align_ * compute->deep_;
301   int b_pack_size = matmul->b_batch_ * compute->col_align_ * compute->deep_;
302   if ((matmul->matrix_a_.has_packed_ && matmul->matrix_a_.pack_size_ != a_pack_size) ||
303       (matmul->matrix_b_.has_packed_ && matmul->matrix_b_.pack_size_ != b_pack_size)) {
304     return NNACL_ERR;
305   }
306   matmul->matrix_a_.pack_size_ = a_pack_size;
307   matmul->matrix_b_.pack_size_ = b_pack_size;
308   compute->row_align_ = UP_ROUND(compute->row_, compute->row_tile_);
309   matmul->out_need_aligned_ = (matmul->out_need_aligned_ && ((compute->col_ % compute->col_tile_) != 0));
310   compute->col_step_ = matmul->out_need_aligned_ ? compute->col_align_ : compute->col_;
311   NNACL_CHECK_FALSE(INT_MUL_OVERFLOW(matmul->a_batch_, compute->row_), NNACL_ERR);
312   compute->row_num_ = matmul->a_batch_ * compute->row_;
313   return NNACL_OK;
314 }
315 
MatmulAVX512ParallelRunByRow(MatmulStruct * matmul,int task_id)316 int MatmulAVX512ParallelRunByRow(MatmulStruct *matmul, int task_id) {
317   MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_);
318   NNACL_CHECK_FALSE(task_id < 0 || task_id >= matmul->base_.thread_nr_, NNACL_ERR);
319 
320   int start_row = matmul->split_points_[task_id];
321   int end_row = matmul->compute_.row_num_;
322   if (task_id < (matmul->base_.thread_nr_ - 1)) {
323     end_row = matmul->split_points_[task_id + 1];
324   }
325   int row_num = end_row - start_row;
326   if (row_num <= 0) {
327     return NNACL_OK;
328   }
329   const float *input = matmul->matrix_a_.pack_ptr_ + start_row * matmul->compute_.deep_;
330   float *output = matmul->output_data_ + start_row * matmul->compute_.col_step_;
331   if (matmul->compute_.col_ == 1) {
332     float bias = 0;
333     if (matmul->matrix_c_.pack_ptr_ != NULL) {
334       bias = matmul->matrix_c_.pack_ptr_[0];
335     }
336     matmul->gemm_not_pack_fun_(input, matmul->matrix_b_.pack_ptr_, output, &bias, row_num, matmul->compute_.deep_,
337                                param->act_type_);
338   } else {
339     if (matmul->out_need_aligned_) {
340       MatMulAvx512Fp32(input, matmul->matrix_b_.pack_ptr_, output, matmul->matrix_c_.pack_ptr_, param->act_type_,
341                        matmul->compute_.deep_, matmul->compute_.col_align_, matmul->compute_.col_align_, row_num);
342     } else {
343       MatMulMaskAvx512Fp32(input, matmul->matrix_b_.pack_ptr_, output, matmul->matrix_c_.pack_ptr_, param->act_type_,
344                            matmul->compute_.deep_, matmul->compute_.col_, matmul->compute_.col_, row_num);
345     }
346   }
347   return NNACL_OK;
348 }
349 
MatmulAVX512ParallelRunByOC(MatmulStruct * matmul,int task_id)350 int MatmulAVX512ParallelRunByOC(MatmulStruct *matmul, int task_id) {
351   MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_);
352   NNACL_CHECK_FALSE(task_id < 0 || task_id >= matmul->base_.thread_nr_, NNACL_ERR);
353   MatmulComputeParam *compute = &matmul->compute_;
354   ActType act = param->act_type_;
355 
356   int start_oc = matmul->split_points_[task_id];
357   int end_oc = compute->col_step_;
358   if (task_id < (matmul->base_.thread_nr_ - 1)) {
359     end_oc = matmul->split_points_[task_id + 1];
360   }
361   int compute_oc = end_oc - start_oc;
362   if (compute_oc <= 0) {
363     return NNACL_OK;
364   }
365   int func_flag = 0;
366   if (matmul->compute_.row_ == 1) {
367     func_flag += (!matmul->b_const_ && compute->col_ <= C128NUM) ? C2NUM : C1NUM;
368   }
369   int b_stride = func_flag == C2NUM ? start_oc : start_oc * compute->deep_;
370   for (int i = 0; i < matmul->batch_; ++i) {
371     float *a = matmul->matrix_a_.pack_ptr_ + matmul->a_offset_[i] * compute->row_align_ * compute->deep_;
372     float *b = matmul->matrix_b_.pack_ptr_ + matmul->b_offset_[i] * compute->deep_ * compute->col_align_ + b_stride;
373     float *c = matmul->output_data_ + i * compute->row_ * compute->col_step_ + start_oc;
374     float *bias = (matmul->matrix_c_.pack_ptr_ == NULL) ? NULL : matmul->matrix_c_.pack_ptr_ + start_oc;
375 
376     if (func_flag == 0) {
377       if (matmul->out_need_aligned_) {
378         MatMulAvx512Fp32(a, b, c, bias, act, compute->deep_, compute_oc, compute->col_align_, compute->row_);
379       } else {
380         MatMulMaskAvx512Fp32(a, b, c, bias, act, compute->deep_, compute_oc, compute->col_, compute->row_);
381       }
382     } else if (func_flag == C1NUM) {
383       if (matmul->out_need_aligned_) {
384         MatVecMulAvx512Fp32(a, b, c, bias, act, compute->deep_, compute_oc, compute->col_align_);
385       } else {
386         MatVecMulMaskAvx512Fp32(a, b, c, bias, act, compute->deep_, compute_oc, compute->col_);
387       }
388     } else {
389       MatVecMulNoPackFp32(a, b, c, bias, act, compute->deep_, compute_oc, compute->col_step_);
390     }
391   }
392 
393   return NNACL_OK;
394 }
395 
MatmulAVX512ParallelRunByBatch(MatmulStruct * matmul,int task_id)396 int MatmulAVX512ParallelRunByBatch(MatmulStruct *matmul, int task_id) {
397   MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_);
398   MatmulComputeParam *compute = &matmul->compute_;
399   ActType act = param->act_type_;
400 
401   int start_batch = task_id * compute->batch_stride_;
402   int end_batch = NNACL_MIN(matmul->batch_, start_batch + compute->batch_stride_);
403   int func_flag = 0;
404   if (compute->row_ == 1) {
405     func_flag += (!matmul->b_const_ && compute->col_ <= C128NUM) ? C2NUM : C1NUM;
406   }
407 
408   for (int index = start_batch; index < end_batch; ++index) {
409     const float *a = matmul->matrix_a_.pack_ptr_ + matmul->a_offset_[index] * compute->row_align_ * compute->deep_;
410     const float *b = matmul->matrix_b_.pack_ptr_ + matmul->b_offset_[index] * compute->deep_ * compute->col_align_;
411     float *c = matmul->output_data_ + index * compute->row_ * compute->col_step_;
412     float *bias = (matmul->matrix_c_.pack_ptr_ == NULL) ? NULL : matmul->matrix_c_.pack_ptr_;
413 
414     if (func_flag == 0) {
415       if (matmul->out_need_aligned_) {
416         MatMulAvx512Fp32(a, b, c, bias, act, compute->deep_, compute->col_step_, compute->col_align_, compute->row_);
417       } else {
418         MatMulMaskAvx512Fp32(a, b, c, bias, act, compute->deep_, compute->col_step_, compute->col_, compute->row_);
419       }
420     } else if (func_flag == C1NUM) {
421       if (matmul->out_need_aligned_) {
422         MatVecMulAvx512Fp32(a, b, c, bias, act, compute->deep_, compute->col_step_, compute->col_align_);
423       } else {
424         MatVecMulMaskAvx512Fp32(a, b, c, bias, act, compute->deep_, compute->col_step_, compute->col_);
425       }
426     } else {
427       MatVecMulNoPackFp32(a, b, c, bias, act, compute->deep_, compute->col_step_, compute->col_step_);
428     }
429   }
430   return NNACL_OK;
431 }
432 
MatmulAVX512ParallelRunByGEPM(MatmulStruct * matmul,int task_id)433 int MatmulAVX512ParallelRunByGEPM(MatmulStruct *matmul, int task_id) {
434   MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_);
435   NNACL_CHECK_FALSE(task_id < 0 || task_id >= matmul->base_.thread_nr_, NNACL_ERR);
436 
437   int a_plane_size = matmul->compute_.row_align_ * matmul->compute_.deep_;
438   int b_plane_size = matmul->compute_.deep_ * matmul->compute_.col_align_;
439   int c_plane_size = matmul->compute_.row_ * matmul->compute_.col_step_;
440   int matrix_col = matmul->compute_.col_step_;
441   int matrix_deep = matmul->compute_.deep_;
442 
443   // by BatchCut
444   int start_batch = task_id * matmul->compute_.batch_stride_;
445   int end_batch = NNACL_MIN(matmul->batch_, start_batch + matmul->compute_.batch_stride_);
446 
447   for (int index = start_batch; index < end_batch; ++index) {
448     const float *a = matmul->matrix_a_.pack_ptr_ + matmul->a_offset_[index] * a_plane_size;
449     const float *b = matmul->matrix_b_.pack_ptr_ + matmul->b_offset_[index] * b_plane_size;
450     float *c = matmul->output_data_ + index * c_plane_size;
451 
452     float *bias = (matmul->matrix_c_.pack_ptr_ == NULL) ? NULL : matmul->matrix_c_.pack_ptr_;
453     MatVecMulNoPackFp32(a, b, c, bias, param->act_type_, matrix_deep, matrix_col, matrix_col);
454   }
455 
456   // by ColCut
457   int col_split_points_size = matmul->col_split_points_size_;
458   if (task_id < col_split_points_size) {
459     int start_oc = matmul->col_split_points_[task_id];
460     int end_oc = matrix_col;
461     if (task_id < (col_split_points_size - 1)) {
462       end_oc = matmul->col_split_points_[task_id + 1];
463     }
464     int compute_oc = end_oc - start_oc;
465     if (compute_oc <= 0) {
466       return NNACL_OK;
467     }
468 
469     float *bias = (matmul->matrix_c_.pack_ptr_ == NULL) ? NULL : matmul->matrix_c_.pack_ptr_ + start_oc;
470     for (int i = matmul->base_.thread_nr_ * matmul->compute_.batch_stride_; i < matmul->batch_; ++i) {
471       float *a = matmul->matrix_a_.pack_ptr_ + matmul->a_offset_[i] * a_plane_size;
472       float *b = matmul->matrix_b_.pack_ptr_ + matmul->b_offset_[i] * b_plane_size + start_oc;
473       float *c = matmul->output_data_ + i * c_plane_size + start_oc;
474       MatVecMulNoPackFp32(a, b, c, bias, param->act_type_, matrix_deep, compute_oc, matrix_col);
475     }
476   }
477   return NNACL_OK;
478 }
MatmulAVX512ParallelRunByGEMM(MatmulStruct * matmul,int task_id)479 int MatmulAVX512ParallelRunByGEMM(MatmulStruct *matmul, int task_id) {
480   MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_);
481   NNACL_CHECK_FALSE(task_id < 0 || task_id >= matmul->base_.thread_nr_, NNACL_ERR);
482 
483   int a_plane_size = matmul->compute_.row_align_ * matmul->compute_.deep_;
484   int b_plane_size = matmul->compute_.deep_ * matmul->compute_.col_align_;
485   int c_plane_size = matmul->compute_.row_ * matmul->compute_.col_step_;
486   int matrix_row = matmul->compute_.row_;
487   int matrix_col = matmul->compute_.col_step_;
488   int matrix_deep = matmul->compute_.deep_;
489 
490   // by BatchCut
491   int start_batch = task_id * matmul->compute_.batch_stride_;
492   int end_batch = start_batch + matmul->compute_.batch_stride_;
493   float *bias = (matmul->matrix_c_.pack_ptr_ == NULL) ? NULL : matmul->matrix_c_.pack_ptr_;
494   for (int index = start_batch; index < end_batch; ++index) {
495     const float *a = matmul->matrix_a_.pack_ptr_ + matmul->a_offset_[index] * a_plane_size;
496     const float *b = matmul->matrix_b_.pack_ptr_ + matmul->b_offset_[index] * b_plane_size;
497     float *c = matmul->output_data_ + index * c_plane_size;
498     MatMulMaskAvx512Fp32(a, b, c, bias, param->act_type_, matrix_deep, matrix_col, matrix_col, matrix_row);
499   }
500 
501   // by ColCut
502   int col_split_points_size = matmul->col_split_points_size_;
503   if (task_id < col_split_points_size) {
504     int start_oc = matmul->col_split_points_[task_id];
505     int end_oc = matmul->col_split_points_[task_id + 1];
506     int compute_oc = end_oc - start_oc;
507 
508     bias = (matmul->matrix_c_.pack_ptr_ == NULL) ? NULL : matmul->matrix_c_.pack_ptr_ + start_oc;
509     if (compute_oc > 0) {
510       for (int i = matmul->base_.thread_nr_ * matmul->compute_.batch_stride_; i < matmul->batch_; ++i) {
511         float *a = matmul->matrix_a_.pack_ptr_ + matmul->a_offset_[i] * a_plane_size;
512         float *b = matmul->matrix_b_.pack_ptr_ + matmul->b_offset_[i] * b_plane_size + start_oc * matrix_deep;
513         float *c = matmul->output_data_ + i * c_plane_size + start_oc;
514         MatMulMaskAvx512Fp32(a, b, c, bias, param->act_type_, matrix_deep, compute_oc, matrix_col, matrix_row);
515       }
516     }
517   }
518 
519   // by RowCut
520   int start_oc = matmul->col_split_points_[col_split_points_size];
521   int end_oc = matrix_col;
522   int compute_oc = end_oc - start_oc;
523   if (compute_oc <= 0) {
524     return NNACL_OK;
525   }
526 
527   int row_split_points_size = matmul->row_split_points_size_;
528   if (task_id >= row_split_points_size) {
529     return NNACL_OK;
530   }
531   int start_row = matmul->row_split_points_[task_id];
532   int end_row = matmul->row_split_points_[task_id + 1];
533   int row_num = end_row - start_row;
534   if (row_num <= 0) {
535     return NNACL_OK;
536   }
537 
538   bias = (matmul->matrix_c_.pack_ptr_ == NULL) ? NULL : matmul->matrix_c_.pack_ptr_ + start_oc;
539   for (int i = matmul->base_.thread_nr_ * matmul->compute_.batch_stride_; i < matmul->batch_; ++i) {
540     float *a = matmul->matrix_a_.pack_ptr_ + matmul->a_offset_[i] * a_plane_size + start_row * matrix_deep;
541     float *b = matmul->matrix_b_.pack_ptr_ + matmul->b_offset_[i] * b_plane_size + start_oc * matrix_deep;
542     float *c = matmul->output_data_ + i * c_plane_size + start_row * matrix_col + start_oc;
543     MatMulMaskAvx512Fp32(a, b, c, bias, param->act_type_, matrix_deep, compute_oc, matrix_col, row_num);
544   }
545 
546   return NNACL_OK;
547 }
548 
MatmulAVX512ParallelRunByGEPDOT(MatmulStruct * matmul,int task_id)549 int MatmulAVX512ParallelRunByGEPDOT(MatmulStruct *matmul, int task_id) {
550   MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_);
551   NNACL_CHECK_FALSE(task_id < 0 || task_id >= matmul->base_.thread_nr_, NNACL_ERR);
552   MatmulComputeParam *compute = &matmul->compute_;
553 
554   // by BatchCut
555   int start_batch = task_id * compute->batch_stride_;
556   int end_batch = start_batch + compute->batch_stride_;
557   float bias = 0;
558   if (matmul->matrix_c_.pack_ptr_ != NULL) {
559     bias = matmul->matrix_c_.pack_ptr_[0];
560   }
561   int a_stride = compute->row_ * compute->deep_;
562   int b_stride = compute->deep_ * compute->col_;
563 
564   for (int index = start_batch; index < end_batch; ++index) {
565     const float *a = matmul->matrix_a_.pack_ptr_ + matmul->a_offset_[index] * a_stride;
566     const float *b = matmul->matrix_b_.pack_ptr_ + matmul->b_offset_[index] * b_stride;
567     float *c = matmul->output_data_ + index * compute->row_ * compute->col_;
568     matmul->gemm_not_pack_fun_(a, b, c, &bias, compute->row_, compute->deep_, param->act_type_);
569   }
570 
571   // by RowCut
572   int split_points_size = matmul->row_split_points_size_;
573   if (task_id >= split_points_size) {
574     return NNACL_OK;
575   }
576   for (int index = matmul->base_.thread_nr_ * compute->batch_stride_; index < matmul->batch_; ++index) {
577     int start_row = matmul->row_split_points_[task_id];
578     int end_row = matmul->row_split_points_[task_id + 1];
579     int row_num = end_row - start_row;
580     if (row_num <= 0) {
581       continue;
582     }
583     const float *a = matmul->matrix_a_.pack_ptr_ + matmul->a_offset_[index] * a_stride + start_row * compute->deep_;
584     const float *b = matmul->matrix_b_.pack_ptr_ + matmul->b_offset_[index] * b_stride;
585     float *c = matmul->output_data_ + index * compute->row_ * compute->col_ + start_row * compute->col_step_;
586     matmul->gemm_not_pack_fun_(a, b, c, &bias, row_num, compute->deep_, param->act_type_);
587   }
588 
589   return NNACL_OK;
590 }
591 
MatmulAVX512ParallelRunByRow1Deep1GEPDOT(MatmulStruct * matmul,int task_id)592 int MatmulAVX512ParallelRunByRow1Deep1GEPDOT(MatmulStruct *matmul, int task_id) {
593   MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_);
594   NNACL_CHECK_FALSE(task_id < 0 || task_id >= matmul->base_.thread_nr_, NNACL_ERR);
595 
596   int a_plane_size = matmul->compute_.row_align_ * matmul->compute_.deep_;
597   int b_plane_size = matmul->compute_.deep_ * matmul->compute_.col_align_;
598   int c_plane_size = matmul->compute_.row_ * matmul->compute_.col_step_;
599   int matrix_col = matmul->compute_.col_step_;
600   int matrix_deep = matmul->compute_.deep_;
601 
602   // by BatchCut
603   int start_batch = task_id * matmul->compute_.batch_stride_;
604   int end_batch = NNACL_MIN(matmul->batch_, start_batch + matmul->compute_.batch_stride_);
605 
606   for (int index = start_batch; index < end_batch; ++index) {
607     const float *a = matmul->matrix_a_.pack_ptr_ + matmul->a_offset_[index] * a_plane_size;
608     const float *b = matmul->matrix_b_.pack_ptr_ + matmul->b_offset_[index] * b_plane_size;
609     float *c = matmul->output_data_ + index * c_plane_size;
610     float *bias = (matmul->matrix_c_.pack_ptr_ == NULL) ? NULL : matmul->matrix_c_.pack_ptr_;
611     matmul->gemm_not_pack_fun_(a, b, c, bias, matrix_col, matrix_deep, param->act_type_);
612   }
613 
614   // by ColCut
615   int col_split_points_size = matmul->col_split_points_size_;
616   if (task_id < col_split_points_size) {
617     int start_oc = matmul->col_split_points_[task_id];
618     int end_oc = matrix_col;
619     if (task_id < (col_split_points_size - 1)) {
620       end_oc = matmul->col_split_points_[task_id + 1];
621     }
622     int compute_oc = end_oc - start_oc;
623     if (compute_oc <= 0) {
624       return NNACL_OK;
625     }
626 
627     float *bias = (matmul->matrix_c_.pack_ptr_ == NULL) ? NULL : matmul->matrix_c_.pack_ptr_ + start_oc;
628     for (int i = matmul->base_.thread_nr_ * matmul->compute_.batch_stride_; i < matmul->batch_; ++i) {
629       float *a = matmul->matrix_a_.pack_ptr_ + matmul->a_offset_[i] * a_plane_size;
630       float *b = matmul->matrix_b_.pack_ptr_ + matmul->b_offset_[i] * b_plane_size + start_oc;
631       float *c = matmul->output_data_ + i * c_plane_size + start_oc;
632       matmul->gemm_not_pack_fun_(a, b, c, bias, compute_oc, matrix_deep, param->act_type_);
633     }
634   }
635   return NNACL_OK;
636 }
637 
MatmulAVX512ParallelRunByBatchColRowGEMM(MatmulStruct * matmul,int task_id)638 int MatmulAVX512ParallelRunByBatchColRowGEMM(MatmulStruct *matmul, int task_id) {
639   MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_);
640   NNACL_CHECK_FALSE(task_id < 0 || task_id >= matmul->base_.thread_nr_, NNACL_ERR);
641 
642   int a_plane_size = matmul->compute_.row_align_ * matmul->compute_.deep_;
643   int b_plane_size = matmul->compute_.deep_ * matmul->compute_.col_align_;
644   int c_plane_size = matmul->compute_.row_ * matmul->compute_.col_step_;
645   int matrix_row = matmul->compute_.row_;
646   int matrix_col = matmul->compute_.col_step_;
647   int matrix_deep = matmul->compute_.deep_;
648 
649   // by BatchCut
650   int start_batch = task_id * matmul->compute_.batch_stride_;
651   int end_batch = start_batch + matmul->compute_.batch_stride_;
652   float *bias = (matmul->matrix_c_.pack_ptr_ == NULL) ? NULL : matmul->matrix_c_.pack_ptr_;
653   for (int index = start_batch; index < end_batch; ++index) {
654     const float *a = matmul->matrix_a_.pack_ptr_ + matmul->a_offset_[index] * a_plane_size;
655     const float *b = matmul->matrix_b_.pack_ptr_ + matmul->b_offset_[index] * b_plane_size;
656     float *c = matmul->output_data_ + index * c_plane_size;
657     MatMulMaskAvx512Fp32(a, b, c, bias, param->act_type_, matrix_deep, matrix_col, matrix_col, matrix_row);
658   }
659 
660   MatmulSlice *matmul_slices = matmul->matmul_slice_set_[task_id];
661   int slice_count = matmul->matmul_slice_count_[task_id];
662   for (int s = 0; s < slice_count; s++) {
663     MatmulSlice matmul_slice = matmul_slices[s];
664 
665     int start_oc = matmul_slice.col_s_;
666     int end_oc = matmul_slice.col_e_;
667     int compute_oc = end_oc - start_oc;
668     if (compute_oc <= 0) {
669       return NNACL_OK;
670     }
671 
672     int start_row = matmul_slice.row_s_;
673     int end_row = matmul_slice.row_e_;
674     int row_num = end_row - start_row;
675     if (row_num <= 0) {
676       return NNACL_OK;
677     }
678 
679     bias = (matmul->matrix_c_.pack_ptr_ == NULL) ? NULL : matmul->matrix_c_.pack_ptr_ + start_oc;
680     for (int i = matmul->base_.thread_nr_ * matmul->compute_.batch_stride_; i < matmul->batch_; ++i) {
681       float *a = matmul->matrix_a_.pack_ptr_ + matmul->a_offset_[i] * a_plane_size + start_row * matrix_deep;
682       float *b = matmul->matrix_b_.pack_ptr_ + matmul->b_offset_[i] * b_plane_size + start_oc * matrix_deep;
683       float *c = matmul->output_data_ + i * c_plane_size + start_row * matrix_col + start_oc;
684       MatMulMaskAvx512Fp32(a, b, c, bias, param->act_type_, matrix_deep, compute_oc, matrix_col, row_num);
685     }
686   }
687   return NNACL_OK;
688 }
689 
CreateMatmulAVX512()690 KernelBase *CreateMatmulAVX512() {
691   MatmulStruct *matmul = (MatmulStruct *)CreateMatmulBase();
692   NNACL_MALLOC_CHECK_NULL_RETURN_NULL(matmul);
693   matmul->matmul_type_ = kNotImplemented;
694   matmul->check_thread_cutting_by_row_ = MatmulAVX512CheckThreadCuttingByRow;
695   matmul->get_thread_cutting_policy_ = MatmulAVX512GetThreadCuttingPolicy;
696   matmul->init_parameter_ = MatmulAVX512InitParameter;
697   matmul->init_global_varibale_ = MatmulAVX512InitGlobalVariable;
698   matmul->parallel_run_by_oc_ = MatmulAVX512ParallelRunByOC;
699   matmul->parallel_run_by_row_ = MatmulAVX512ParallelRunByRow;
700   matmul->parallel_run_by_batch_ = MatmulAVX512ParallelRunByBatch;
701   matmul->parallel_run_by_gemm_ = MatmulAVX512ParallelRunByGEMM;
702   matmul->parallel_run_by_gepm_ = MatmulAVX512ParallelRunByGEPM;
703   matmul->parallel_run_by_gepdot_ = MatmulAVX512ParallelRunByGEPDOT;
704   matmul->parallel_run_by_batch_col_row_gemm_ = MatmulAVX512ParallelRunByBatchColRowGEMM;
705   matmul->parallel_run_by_row1_deep1_gepdot_ = MatmulAVX512ParallelRunByRow1Deep1GEPDOT;
706   return (KernelBase *)matmul;
707 }
708 #endif
709