1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include <float.h>
17 #include "nnacl/fp32/activation_fp32.h"
18 #include "nnacl/fp32/exp_fp32.h"
19 #include "nnacl/errorcode.h"
20
Fp32Relu(const float * src,int length,float * dst)21 int Fp32Relu(const float *src, int length, float *dst) {
22 int i = 0;
23 #if defined(ENABLE_AVX)
24 MS_FLOAT32X8 zero_8 = MS_MOV256_F32(0.0f);
25 for (; i <= length - 8; i += 8) {
26 MS_ST256_F32(dst + i, MS_MAX256_F32(MS_LD256_F32(src + i), zero_8));
27 }
28 #endif
29
30 #if defined(ENABLE_SSE) || defined(ENABLE_ARM)
31 MS_FLOAT32X4 zero = MS_MOVQ_F32(0.0f);
32 for (; i <= length - 4; i += 4) {
33 MS_STQ_F32(dst + i, MS_MAXQ_F32(MS_LDQ_F32(src + i), zero));
34 }
35 #endif
36 for (; i < length; ++i) {
37 dst[i] = src[i] > 0 ? src[i] : 0;
38 }
39 return NNACL_OK;
40 }
41
Fp32Relu6(const float * src,int length,float * dst)42 int Fp32Relu6(const float *src, int length, float *dst) {
43 int i = 0;
44
45 #if defined(ENABLE_AVX)
46 MS_FLOAT32X8 zero_8 = MS_MOV256_F32(0.0f);
47 MS_FLOAT32X8 six_8 = MS_MOV256_F32(6.0f);
48 for (; i <= length - 8; i += 8) {
49 MS_FLOAT32X8 dst_tmp = MS_MAX256_F32(MS_LD256_F32(src + i), zero_8);
50 dst_tmp = MS_MIN256_F32(dst_tmp, six_8);
51 MS_ST256_F32(dst + i, dst_tmp);
52 }
53 #endif
54
55 #if defined(ENABLE_ARM) || defined(ENABLE_SSE)
56 MS_FLOAT32X4 zero = MS_MOVQ_F32(0.0f);
57 MS_FLOAT32X4 six = MS_MOVQ_F32(6.0f);
58 for (; i <= length - 4; i += 4) {
59 MS_FLOAT32X4 dst_tmp = MS_MAXQ_F32(MS_LDQ_F32(src + i), zero);
60 dst_tmp = MS_MINQ_F32(dst_tmp, six);
61 MS_STQ_F32(dst + i, dst_tmp);
62 }
63 #endif
64 for (; i < length; ++i) {
65 if (src[i] < 0) {
66 dst[i] = 0;
67 } else {
68 dst[i] = src[i] > 6.0f ? 6.0f : src[i];
69 }
70 }
71 return NNACL_OK;
72 }
73
LRelu(const float * src,int length,float * dst,float alpha)74 int LRelu(const float *src, int length, float *dst, float alpha) {
75 int i = 0;
76 #if defined(ENABLE_AVX)
77 for (; i <= length - 8; i += 8) {
78 MS_FLOAT32X8 src_tmp = MS_LD256_F32(src + i);
79 MS_FLOAT32X8 mul_tmp = MS_MUL256_N_F32(src_tmp, alpha);
80 MS_FLOAT32X8 mask = MS_CMP256_F32(src_tmp, MS_MOV256_F32(0.0f), 30);
81 MS_ST256_F32(dst + i, MS_BLEND256_F32(mul_tmp, src_tmp, mask));
82 }
83 #endif
84
85 #if defined(ENABLE_ARM) || defined(ENABLE_SSE)
86 for (; i <= length - 4; i += 4) {
87 MS_FLOAT32X4 src_tmp = MS_LDQ_F32(src + i);
88 MS_FLOAT32X4 mul_tmp = MS_MULQ_N_F32(src_tmp, alpha);
89 #ifdef ENABLE_ARM
90 MS_UINT32X4 mask = MS_CMPGTQ_F32(src_tmp, MS_MOVQ_F32(0.0f));
91 #else
92 MS_FLOAT32X4 mask = MS_CMPGTQ_F32(src_tmp, MS_MOVQ_F32(0.0f));
93 #endif
94 MS_STQ_F32(dst + i, MS_BLENDQ_F32(mul_tmp, src_tmp, mask));
95 }
96 #endif
97 for (; i < length; ++i) {
98 dst[i] = src[i] > 0 ? src[i] : (src[i] * alpha);
99 }
100 return NNACL_OK;
101 }
102
Sigmoid(const float * src,int length,float * dst)103 int Sigmoid(const float *src, int length, float *dst) {
104 int i = 0;
105 #if defined(ENABLE_AVX)
106 for (; i <= length - 8; i += 8) {
107 simd_exp_avx(MS_SUB256_F32(MS_MOV256_F32(0.0f), (MS_LD256_F32(src + i))), dst + i);
108 MS_ST256_F32(dst + i,
109 MS_DIV256_F32(MS_MOV256_F32(1.0f), MS_ADD256_F32(MS_MOV256_F32(1.0f), MS_LD256_F32(dst + i))));
110 }
111 #endif
112
113 #if defined(ENABLE_ARM) || defined(ENABLE_SSE)
114 for (; i <= length - 4; i += 4) {
115 simd_exp(MS_SUBQ_F32(MS_MOVQ_F32(0.0f), MS_LDQ_F32(src + i)), dst + i);
116 MS_STQ_F32(dst + i, MS_DIVQ_F32(MS_MOVQ_F32(1.0f), MS_ADDQ_F32(MS_MOVQ_F32(1.0f), MS_LDQ_F32(dst + i))));
117 }
118 #endif
119 for (; i < length; ++i) {
120 single_exp(-src[i], dst + i);
121 dst[i] = 1.0f / (1.0f + dst[i]);
122 }
123 return NNACL_OK;
124 }
125
TanhOpt(float src)126 float TanhOpt(float src) {
127 if (src > 5.0) {
128 return 1.0f;
129 } else if (src < -5.0) {
130 return -1.0f;
131 } else {
132 float square = src * src;
133 float a = (((square + 378.0f) * square + 17325.0f) * square + 135135.0f) * src;
134 float b = ((28.0f * square + 3150.0f) * square + 62370.0f) * square + 135135.0f;
135 return a / b;
136 }
137 }
138
Tanh(const float * src,int length,float * dst)139 int Tanh(const float *src, int length, float *dst) {
140 int i = 0;
141 #if defined(ENABLE_AVX)
142 for (; i <= length - 8; i += 8) {
143 MS_FLOAT32X8 input = MS_LD256_F32(src + i);
144 MS_ST256_F32(dst + i, MS_TANHX8_F32(input));
145 }
146 #endif
147
148 #if defined(ENABLE_ARM) || defined(ENABLE_SSE)
149 for (; i <= length - 4; i += 4) {
150 MS_FLOAT32X4 input = MS_LDQ_F32(src + i);
151 MS_STQ_F32(dst + i, MS_TANHX4_F32(input));
152 }
153 #endif
154 for (; i < length; ++i) {
155 dst[i] = TanhOpt(src[i]);
156 }
157 return NNACL_OK;
158 }
159
Swish(const float * src,int length,float * dst)160 int Swish(const float *src, int length, float *dst) {
161 int ret = Sigmoid(src, length, dst);
162 if (ret != NNACL_OK) {
163 return NNACL_ERR;
164 }
165 int index = 0;
166 #if defined(ENABLE_AVX)
167 for (; index <= length - 8; index += 8) {
168 MS_FLOAT32X8 src_value = MS_LD256_F32(src + index);
169 MS_FLOAT32X8 sigmoid_value = MS_LD256_F32(dst + index);
170 MS_FLOAT32X8 result = MS_MUL256_F32(src_value, sigmoid_value);
171 MS_ST256_F32(dst + index, result);
172 }
173 #endif
174
175 #if defined(ENABLE_ARM) || defined(ENABLE_SSE)
176 for (; index <= length - 4; index += 4) {
177 MS_FLOAT32X4 src_value = MS_LDQ_F32(src + index);
178 MS_FLOAT32X4 sigmoid_value = MS_LDQ_F32(dst + index);
179 MS_FLOAT32X4 result = MS_MULQ_F32(src_value, sigmoid_value);
180 MS_STQ_F32(dst + index, result);
181 }
182 #endif
183 for (; index < length; ++index) {
184 dst[index] = src[index] * dst[index];
185 }
186 return NNACL_OK;
187 }
188
HSwish(const float * src,int length,float * dst)189 int HSwish(const float *src, int length, float *dst) {
190 for (int i = 0; i < length; ++i) {
191 float in = src[i];
192 float relu6 = MSMIN(MSMAX(in + 3, 0), 6);
193 dst[i] = in * relu6 / 6;
194 }
195 return NNACL_OK;
196 }
197
HSigmoid(const float * src,int length,float * dst)198 int HSigmoid(const float *src, int length, float *dst) {
199 for (int i = 0; i < length; ++i) {
200 float relu6 = MSMIN(MSMAX(src[i] + 3, 0), 6);
201 dst[i] = relu6 / 6;
202 }
203 return NNACL_OK;
204 }
205
HardTanh(const float * src,int length,float * dst,float min_val,float max_val)206 int HardTanh(const float *src, int length, float *dst, float min_val, float max_val) {
207 if (max_val <= min_val) {
208 return NNACL_ERR;
209 }
210 int i = 0;
211 if (min_val == FLT_MIN) {
212 for (i = 0; i < length; ++i) {
213 dst[i] = src[i] > max_val ? max_val : src[i];
214 }
215 } else if (max_val == FLT_MAX) {
216 for (i = 0; i < length; ++i) {
217 dst[i] = src[i] < min_val ? min_val : src[i];
218 }
219 } else {
220 for (i = 0; i < length; ++i) {
221 dst[i] = src[i] < min_val ? min_val : (src[i] > max_val ? max_val : src[i]);
222 }
223 }
224 return NNACL_OK;
225 }
226
Gelu(const float * src,int length,float * dst,bool approximate)227 int Gelu(const float *src, int length, float *dst, bool approximate) {
228 if (src == NULL || dst == NULL) {
229 return NNACL_ERR;
230 }
231 int i = 0;
232 if (approximate) {
233 // dst = 0.5 * x * (1 + tanh((2 / pi) ^ 0.5 * (x + 0.044715x^3)))
234 #if defined(ENABLE_AVX)
235 MS_FLOAT32X8 para1 = MS_MOV256_F32(0.79788456080287f);
236 MS_FLOAT32X8 para2 = MS_MOV256_F32(0.035677408136f);
237 MS_FLOAT32X8 para3 = MS_MOV256_F32(1.0f);
238 MS_FLOAT32X8 para4 = MS_MOV256_F32(0.5f);
239 int C8 = DOWN_ROUND(length, C8NUM);
240 for (; i < C8; i += C8NUM) {
241 MS_FLOAT32X8 in = MS_LD256_F32(src + i);
242 const MS_FLOAT32X8 res = MS_MUL256_F32(
243 MS_MUL256_F32(para4, in),
244 MS_ADD256_F32(
245 para3, MS_TANHX8_F32(MS_MUL256_F32(MS_ADD256_F32(para1, MS_MUL256_F32(MS_MUL256_F32(para2, in), in)), in))));
246 MS_ST256_F32(dst + i, res);
247 }
248 #endif
249 #if defined(ENABLE_SSE) || defined(ENABLE_ARM)
250 MS_FLOAT32X4 para5 = MS_MOVQ_F32(0.79788456080287f);
251 MS_FLOAT32X4 para6 = MS_MOVQ_F32(0.035677408136f);
252 MS_FLOAT32X4 para7 = MS_MOVQ_F32(1.0f);
253 MS_FLOAT32X4 para8 = MS_MOVQ_F32(0.5f);
254 int C4 = DOWN_ROUND(length, C4NUM);
255 for (; i < C4; i += C4NUM) {
256 MS_FLOAT32X4 in = MS_LDQ_F32(src + i);
257 MS_FLOAT32X4 res = MS_MULQ_F32(
258 MS_MULQ_F32(para8, in),
259 MS_ADDQ_F32(para7,
260 MS_TANHX4_F32(MS_MULQ_F32(MS_ADDQ_F32(para5, MS_MULQ_F32(MS_MULQ_F32(para6, in), in)), in))));
261 MS_STQ_F32(dst + i, res);
262 }
263 #endif
264 for (; i < length; i++) {
265 dst[i] = 0.5 * src[i] * (1.0 + TanhOpt((0.79788456080287f + 0.035677408136f * src[i] * src[i]) * src[i]));
266 }
267 } else {
268 #if defined(ENABLE_AVX) || defined(ENABLE_SSE) || defined(ENABLE_ARM)
269 MS_FLOAT32X4 para1 = MS_MOVQ_F32(1.4142135623730951f);
270 MS_FLOAT32X4 para2 = MS_MOVQ_F32(1.0f);
271 MS_FLOAT32X4 para3 = MS_MOVQ_F32(0.5f);
272 int C4 = DOWN_ROUND(length, C4NUM);
273 for (; i < C4; i += C4NUM) {
274 MS_FLOAT32X4 in = MS_LDQ_F32(src + i);
275 MS_FLOAT32X4 res = MS_MULQ_F32(MS_MULQ_F32(para3, in), MS_ADDQ_F32(para2, MS_ERFX4_F32(MS_DIVQ_F32(in, para1))));
276 MS_STQ_F32(dst + i, res);
277 }
278 #endif
279 for (; i < length; i++) {
280 dst[i] = 0.5 * src[i] * (1.0 + erf(src[i] / 1.4142135623730951f));
281 }
282 }
283 return NNACL_OK;
284 }
285
Softplus(const float * src,int length,float * dst)286 int Softplus(const float *src, int length, float *dst) {
287 int i = 0;
288 for (; i < length; ++i) {
289 single_exp(src[i], dst + i);
290 dst[i] = log1p(dst[i]);
291 }
292 return NNACL_OK;
293 }
294
Elu(const float * src,int length,float * dst,float alpha)295 int Elu(const float *src, int length, float *dst, float alpha) {
296 int i = 0;
297 #if defined(ENABLE_ARM)
298 MS_FLOAT32X4 one = MS_MOVQ_F32(1.0f);
299 for (; i <= length - 4; i += 4) {
300 MS_FLOAT32X4 src_tmp = MS_LDQ_F32(src + i);
301 MS_FLOAT32X4 exp_tmp = VexpFp32(src_tmp); // exp(x)
302 exp_tmp = MS_SUBQ_F32(exp_tmp, one); // exp(x) - 1
303 MS_FLOAT32X4 elu_tmp = MS_MULQ_N_F32(exp_tmp, alpha);
304 MS_UINT32X4 mask = MS_CMPGTQ_F32(src_tmp, MS_MOVQ_F32(0.0f));
305 MS_STQ_F32(dst + i, MS_BLENDQ_F32(elu_tmp, src_tmp, mask));
306 }
307 #endif
308 for (; i < length; ++i) {
309 dst[i] = src[i] > 0 ? src[i] : (expm1(src[i]) * alpha);
310 }
311 return NNACL_OK;
312 }
313