• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "nnacl/kernel/matmul_base.h"
18 #include "nnacl/fp32/pack_fp32.h"
19 #include "nnacl/fp32/matmul_fp32.h"
20 #include "nnacl/tensor_c_utils.h"
21 #include "nnacl/op_base.h"
22 
23 #define kNumDeepThreshold 512
24 
MatmulFp32Run(void * cdata,int task_id,float l,float r)25 int MatmulFp32Run(void *cdata, int task_id, float l, float r) {
26   NNACL_CHECK_NULL_RETURN_ERR(cdata);
27   MatmulStruct *matmul = (MatmulStruct *)cdata;
28   return matmul->parallel_run_(matmul, task_id);
29 }
30 
MatmulBaseFreeBatchOffset(MatmulStruct * matmul)31 void MatmulBaseFreeBatchOffset(MatmulStruct *matmul) {
32   if (matmul->a_offset_ != NULL) {
33     free(matmul->a_offset_);
34     matmul->a_offset_ = NULL;
35   }
36   if (matmul->b_offset_ != NULL) {
37     free(matmul->b_offset_);
38     matmul->b_offset_ = NULL;
39   }
40 }
41 
MatmulBaseMallocBatchOffset(MatmulStruct * matmul)42 int MatmulBaseMallocBatchOffset(MatmulStruct *matmul) {
43   matmul->a_offset_ = malloc(matmul->batch_ * sizeof(int));
44   NNACL_MALLOC_CHECK_NULL_RETURN_ERR(matmul->a_offset_);
45   memset(matmul->a_offset_, 0, matmul->batch_ * sizeof(int));
46 
47   matmul->b_offset_ = malloc(matmul->batch_ * sizeof(int));
48   NNACL_MALLOC_CHECK_NULL_RETURN_ERR(matmul->b_offset_);
49   memset(matmul->b_offset_, 0, matmul->batch_ * sizeof(int));
50   return NNACL_OK;
51 }
52 
MatmulBasePackMatrixBParallelRunByBatch(MatmulStruct * matmul,int task_id)53 int MatmulBasePackMatrixBParallelRunByBatch(MatmulStruct *matmul, int task_id) {
54   MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_);
55   MatmulComputeParam *compute = &matmul->compute_;
56 
57   int start = task_id * compute->pack_b_stride_;
58   if (param->b_transpose_) {
59     int end = NNACL_MIN(matmul->compute_.col_, start + compute->pack_b_stride_);
60     matmul->matrix_b_pack_fun_(matmul->pack_b_src_, matmul->pack_b_dst_, compute->col_, compute->deep_, start, end);
61   } else {
62     int end = NNACL_MIN(matmul->compute_.deep_, start + compute->pack_b_stride_);
63     matmul->matrix_b_pack_fun_(matmul->pack_b_src_, matmul->pack_b_dst_, compute->deep_, compute->col_, start, end);
64   }
65   return NNACL_OK;
66 }
67 
MatmulFp32PackMatrixBRun(void * cdata,int task_id,float l,float r)68 int MatmulFp32PackMatrixBRun(void *cdata, int task_id, float l, float r) {
69   NNACL_CHECK_NULL_RETURN_ERR(cdata);
70   MatmulStruct *matmul = (MatmulStruct *)cdata;
71   return MatmulBasePackMatrixBParallelRunByBatch(matmul, task_id);
72 }
73 
MatmulBaseCheckRowOptimalConditions(MatmulStruct * matmul)74 bool MatmulBaseCheckRowOptimalConditions(MatmulStruct *matmul) {
75   return matmul->compute_.row_ == 1 &&
76          !(matmul->support_mul_batch_cut_by_row_ && (matmul->a_batch_ > 1 && matmul->b_batch_ == 1));
77 }
78 
MatmulBaseInitParameter(MatmulStruct * matmul)79 int MatmulBaseInitParameter(MatmulStruct *matmul) {
80   MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_);
81   MatmulComputeParam *compute = &matmul->compute_;
82 
83   matmul->init_global_varibale_(matmul);
84   if (MatmulBaseCheckRowOptimalConditions(matmul)) {
85     compute->row_tile_ = 1;
86     matmul->matrix_a_pack_fun_ = param->a_transpose_ ? RowMajor2ColMajorParallel : RowMajor2RowMajorParallel;
87     matmul->matrix_a_.need_pack_ = false;
88     matmul->pack_opt_ = false;
89     if (!matmul->b_const_ && compute->col_ <= C128NUM) {
90       compute->col_tile_ = 1;
91       matmul->out_need_aligned_ = false;
92       matmul->matrix_b_pack_fun_ = param->b_transpose_ ? RowMajor2ColMajorParallel : RowMajor2RowMajorParallel;
93       matmul->matrix_b_.need_pack_ = param->b_transpose_;
94     }
95   }
96   if (compute->col_ == 1 && !matmul->a_const_) {
97     matmul->out_need_aligned_ = false;
98     compute->row_tile_ = 1;
99     compute->col_tile_ = 1;
100     matmul->matrix_a_pack_fun_ = param->a_transpose_ ? RowMajor2ColMajorParallel : RowMajor2RowMajorParallel;
101     matmul->matrix_b_pack_fun_ = param->b_transpose_ ? RowMajor2ColMajorParallel : RowMajor2RowMajorParallel;
102     matmul->matrix_a_.need_pack_ = param->a_transpose_ && compute->row_ != 1;
103     matmul->matrix_b_.need_pack_ = false;
104     matmul->pack_opt_ = false;
105   }
106   compute->row_align_ = UP_ROUND(compute->row_, compute->row_tile_);
107   compute->col_align_ = UP_ROUND(compute->col_, compute->col_tile_);
108   NNACL_CHECK_INT_MUL_NOT_OVERFLOW(matmul->a_batch_, compute->row_align_, NNACL_ERR);
109   NNACL_CHECK_INT_MUL_NOT_OVERFLOW(matmul->a_batch_ * compute->row_align_, compute->deep_, NNACL_ERR);
110   NNACL_CHECK_INT_MUL_NOT_OVERFLOW(matmul->a_batch_, compute->col_align_, NNACL_ERR);
111   NNACL_CHECK_INT_MUL_NOT_OVERFLOW(matmul->a_batch_ * compute->col_align_, compute->deep_, NNACL_ERR);
112   int a_pack_size = matmul->a_batch_ * compute->row_align_ * compute->deep_;
113   int b_pack_size = matmul->b_batch_ * compute->col_align_ * compute->deep_;
114   if ((matmul->matrix_a_.has_packed_ && matmul->matrix_a_.pack_size_ != a_pack_size) ||
115       (matmul->matrix_b_.has_packed_ && matmul->matrix_b_.pack_size_ != b_pack_size)) {
116     return NNACL_ERR;
117   }
118   matmul->matrix_a_.pack_size_ = a_pack_size;
119   matmul->matrix_b_.pack_size_ = b_pack_size;
120   compute->row_align_ = UP_ROUND(compute->row_, compute->row_tile_);
121   matmul->out_need_aligned_ = (matmul->out_need_aligned_ && ((compute->col_ % compute->col_tile_) != 0));
122   compute->col_step_ = matmul->out_need_aligned_ ? compute->col_align_ : compute->col_;
123   NNACL_CHECK_FALSE(INT_MUL_OVERFLOW(matmul->a_batch_, compute->row_), NNACL_ERR);
124   compute->row_num_ = matmul->a_batch_ * compute->row_;
125   return NNACL_OK;
126 }
127 
MatmulBasePackMatrixAImplOpt(MatmulStruct * matmul)128 int MatmulBasePackMatrixAImplOpt(MatmulStruct *matmul) { return NNACL_ERR; }
129 
MatmulBasePackMatrixAImpl(MatmulStruct * matmul)130 int MatmulBasePackMatrixAImpl(MatmulStruct *matmul) {
131   MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_);
132   float *src_ptr = (matmul->matrix_a_.origin_ptr_ != NULL) ? (matmul->matrix_a_.origin_ptr_)
133                                                            : (float *)(matmul->base_.in_[FIRST_INPUT]->data_);
134   NNACL_CHECK_TRUE_RET(src_ptr != NULL, NNACL_ERR);
135   NNACL_CHECK_TRUE_RET(matmul->matrix_a_.pack_ptr_ != NULL, NNACL_ERR);
136   NNACL_CHECK_TRUE_RET(matmul->matrix_a_pack_fun_ != NULL, NNACL_ERR);
137   for (int i = 0; i < matmul->a_batch_; i++) {
138     const float *src = src_ptr + i * matmul->compute_.deep_ * matmul->compute_.row_;
139     float *dst = matmul->matrix_a_.pack_ptr_ + i * matmul->compute_.deep_ * matmul->compute_.row_align_;
140     if (param->a_transpose_) {
141       matmul->matrix_a_pack_fun_(src, dst, matmul->compute_.deep_, matmul->compute_.row_, 0, matmul->compute_.deep_);
142     } else {
143       matmul->matrix_a_pack_fun_(src, dst, matmul->compute_.row_, matmul->compute_.deep_, 0, matmul->compute_.row_);
144     }
145   }
146   return NNACL_OK;
147 }
148 
MatmulBasePackMatrixBImpl(MatmulStruct * matmul)149 int MatmulBasePackMatrixBImpl(MatmulStruct *matmul) {
150   MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_);
151 
152   float *src_ptr = matmul->matrix_b_.origin_ptr_ != NULL ? matmul->matrix_b_.origin_ptr_
153                                                          : (float *)matmul->base_.in_[SECOND_INPUT]->data_;
154   NNACL_CHECK_TRUE_RET(src_ptr != NULL, NNACL_ERR);
155   NNACL_CHECK_TRUE_RET(matmul->matrix_b_.pack_ptr_ != NULL, NNACL_ERR);
156   NNACL_CHECK_TRUE_RET(matmul->matrix_b_pack_fun_ != NULL, NNACL_ERR);
157 
158   for (int i = 0; i < matmul->b_batch_; i++) {
159     if (param->b_transpose_) {
160       matmul->compute_.pack_b_stride_ = UP_DIV(matmul->compute_.col_, matmul->base_.thread_nr_);
161     } else {
162       matmul->compute_.pack_b_stride_ = UP_DIV(matmul->compute_.deep_, matmul->base_.thread_nr_);
163     }
164     matmul->pack_b_src_ = src_ptr + i * matmul->compute_.deep_ * matmul->compute_.col_;
165     matmul->pack_b_dst_ = matmul->matrix_b_.pack_ptr_ + i * matmul->compute_.deep_ * matmul->compute_.col_align_;
166     int ret = matmul->base_.env_->ParallelLaunch(matmul->base_.env_->thread_pool_, MatmulFp32PackMatrixBRun, matmul,
167                                                  matmul->base_.thread_nr_);
168     NNACL_CHECK_FALSE(ret != NNACL_OK, ret);
169   }
170   return NNACL_OK;
171 }
172 
MatmulBasePackMatrixA(MatmulStruct * matmul)173 int MatmulBasePackMatrixA(MatmulStruct *matmul) {
174   if (!matmul->a_const_) {
175     if (!matmul->matrix_a_.need_pack_) {
176       matmul->matrix_a_.pack_ptr_ = (float *)matmul->base_.in_[0]->data_;
177       return NNACL_OK;
178     }
179     if (matmul->base_.train_session_) {
180       matmul->matrix_a_.pack_ptr_ = (float *)(matmul->base_.workspace_);
181     } else {
182       matmul->matrix_a_.pack_ptr_ = (float *)(matmul->base_.env_->Alloc(matmul->base_.env_->allocator_,
183                                                                         matmul->matrix_a_.pack_size_ * sizeof(float)));
184     }
185     NNACL_MALLOC_CHECK_NULL_RETURN_ERR(matmul->matrix_a_.pack_ptr_);
186   } else {
187     bool is_packed = false;
188     void *data = NULL;
189     size_t data_size = (size_t)(matmul->matrix_a_.pack_size_) * sizeof(float);
190     if (matmul->is_sharing_pack_) {
191       TensorC *a_matrix = matmul->base_.in_[FIRST_INPUT];
192       data = matmul->get_sharing_weight_(matmul->shaing_manager_, a_matrix->data_, data_size, &is_packed);
193     } else {
194       data = malloc(data_size);
195     }
196     matmul->matrix_a_.pack_ptr_ = (float *)data;
197     NNACL_MALLOC_CHECK_NULL_RETURN_ERR(matmul->matrix_a_.pack_ptr_);
198     if (is_packed) {
199       return NNACL_OK;
200     }
201   }
202   if (matmul->pack_opt_) {
203     /* valid in arm64 */
204     return matmul->pack_matrix_a_impl_opt_(matmul);
205   }
206   return matmul->pack_matrix_a_impl_(matmul);
207 }
208 
MatmulBasePackMatrixB(MatmulStruct * matmul)209 int MatmulBasePackMatrixB(MatmulStruct *matmul) {
210   if (!matmul->b_const_) {
211     if (!matmul->matrix_b_.need_pack_) {
212       matmul->matrix_b_.pack_ptr_ = (float *)matmul->base_.in_[SECOND_INPUT]->data_;
213       return NNACL_OK;
214     }
215     if (matmul->base_.train_session_) {
216       matmul->matrix_b_.pack_ptr_ = (float *)(matmul->base_.workspace_) + matmul->matrix_a_.pack_size_;
217     } else {
218       matmul->matrix_b_.pack_ptr_ = (float *)(matmul->base_.env_->Alloc(matmul->base_.env_->allocator_,
219                                                                         matmul->matrix_b_.pack_size_ * sizeof(float)));
220     }
221     NNACL_MALLOC_CHECK_NULL_RETURN_ERR(matmul->matrix_b_.pack_ptr_);
222   } else {
223     if (!matmul->matrix_b_.need_pack_ && matmul->weight_is_packed_) {
224       matmul->matrix_b_.pack_ptr_ = (float *)matmul->base_.in_[SECOND_INPUT]->data_;
225       return NNACL_OK;
226     }
227     bool is_packed = false;
228     void *data = NULL;
229     size_t data_size = (size_t)(matmul->matrix_b_.pack_size_) * sizeof(float);
230     if (matmul->is_sharing_pack_) {
231       TensorC *b_matrix = matmul->base_.in_[SECOND_INPUT];
232       data = matmul->get_sharing_weight_(matmul->shaing_manager_, b_matrix->data_, data_size, &is_packed);
233     } else {
234       data = malloc(data_size);
235     }
236     NNACL_MALLOC_CHECK_NULL_RETURN_ERR(data);
237     matmul->matrix_b_.pack_ptr_ = (float *)data;
238     if (is_packed) {
239       return NNACL_OK;
240     }
241   }
242   return matmul->pack_matrix_b_impl_(matmul);
243 }
244 
MatmulBaseBackupConstMatrix(MatmulStruct * matmul,MatrixInfo * matrix_info,int index)245 int MatmulBaseBackupConstMatrix(MatmulStruct *matmul, MatrixInfo *matrix_info, int index) {
246   NNACL_CHECK_TRUE_RET(index < (int)matmul->base_.in_size_, NNACL_ERR);
247   size_t backup_size = (size_t)GetElementNum(matmul->base_.in_[index]) * sizeof(float);
248   NNACL_CHECK_TRUE_RET(backup_size > 0, NNACL_ERR);
249   matrix_info->origin_ptr_ = (float *)(matmul->base_.env_->Alloc(matmul->base_.env_->allocator_, backup_size));
250   NNACL_MALLOC_CHECK_NULL_RETURN_ERR(matrix_info->origin_ptr_);
251   void *src_ptr = matmul->base_.in_[index]->data_;
252   NNACL_CHECK_NULL_RETURN_ERR(src_ptr);
253   (void)memcpy(matrix_info->origin_ptr_, src_ptr, backup_size);
254   matrix_info->origin_need_free_ = true;
255   return NNACL_OK;
256 }
257 
MatmulBaseParallelRunByRow(MatmulStruct * matmul,int task_id)258 int MatmulBaseParallelRunByRow(MatmulStruct *matmul, int task_id) { return NNACL_ERR; }
259 
MatmulBaseParallelRunByBatch(MatmulStruct * matmul,int task_id)260 int MatmulBaseParallelRunByBatch(MatmulStruct *matmul, int task_id) {
261   MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_);
262   MatmulComputeParam *compute = &matmul->compute_;
263 
264   int start_batch = task_id * compute->batch_stride_;
265   int end_batch = MSMIN(matmul->batch_, start_batch + compute->batch_stride_);
266   int func_flag = 0;
267   if (compute->row_ == 1) {
268     func_flag += (!matmul->b_const_ && compute->col_ <= C128NUM) ? C2NUM : C1NUM;
269   }
270 
271   for (int index = start_batch; index < end_batch; ++index) {
272     const float *a = matmul->matrix_a_.pack_ptr_ + matmul->a_offset_[index] * compute->row_align_ * compute->deep_;
273     const float *b = matmul->matrix_b_.pack_ptr_ + matmul->b_offset_[index] * compute->deep_ * compute->col_align_;
274     float *c = matmul->output_data_ + index * compute->row_ * compute->col_step_;
275 
276     float *bias = (matmul->matrix_c_.pack_ptr_ == NULL) ? NULL : matmul->matrix_c_.pack_ptr_;
277     if (func_flag == 0) {
278       MatMulOpt(a, b, c, bias, param->act_type_, compute->deep_, compute->row_, compute->col_step_, compute->col_,
279                 OutType_Nhwc);
280     } else if (func_flag == C1NUM) {
281       MatVecMulFp32Block8(a, b, c, bias, param->act_type_, compute->deep_, compute->col_step_);
282     } else {
283       MatVecMulNoPackFp32(a, b, c, bias, param->act_type_, compute->deep_, compute->col_step_, compute->col_step_);
284     }
285   }
286   return NNACL_OK;
287 }
288 
MatmulBaseParallelRunIsNotPackByBatch(MatmulStruct * matmul,int task_id)289 int MatmulBaseParallelRunIsNotPackByBatch(MatmulStruct *matmul, int task_id) {
290   MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_);
291   int start_batch = task_id * matmul->compute_.batch_stride_;
292   int end_batch = MSMIN(matmul->batch_, start_batch + matmul->compute_.batch_stride_);
293   float bias = 0;
294   if (matmul->matrix_c_.pack_ptr_ != NULL) {
295     bias = matmul->matrix_c_.pack_ptr_[0];
296   }
297   for (int index = start_batch; index < end_batch; ++index) {
298     float *a = matmul->matrix_a_.pack_ptr_ + matmul->a_offset_[index] * matmul->compute_.row_ * matmul->compute_.deep_;
299     float *b = matmul->matrix_b_.pack_ptr_ + matmul->b_offset_[index] * matmul->compute_.deep_ * matmul->compute_.col_;
300     float *c = matmul->output_data_ + index * matmul->compute_.row_ * matmul->compute_.col_;
301     matmul->gemm_not_pack_fun_(a, b, c, &bias, matmul->compute_.row_, matmul->compute_.deep_, param->act_type_);
302   }
303   return NNACL_OK;
304 }
305 
MatmulBaseGetThreadCuttingInfoByRow(MatmulStruct * matmul)306 void MatmulBaseGetThreadCuttingInfoByRow(MatmulStruct *matmul) {
307   int row_step = NNACL_MAX(matmul->compute_.row_num_ / matmul->base_.thread_nr_, matmul->compute_.row_min_unit_);
308   int row_remaining = matmul->compute_.row_num_ - row_step * matmul->base_.thread_nr_;
309 
310   int split_point = 0;
311   int count = 0;
312   while (split_point < matmul->compute_.row_num_) {
313     matmul->split_points_[count++] = split_point;
314     split_point += row_step;
315     if (row_remaining > 0) {
316       ++split_point;
317       --row_remaining;
318     }
319   }
320   matmul->base_.thread_nr_ = count;
321 }
322 
MatmulBaseParallelRunByOC(MatmulStruct * matmul,int task_id)323 int MatmulBaseParallelRunByOC(MatmulStruct *matmul, int task_id) {
324   MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_);
325   NNACL_CHECK_FALSE(task_id < 0 || task_id >= matmul->base_.thread_nr_, NNACL_ERR);
326   MatmulComputeParam *compute = &matmul->compute_;
327   ActType act = param->act_type_;
328 
329   int start_oc = matmul->split_points_[task_id];
330   int end_oc = compute->col_step_;
331   if (task_id < (matmul->base_.thread_nr_ - 1)) {
332     end_oc = matmul->split_points_[task_id + 1];
333   }
334   int compute_oc = end_oc - start_oc;
335   if (compute_oc <= 0) {
336     return NNACL_OK;
337   }
338 
339   int func_flag = 0;
340   if (compute->row_ == 1) {
341     func_flag += (!matmul->b_const_ && compute->col_ <= C128NUM) ? C2NUM : C1NUM;
342   }
343   int b_stride = func_flag == C2NUM ? start_oc : start_oc * compute->deep_;
344 
345   for (int i = 0; i < matmul->batch_; ++i) {
346     float *a = matmul->matrix_a_.pack_ptr_ + matmul->a_offset_[i] * compute->row_align_ * compute->deep_;
347     float *b = matmul->matrix_b_.pack_ptr_ + matmul->b_offset_[i] * compute->deep_ * compute->col_align_ + b_stride;
348     float *c = matmul->output_data_ + i * compute->row_ * compute->col_step_ + start_oc;
349     float *bias = (matmul->matrix_c_.pack_ptr_ == NULL) ? NULL : matmul->matrix_c_.pack_ptr_ + start_oc;
350 
351     if (func_flag == 0) {
352       MatMulOpt(a, b, c, bias, act, compute->deep_, compute->row_, compute_oc, compute->col_, OutType_Nhwc);
353     } else if (func_flag == C1NUM) {
354       MatVecMulFp32Block8(a, b, c, bias, act, compute->deep_, compute_oc);
355     } else {
356       MatVecMulNoPackFp32(a, b, c, bias, act, compute->deep_, compute_oc, compute->col_step_);
357     }
358   }
359   return NNACL_OK;
360 }
361 
MatmulBaseGetThreadCuttingPolicy(MatmulStruct * matmul)362 void MatmulBaseGetThreadCuttingPolicy(MatmulStruct *matmul) {
363   if (matmul->compute_.deep_ < kNumDeepThreshold) {
364     if (matmul->model_thread_nr_ != -1) {
365       matmul->base_.thread_nr_ = matmul->model_thread_nr_;
366     }
367   }
368 
369   if ((matmul->a_batch_ >= matmul->base_.thread_nr_ &&
370        (matmul->b_batch_ == matmul->a_batch_ || !matmul->support_mul_batch_cut_by_row_)) ||
371       matmul->compute_.col_ == 1) {
372     matmul->compute_.batch_stride_ = UP_DIV(matmul->batch_, matmul->base_.thread_nr_);
373     matmul->parallel_run_ = matmul->parallel_run_by_batch_;
374     if (matmul->compute_.col_ != 1 || matmul->a_const_) {
375       return;
376     }
377 
378     matmul->parallel_run_ = matmul->parallel_run_not_pack_by_batch_;
379     if (matmul->compute_.deep_ == 1) {
380       matmul->gemm_not_pack_fun_ = GemmIsNotPack;
381     } else {
382       matmul->gemm_not_pack_fun_ = GemmIsNotPackOptimize;
383       if (matmul->check_thread_cutting_by_row_(matmul)) {
384         matmul->parallel_run_ = matmul->parallel_run_by_row_;
385         matmul->get_thread_cutting_info_by_row_(matmul);
386       }
387     }
388     return;
389   } else if ((matmul->a_batch_ >= matmul->base_.thread_nr_ && matmul->b_batch_ == 1) ||
390              matmul->check_thread_cutting_by_row_(matmul)) {
391     matmul->parallel_run_ = matmul->parallel_run_by_row_;
392     matmul->get_thread_cutting_info_by_row_(matmul);
393   } else {
394     int total_col_unit = UP_DIV(matmul->compute_.col_align_, matmul->compute_.col_min_unit_);
395     matmul->base_.thread_nr_ = MSMIN(matmul->base_.thread_nr_, total_col_unit);
396     int block_col_unit = UP_DIV(total_col_unit, matmul->base_.thread_nr_);
397 
398     int count = 0;
399     int split_point = 0;
400     while (split_point < total_col_unit) {
401       matmul->split_points_[count++] = (split_point * matmul->compute_.col_min_unit_);
402       split_point += block_col_unit;
403     }
404     matmul->base_.thread_nr_ = count;
405     matmul->parallel_run_ = matmul->parallel_run_by_oc_;
406   }
407   return;
408 }
409 
MatmulBasePackBiasMatrix(MatmulStruct * matmul)410 int MatmulBasePackBiasMatrix(MatmulStruct *matmul) {
411   if (matmul->base_.in_size_ != FOURTH_INPUT) {
412     return NNACL_OK;
413   }
414   if (matmul->matrix_c_.has_packed_) {
415     NNACL_CHECK_FALSE(matmul->matrix_c_.pack_size_ < matmul->compute_.col_align_, NNACL_ERR);
416     return NNACL_OK;
417   }
418   TensorC *bias_tensor = matmul->base_.in_[THIRD_INPUT];
419   float *bias_src = matmul->matrix_c_.origin_ptr_ != NULL ? matmul->matrix_c_.origin_ptr_ : (float *)bias_tensor->data_;
420   NNACL_CHECK_NULL_RETURN_ERR(bias_src);
421 
422   int bias_num = GetElementNum(bias_tensor);
423   NNACL_CHECK_TRUE_RET(bias_num > 0 && matmul->compute_.col_align_ >= bias_num, NNACL_ERR);
424 
425   matmul->matrix_c_.pack_size_ = matmul->compute_.col_align_;
426   if (matmul->matrix_c_.pack_ptr_ == NULL) {
427     matmul->matrix_c_.pack_ptr_ = (float *)(malloc(matmul->matrix_c_.pack_size_ * sizeof(float)));
428   }
429   NNACL_CHECK_NULL_RETURN_ERR(matmul->matrix_c_.pack_ptr_);
430 
431   if (bias_num == 1) {
432     for (int i = 0; i < matmul->matrix_c_.pack_size_; ++i) {
433       matmul->matrix_c_.pack_ptr_[i] = bias_src[0];
434     }
435   } else {
436     (void)memcpy(matmul->matrix_c_.pack_ptr_, bias_src, bias_num * sizeof(float));
437     (void)memset(matmul->matrix_c_.pack_ptr_ + bias_num, 0, (matmul->matrix_c_.pack_size_ - bias_num) * sizeof(float));
438   }
439   if (matmul->matrix_c_.origin_need_free_) {
440     matmul->base_.env_->Free(matmul->base_.env_->allocator_, matmul->matrix_c_.origin_ptr_);
441     matmul->matrix_c_.origin_ptr_ = NULL;
442     matmul->matrix_c_.origin_need_free_ = false;
443   }
444   return NNACL_OK;
445 }
446 
MatmulBaseInitTmpOutBuffer(MatmulStruct * matmul)447 int MatmulBaseInitTmpOutBuffer(MatmulStruct *matmul) {
448   if (matmul->out_need_aligned_) {
449     if (matmul->output_data_ != NULL) {
450       matmul->base_.env_->Free(matmul->base_.env_->allocator_, matmul->output_data_);
451     }
452     // avx need to malloc dst aligned to C8NUM
453     // avx512 need to malloc dst aligned to C16NUM
454     int out_channel = matmul->compute_.col_;
455     NNACL_CHECK_ZERO_RETURN_ERR(matmul->compute_.col_tile_);
456     int oc_block_num = UP_DIV(out_channel, matmul->compute_.col_tile_);
457     int ele_num = matmul->batch_ * matmul->compute_.row_ * oc_block_num * matmul->compute_.col_tile_;
458     int data_size = ele_num * (int)sizeof(float);
459     matmul->output_data_ = (float *)(matmul->base_.env_->Alloc(matmul->base_.env_->allocator_, data_size));
460     NNACL_CHECK_NULL_RETURN_ERR(matmul->output_data_);
461   }
462   return NNACL_OK;
463 }
464 
MatmulBaseInitGlobalVariable(MatmulStruct * matmul)465 void MatmulBaseInitGlobalVariable(MatmulStruct *matmul) {
466   MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_);
467   matmul->matrix_a_.need_pack_ = true;
468   matmul->matrix_b_.need_pack_ = !matmul->weight_is_packed_;
469   matmul->matrix_a_pack_fun_ = param->a_transpose_ ? RowMajor2Row12MajorParallel : RowMajor2Col12MajorParallel;
470   matmul->matrix_b_pack_fun_ = param->b_transpose_ ? RowMajor2Col8MajorParallel : RowMajor2Row8MajorParallel;
471   matmul->compute_.row_tile_ = C12NUM;
472   matmul->compute_.col_tile_ = C8NUM;
473   matmul->compute_.col_min_unit_ = C8NUM;
474   return;
475 }
476 
MatmulBaseCheckThreadCuttingByRow()477 bool MatmulBaseCheckThreadCuttingByRow() { return false; }
478 
MatmulBaseFreePackedMatrixA(KernelBase * self)479 void MatmulBaseFreePackedMatrixA(KernelBase *self) {
480   MatmulStruct *matmul = (MatmulStruct *)self;
481   if (matmul->matrix_a_.need_pack_ && !matmul->base_.train_session_ && matmul->matrix_a_.pack_ptr_ != NULL) {
482     self->env_->Free(self->env_->allocator_, matmul->matrix_a_.pack_ptr_);
483   }
484   matmul->matrix_a_.pack_ptr_ = NULL;
485 }
486 
MatmulBaseFreePackedMatrixB(KernelBase * self)487 void MatmulBaseFreePackedMatrixB(KernelBase *self) {
488   MatmulStruct *matmul = (MatmulStruct *)self;
489   if (matmul->matrix_b_.need_pack_ && !matmul->base_.train_session_ && matmul->matrix_b_.pack_ptr_ != NULL) {
490     matmul->base_.env_->Free(matmul->base_.env_->allocator_, matmul->matrix_b_.pack_ptr_);
491   }
492   matmul->matrix_b_.pack_ptr_ = NULL;
493 }
494 
MatmulBaseResize(KernelBase * self)495 int MatmulBaseResize(KernelBase *self) {
496   MatmulStruct *matmul = (MatmulStruct *)self;
497 
498   int ret = matmul->init_parameter_(matmul);
499   NNACL_CHECK_FALSE(ret != NNACL_OK, ret);
500   if (self->train_session_) {
501     self->work_size_ = (matmul->matrix_a_.pack_size_ + matmul->matrix_b_.pack_size_) * (int)sizeof(float);
502   }
503 
504   matmul->get_thread_cutting_policy_(matmul);
505   if (!matmul->matrix_c_.has_packed_) {
506     ret = MatmulBasePackBiasMatrix(matmul);
507     NNACL_CHECK_FALSE(ret != NNACL_OK, ret);
508     if (!matmul->bias_need_repack_) {
509       matmul->matrix_c_.has_packed_ = true;
510     }
511   }
512   ret = MatmulBaseInitTmpOutBuffer(matmul);
513   NNACL_CHECK_FALSE(ret != NNACL_OK, ret);
514 
515   return NNACL_OK;
516 }
517 
MatmulBaseRelease(struct KernelBase * self)518 int MatmulBaseRelease(struct KernelBase *self) {
519   MatmulStruct *matmul = (MatmulStruct *)self;
520   MatmulBaseFreeBatchOffset(matmul);
521 
522   if (matmul->out_need_aligned_ && matmul->output_data_ != NULL) {
523     matmul->base_.env_->Free(matmul->base_.env_->allocator_, matmul->output_data_);
524     matmul->output_data_ = NULL;
525   }
526   if (matmul->matrix_c_.pack_ptr_ != NULL) {
527     free(matmul->matrix_c_.pack_ptr_);
528     matmul->matrix_c_.pack_ptr_ = NULL;
529   }
530   if (matmul->a_const_) {
531     if (matmul->is_sharing_pack_) {
532       matmul->free_sharing_weight_(matmul->shaing_manager_, matmul->matrix_a_.pack_ptr_);
533     } else {
534       free(matmul->matrix_a_.pack_ptr_);
535     }
536   }
537   if (matmul->b_const_) {
538     if (!matmul->matrix_b_.need_pack_ && matmul->weight_is_packed_) {
539       return NNACL_OK;
540     }
541     if (matmul->is_sharing_pack_) {
542       matmul->free_sharing_weight_(matmul->shaing_manager_, matmul->matrix_b_.pack_ptr_);
543     } else {
544       free(matmul->matrix_b_.pack_ptr_);
545     }
546   }
547   return NNACL_OK;
548 }
549 
MatmulBasePrepare(struct KernelBase * self)550 int MatmulBasePrepare(struct KernelBase *self) {
551   MatmulStruct *matmul = (MatmulStruct *)self;
552 
553   NNACL_CHECK_FALSE(matmul->base_.in_size_ < C2NUM, NNACL_INPUT_TENSOR_ERROR);
554   NNACL_CHECK_FALSE(matmul->base_.out_size_ < 1, NNACL_OUTPUT_TENSOR_ERROR);
555   NNACL_CHECK_FALSE(matmul->base_.in_[FIRST_INPUT]->data_type_ != kNumberTypeFloat32, NNACL_INPUT_TENSOR_ERROR);
556   NNACL_CHECK_FALSE(matmul->base_.in_[SECOND_INPUT]->data_type_ != kNumberTypeFloat32, NNACL_INPUT_TENSOR_ERROR);
557 
558   if (matmul->base_.in_size_ == THREE_TENSOR) {
559     NNACL_CHECK_TRUE_RET(matmul->base_.in_[THIRD_INPUT]->data_type_ == kNumberTypeFloat32, NNACL_MATMUL_BIAS_INVALID);
560   }
561 
562   MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_);
563   NNACL_CHECK_FALSE(
564     param->act_type_ != ActType_No && param->act_type_ != ActType_Relu && param->act_type_ != ActType_Relu6,
565     NNACL_MATMUL_ACT_TYPE_INVALID);
566 
567   int ret = matmul->init_parameter_(matmul);
568   NNACL_CHECK_FALSE(ret != NNACL_OK, ret);
569 
570   if (matmul->a_const_) {
571     ret = MatmulBasePackMatrixA(matmul);
572     NNACL_CHECK_FALSE(ret != NNACL_OK, ret);
573     matmul->matrix_a_.has_packed_ = true;
574   }
575   if (matmul->b_const_) {
576     ret = MatmulBasePackMatrixB(matmul);
577     NNACL_CHECK_FALSE(ret != NNACL_OK, ret);
578     matmul->matrix_b_.has_packed_ = true;
579   }
580 
581   if (matmul->base_.in_size_ == THREE_TENSOR) {
582     /* deal with const bias */
583     bool bias_const = IsConst(self->in_[THIRD_INPUT]);
584     if (!matmul->infer_shape_ && bias_const && !matmul->base_.train_session_ && matmul->matrix_c_.origin_ptr_ == NULL) {
585       ret = MatmulBaseBackupConstMatrix(matmul, &matmul->matrix_c_, THIRD_INPUT);
586       NNACL_CHECK_FALSE(ret != NNACL_OK, ret);
587     }
588   }
589   return NNACL_OK;
590 }
591 
MatmulBaseCompute(struct KernelBase * self)592 int MatmulBaseCompute(struct KernelBase *self) {
593   MatmulStruct *matmul = (MatmulStruct *)self;
594 
595   float *out_data = (float *)(matmul->base_.out_[FIRST_INPUT]->data_);
596   NNACL_CHECK_FALSE(out_data == NULL, NNACL_ERR);
597   if (!matmul->out_need_aligned_) {
598     matmul->output_data_ = out_data;
599   }
600 
601   if (!matmul->a_const_) {
602     int ret = MatmulBasePackMatrixA(matmul);
603     NNACL_CHECK_FALSE(ret != NNACL_OK, ret);
604   }
605   if (!matmul->b_const_) {
606     int ret = MatmulBasePackMatrixB(matmul);
607     NNACL_CHECK_FALSE(ret != NNACL_OK, ret);
608   }
609 
610   NNACL_CHECK_NULL_RETURN_ERR(matmul->matrix_a_.pack_ptr_);
611   NNACL_CHECK_NULL_RETURN_ERR(matmul->matrix_b_.pack_ptr_);
612 
613   int ret = self->env_->ParallelLaunch(self->env_->thread_pool_, MatmulFp32Run, self, self->thread_nr_);
614   NNACL_CHECK_FALSE(ret != NNACL_OK, ret);
615 
616   if (matmul->out_need_aligned_) {
617     PackNHWCXToNHWCFp32(matmul->output_data_, out_data, matmul->batch_, matmul->compute_.row_, matmul->compute_.col_,
618                         matmul->compute_.col_tile_);
619   } else {
620     matmul->output_data_ = NULL;
621   }
622   if (!matmul->a_const_) {
623     MatmulBaseFreePackedMatrixA(self);
624   }
625 
626   if (!matmul->b_const_) {
627     MatmulBaseFreePackedMatrixB(self);
628   }
629   return NNACL_OK;
630 }
631 
InitMatrixInfo(MatrixInfo * info)632 void InitMatrixInfo(MatrixInfo *info) {
633   info->need_pack_ = false;
634   info->has_packed_ = false;
635   info->origin_need_free_ = false;
636   info->pack_size_ = -1;
637   info->origin_ptr_ = NULL;
638   info->pack_ptr_ = NULL;
639 }
640 
CreateMatmulBase()641 KernelBase *CreateMatmulBase() {
642   NNACL_LOG_INFO("Attempting to allocate memory for MatmulStruct, size: %zu bytes", sizeof(MatmulStruct));
643   MatmulStruct *matmul = (MatmulStruct *)malloc(sizeof(MatmulStruct));
644   NNACL_MALLOC_CHECK_NULL_RETURN_NULL(matmul);
645   memset(matmul, 0, sizeof(MatmulStruct));
646   matmul->base_.Prepare = MatmulBasePrepare;
647   matmul->base_.Resize = MatmulBaseResize;
648   matmul->base_.Release = MatmulBaseRelease;
649   matmul->base_.Compute = MatmulBaseCompute;
650   InitMatrixInfo(&(matmul->matrix_a_));
651   InitMatrixInfo(&(matmul->matrix_b_));
652   InitMatrixInfo(&(matmul->matrix_c_));
653   matmul->is_sharing_pack_ = false;
654   matmul->pack_opt_ = false;
655   matmul->a_const_ = false;
656   matmul->b_const_ = false;
657   matmul->bias_need_repack_ = false;
658   matmul->out_need_aligned_ = false;
659   matmul->a_offset_ = NULL;
660   matmul->b_offset_ = NULL;
661   matmul->model_thread_nr_ = -1;
662   matmul->support_mul_batch_cut_by_row_ = false;
663   matmul->matmul_type_ = kMatmulFp32BaseCpu;
664   matmul->get_thread_cutting_policy_ = MatmulBaseGetThreadCuttingPolicy;
665   matmul->check_thread_cutting_by_row_ = MatmulBaseCheckThreadCuttingByRow;
666   matmul->get_thread_cutting_info_by_row_ = MatmulBaseGetThreadCuttingInfoByRow;
667   matmul->init_parameter_ = MatmulBaseInitParameter;
668   matmul->init_global_varibale_ = MatmulBaseInitGlobalVariable;
669   matmul->pack_matrix_a_impl_opt_ = MatmulBasePackMatrixAImplOpt;
670   matmul->pack_matrix_a_impl_ = MatmulBasePackMatrixAImpl;
671   matmul->pack_matrix_b_impl_ = MatmulBasePackMatrixBImpl;
672   matmul->parallel_run_by_batch_ = MatmulBaseParallelRunByBatch;
673   matmul->parallel_run_not_pack_by_batch_ = MatmulBaseParallelRunIsNotPackByBatch;
674   matmul->parallel_run_by_oc_ = MatmulBaseParallelRunByOC;
675   matmul->parallel_run_by_row_ = MatmulBaseParallelRunByRow;
676   return (KernelBase *)matmul;
677 }
678