• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_NNACL_FP16_MATMUL_H_
18 #define MINDSPORE_NNACL_FP16_MATMUL_H_
19 
20 #include <float.h>
21 #include <string.h>
22 #include "nnacl/errorcode.h"
23 #include "nnacl/matmul_parameter.h"
24 #include "nnacl/op_base.h"
25 #include "nnacl/intrinsics/ms_simd_instructions_fp16.h"
26 #include "nnacl/fp16/pack_fp16.h"
27 
28 #define ADD_BIAS(value, bias, c) \
29   if (bias != NULL) value = value + bias[c];
30 
31 #define DO_RELU(value, act_type) \
32   if (act_type == ActType_Relu) value = MSMAX(0.0f, value);
33 
34 #define DO_RELU6(value, act_type)                            \
35   if (act_type == ActType_Relu6) value = MSMIN(6.0f, value); \
36   if (act_type == ActType_Relu6) value = MSMAX(0.0f, value);
37 
38 #ifdef __cplusplus
39 extern "C" {
40 #endif
41 void MatMul16x8Fp16(const float16_t *a, const float16_t *b, float16_t *dst, const float16_t *bias, ActType act_type,
42                     int deep, int row, int col, int stride, int write_mode);
43 
44 void MatMul12x8Fp16(const float16_t *a, const float16_t *b, float16_t *dst, const float16_t *bias, ActType act_type,
45                     int deep, int row, int col, int stride, int write_mode);
46 
47 #ifdef ENABLE_ARM64
48 void RowMajor2ColNMajorFp16(const float16_t *src, float16_t *dst_ptr, int row, int col);
49 
50 void RowMajor2RowNMajorFp16(const float16_t *src, float16_t *dst, int row, int col);
51 
52 void MatMul12x16Fp16Opt(const float16_t *a, const float16_t *b, float16_t *dst, const float16_t *bias, ActType act_type,
53                         int deep, int row, int col, size_t stride, size_t out_type);
54 void MatmulFp16Neon64(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type,
55                       size_t depth, size_t row, size_t col, size_t stride, bool write_nhwc);
56 
57 void MatmulFp16Neon64Opt(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type,
58                          size_t depth, size_t row, size_t col, size_t stride, size_t write_nhwc);
59 
60 void MatmulBaseFp16Neon(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type,
61                         size_t depth, size_t row, size_t col, size_t stride, size_t write_nhwc);
62 
63 #ifdef ENABLE_DEBUG
64 void MatmulBaseFp16(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type,
65                     size_t depth, size_t row, size_t col, size_t stride, size_t write_nhwc);
66 #endif
67 
68 void MatVecMulFp16Neon64(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type,
69                          int depth, int col);
70 
71 void VecMatmulFp16(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type, int depth,
72                    int col);
73 void VecMatmulFp16Neon64(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type,
74                          int depth, int col);
75 #elif ENABLE_ARM82_A32
76 void MatMul12x8A32Fp16(const float16_t *a, const float16_t *b, float16_t *dst, const float16_t *bias, ActType act_type,
77                        int deep, int row, int col, int stride, int write_mode);
78 
79 void MatVecMulA32Fp16(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type,
80                       int depth, int col);
81 
82 void MatVecMulA32NeonFp16(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type,
83                           int depth, int col);
84 #endif
85 
86 void MatMul12x16Fp16(const float16_t *a, const float16_t *b, float16_t *dst, const float16_t *bias, ActType act_type,
87                      int deep, int row, int col, size_t stride, size_t out_type);
88 
89 void MatMulFp16(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, ActType act_type,
90                 int depth, int row, int col, int stride, int out_type);
91 
92 void MatVecMulFp16(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, ActType act_type,
93                    int depth, int col);
94 
95 void ColMajor2Row8MajorFp16(const void *src_ptr, float16_t *dst_ptr, size_t row, size_t col, bool src_float16);
96 
97 void RowMajor2Col16MajorFp16Opt(const float16_t *src_ptr, float16_t *dst_ptr, size_t row, size_t col);
98 
99 void RowMajor2Col12MajorFp16Opt(const float16_t *src_ptr, float16_t *dst_ptr, size_t row, size_t col);
100 
101 void RowMajor2Col16MajorFp16(const void *src, float16_t *dst, int row, int col, bool is_fp32_src);
102 
103 void RowMajor2Col12MajorFp16(const void *src, float16_t *dst, int row, int col, bool is_fp32_src);
104 
105 void RowMajor2Row16MajorFp16Opt(const float16_t *src, float16_t *dst, int row, int col);
106 
107 void RowMajor2Row16MajorFp16(const void *src, float16_t *dst, int row, int col, bool is_fp32_src);
108 
109 void RowMajor2Row12MajorFp16(const void *src, float16_t *dst, int row, int col, bool is_fp32_src);
110 
111 void RowMajor2Row8MajorFp16(const void *src, float16_t *dst, int row, int col, bool is_fp32_src);
112 
113 void RowMajor2Col8MajorFp16(const void *src, float16_t *dst, int row, int col, bool is_fp32_src);
114 
115 void RowMajor2ColMajorFp16(const void *src, float16_t *dst, int row, int col, bool is_fp32_src);
116 
117 #ifdef __cplusplus
118 }
119 #endif
120 
121 #endif  // MINDSPORE_NNACL_FP16_MATMUL_H_
122