• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include <float.h>
17 #include "nnacl/fp32/activation_fp32.h"
18 #include "nnacl/fp32/exp_fp32.h"
19 #include "nnacl/errorcode.h"
20 
Fp32Relu(const float * src,int length,float * dst)21 int Fp32Relu(const float *src, int length, float *dst) {
22   int i = 0;
23 #if defined(ENABLE_AVX)
24   MS_FLOAT32X8 zero_8 = MS_MOV256_F32(0.0f);
25   for (; i <= length - 8; i += 8) {
26     MS_ST256_F32(dst + i, MS_MAX256_F32(MS_LD256_F32(src + i), zero_8));
27   }
28 #endif
29 
30 #if defined(ENABLE_SSE) || defined(ENABLE_ARM)
31   MS_FLOAT32X4 zero = MS_MOVQ_F32(0.0f);
32   for (; i <= length - 4; i += 4) {
33     MS_STQ_F32(dst + i, MS_MAXQ_F32(MS_LDQ_F32(src + i), zero));
34   }
35 #endif
36   for (; i < length; ++i) {
37     dst[i] = src[i] > 0 ? src[i] : 0;
38   }
39   return NNACL_OK;
40 }
41 
Fp32Relu6(const float * src,int length,float * dst)42 int Fp32Relu6(const float *src, int length, float *dst) {
43   int i = 0;
44 
45 #if defined(ENABLE_AVX)
46   MS_FLOAT32X8 zero_8 = MS_MOV256_F32(0.0f);
47   MS_FLOAT32X8 six_8 = MS_MOV256_F32(6.0f);
48   for (; i <= length - 8; i += 8) {
49     MS_FLOAT32X8 dst_tmp = MS_MAX256_F32(MS_LD256_F32(src + i), zero_8);
50     dst_tmp = MS_MIN256_F32(dst_tmp, six_8);
51     MS_ST256_F32(dst + i, dst_tmp);
52   }
53 #endif
54 
55 #if defined(ENABLE_ARM) || defined(ENABLE_SSE)
56   MS_FLOAT32X4 zero = MS_MOVQ_F32(0.0f);
57   MS_FLOAT32X4 six = MS_MOVQ_F32(6.0f);
58   for (; i <= length - 4; i += 4) {
59     MS_FLOAT32X4 dst_tmp = MS_MAXQ_F32(MS_LDQ_F32(src + i), zero);
60     dst_tmp = MS_MINQ_F32(dst_tmp, six);
61     MS_STQ_F32(dst + i, dst_tmp);
62   }
63 #endif
64   for (; i < length; ++i) {
65     if (src[i] < 0) {
66       dst[i] = 0;
67     } else {
68       dst[i] = src[i] > 6.0f ? 6.0f : src[i];
69     }
70   }
71   return NNACL_OK;
72 }
73 
LRelu(const float * src,int length,float * dst,float alpha)74 int LRelu(const float *src, int length, float *dst, float alpha) {
75   int i = 0;
76 #if defined(ENABLE_AVX)
77   for (; i <= length - 8; i += 8) {
78     MS_FLOAT32X8 src_tmp = MS_LD256_F32(src + i);
79     MS_FLOAT32X8 mul_tmp = MS_MUL256_N_F32(src_tmp, alpha);
80     MS_FLOAT32X8 mask = MS_CMP256_F32(src_tmp, MS_MOV256_F32(0.0f), 30);
81     MS_ST256_F32(dst + i, MS_BLEND256_F32(mul_tmp, src_tmp, mask));
82   }
83 #endif
84 
85 #if defined(ENABLE_ARM) || defined(ENABLE_SSE)
86   for (; i <= length - 4; i += 4) {
87     MS_FLOAT32X4 src_tmp = MS_LDQ_F32(src + i);
88     MS_FLOAT32X4 mul_tmp = MS_MULQ_N_F32(src_tmp, alpha);
89 #ifdef ENABLE_ARM
90     MS_UINT32X4 mask = MS_CMPGTQ_F32(src_tmp, MS_MOVQ_F32(0.0f));
91 #else
92     MS_FLOAT32X4 mask = MS_CMPGTQ_F32(src_tmp, MS_MOVQ_F32(0.0f));
93 #endif
94     MS_STQ_F32(dst + i, MS_BLENDQ_F32(mul_tmp, src_tmp, mask));
95   }
96 #endif
97   for (; i < length; ++i) {
98     dst[i] = src[i] > 0 ? src[i] : (src[i] * alpha);
99   }
100   return NNACL_OK;
101 }
102 
Sigmoid(const float * src,int length,float * dst)103 int Sigmoid(const float *src, int length, float *dst) {
104   int i = 0;
105 #if defined(ENABLE_AVX)
106   for (; i <= length - 8; i += 8) {
107     simd_exp_avx(MS_SUB256_F32(MS_MOV256_F32(0.0f), (MS_LD256_F32(src + i))), dst + i);
108     MS_ST256_F32(dst + i,
109                  MS_DIV256_F32(MS_MOV256_F32(1.0f), MS_ADD256_F32(MS_MOV256_F32(1.0f), MS_LD256_F32(dst + i))));
110   }
111 #endif
112 
113 #if defined(ENABLE_ARM) || defined(ENABLE_SSE)
114   for (; i <= length - 4; i += 4) {
115     simd_exp(MS_SUBQ_F32(MS_MOVQ_F32(0.0f), MS_LDQ_F32(src + i)), dst + i);
116     MS_STQ_F32(dst + i, MS_DIVQ_F32(MS_MOVQ_F32(1.0f), MS_ADDQ_F32(MS_MOVQ_F32(1.0f), MS_LDQ_F32(dst + i))));
117   }
118 #endif
119   for (; i < length; ++i) {
120     single_exp(-src[i], dst + i);
121     dst[i] = 1.0f / (1.0f + dst[i]);
122   }
123   return NNACL_OK;
124 }
125 
TanhOpt(float src)126 float TanhOpt(float src) {
127   if (src > 5.0) {
128     return 1.0f;
129   } else if (src < -5.0) {
130     return -1.0f;
131   } else {
132     float square = src * src;
133     float a = (((square + 378.0f) * square + 17325.0f) * square + 135135.0f) * src;
134     float b = ((28.0f * square + 3150.0f) * square + 62370.0f) * square + 135135.0f;
135     return a / b;
136   }
137 }
138 
Tanh(const float * src,int length,float * dst)139 int Tanh(const float *src, int length, float *dst) {
140   int i = 0;
141 #if defined(ENABLE_AVX)
142   for (; i <= length - 8; i += 8) {
143     MS_FLOAT32X8 input = MS_LD256_F32(src + i);
144     MS_ST256_F32(dst + i, MS_TANHX8_F32(input));
145   }
146 #endif
147 
148 #if defined(ENABLE_ARM) || defined(ENABLE_SSE)
149   for (; i <= length - 4; i += 4) {
150     MS_FLOAT32X4 input = MS_LDQ_F32(src + i);
151     MS_STQ_F32(dst + i, MS_TANHX4_F32(input));
152   }
153 #endif
154   for (; i < length; ++i) {
155     dst[i] = TanhOpt(src[i]);
156   }
157   return NNACL_OK;
158 }
159 
Swish(const float * src,int length,float * dst)160 int Swish(const float *src, int length, float *dst) {
161   int ret = Sigmoid(src, length, dst);
162   if (ret != NNACL_OK) {
163     return NNACL_ERR;
164   }
165   int index = 0;
166 #if defined(ENABLE_AVX)
167   for (; index <= length - 8; index += 8) {
168     MS_FLOAT32X8 src_value = MS_LD256_F32(src + index);
169     MS_FLOAT32X8 sigmoid_value = MS_LD256_F32(dst + index);
170     MS_FLOAT32X8 result = MS_MUL256_F32(src_value, sigmoid_value);
171     MS_ST256_F32(dst + index, result);
172   }
173 #endif
174 
175 #if defined(ENABLE_ARM) || defined(ENABLE_SSE)
176   for (; index <= length - 4; index += 4) {
177     MS_FLOAT32X4 src_value = MS_LDQ_F32(src + index);
178     MS_FLOAT32X4 sigmoid_value = MS_LDQ_F32(dst + index);
179     MS_FLOAT32X4 result = MS_MULQ_F32(src_value, sigmoid_value);
180     MS_STQ_F32(dst + index, result);
181   }
182 #endif
183   for (; index < length; ++index) {
184     dst[index] = src[index] * dst[index];
185   }
186   return NNACL_OK;
187 }
188 
HSwish(const float * src,int length,float * dst)189 int HSwish(const float *src, int length, float *dst) {
190   for (int i = 0; i < length; ++i) {
191     float in = src[i];
192     float relu6 = MSMIN(MSMAX(in + 3, 0), 6);
193     dst[i] = in * relu6 / 6;
194   }
195   return NNACL_OK;
196 }
197 
HSigmoid(const float * src,int length,float * dst)198 int HSigmoid(const float *src, int length, float *dst) {
199   for (int i = 0; i < length; ++i) {
200     float relu6 = MSMIN(MSMAX(src[i] + 3, 0), 6);
201     dst[i] = relu6 / 6;
202   }
203   return NNACL_OK;
204 }
205 
HardTanh(const float * src,int length,float * dst,float min_val,float max_val)206 int HardTanh(const float *src, int length, float *dst, float min_val, float max_val) {
207   if (max_val <= min_val) {
208     return NNACL_ERR;
209   }
210   int i = 0;
211   if (min_val == FLT_MIN) {
212     for (i = 0; i < length; ++i) {
213       dst[i] = src[i] > max_val ? max_val : src[i];
214     }
215   } else if (max_val == FLT_MAX) {
216     for (i = 0; i < length; ++i) {
217       dst[i] = src[i] < min_val ? min_val : src[i];
218     }
219   } else {
220     for (i = 0; i < length; ++i) {
221       dst[i] = src[i] < min_val ? min_val : (src[i] > max_val ? max_val : src[i]);
222     }
223   }
224   return NNACL_OK;
225 }
226 
Gelu(const float * src,int length,float * dst,bool approximate)227 int Gelu(const float *src, int length, float *dst, bool approximate) {
228   if (src == NULL || dst == NULL) {
229     return NNACL_ERR;
230   }
231   int i = 0;
232   if (approximate) {
233     // dst = 0.5 * x * (1 + tanh((2 / pi) ^ 0.5 * (x + 0.044715x^3)))
234 #if defined(ENABLE_AVX)
235     MS_FLOAT32X8 para1 = MS_MOV256_F32(0.79788456080287f);
236     MS_FLOAT32X8 para2 = MS_MOV256_F32(0.035677408136f);
237     MS_FLOAT32X8 para3 = MS_MOV256_F32(1.0f);
238     MS_FLOAT32X8 para4 = MS_MOV256_F32(0.5f);
239     int C8 = DOWN_ROUND(length, C8NUM);
240     for (; i < C8; i += C8NUM) {
241       MS_FLOAT32X8 in = MS_LD256_F32(src + i);
242       const MS_FLOAT32X8 res = MS_MUL256_F32(
243         MS_MUL256_F32(para4, in),
244         MS_ADD256_F32(
245           para3, MS_TANHX8_F32(MS_MUL256_F32(MS_ADD256_F32(para1, MS_MUL256_F32(MS_MUL256_F32(para2, in), in)), in))));
246       MS_ST256_F32(dst + i, res);
247     }
248 #endif
249 #if defined(ENABLE_SSE) || defined(ENABLE_ARM)
250     MS_FLOAT32X4 para5 = MS_MOVQ_F32(0.79788456080287f);
251     MS_FLOAT32X4 para6 = MS_MOVQ_F32(0.035677408136f);
252     MS_FLOAT32X4 para7 = MS_MOVQ_F32(1.0f);
253     MS_FLOAT32X4 para8 = MS_MOVQ_F32(0.5f);
254     int C4 = DOWN_ROUND(length, C4NUM);
255     for (; i < C4; i += C4NUM) {
256       MS_FLOAT32X4 in = MS_LDQ_F32(src + i);
257       MS_FLOAT32X4 res = MS_MULQ_F32(
258         MS_MULQ_F32(para8, in),
259         MS_ADDQ_F32(para7,
260                     MS_TANHX4_F32(MS_MULQ_F32(MS_ADDQ_F32(para5, MS_MULQ_F32(MS_MULQ_F32(para6, in), in)), in))));
261       MS_STQ_F32(dst + i, res);
262     }
263 #endif
264     for (; i < length; i++) {
265       dst[i] = 0.5 * src[i] * (1.0 + TanhOpt((0.79788456080287f + 0.035677408136f * src[i] * src[i]) * src[i]));
266     }
267   } else {
268 #if defined(ENABLE_AVX) || defined(ENABLE_SSE) || defined(ENABLE_ARM)
269     MS_FLOAT32X4 para1 = MS_MOVQ_F32(1.4142135623730951f);
270     MS_FLOAT32X4 para2 = MS_MOVQ_F32(1.0f);
271     MS_FLOAT32X4 para3 = MS_MOVQ_F32(0.5f);
272     int C4 = DOWN_ROUND(length, C4NUM);
273     for (; i < C4; i += C4NUM) {
274       MS_FLOAT32X4 in = MS_LDQ_F32(src + i);
275       MS_FLOAT32X4 res = MS_MULQ_F32(MS_MULQ_F32(para3, in), MS_ADDQ_F32(para2, MS_ERFX4_F32(MS_DIVQ_F32(in, para1))));
276       MS_STQ_F32(dst + i, res);
277     }
278 #endif
279     for (; i < length; i++) {
280       dst[i] = 0.5 * src[i] * (1.0 + erf(src[i] / 1.4142135623730951f));
281     }
282   }
283   return NNACL_OK;
284 }
285 
Softplus(const float * src,int length,float * dst)286 int Softplus(const float *src, int length, float *dst) {
287   int i = 0;
288   for (; i < length; ++i) {
289     single_exp(src[i], dst + i);
290     dst[i] = log1p(dst[i]);
291   }
292   return NNACL_OK;
293 }
294 
Elu(const float * src,int length,float * dst,float alpha)295 int Elu(const float *src, int length, float *dst, float alpha) {
296   int i = 0;
297 #if defined(ENABLE_ARM)
298   MS_FLOAT32X4 one = MS_MOVQ_F32(1.0f);
299   for (; i <= length - 4; i += 4) {
300     MS_FLOAT32X4 src_tmp = MS_LDQ_F32(src + i);
301     MS_FLOAT32X4 exp_tmp = VexpFp32(src_tmp);  // exp(x)
302     exp_tmp = MS_SUBQ_F32(exp_tmp, one);       // exp(x) - 1
303     MS_FLOAT32X4 elu_tmp = MS_MULQ_N_F32(exp_tmp, alpha);
304     MS_UINT32X4 mask = MS_CMPGTQ_F32(src_tmp, MS_MOVQ_F32(0.0f));
305     MS_STQ_F32(dst + i, MS_BLENDQ_F32(elu_tmp, src_tmp, mask));
306   }
307 #endif
308   for (; i < length; ++i) {
309     dst[i] = src[i] > 0 ? src[i] : (expm1(src[i]) * alpha);
310   }
311   return NNACL_OK;
312 }
313