• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef NNACL_KERNEL_MATMUL_STRUCT_H_
17 #define NNACL_KERNEL_MATMUL_STRUCT_H_
18 
19 #include "nnacl/kernel.h"
20 #include "nnacl/matmul_parameter.h"
21 
22 #define SPLIT_COUNT MAX_THREAD_NUM
23 
24 typedef struct MatrixInfo {
25   bool need_pack_;
26   bool has_packed_;        // only valid for constant, only do once throughout the process.
27   bool origin_need_free_;  // true when failing to infer shape, false in conv1x1 free in convolution delegate
28   int pack_size_;
29   float *origin_ptr_;  // only valid for constant, which is synchronized with the 'has_origin'.
30   float *pack_ptr_;
31 } MatrixInfo;
32 
33 typedef struct MatmulSlice {
34   int row_s_;
35   int row_e_;
36   int col_s_;
37   int col_e_;
38 } MatmulSlice;
39 
40 typedef struct MatmulComputeParam {
41   int row_;
42   int col_;
43   int deep_;
44   int row_align_;
45   int col_align_;
46   int deep_align_;
47   int row_num_;
48   int col_tile_;
49   int row_tile_;
50   int col_step_;
51   int row_min_unit_;
52   int col_min_unit_;
53   int batch_stride_;
54   int pack_b_stride_;
55   int block_col_unit_;
56 } MatmulComputeParam;
57 
58 typedef struct MatmulStruct {
59   KernelBase base_;
60   MatmulComputeParam compute_;
61   MatmulType matmul_type_;
62 
63   /* model pool optimize */
64   int model_thread_nr_;
65 
66   /* batch-matmul broadcast */
67   int batch_;
68   int a_batch_;
69   int b_batch_;
70   int *a_offset_; /* batch_ size */
71   int *b_offset_; /* batch_ size */
72 
73   int split_points_[SPLIT_COUNT];
74 
75   float *output_data_;
76   float *pack_b_src_;
77   float *pack_b_dst_;
78 
79   bool a_const_;
80   bool b_const_;
81   bool bias_need_repack_;
82   bool infer_shape_;
83   bool pack_opt_;
84   bool is_sharing_pack_;
85   bool out_need_aligned_;
86   bool weight_is_packed_;
87   bool support_mul_batch_cut_by_row_;
88 
89   MatrixInfo matrix_a_;
90   MatrixInfo matrix_b_;
91   MatrixInfo matrix_c_;
92 
93   void (*matrix_a_pack_fun_)(const float *src_ptr, float *dst_ptr, int row, int col, int start_row, int end_row);
94   void (*matrix_b_pack_fun_)(const float *src_ptr, float *dst_ptr, int row, int col, int start_row, int end_row);
95 
96   int (*pack_matrix_a_impl_opt_)(struct MatmulStruct *matmul);
97   int (*pack_matrix_a_impl_)(struct MatmulStruct *matmul);
98   int (*pack_matrix_b_impl_)(struct MatmulStruct *matmul);
99 
100   int (*init_parameter_)(struct MatmulStruct *matmul);
101   void (*init_global_varibale_)(struct MatmulStruct *matmul);
102 
103   bool (*check_thread_cutting_by_row_)(struct MatmulStruct *matmul);
104   void (*get_thread_cutting_policy_)(struct MatmulStruct *matmul);
105   void (*get_thread_cutting_info_by_row_)(struct MatmulStruct *matmul);
106 
107   void *shaing_manager_;
108   void *(*get_sharing_weight_)(void *manager, const void *tensor_data, const size_t size, bool *is_packed);
109   void (*free_sharing_weight_)(void *manager, void *tensor_data);
110 
111   void (*gemm_not_pack_fun_)(const float *a, const float *b, float *c, const float *bias, int m, int k, int act_type);
112 
113   int (*parallel_run_)(struct MatmulStruct *matmul, int task_id);
114   int (*parallel_run_by_row_)(struct MatmulStruct *matmul, int task_id);
115   int (*parallel_run_by_oc_)(struct MatmulStruct *matmul, int task_id);
116   int (*parallel_run_by_batch_)(struct MatmulStruct *matmul, int task_id);
117   int (*parallel_run_not_pack_by_batch_)(struct MatmulStruct *matmul, int task_id);
118 
119   /* optimize for avx512 */
120   int col_split_points_size_;
121   int row_split_points_size_;
122   int col_split_points_[SPLIT_COUNT];
123   int row_split_points_[SPLIT_COUNT];
124   int matmul_slice_count_[SPLIT_COUNT];
125   MatmulSlice matmul_slice_set_[SPLIT_COUNT][SPLIT_COUNT];
126   int (*parallel_run_by_gemm_)(struct MatmulStruct *matmul, int task_id);
127   int (*parallel_run_by_gepm_)(struct MatmulStruct *matmul, int task_id);
128   int (*parallel_run_by_gepdot_)(struct MatmulStruct *matmul, int task_id);
129   int (*parallel_run_by_batch_col_row_gemm_)(struct MatmulStruct *matmul, int task_id);
130   int (*parallel_run_by_row1_deep1_gepdot_)(struct MatmulStruct *matmul, int task_id);
131 } MatmulStruct;
132 
133 #endif  // NNACL_KERNEL_MATMUL_STRUCT_H_
134