• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifdef ENABLE_ARM32
18 #include "nnacl/kernel/matmul_arm32.h"
19 #include "nnacl/kernel/matmul_base.h"
20 #include "nnacl/fp32/pack_fp32.h"
21 #include "nnacl/fp32/matmul_fp32.h"
22 
MatmulARM32InitGlobalVariable(MatmulStruct * matmul)23 void MatmulARM32InitGlobalVariable(MatmulStruct *matmul) {
24   MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_);
25   matmul->matrix_a_.need_pack_ = true;
26   matmul->matrix_b_.need_pack_ = true;
27   matmul->matrix_a_pack_fun_ = param->a_transpose_ ? RowMajor2Row12MajorParallel : RowMajor2Col12MajorParallel;
28   matmul->matrix_b_pack_fun_ = param->b_transpose_ ? RowMajor2Col4MajorParallel : RowMajor2Row4MajorParallel;
29   matmul->compute_.row_tile_ = C12NUM;
30   matmul->compute_.col_tile_ = C4NUM;
31   matmul->compute_.col_min_unit_ = C4NUM;
32 }
33 
MatmulARM32ParallelRunByBatch(MatmulStruct * matmul,int task_id)34 int MatmulARM32ParallelRunByBatch(MatmulStruct *matmul, int task_id) {
35   MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_);
36   MatmulComputeParam *compute = &matmul->compute_;
37   ActType act = param->act_type_;
38 
39   int start_batch = task_id * compute->batch_stride_;
40   int end_batch = MSMIN(matmul->batch_, start_batch + compute->batch_stride_);
41   int func_flag = 0;
42   if (compute->row_ == 1) {
43     func_flag += (!matmul->b_const_ && compute->col_ <= C128NUM) ? C2NUM : C1NUM;
44   }
45 
46   for (int index = start_batch; index < end_batch; ++index) {
47     const float *a = matmul->matrix_a_.pack_ptr_ + matmul->a_offset_[index] * compute->row_align_ * compute->deep_;
48     const float *b = matmul->matrix_b_.pack_ptr_ + matmul->b_offset_[index] * compute->deep_ * compute->col_align_;
49     float *c = matmul->output_data_ + index * compute->row_ * compute->col_step_;
50 
51     float *bias = (matmul->matrix_c_.pack_ptr_ == NULL) ? NULL : matmul->matrix_c_.pack_ptr_;
52     if (func_flag == 0) {
53       MatMulOpt(a, b, c, bias, act, compute->deep_, compute->row_, compute->col_step_, compute->col_, OutType_Nhwc);
54     } else if (func_flag == C1NUM) {
55       MatVecMulFp32Block4(a, b, c, bias, act, compute->deep_, compute->col_step_);
56     } else {
57       MatVecMulNoPackFp32(a, b, c, bias, act, compute->deep_, compute->col_step_, compute->col_step_);
58     }
59   }
60   return NNACL_OK;
61 }
62 
MatmulARM32ParallelRunByOC(MatmulStruct * matmul,int task_id)63 int MatmulARM32ParallelRunByOC(MatmulStruct *matmul, int task_id) {
64   NNACL_CHECK_FALSE(task_id < 0 || task_id >= matmul->base_.thread_nr_, NNACL_ERR);
65   MatMulParameter *param = (MatMulParameter *)(matmul->base_.param_);
66   MatmulComputeParam *compute = &matmul->compute_;
67   ActType act = param->act_type_;
68 
69   int start_oc = matmul->split_points_[task_id];
70   int end_oc = compute->col_step_;
71   if (task_id < (matmul->base_.thread_nr_ - 1)) {
72     end_oc = matmul->split_points_[task_id + 1];
73   }
74   int compute_oc = end_oc - start_oc;
75   if (compute_oc <= 0) {
76     return NNACL_OK;
77   }
78   int func_flag = 0;
79   if (compute->row_ == 1) {
80     func_flag += (!matmul->b_const_ && compute->col_ <= C128NUM) ? C2NUM : C1NUM;
81   }
82   int b_stride = func_flag == C2NUM ? start_oc : start_oc * compute->deep_;
83 
84   for (int i = 0; i < matmul->batch_; ++i) {
85     float *a = matmul->matrix_a_.pack_ptr_ + matmul->a_offset_[i] * compute->row_align_ * compute->deep_;
86     float *b = matmul->matrix_b_.pack_ptr_ + matmul->b_offset_[i] * compute->deep_ * compute->col_align_ + b_stride;
87     float *c = matmul->output_data_ + i * compute->row_ * compute->col_step_ + start_oc;
88     float *bias = (matmul->matrix_c_.pack_ptr_ == NULL) ? NULL : matmul->matrix_c_.pack_ptr_ + start_oc;
89 
90     if (func_flag == 0) {
91       MatMulOpt(a, b, c, bias, act, compute->deep_, compute->row_, compute_oc, compute->col_, OutType_Nhwc);
92     } else if (func_flag == C1NUM) {
93       MatVecMulFp32Block4(a, b, c, bias, act, compute->deep_, compute_oc);
94     } else {
95       MatVecMulNoPackFp32(a, b, c, bias, act, compute->deep_, compute_oc, compute->col_step_);
96     }
97   }
98   return NNACL_OK;
99 }
100 
CreateMatmulARM32()101 KernelBase *CreateMatmulARM32() {
102   MatmulStruct *matmul = (MatmulStruct *)CreateMatmulBase();
103   NNACL_MALLOC_CHECK_NULL_RETURN_NULL(matmul);
104   matmul->matmul_type_ = kNotImplemented;
105   matmul->init_global_varibale_ = MatmulARM32InitGlobalVariable;
106   matmul->parallel_run_by_batch_ = MatmulARM32ParallelRunByBatch;
107   matmul->parallel_run_by_oc_ = MatmulARM32ParallelRunByOC;
108   return (KernelBase *)matmul;
109 }
110 #endif
111