1 /** 2 * Copyright 2023 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef NNACL_KERNEL_MATMUL_STRUCT_H_ 17 #define NNACL_KERNEL_MATMUL_STRUCT_H_ 18 19 #include "nnacl/kernel.h" 20 #include "nnacl/matmul_parameter.h" 21 22 #define SPLIT_COUNT MAX_THREAD_NUM 23 24 typedef struct MatrixInfo { 25 bool need_pack_; 26 bool has_packed_; // only valid for constant, only do once throughout the process. 27 bool origin_need_free_; // true when failing to infer shape, false in conv1x1 free in convolution delegate 28 int pack_size_; 29 float *origin_ptr_; // only valid for constant, which is synchronized with the 'has_origin'. 30 float *pack_ptr_; 31 } MatrixInfo; 32 33 typedef struct MatmulSlice { 34 int row_s_; 35 int row_e_; 36 int col_s_; 37 int col_e_; 38 } MatmulSlice; 39 40 typedef struct MatmulComputeParam { 41 int row_; 42 int col_; 43 int deep_; 44 int row_align_; 45 int col_align_; 46 int deep_align_; 47 int row_num_; 48 int col_tile_; 49 int row_tile_; 50 int col_step_; 51 int row_min_unit_; 52 int col_min_unit_; 53 int batch_stride_; 54 int pack_b_stride_; 55 int block_col_unit_; 56 } MatmulComputeParam; 57 58 typedef struct MatmulStruct { 59 KernelBase base_; 60 MatmulComputeParam compute_; 61 MatmulType matmul_type_; 62 63 /* model pool optimize */ 64 int model_thread_nr_; 65 66 /* batch-matmul broadcast */ 67 int batch_; 68 int a_batch_; 69 int b_batch_; 70 int *a_offset_; /* batch_ size */ 71 int *b_offset_; /* batch_ size */ 72 73 int split_points_[SPLIT_COUNT]; 74 75 float *output_data_; 76 float *pack_b_src_; 77 float *pack_b_dst_; 78 79 bool a_const_; 80 bool b_const_; 81 bool bias_need_repack_; 82 bool infer_shape_; 83 bool pack_opt_; 84 bool is_sharing_pack_; 85 bool out_need_aligned_; 86 bool weight_is_packed_; 87 bool support_mul_batch_cut_by_row_; 88 89 MatrixInfo matrix_a_; 90 MatrixInfo matrix_b_; 91 MatrixInfo matrix_c_; 92 93 void (*matrix_a_pack_fun_)(const float *src_ptr, float *dst_ptr, int row, int col, int start_row, int end_row); 94 void (*matrix_b_pack_fun_)(const float *src_ptr, float *dst_ptr, int row, int col, int start_row, int end_row); 95 96 int (*pack_matrix_a_impl_opt_)(struct MatmulStruct *matmul); 97 int (*pack_matrix_a_impl_)(struct MatmulStruct *matmul); 98 int (*pack_matrix_b_impl_)(struct MatmulStruct *matmul); 99 100 int (*init_parameter_)(struct MatmulStruct *matmul); 101 void (*init_global_varibale_)(struct MatmulStruct *matmul); 102 103 bool (*check_thread_cutting_by_row_)(struct MatmulStruct *matmul); 104 void (*get_thread_cutting_policy_)(struct MatmulStruct *matmul); 105 void (*get_thread_cutting_info_by_row_)(struct MatmulStruct *matmul); 106 107 void *shaing_manager_; 108 void *(*get_sharing_weight_)(void *manager, const void *tensor_data, const size_t size, bool *is_packed); 109 void (*free_sharing_weight_)(void *manager, void *tensor_data); 110 111 void (*gemm_not_pack_fun_)(const float *a, const float *b, float *c, const float *bias, int m, int k, int act_type); 112 113 int (*parallel_run_)(struct MatmulStruct *matmul, int task_id); 114 int (*parallel_run_by_row_)(struct MatmulStruct *matmul, int task_id); 115 int (*parallel_run_by_oc_)(struct MatmulStruct *matmul, int task_id); 116 int (*parallel_run_by_batch_)(struct MatmulStruct *matmul, int task_id); 117 int (*parallel_run_not_pack_by_batch_)(struct MatmulStruct *matmul, int task_id); 118 119 /* optimize for avx512 */ 120 int col_split_points_size_; 121 int row_split_points_size_; 122 int col_split_points_[SPLIT_COUNT]; 123 int row_split_points_[SPLIT_COUNT]; 124 int matmul_slice_count_[SPLIT_COUNT]; 125 MatmulSlice matmul_slice_set_[SPLIT_COUNT][SPLIT_COUNT]; 126 int (*parallel_run_by_gemm_)(struct MatmulStruct *matmul, int task_id); 127 int (*parallel_run_by_gepm_)(struct MatmulStruct *matmul, int task_id); 128 int (*parallel_run_by_gepdot_)(struct MatmulStruct *matmul, int task_id); 129 int (*parallel_run_by_batch_col_row_gemm_)(struct MatmulStruct *matmul, int task_id); 130 int (*parallel_run_by_row1_deep1_gepdot_)(struct MatmulStruct *matmul, int task_id); 131 } MatmulStruct; 132 133 #endif // NNACL_KERNEL_MATMUL_STRUCT_H_ 134