• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2018-2021 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #ifndef ARM_COMPUTE_DETAIL_NEACTIVATION_FUNCTION_DETAIL_H
25 #define ARM_COMPUTE_DETAIL_NEACTIVATION_FUNCTION_DETAIL_H
26 
27 #include "src/core/NEON/wrapper/wrapper.h"
28 
29 namespace arm_compute
30 {
31 namespace detail
32 {
33 /** Dummy activation object */
34 template <typename T, int S>
35 struct dummy
36 {
37     /** SIMD vector type. */
38     using ExactType = typename wrapper::traits::neon_vector<T, S>::type;
39 
40     /** Construct a dummy activation object.
41      *
42      * @param[in] act_info Activation layer information.
43      */
dummydummy44     explicit dummy(ActivationLayerInfo act_info)
45     {
46         ARM_COMPUTE_UNUSED(act_info);
47     }
48 
49     /** Run activation function.
50      *
51      * @param[in] vval Vector of values.
52      */
operatordummy53     void operator()(ExactType &vval)
54     {
55         ARM_COMPUTE_UNUSED(vval);
56     }
57 
58     /** Run activation function.
59      *
60      * @param[in] val Scalar value.
61      */
operatordummy62     void operator()(T &val)
63     {
64         ARM_COMPUTE_UNUSED(val);
65     }
66 };
67 /** Linear activation object */
68 template <typename T, int S>
69 struct linear
70 {
71     /** SIMD vector type. */
72     using ExactType = typename wrapper::traits::neon_vector<T, S>::type;
73     /** SIMD vector tag type. */
74     using ExactTagType = typename wrapper::traits::neon_vector<T, S>::tag_type;
75 
76     /** Construct a Linear activation object.
77      *
78      * @param[in] act_info Activation layer information.
79      */
linearlinear80     explicit linear(ActivationLayerInfo act_info)
81         : alpha(act_info.a()),
82           beta(act_info.b()),
83           valpha(wrapper::vdup_n(static_cast<T>(alpha), ExactTagType{})),
84           vbeta(wrapper::vdup_n(static_cast<T>(beta), ExactTagType{}))
85     {
86     }
87 
88     /** Run activation function.
89      *
90      * @param[in] vval Vector of values.
91      */
operatorlinear92     void operator()(ExactType &vval)
93     {
94         vval = wrapper::vmla(vbeta, vval, valpha);
95     }
96 
97     /** Run activation function.
98      *
99      * @param[in] val Scalar value.
100      */
operatorlinear101     void operator()(T &val)
102     {
103         val = alpha * val + beta;
104     }
105 
106     const T         alpha;  /**< Scalar alpha */
107     const T         beta;   /**< Scalar alpha */
108     const ExactType valpha; /**< Vector of alphas. */
109     const ExactType vbeta;  /**< Vector of betas. */
110 };
111 /** Square activation object */
112 template <typename T, int S>
113 struct square
114 {
115     /** SIMD vector type. */
116     using ExactType = typename wrapper::traits::neon_vector<T, S>::type;
117     /** SIMD vector tag type. */
118     using ExactTagType = typename wrapper::traits::neon_vector<T, S>::tag_type;
119 
120     /** Construct a Square activation object.
121      *
122      * @param[in] act_info Activation layer information.
123      */
squaresquare124     explicit square(ActivationLayerInfo act_info)
125     {
126         ARM_COMPUTE_UNUSED(act_info);
127     }
128 
129     /** Run activation function.
130      *
131      * @param[in] vval Vector of values.
132      */
operatorsquare133     void operator()(ExactType &vval)
134     {
135         vval = wrapper::vmul(vval, vval);
136     }
137 
138     /** Run activation function.
139      *
140      * @param[in] val Scalar value.
141      */
operatorsquare142     void operator()(T &val)
143     {
144         val = val * val;
145     }
146 };
147 /** Logistic activation object */
148 template <typename T, int S>
149 struct logistic
150 {
151     /** SIMD vector type. */
152     using ExactType = typename wrapper::traits::neon_vector<T, S>::type;
153     /** SIMD vector tag type. */
154     using ExactTagType = typename wrapper::traits::neon_vector<T, S>::tag_type;
155 
156     /** Construct a Logistic activation object.
157      *
158      * @param[in] act_info Activation layer information.
159      */
logisticlogistic160     explicit logistic(ActivationLayerInfo act_info)
161         : vone(wrapper::vdup_n(static_cast<T>(1), ExactTagType{}))
162     {
163         ARM_COMPUTE_UNUSED(act_info);
164     }
165 
166     /** Run activation function.
167      *
168      * @param[in] vval Vector of values.
169      */
operatorlogistic170     void operator()(ExactType &vval)
171     {
172         vval = wrapper::vinv(wrapper::vadd(vone, wrapper::vexpq(wrapper::vneg(vval))));
173     }
174 
175     /** Run activation function.
176      *
177      * @param[in] val Scalar value.
178      */
operatorlogistic179     void operator()(T &val)
180     {
181         val = 1 / (1 + std::exp(-val));
182     }
183 
184     /** Vector of ones. */
185     const ExactType vone;
186 };
187 /** RELU activation object */
188 template <typename T, int S>
189 struct relu
190 {
191     /** SIMD vector type. */
192     using ExactType = typename wrapper::traits::neon_vector<T, S>::type;
193     /** SIMD vector tag type. */
194     using ExactTagType = typename wrapper::traits::neon_vector<T, S>::tag_type;
195 
196     /** Construct a RELU activation object.
197      *
198      * @param[in] act_info Activation layer information.
199      */
relurelu200     explicit relu(ActivationLayerInfo act_info)
201         : vzero(wrapper::vdup_n(static_cast<T>(0), ExactTagType{}))
202     {
203         ARM_COMPUTE_UNUSED(act_info);
204     }
205 
206     /** Run activation function.
207      *
208      * @param[in] vval Vector of values.
209      */
operatorrelu210     void operator()(ExactType &vval)
211     {
212         vval = wrapper::vmax(vzero, vval);
213     }
214 
215     /** Run activation function.
216      *
217      * @param[in] val Scalar value.
218      */
operatorrelu219     void operator()(T &val)
220     {
221         val = std::max(static_cast<T>(0), val);
222     }
223 
224     /** Vector of zeroes. */
225     const ExactType vzero;
226 };
227 /** Bounded RELU activation object */
228 template <typename T, int S>
229 struct brelu
230 {
231     /** SIMD vector type. */
232     using ExactType = typename wrapper::traits::neon_vector<T, S>::type;
233     /** SIMD vector tag type. */
234     using ExactTagType = typename wrapper::traits::neon_vector<T, S>::tag_type;
235 
236     /** Construct a bounded RELU activation object.
237      *
238      * @param[in] act_info Activation layer information.
239      */
brelubrelu240     explicit brelu(ActivationLayerInfo act_info)
241         : alpha(act_info.a()),
242           vzero(wrapper::vdup_n(static_cast<T>(0), ExactTagType{})),
243           valpha(wrapper::vdup_n(static_cast<T>(act_info.a()), ExactTagType{}))
244     {
245     }
246 
247     /** Run activation function.
248      *
249      * @param[in] vval Vector of values.
250      */
operatorbrelu251     void operator()(ExactType &vval)
252     {
253         vval = wrapper::vmin(valpha, wrapper::vmax(vzero, vval));
254     }
255 
256     /** Run activation function.
257      *
258      * @param[in] val Scalar value.
259      */
operatorbrelu260     void operator()(T &val)
261     {
262         val = std::min(alpha, std::max(static_cast<T>(0), val));
263     }
264 
265     const T         alpha;  /** Scalar alpha */
266     const ExactType vzero;  /** Vector of zeroes. */
267     const ExactType valpha; /** Vector of alphas. */
268 };
269 /** Lower-Upper Bounded RELU activation object */
270 template <typename T, int S>
271 struct lubrelu
272 {
273     /** SIMD vector type. */
274     using ExactType = typename wrapper::traits::neon_vector<T, S>::type;
275     /** SIMD vector tag type. */
276     using ExactTagType = typename wrapper::traits::neon_vector<T, S>::tag_type;
277 
278     /** Construct a lower-upper bounded RELU activation object.
279      *
280      * @param[in] act_info Activation layer information.
281      */
lubrelulubrelu282     explicit lubrelu(ActivationLayerInfo act_info)
283         : alpha(act_info.a()),
284           beta(act_info.b()),
285           valpha(wrapper::vdup_n(static_cast<T>(act_info.a()), ExactTagType{})),
286           vbeta(wrapper::vdup_n(static_cast<T>(act_info.b()), ExactTagType{}))
287     {
288     }
289 
290     /** Run activation function.
291      *
292      * @param[in] vval Vector of values.
293      */
operatorlubrelu294     void operator()(ExactType &vval)
295     {
296         vval = wrapper::vmin(valpha, wrapper::vmax(vbeta, vval));
297     }
298 
299     /** Run activation function.
300      *
301      * @param[in] val Scalar value.
302      */
operatorlubrelu303     void operator()(T &val)
304     {
305         val = std::min(alpha, std::max(beta, val));
306     }
307 
308     const T         alpha;  /**< Scalar alpha */
309     const T         beta;   /**< Scalar alpha */
310     const ExactType valpha; /** Vector of alphas. */
311     const ExactType vbeta;  /** Vector of betas. */
312 };
313 } // namespace detail
314 } // namespace arm_compute
315 #endif /* ARM_COMPUTE_DETAIL_NEACTIVATION_FUNCTION_DETAIL_H */
316