1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #ifndef MINDSPORE_NNACL_INTRINSICS_MS_SIMD_INSTRUCTIONS_FP16_H_
17 #define MINDSPORE_NNACL_INTRINSICS_MS_SIMD_INSTRUCTIONS_FP16_H_
18 #include <math.h>
19 #include "nnacl/intrinsics/ms_simd_instructions.h"
20
21 #if defined(ENABLE_ARM82_A32)
ms_vdivq_f16(float16x8_t in1,float16x8_t in2)22 static inline float16x8_t ms_vdivq_f16(float16x8_t in1, float16x8_t in2) {
23 float16x8_t dst;
24 asm volatile(
25 "vrecpe.f16 q14, %3\n"
26 "vrecps.f16 q15, %3, q14\n"
27 "vmul.f16 q14, q15, q14\n"
28 "vrecps.f16 q15, %3, q14\n"
29 "vmul.f16 q14, q15, q14\n"
30 "vmul.f16 %0, %2, q14\n"
31 : "=w"(dst)
32 : "0"(dst), "w"(in1), "w"(in2)
33 : "q14", "q15");
34 return dst;
35 }
36
ms_vdiv_f16(float16x4_t in1,float16x4_t in2)37 static inline float16x4_t ms_vdiv_f16(float16x4_t in1, float16x4_t in2) {
38 float16x4_t dst;
39 asm volatile(
40 "vrecpe.f16 d14, %3\n"
41 "vrecps.f16 d16, %3, d14\n"
42 "vmul.f16 d14, d16, d14\n"
43 "vrecps.f16 d16, %3, d14\n"
44 "vmul.f16 d14, d16, d14\n"
45 "vmul.f16 %0, %2, d14\n"
46 : "=w"(dst)
47 : "0"(dst), "w"(in1), "w"(in2)
48 : "d14", "d16");
49 return dst;
50 }
51
ms_vaddvq_f32(float32x4_t in)52 static inline float ms_vaddvq_f32(float32x4_t in) {
53 // is not support in arm82 aarch32 and there is no assembly instruction to process all the data
54 return in[0] + in[1] + in[2] + in[3];
55 }
56
ms_vmaxvq_f16(float16x8_t in)57 static inline float16_t ms_vmaxvq_f16(float16x8_t in) {
58 // is not support in arm82 aarch32 and there is no assembly instruction to process all the data
59 float16_t dst = in[0];
60 for (int i = 1; i < 8; ++i) {
61 dst = dst > in[i] ? dst : in[i];
62 }
63 return dst;
64 }
65
ms_vcvt_f32_f16(float16x4_t in)66 static inline float32x4_t ms_vcvt_f32_f16(float16x4_t in) {
67 float32x4_t dst;
68 asm volatile("vcvt.f32.f16 %0, %2\n" : "=w"(dst) : "0"(dst), "w"(in) :);
69 return dst;
70 }
71
ms_vcvt_f16_f32(float32x4_t in)72 static inline float16x4_t ms_vcvt_f16_f32(float32x4_t in) {
73 float16x4_t dst;
74 asm volatile("vcvt.f16.f32 %0, %2\n" : "=w"(dst) : "0"(dst), "w"(in) :);
75 return dst;
76 }
77
78 #define MS_CVT_F32_F16(src) ms_vcvt_f32_f16(src)
79 #define MS_CVT_F16_F32(src) ms_vcvt_f16_f32(src)
80 #define MS_DIV_F16(src1, src2) ms_vdiv_f16(src1, src2)
81 #define MS_DIVQ_F16(src1, src2) ms_vdivq_f16(src1, src2)
82 #define MS_FMAQ_N_F16(src1, src2, src3) vfmaq_f16(src1, src2, vdupq_n_f16(src3))
83 #define MS_MAXVQ_F16(src) ms_vmaxvq_f16(src)
84 #define MS_ADDVQ_F32(src) ms_vaddvq_f32(src)
85 #else
86 #define MS_CVT_F32_F16(src) vcvt_f32_f16(src)
87 #define MS_CVT_F16_F32(src) vcvt_f16_f32(src)
88 #define MS_DIV_F16(src1, src2) vdiv_f16(src1, src2)
89 #define MS_DIVQ_F16(src1, src2) vdivq_f16(src1, src2)
90 #define MS_FMAQ_N_F16(src1, src2, src3) vfmaq_n_f16(src1, src2, src3)
91 #define MS_MAXVQ_F16(src) vmaxvq_f16(src)
92 #define MS_ADDVQ_F32(src) vaddvq_f32(src)
93 #endif
94
95 #define MS_FLOAT16X8 float16x8_t
96 #define MS_FLOAT16X4 float16x4_t
97 #define MS_MOVQ_F16 vmovq_n_f16
98 #define MS_STQ_F16 vst1q_f16
99 #define MS_ST_F16 vst1_f16
100 #define MS_MINQ_F16 vminq_f16
101 #define MS_MAXQ_F16 vmaxq_f16
102 #define MS_LDQ_F16 vld1q_f16
103 #define MS_LD_F16 vld1_f16
104 #define MS_ADDQ_F16 vaddq_f16
105 #define MS_SUBQ_F16 vsubq_f16
106 #define MS_MULQ_F16 vmulq_f16
107 #define MS_FMAQ_F16 vfmaq_f16
108 #define MS_MULQ_N_F16(vector, scalar) vmulq_n_f16(vector, scalar)
109 #define MS_CMPGTQ_F16(src1, src2) vcgtq_f32(src1, src2)
110
MS_TANHX8_F16(float16x8_t src)111 static inline float16x8_t MS_TANHX8_F16(float16x8_t src) {
112 float32x4_t src_low = MS_CVT_F32_F16(vget_low_f16(src));
113 float32x4_t src_high = MS_CVT_F32_F16(vget_high_f16(src));
114 return vcombine_f16(MS_CVT_F16_F32(MS_TANHX4_F32(src_low)), MS_CVT_F16_F32(MS_TANHX4_F32(src_high)));
115 }
116
MS_ERFX8_F16(float16x8_t src)117 static inline float16x8_t MS_ERFX8_F16(float16x8_t src) {
118 float16x8_t dst;
119 dst[0] = erff(src[0]);
120 dst[1] = erff(src[1]);
121 dst[2] = erff(src[2]);
122 dst[3] = erff(src[3]);
123 dst[4] = erff(src[4]);
124 dst[5] = erff(src[5]);
125 dst[6] = erff(src[6]);
126 dst[7] = erff(src[7]);
127 return dst;
128 }
129
MS_SQRTFX8_F16(float16x8_t src)130 static inline float16x8_t MS_SQRTFX8_F16(float16x8_t src) {
131 float16x8_t dst;
132 dst[0] = sqrtf(src[0]);
133 dst[1] = sqrtf(src[1]);
134 dst[2] = sqrtf(src[2]);
135 dst[3] = sqrtf(src[3]);
136 dst[4] = sqrtf(src[4]);
137 dst[5] = sqrtf(src[5]);
138 dst[6] = sqrtf(src[6]);
139 dst[7] = sqrtf(src[7]);
140 return dst;
141 }
142
MS_SQRTFX4_F16(float16x4_t src)143 static inline float16x4_t MS_SQRTFX4_F16(float16x4_t src) {
144 float16x4_t dst;
145 dst[0] = sqrtf(src[0]);
146 dst[1] = sqrtf(src[1]);
147 dst[2] = sqrtf(src[2]);
148 dst[3] = sqrtf(src[3]);
149 return dst;
150 }
151
MS_VMLAL_F16(float16x4_t x,float16x4_t dy,float32x4_t sum)152 static inline float32x4_t MS_VMLAL_F16(float16x4_t x, float16x4_t dy, float32x4_t sum) {
153 float32x4_t x_fp32 = MS_CVT_F32_F16(x);
154 float32x4_t dy_fp32 = MS_CVT_F32_F16(dy);
155 return vmlaq_f32(sum, x_fp32, dy_fp32);
156 }
157
158 #endif // MINDSPORE_NNACL_INTRINSICS_MS_SIMD_INSTRUCTIONS_FP16_H_
159