• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef NNACL_MATMUL_H_
18 #define NNACL_MATMUL_H_
19 
20 #include "nnacl/op_base.h"
21 
22 typedef void (*MATMUL_OPT_R4_FUNC)(const int8_t *a, const int8_t *b, int32_t *dst, int row_4, int col_4, int deep_16,
23                                    const int32_t *input_sum, const int32_t *bias);
24 
25 typedef void (*MATMUL_OPT_R_FUNC)(const int8_t *a, const int8_t *b, int8_t *dst, size_t row, size_t col, size_t deep_4,
26                                   size_t stride, const int32_t *input_sum, const int32_t *bias,
27                                   const int32_t *left_shift, const int32_t *right_shift, const int32_t *multiplier,
28                                   int32_t output_zp, int32_t mini, int32_t maxi, size_t per_channel);
29 
30 typedef void (*MATMUL_OPT_DP_FUNC)(const int8_t *a, const int8_t *b, int8_t *dst, size_t row, size_t col, size_t deep_4,
31                                    size_t stride, const int32_t *input_sum, const int32_t *bias,
32                                    const int32_t *left_shift, const int32_t *right_shift, const int32_t *multiplier,
33                                    int32_t output_zp, int32_t mini, int32_t maxi, size_t per_channel,
34                                    const int32_t *filter_zp);
35 
36 typedef enum OutType { OutType_C8 = 0, OutType_Nhwc = 1, OutType_TileC8 = 2, OutType_NC4HW4 = 3 } OutType;
37 
38 typedef enum MatmulType {
39   // reserve 0 for base op
40   kNotImplemented = 0,
41   kMatmulInt8Cpu,
42   kMatmulDynamicInt8Cpu,
43   kMatmulDynamicSdotInt8Cpu,
44   kMatmulFp32BaseCpu,
45   kMatmulFp32Arm64Cpu,
46 } MatmulType;
47 
48 typedef struct MatMulParameter {
49   // Primitive parameter
50   OpParameter op_parameter_;
51   bool has_bias_;
52   bool use_axis_;
53   bool a_transpose_; /* false :  row-major  */
54   bool b_transpose_; /* true  :  col-major  */
55   ActType act_type_;
56 
57   // other parameter
58   int row_;
59   int col_;
60   int row_4_;
61   int row_16_;
62   int row_align_;
63   int col_8_;
64   int col_align_;
65   int deep_;
66   int deep_4_;
67   int deep_16_;
68   int deep_align_;
69   int batch;
70   bool a_const_;
71   bool b_const_;
72   int axis_;
73   MatmulType matmul_type_;
74 } MatMulParameter;
75 
76 typedef struct MatmulQuantParameter {
77   QuantArg input_;
78   QuantArg weight_;
79   QuantArg output_;
80   int32_t out_act_min_;
81   int32_t out_act_max_;
82   float *filter_scale_;
83   int32_t *filter_zp_;
84   int32_t *left_shift_;
85   int32_t *right_shift_;
86   int32_t *quant_multiplier_;
87 } MatmulQuantParameter;
88 
89 typedef struct MatmulDynamicQuantParameter {
90   float *input_scale_;
91   int32_t *input_zp_;
92   float *filter_scale_;
93   int32_t *filter_zp_;
94 } MatmulDynamicQuantParameter;
95 
96 #endif  // NNACL_MATMUL_H_
97