1 /** 2 * Copyright 2020-2023 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef NNACL_MATMUL_H_ 18 #define NNACL_MATMUL_H_ 19 20 #include "nnacl/op_base.h" 21 22 typedef void (*MATMUL_OPT_R4_FUNC)(const int8_t *a, const int8_t *b, int32_t *dst, int row_4, int col_4, int deep_16, 23 const int32_t *input_sum, const int32_t *bias); 24 25 typedef void (*MATMUL_OPT_R_FUNC)(const int8_t *a, const int8_t *b, int8_t *dst, size_t row, size_t col, size_t deep_4, 26 size_t stride, const int32_t *input_sum, const int32_t *bias, 27 const int32_t *left_shift, const int32_t *right_shift, const int32_t *multiplier, 28 int32_t output_zp, int32_t mini, int32_t maxi, size_t per_channel); 29 30 typedef void (*MATMUL_OPT_DP_FUNC)(const int8_t *a, const int8_t *b, int8_t *dst, size_t row, size_t col, size_t deep_4, 31 size_t stride, const int32_t *input_sum, const int32_t *bias, 32 const int32_t *left_shift, const int32_t *right_shift, const int32_t *multiplier, 33 int32_t output_zp, int32_t mini, int32_t maxi, size_t per_channel, 34 const int32_t *filter_zp); 35 36 typedef enum OutType { OutType_C8 = 0, OutType_Nhwc = 1, OutType_TileC8 = 2, OutType_NC4HW4 = 3 } OutType; 37 38 typedef enum MatmulType { 39 // reserve 0 for base op 40 kNotImplemented = 0, 41 kMatmulInt8Cpu, 42 kMatmulDynamicInt8Cpu, 43 kMatmulDynamicSdotInt8Cpu, 44 kMatmulFp32BaseCpu, 45 kMatmulFp32Arm64Cpu, 46 } MatmulType; 47 48 typedef struct MatMulParameter { 49 // Primitive parameter 50 OpParameter op_parameter_; 51 bool has_bias_; 52 bool use_axis_; 53 bool a_transpose_; /* false : row-major */ 54 bool b_transpose_; /* true : col-major */ 55 ActType act_type_; 56 57 // other parameter 58 int row_; 59 int col_; 60 int row_4_; 61 int row_16_; 62 int row_align_; 63 int col_8_; 64 int col_align_; 65 int deep_; 66 int deep_4_; 67 int deep_16_; 68 int deep_align_; 69 int batch; 70 bool a_const_; 71 bool b_const_; 72 int axis_; 73 MatmulType matmul_type_; 74 } MatMulParameter; 75 76 typedef struct MatmulQuantParameter { 77 QuantArg input_; 78 QuantArg weight_; 79 QuantArg output_; 80 int32_t out_act_min_; 81 int32_t out_act_max_; 82 float *filter_scale_; 83 int32_t *filter_zp_; 84 int32_t *left_shift_; 85 int32_t *right_shift_; 86 int32_t *quant_multiplier_; 87 } MatmulQuantParameter; 88 89 typedef struct MatmulDynamicQuantParameter { 90 float *input_scale_; 91 int32_t *input_zp_; 92 float *filter_scale_; 93 int32_t *filter_zp_; 94 } MatmulDynamicQuantParameter; 95 96 #endif // NNACL_MATMUL_H_ 97