• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifdef ENABLE_ARM64
18 #include "nnacl/kernel/matmul_arm64.h"
19 #include "nnacl/kernel/matmul_base.h"
20 #include "nnacl/fp32/matmul_fp32.h"
21 #include "nnacl/fp32/pack_fp32.h"
22 #include "nnacl/fp32/pack_fp32_opt.h"
23 
24 typedef struct MatrixAPack {
25   int64_t points_[MAX_THREAD_NUM];
26   int64_t unit_num_;
27   int thread_;
28   int deep_;
29   int row_;
30   int col_;
31   MatrixInfo *matrix_a_;
32   float *src_ptr_;
33   bool a_transpose_;
34 } MatrixAPack;
35 
MatmulARM64PackMatrixAImplOptPack(void * cdata,int task_id,float l,float r)36 int MatmulARM64PackMatrixAImplOptPack(void *cdata, int task_id, float l, float r) {
37   MatrixAPack *pack = (MatrixAPack *)cdata;
38   int64_t start = pack->points_[task_id];
39   int64_t end = pack->unit_num_;
40   if (task_id < pack->thread_ - 1) {
41     end = pack->points_[task_id + 1];
42   }
43 
44   if (pack->a_transpose_) {
45     RowMajor2Row12MajorOpt(pack->src_ptr_, pack->matrix_a_->pack_ptr_, pack->deep_, pack->row_, start, end);
46   } else {
47     RowMajor2Col12MajorOpt(pack->src_ptr_, pack->matrix_a_->pack_ptr_, pack->row_, pack->deep_, start, end);
48   }
49   return NNACL_OK;
50 }
51 
MatmulARM64PackMatrixAImplOpt(MatmulStruct * matmul)52 int MatmulARM64PackMatrixAImplOpt(MatmulStruct *matmul) {
53   int64_t kPackAMinUnitNum = 1 << 13;
54   MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_);
55   float *src_ptr = matmul->matrix_a_.origin_ptr_ != NULL ? matmul->matrix_a_.origin_ptr_
56                                                          : (float *)(matmul->base_.in_[FIRST_INPUT]->data_);
57   NNACL_CHECK_TRUE_RET(src_ptr != NULL, NNACL_ERR);
58   NNACL_CHECK_TRUE_RET(matmul->matrix_a_.pack_ptr_ != NULL, NNACL_ERR);
59 
60   MatrixAPack pack;
61   pack.src_ptr_ = src_ptr;
62   pack.matrix_a_ = &matmul->matrix_a_;
63   pack.deep_ = matmul->compute_.deep_;
64   pack.col_ = matmul->compute_.col_;
65   pack.row_ = matmul->compute_.row_;
66   pack.a_transpose_ = param->a_transpose_;
67   pack.unit_num_ = 0;
68   pack.unit_num_ = matmul->a_batch_ * UP_DIV(matmul->compute_.row_, C12NUM) * matmul->compute_.deep_;
69   pack.thread_ = MSMIN(matmul->base_.thread_nr_, UP_DIV(pack.unit_num_, kPackAMinUnitNum));
70   if (pack.thread_ < 1) {
71     pack.thread_ = 1;
72   }
73   int64_t block_size = pack.unit_num_ / pack.thread_;
74   int64_t remain_size = pack.unit_num_ - block_size * pack.thread_;
75   int64_t start = 0;
76   size_t count = 0;
77   while (start < pack.unit_num_) {
78     pack.points_[count++] = start;
79     start += block_size;
80     if (remain_size > 0) {
81       ++start;
82       --remain_size;
83     }
84   }
85   pack.thread_ = count;
86 
87   if (pack.thread_ == 1) {
88     return MatmulARM64PackMatrixAImplOptPack(&pack, 0, 0, 1);
89   }
90   return matmul->base_.env_->ParallelLaunch(matmul->base_.env_->thread_pool_, MatmulARM64PackMatrixAImplOptPack, &pack,
91                                             pack.thread_);
92 }
93 
MatmulARM64CheckThreadCuttingByRow(MatmulStruct * matmul)94 bool MatmulARM64CheckThreadCuttingByRow(MatmulStruct *matmul) {
95   if (matmul->b_batch_ != C1NUM) {
96     return false;
97   }
98   if (matmul->batch_ >= matmul->base_.thread_nr_ || matmul->compute_.col_ == 1) {
99     matmul->compute_.row_min_unit_ = C4NUM;
100     return true;
101   }
102   return false;
103 }
MatmulARM64InitGlobalVariable(MatmulStruct * matmul)104 void MatmulARM64InitGlobalVariable(MatmulStruct *matmul) {
105   MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_);
106   matmul->pack_opt_ = true;
107   matmul->compute_.row_tile_ = C12NUM;
108   matmul->compute_.col_tile_ = C8NUM;
109   matmul->compute_.col_min_unit_ = C8NUM;
110   matmul->matrix_a_.need_pack_ = true;
111   matmul->matrix_b_.need_pack_ = !matmul->weight_is_packed_;
112   matmul->matrix_a_pack_fun_ = param->a_transpose_ ? RowMajor2Row12MajorParallel : RowMajor2Col12MajorParallel;
113   matmul->matrix_b_pack_fun_ = param->b_transpose_ ? RowMajor2Col8MajorParallel : RowMajor2Row8MajorParallel;
114 }
115 
MatmulARM64ParallelRunByBatch(MatmulStruct * matmul,int task_id)116 int MatmulARM64ParallelRunByBatch(MatmulStruct *matmul, int task_id) {
117   NNACL_CHECK_FALSE(task_id < 0 || task_id >= matmul->base_.thread_nr_, NNACL_ERR);
118   MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_);
119   MatmulComputeParam *compute = &matmul->compute_;
120   ActType act = param->act_type_;
121 
122   int start_batch = task_id * compute->batch_stride_;
123   int end_batch = MSMIN(matmul->batch_, start_batch + compute->batch_stride_);
124   int func_flag = 0;
125   if (compute->row_ == 1) {
126     func_flag += (!matmul->b_const_ && compute->col_ <= C128NUM) ? C2NUM : C1NUM;
127   }
128 
129   for (int index = start_batch; index < end_batch; ++index) {
130     const float *a = matmul->matrix_a_.pack_ptr_ + matmul->a_offset_[index] * compute->row_align_ * compute->deep_;
131     const float *b = matmul->matrix_b_.pack_ptr_ + matmul->b_offset_[index] * compute->deep_ * compute->col_align_;
132     float *c = matmul->output_data_ + index * compute->row_ * compute->col_step_;
133     float *bias = (matmul->matrix_c_.pack_ptr_ == NULL) ? NULL : matmul->matrix_c_.pack_ptr_;
134 
135     if (func_flag == 0) {
136       MatMulOpt(a, b, c, bias, act, compute->deep_, compute->row_, compute->col_step_, compute->col_, OutType_Nhwc);
137     } else if (func_flag == C1NUM) {
138       MatVecMulPackFp32(a, b, c, bias, act, compute->deep_, compute->col_step_);
139     } else {
140       MatVecMulNoPackFp32(a, b, c, bias, act, compute->deep_, compute->col_step_, compute->col_step_);
141     }
142   }
143   return NNACL_OK;
144 }
145 
MatmulARM64ParallelRunByRow(MatmulStruct * matmul,int task_id)146 int MatmulARM64ParallelRunByRow(MatmulStruct *matmul, int task_id) {
147   MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_);
148   NNACL_CHECK_FALSE(task_id < 0 || task_id >= matmul->base_.thread_nr_, NNACL_ERR);
149 
150   int start_row = matmul->split_points_[task_id];
151   int end_row = matmul->compute_.row_num_;
152   if (task_id < (matmul->base_.thread_nr_ - 1)) {
153     end_row = matmul->split_points_[task_id + 1];
154   }
155   int row_num = end_row - start_row;
156   if (row_num <= 0) {
157     return NNACL_OK;
158   }
159   GemmIsNotPackByRow(matmul->matrix_a_.pack_ptr_, matmul->matrix_b_.pack_ptr_, matmul->output_data_,
160                      matmul->matrix_c_.pack_ptr_, start_row, end_row, matmul->compute_.deep_, param->act_type_);
161   return NNACL_OK;
162 }
163 
MatmulARM64ParallelRunByOC(MatmulStruct * matmul,int task_id)164 int MatmulARM64ParallelRunByOC(MatmulStruct *matmul, int task_id) {
165   NNACL_CHECK_FALSE(task_id < 0 || task_id >= matmul->base_.thread_nr_, NNACL_ERR);
166   MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_);
167   MatmulComputeParam *compute = &matmul->compute_;
168   ActType act = param->act_type_;
169 
170   int start_oc = matmul->split_points_[task_id];
171   int end_oc = compute->col_step_;
172   if (task_id < (matmul->base_.thread_nr_ - 1)) {
173     end_oc = matmul->split_points_[task_id + 1];
174   }
175   int compute_oc = end_oc - start_oc;
176   if (compute_oc <= 0) {
177     return NNACL_OK;
178   }
179   int func_flag = 0;
180   if (compute->row_ == 1) {
181     func_flag += (!matmul->b_const_ && compute->col_ <= C128NUM) ? C2NUM : C1NUM;
182   }
183   int b_stride = func_flag == C2NUM ? start_oc : start_oc * compute->deep_;
184 
185   for (int i = 0; i < matmul->batch_; ++i) {
186     float *a = matmul->matrix_a_.pack_ptr_ + matmul->a_offset_[i] * compute->row_align_ * compute->deep_;
187     float *b = matmul->matrix_b_.pack_ptr_ + matmul->b_offset_[i] * compute->deep_ * compute->col_align_ + b_stride;
188     float *c = matmul->output_data_ + i * compute->row_ * compute->col_step_ + start_oc;
189     float *bias = (matmul->matrix_c_.pack_ptr_ == NULL) ? NULL : matmul->matrix_c_.pack_ptr_ + start_oc;
190 
191     if (func_flag == 0) {
192       MatMulOpt(a, b, c, bias, act, compute->deep_, compute->row_, compute_oc, compute->col_, OutType_Nhwc);
193     } else if (func_flag == C1NUM) {
194       MatVecMulPackFp32(a, b, c, bias, act, compute->deep_, compute_oc);
195     } else {
196       MatVecMulNoPackFp32(a, b, c, bias, act, compute->deep_, compute_oc, compute->col_step_);
197     }
198   }
199   return NNACL_OK;
200 }
201 
CreateMatmulARM64()202 KernelBase *CreateMatmulARM64() {
203   MatmulStruct *matmul = (MatmulStruct *)CreateMatmulBase();
204   NNACL_MALLOC_CHECK_NULL_RETURN_NULL(matmul);
205   matmul->matmul_type_ = kMatmulFp32Arm64Cpu;
206   matmul->check_thread_cutting_by_row_ = MatmulARM64CheckThreadCuttingByRow;
207   matmul->init_global_varibale_ = MatmulARM64InitGlobalVariable;
208   matmul->parallel_run_by_oc_ = MatmulARM64ParallelRunByOC;
209   matmul->parallel_run_by_row_ = MatmulARM64ParallelRunByRow;
210   matmul->parallel_run_by_batch_ = MatmulARM64ParallelRunByBatch;
211   matmul->pack_matrix_a_impl_opt_ = MatmulARM64PackMatrixAImplOpt;
212   return (KernelBase *)matmul;
213 }
214 #endif
215