• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifdef ENABLE_AVX
18 #include "nnacl/kernel/matmul_avx.h"
19 #include "nnacl/kernel/matmul_base.h"
20 #include "nnacl/fp32/matmul_fp32.h"
21 #include "nnacl/fp32/pack_fp32.h"
22 
MatmulAVXInitGlobalVariable(MatmulStruct * matmul)23 void MatmulAVXInitGlobalVariable(MatmulStruct *matmul) {
24   MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_);
25   matmul->compute_.row_tile_ = C1NUM;
26   matmul->compute_.col_tile_ = C8NUM;
27   matmul->compute_.col_min_unit_ = C32NUM;
28   matmul->out_need_aligned_ = true;
29   matmul->matrix_b_.need_pack_ = true;
30   matmul->matrix_a_.need_pack_ = param->a_transpose_;
31   matmul->matrix_a_pack_fun_ = param->a_transpose_ ? RowMajor2ColMajorParallel : RowMajor2RowMajorParallel;
32   matmul->matrix_b_pack_fun_ = param->b_transpose_ ? RowMajor2Col32MajorParallel : RowMajor2Row32MajorParallel;
33 }
34 
MatmulAVXParallelRunByBatch(MatmulStruct * matmul,int task_id)35 int MatmulAVXParallelRunByBatch(MatmulStruct *matmul, int task_id) {
36   MatMulParameter *param = (MatMulParameter *)matmul->base_.param_;
37   MatmulComputeParam *compute = (MatmulComputeParam *)&matmul->compute_;
38 
39   int start_batch = task_id * compute->batch_stride_;
40   int end_batch = MSMIN(matmul->batch_, start_batch + compute->batch_stride_);
41   int func_flag = 0;
42   if (matmul->compute_.row_ == 1) {
43     func_flag += (!matmul->b_const_ && compute->col_ <= C128NUM) ? C2NUM : C1NUM;
44   }
45 
46   ActType act = param->act_type_;
47   for (int index = start_batch; index < end_batch; ++index) {
48     const float *a = matmul->matrix_a_.pack_ptr_ + matmul->a_offset_[index] * compute->row_align_ * compute->deep_;
49     const float *b = matmul->matrix_b_.pack_ptr_ + matmul->b_offset_[index] * compute->deep_ * compute->col_align_;
50     float *c = matmul->output_data_ + index * compute->row_ * compute->col_step_;
51 
52     float *bias = (matmul->matrix_c_.pack_ptr_ == NULL) ? NULL : matmul->matrix_c_.pack_ptr_;
53     if (func_flag == 0) {
54       MatMulAvxFp32(a, b, c, bias, act, compute->deep_, compute->col_step_, compute->col_align_, compute->row_);
55     } else if (func_flag == C1NUM) {
56       MatVecMulAvxFp32(a, b, c, bias, act, compute->deep_, compute->col_step_, compute->col_align_);
57     } else {
58       MatVecMulNoPackFp32(a, b, c, bias, act, compute->deep_, compute->col_step_, compute->col_step_);
59     }
60   }
61   return NNACL_OK;
62 }
63 
MatmulAVXParallelRunByRow(MatmulStruct * matmul,int task_id)64 int MatmulAVXParallelRunByRow(MatmulStruct *matmul, int task_id) {
65   MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_);
66   NNACL_CHECK_FALSE(task_id < 0 || task_id >= matmul->base_.thread_nr_, NNACL_ERR);
67   MatmulComputeParam *compute = (MatmulComputeParam *)&matmul->compute_;
68 
69   int start_row = matmul->split_points_[task_id];
70   int end_row = compute->row_num_;
71   if (task_id < (matmul->base_.thread_nr_ - 1)) {
72     end_row = matmul->split_points_[task_id + 1];
73   }
74   int row_num = end_row - start_row;
75   if (row_num <= 0) {
76     return NNACL_OK;
77   }
78   const float *input = matmul->matrix_a_.pack_ptr_ + start_row * compute->deep_;
79   float *output = matmul->output_data_ + start_row * compute->col_align_;
80   if (compute->col_ == 1) {
81     float bias = 0;
82     if (matmul->matrix_c_.pack_ptr_ != NULL) {
83       bias = matmul->matrix_c_.pack_ptr_[0];
84     }
85     matmul->gemm_not_pack_fun_(input, matmul->matrix_b_.pack_ptr_, output, &bias, row_num, compute->deep_,
86                                param->act_type_);
87   } else {
88     MatMulAvxFp32(input, matmul->matrix_b_.pack_ptr_, output, matmul->matrix_c_.pack_ptr_, param->act_type_,
89                   compute->deep_, compute->col_align_, compute->col_align_, row_num);
90   }
91   return NNACL_OK;
92 }
93 
MatmulAVXParallelRunByOC(MatmulStruct * matmul,int task_id)94 int MatmulAVXParallelRunByOC(MatmulStruct *matmul, int task_id) {
95   MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_);
96   NNACL_CHECK_FALSE(task_id < 0 || task_id >= matmul->base_.thread_nr_, NNACL_ERR);
97   MatmulComputeParam *compute = (MatmulComputeParam *)&matmul->compute_;
98   ActType act = param->act_type_;
99 
100   int start_oc = matmul->split_points_[task_id];
101   int end_oc = compute->col_step_;
102   if (task_id < (matmul->base_.thread_nr_ - 1)) {
103     end_oc = matmul->split_points_[task_id + 1];
104   }
105   int compute_oc = end_oc - start_oc;
106   if (compute_oc <= 0) {
107     return NNACL_OK;
108   }
109   int func_flag = 0;
110   if (compute->row_ == 1) {
111     func_flag += (!matmul->b_const_ && compute->col_ <= C128NUM) ? C2NUM : C1NUM;
112   }
113   int b_stride = func_flag == C2NUM ? start_oc : start_oc * compute->deep_;
114 
115   for (int i = 0; i < matmul->batch_; ++i) {
116     float *a = matmul->matrix_a_.pack_ptr_ + matmul->a_offset_[i] * compute->row_align_ * compute->deep_;
117     float *b = matmul->matrix_b_.pack_ptr_ + matmul->b_offset_[i] * compute->deep_ * compute->col_align_ + b_stride;
118     float *c = matmul->output_data_ + i * compute->row_ * compute->col_step_ + start_oc;
119     float *bias = (matmul->matrix_c_.pack_ptr_ == NULL) ? NULL : matmul->matrix_c_.pack_ptr_ + start_oc;
120 
121     if (func_flag == 0) {
122       MatMulAvxFp32(a, b, c, bias, param->act_type_, compute->deep_, compute_oc, compute->col_align_, compute->row_);
123     } else if (func_flag == C1NUM) {
124       MatVecMulAvxFp32(a, b, c, bias, act, compute->deep_, compute_oc, compute->col_align_);
125     } else {
126       MatVecMulNoPackFp32(a, b, c, bias, act, compute->deep_, compute_oc, compute->col_step_);
127     }
128   }
129   return NNACL_OK;
130 }
131 
MatmulAVXCheckThreadCuttingByRow(MatmulStruct * matmul)132 bool MatmulAVXCheckThreadCuttingByRow(MatmulStruct *matmul) {
133   if (matmul->b_batch_ != C1NUM) {
134     return false;
135   }
136   if (matmul->compute_.row_num_ < matmul->base_.thread_nr_) {
137     return false;
138   }
139   if (matmul->compute_.col_ == 1) {
140     matmul->compute_.row_min_unit_ = C4NUM;
141     return true;
142   }
143   if (matmul->compute_.row_ == 1 && !matmul->b_const_ && matmul->compute_.col_ <= C128NUM) {
144     return false;
145   }
146   matmul->compute_.row_min_unit_ = C3NUM;
147   if (matmul->compute_.col_step_ < C16NUM) {
148     matmul->compute_.row_min_unit_ = C8NUM;
149   } else if (matmul->compute_.col_step_ < C24NUM) {
150     matmul->compute_.row_min_unit_ = C6NUM;
151   } else if (matmul->compute_.col_step_ < C32NUM) {
152     matmul->compute_.row_min_unit_ = C4NUM;
153   }
154   return MSMIN(matmul->compute_.row_num_ / matmul->compute_.row_min_unit_, matmul->base_.thread_nr_) >
155          MSMIN(matmul->compute_.col_step_ / matmul->compute_.col_min_unit_, matmul->base_.thread_nr_);
156 }
157 
CreateMatmulAVX()158 KernelBase *CreateMatmulAVX() {
159   MatmulStruct *matmul = (MatmulStruct *)CreateMatmulBase();
160   NNACL_MALLOC_CHECK_NULL_RETURN_NULL(matmul);
161   matmul->matmul_type_ = kNotImplemented;
162   matmul->init_global_varibale_ = MatmulAVXInitGlobalVariable;
163   matmul->parallel_run_by_batch_ = MatmulAVXParallelRunByBatch;
164   matmul->parallel_run_by_row_ = MatmulAVXParallelRunByRow;
165   matmul->parallel_run_by_oc_ = MatmulAVXParallelRunByOC;
166   matmul->check_thread_cutting_by_row_ = MatmulAVXCheckThreadCuttingByRow;
167   return (KernelBase *)matmul;
168 }
169 #endif
170