1 /**
2 * Copyright 2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #ifdef ENABLE_AVX
18 #include "nnacl/kernel/matmul_avx.h"
19 #include "nnacl/kernel/matmul_base.h"
20 #include "nnacl/fp32/matmul_fp32.h"
21 #include "nnacl/fp32/pack_fp32.h"
22
MatmulAVXInitGlobalVariable(MatmulStruct * matmul)23 void MatmulAVXInitGlobalVariable(MatmulStruct *matmul) {
24 MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_);
25 matmul->compute_.row_tile_ = C1NUM;
26 matmul->compute_.col_tile_ = C8NUM;
27 matmul->compute_.col_min_unit_ = C32NUM;
28 matmul->out_need_aligned_ = true;
29 matmul->matrix_b_.need_pack_ = true;
30 matmul->matrix_a_.need_pack_ = param->a_transpose_;
31 matmul->matrix_a_pack_fun_ = param->a_transpose_ ? RowMajor2ColMajorParallel : RowMajor2RowMajorParallel;
32 matmul->matrix_b_pack_fun_ = param->b_transpose_ ? RowMajor2Col32MajorParallel : RowMajor2Row32MajorParallel;
33 }
34
MatmulAVXParallelRunByBatch(MatmulStruct * matmul,int task_id)35 int MatmulAVXParallelRunByBatch(MatmulStruct *matmul, int task_id) {
36 MatMulParameter *param = (MatMulParameter *)matmul->base_.param_;
37 MatmulComputeParam *compute = (MatmulComputeParam *)&matmul->compute_;
38
39 int start_batch = task_id * compute->batch_stride_;
40 int end_batch = MSMIN(matmul->batch_, start_batch + compute->batch_stride_);
41 int func_flag = 0;
42 if (matmul->compute_.row_ == 1) {
43 func_flag += (!matmul->b_const_ && compute->col_ <= C128NUM) ? C2NUM : C1NUM;
44 }
45
46 ActType act = param->act_type_;
47 for (int index = start_batch; index < end_batch; ++index) {
48 const float *a = matmul->matrix_a_.pack_ptr_ + matmul->a_offset_[index] * compute->row_align_ * compute->deep_;
49 const float *b = matmul->matrix_b_.pack_ptr_ + matmul->b_offset_[index] * compute->deep_ * compute->col_align_;
50 float *c = matmul->output_data_ + index * compute->row_ * compute->col_step_;
51
52 float *bias = (matmul->matrix_c_.pack_ptr_ == NULL) ? NULL : matmul->matrix_c_.pack_ptr_;
53 if (func_flag == 0) {
54 MatMulAvxFp32(a, b, c, bias, act, compute->deep_, compute->col_step_, compute->col_align_, compute->row_);
55 } else if (func_flag == C1NUM) {
56 MatVecMulAvxFp32(a, b, c, bias, act, compute->deep_, compute->col_step_, compute->col_align_);
57 } else {
58 MatVecMulNoPackFp32(a, b, c, bias, act, compute->deep_, compute->col_step_, compute->col_step_);
59 }
60 }
61 return NNACL_OK;
62 }
63
MatmulAVXParallelRunByRow(MatmulStruct * matmul,int task_id)64 int MatmulAVXParallelRunByRow(MatmulStruct *matmul, int task_id) {
65 MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_);
66 NNACL_CHECK_FALSE(task_id < 0 || task_id >= matmul->base_.thread_nr_, NNACL_ERR);
67 MatmulComputeParam *compute = (MatmulComputeParam *)&matmul->compute_;
68
69 int start_row = matmul->split_points_[task_id];
70 int end_row = compute->row_num_;
71 if (task_id < (matmul->base_.thread_nr_ - 1)) {
72 end_row = matmul->split_points_[task_id + 1];
73 }
74 int row_num = end_row - start_row;
75 if (row_num <= 0) {
76 return NNACL_OK;
77 }
78 const float *input = matmul->matrix_a_.pack_ptr_ + start_row * compute->deep_;
79 float *output = matmul->output_data_ + start_row * compute->col_align_;
80 if (compute->col_ == 1) {
81 float bias = 0;
82 if (matmul->matrix_c_.pack_ptr_ != NULL) {
83 bias = matmul->matrix_c_.pack_ptr_[0];
84 }
85 matmul->gemm_not_pack_fun_(input, matmul->matrix_b_.pack_ptr_, output, &bias, row_num, compute->deep_,
86 param->act_type_);
87 } else {
88 MatMulAvxFp32(input, matmul->matrix_b_.pack_ptr_, output, matmul->matrix_c_.pack_ptr_, param->act_type_,
89 compute->deep_, compute->col_align_, compute->col_align_, row_num);
90 }
91 return NNACL_OK;
92 }
93
MatmulAVXParallelRunByOC(MatmulStruct * matmul,int task_id)94 int MatmulAVXParallelRunByOC(MatmulStruct *matmul, int task_id) {
95 MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_);
96 NNACL_CHECK_FALSE(task_id < 0 || task_id >= matmul->base_.thread_nr_, NNACL_ERR);
97 MatmulComputeParam *compute = (MatmulComputeParam *)&matmul->compute_;
98 ActType act = param->act_type_;
99
100 int start_oc = matmul->split_points_[task_id];
101 int end_oc = compute->col_step_;
102 if (task_id < (matmul->base_.thread_nr_ - 1)) {
103 end_oc = matmul->split_points_[task_id + 1];
104 }
105 int compute_oc = end_oc - start_oc;
106 if (compute_oc <= 0) {
107 return NNACL_OK;
108 }
109 int func_flag = 0;
110 if (compute->row_ == 1) {
111 func_flag += (!matmul->b_const_ && compute->col_ <= C128NUM) ? C2NUM : C1NUM;
112 }
113 int b_stride = func_flag == C2NUM ? start_oc : start_oc * compute->deep_;
114
115 for (int i = 0; i < matmul->batch_; ++i) {
116 float *a = matmul->matrix_a_.pack_ptr_ + matmul->a_offset_[i] * compute->row_align_ * compute->deep_;
117 float *b = matmul->matrix_b_.pack_ptr_ + matmul->b_offset_[i] * compute->deep_ * compute->col_align_ + b_stride;
118 float *c = matmul->output_data_ + i * compute->row_ * compute->col_step_ + start_oc;
119 float *bias = (matmul->matrix_c_.pack_ptr_ == NULL) ? NULL : matmul->matrix_c_.pack_ptr_ + start_oc;
120
121 if (func_flag == 0) {
122 MatMulAvxFp32(a, b, c, bias, param->act_type_, compute->deep_, compute_oc, compute->col_align_, compute->row_);
123 } else if (func_flag == C1NUM) {
124 MatVecMulAvxFp32(a, b, c, bias, act, compute->deep_, compute_oc, compute->col_align_);
125 } else {
126 MatVecMulNoPackFp32(a, b, c, bias, act, compute->deep_, compute_oc, compute->col_step_);
127 }
128 }
129 return NNACL_OK;
130 }
131
MatmulAVXCheckThreadCuttingByRow(MatmulStruct * matmul)132 bool MatmulAVXCheckThreadCuttingByRow(MatmulStruct *matmul) {
133 if (matmul->b_batch_ != C1NUM) {
134 return false;
135 }
136 if (matmul->compute_.row_num_ < matmul->base_.thread_nr_) {
137 return false;
138 }
139 if (matmul->compute_.col_ == 1) {
140 matmul->compute_.row_min_unit_ = C4NUM;
141 return true;
142 }
143 if (matmul->compute_.row_ == 1 && !matmul->b_const_ && matmul->compute_.col_ <= C128NUM) {
144 return false;
145 }
146 matmul->compute_.row_min_unit_ = C3NUM;
147 if (matmul->compute_.col_step_ < C16NUM) {
148 matmul->compute_.row_min_unit_ = C8NUM;
149 } else if (matmul->compute_.col_step_ < C24NUM) {
150 matmul->compute_.row_min_unit_ = C6NUM;
151 } else if (matmul->compute_.col_step_ < C32NUM) {
152 matmul->compute_.row_min_unit_ = C4NUM;
153 }
154 return MSMIN(matmul->compute_.row_num_ / matmul->compute_.row_min_unit_, matmul->base_.thread_nr_) >
155 MSMIN(matmul->compute_.col_step_ / matmul->compute_.col_min_unit_, matmul->base_.thread_nr_);
156 }
157
CreateMatmulAVX()158 KernelBase *CreateMatmulAVX() {
159 MatmulStruct *matmul = (MatmulStruct *)CreateMatmulBase();
160 NNACL_MALLOC_CHECK_NULL_RETURN_NULL(matmul);
161 matmul->matmul_type_ = kNotImplemented;
162 matmul->init_global_varibale_ = MatmulAVXInitGlobalVariable;
163 matmul->parallel_run_by_batch_ = MatmulAVXParallelRunByBatch;
164 matmul->parallel_run_by_row_ = MatmulAVXParallelRunByRow;
165 matmul->parallel_run_by_oc_ = MatmulAVXParallelRunByOC;
166 matmul->check_thread_cutting_by_row_ = MatmulAVXCheckThreadCuttingByRow;
167 return (KernelBase *)matmul;
168 }
169 #endif
170