• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2021-2022 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #include "src/cpu/kernels/softmax/generic/neon/impl.h"
25 #include "src/core/NEON/NEMath.h"
26 #include "src/core/NEON/wrapper/wrapper.h"
27 #include "support/SaturateCast.h"
28 
29 namespace arm_compute
30 {
31 namespace cpu
32 {
33 template <typename T>
neon_logits_1d_max(const ITensor * in,ITensor * out,const Window & window)34 void neon_logits_1d_max(const ITensor *in, ITensor *out, const Window &window)
35 {
36     /** SIMD vector tag type. */
37     using ExactTagType = typename wrapper::traits::neon_bitvector_tag_t<T, wrapper::traits::BitWidth::W128>;
38 
39     constexpr int window_step_x  = 16 / sizeof(T);
40     const auto    window_start_x = static_cast<int>(window.x().start());
41     const auto    window_end_x   = static_cast<int>(window.x().end());
42 
43     Window win{ window };
44     win.set(Window::DimX, Window::Dimension(0, 1, 1));
45     Iterator input(in, win);
46     Iterator output(out, win);
47 
48     const int sum_stages = log2(window_step_x / 2);
49     execute_window_loop(win, [&](const Coordinates &)
50     {
51         // Get pointers
52         const auto in_ptr  = reinterpret_cast<const T *>(input.ptr());
53         const auto out_ptr = reinterpret_cast<T *>(output.ptr());
54 
55         // Init max value
56         auto vec_max = wrapper::vdup_n(support::cpp11::lowest<T>(), ExactTagType{});
57         int  x       = window_start_x;
58 
59         for(; x <= (window_end_x - window_step_x); x += window_step_x)
60         {
61             const auto current_value = wrapper::vloadq(in_ptr + x);
62             vec_max                  = wrapper::vmax(vec_max, current_value);
63         }
64         auto carry_max = wrapper::vpmax(wrapper::vgethigh(vec_max), wrapper::vgetlow(vec_max));
65 
66         for(int i = 0; i < sum_stages; ++i)
67         {
68             carry_max = wrapper::vpmax(carry_max, carry_max);
69         }
70         T max_val = wrapper::vgetlane(carry_max, 0);
71 
72         // Compute left-over elements
73         for(; x < window_end_x; ++x)
74         {
75             max_val = *(in_ptr + x) > max_val ? *(in_ptr + x) : max_val;
76         }
77 
78         *out_ptr = max_val;
79     },
80     input, output);
81 }
82 
83 #if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
84 template void neon_logits_1d_max<float16_t>(const ITensor *in, ITensor *out, const Window &window);
85 #endif //defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
86 template void neon_logits_1d_max<float>(const ITensor *in, ITensor *out, const Window &window);
87 template void neon_logits_1d_max<qasymm8_signed_t>(const ITensor *in, ITensor *out, const Window &window);
88 template void neon_logits_1d_max<qasymm8_t>(const ITensor *in, ITensor *out, const Window &window);
89 
90 template <typename T>
neon_softmax_logits_1d_quantized(const ITensor * in,const ITensor * max,void * const tmp,ITensor * out,float beta,bool is_log,const Window & window)91 void neon_softmax_logits_1d_quantized(const ITensor *in, const ITensor *max, void *const tmp,
92                                       ITensor *out, float beta, bool is_log, const Window &window)
93 {
94     static_assert(std::is_same<T, qasymm8_t>::value
95                   || std::is_same<T, qasymm8_signed_t>::value,
96                   "quantized type should be either qasymm8_t or qasymm8_signed_t.");
97 
98     const int start_x     = in->info()->valid_region().anchor.x();
99     const int input_width = in->info()->valid_region().shape.x();
100 
101     const float scale_beta     = -beta * in->info()->quantization_info().uniform().scale;
102     const auto  scale_beta_vec = vdupq_n_f32(scale_beta);
103 
104     Iterator      in_it(in, window);
105     Iterator      max_it(max, window);
106     Iterator      out_it(out, window);
107     constexpr int vec_size = 16;
108 
109     execute_window_loop(window, [&](const Coordinates &)
110     {
111         /* Get pointers */
112         const auto in_ptr  = reinterpret_cast<const T *>(in_it.ptr()) + start_x;
113         const auto out_ptr = reinterpret_cast<T *>(out_it.ptr()) + start_x;
114         const auto tmp_ptr = reinterpret_cast<float *>(tmp);
115 
116         float sum{};
117         float sum_inversed{};
118 
119         /* Compute exponentials and sum */
120         {
121             /* Get max value */
122             const auto max_val = *reinterpret_cast<const T *>(max_it.ptr());
123             const auto vec_max = wrapper::vdup_n(max_val, wrapper::traits::vector_128_tag{});
124 
125             /* Init sum to zero */
126             float32x4x4_t vec_sum =
127             {
128                 vdupq_n_f32(0.f),
129                 vdupq_n_f32(0.f),
130                 vdupq_n_f32(0.f),
131                 vdupq_n_f32(0.f),
132             };
133 
134             /* Loop over row and compute exponentials and sum */
135             int x = 0;
136             for(; x <= (input_width - vec_size); x += vec_size)
137             {
138                 auto vec_elements     = wrapper::vloadq(in_ptr + x);
139                 vec_elements          = wrapper::vqsub(vec_max, vec_elements);
140                 auto vec_elements_flt = convert_int_to_float<float32x4x4_t>(vec_elements);
141 
142                 if(is_log)
143                 {
144                     vec_elements_flt.val[0] = vmulq_f32(vec_elements_flt.val[0], scale_beta_vec);
145                     vec_elements_flt.val[1] = vmulq_f32(vec_elements_flt.val[1], scale_beta_vec);
146                     vec_elements_flt.val[2] = vmulq_f32(vec_elements_flt.val[2], scale_beta_vec);
147                     vec_elements_flt.val[3] = vmulq_f32(vec_elements_flt.val[3], scale_beta_vec);
148                     vec_sum.val[0]          = vaddq_f32(vec_sum.val[0], vexpq_f32(vec_elements_flt.val[0]));
149                     vec_sum.val[1]          = vaddq_f32(vec_sum.val[1], vexpq_f32(vec_elements_flt.val[1]));
150                     vec_sum.val[2]          = vaddq_f32(vec_sum.val[2], vexpq_f32(vec_elements_flt.val[2]));
151                     vec_sum.val[3]          = vaddq_f32(vec_sum.val[3], vexpq_f32(vec_elements_flt.val[3]));
152                 }
153                 else
154                 {
155                     vec_elements_flt.val[0] = vexpq_f32(vmulq_f32(vec_elements_flt.val[0], scale_beta_vec));
156                     vec_elements_flt.val[1] = vexpq_f32(vmulq_f32(vec_elements_flt.val[1], scale_beta_vec));
157                     vec_elements_flt.val[2] = vexpq_f32(vmulq_f32(vec_elements_flt.val[2], scale_beta_vec));
158                     vec_elements_flt.val[3] = vexpq_f32(vmulq_f32(vec_elements_flt.val[3], scale_beta_vec));
159                     vec_sum.val[0]          = vaddq_f32(vec_sum.val[0], vec_elements_flt.val[0]);
160                     vec_sum.val[1]          = vaddq_f32(vec_sum.val[1], vec_elements_flt.val[1]);
161                     vec_sum.val[2]          = vaddq_f32(vec_sum.val[2], vec_elements_flt.val[2]);
162                     vec_sum.val[3]          = vaddq_f32(vec_sum.val[3], vec_elements_flt.val[3]);
163                 }
164 
165                 vst4q_f32(tmp_ptr + x, vec_elements_flt);
166             }
167 
168             /* Reduce sum */
169             const auto sum_16_byte = vaddq_f32(vaddq_f32(vec_sum.val[0], vec_sum.val[1]), vaddq_f32(vec_sum.val[2], vec_sum.val[3]));
170             auto       sum_res     = vpadd_f32(vget_high_f32(sum_16_byte), vget_low_f32(sum_16_byte));
171             sum_res                = vpadd_f32(sum_res, sum_res);
172             sum                    = wrapper::vgetlane(sum_res, 0);
173 
174             /* Run remaining elements */
175             for(; x < input_width; ++x)
176             {
177                 float element{};
178                 if(is_log)
179                 {
180                     element = (max_val - in_ptr[x]) * scale_beta;
181                     sum += std::exp(element);
182                 }
183                 else
184                 {
185                     element = std::exp((max_val - in_ptr[x]) * scale_beta);
186                     sum += element;
187                 }
188 
189                 tmp_ptr[x] = element;
190             }
191 
192             if(!is_log)
193             {
194                 sum_inversed = 256.f / sum;
195             }
196             else
197             {
198                 sum = std::log(sum);
199             }
200         }
201 
202         /* Normalize exponentials */
203         {
204             constexpr bool is_qasymm8_signed = std::is_same<T, qasymm8_signed_t>::value;
205             /* Loop over row and compute softmax */
206             int x = 0;
207             for(; x <= (input_width - vec_size); x += vec_size)
208             {
209                 using int_vec_type   = wrapper::traits::neon_vector_t<T, 16>;
210                 float32x4x4_t vec_in = vld4q_f32(tmp_ptr + x);
211                 int_vec_type  normalized_value{};
212                 if(is_log)
213                 {
214                     const float32x4x4_t sub =
215                     {
216                         vsubq_f32(vec_in.val[0], vdupq_n_f32(sum)),
217                         vsubq_f32(vec_in.val[1], vdupq_n_f32(sum)),
218                         vsubq_f32(vec_in.val[2], vdupq_n_f32(sum)),
219                         vsubq_f32(vec_in.val[3], vdupq_n_f32(sum)),
220                     };
221                     normalized_value = convert_float_to_int<float32x4x4_t, int_vec_type>(sub);
222                 }
223                 else
224                 {
225                     float32x4x4_t mul =
226                     {
227                         vmulq_f32(vec_in.val[0], vdupq_n_f32(sum_inversed)),
228                         vmulq_f32(vec_in.val[1], vdupq_n_f32(sum_inversed)),
229                         vmulq_f32(vec_in.val[2], vdupq_n_f32(sum_inversed)),
230                         vmulq_f32(vec_in.val[3], vdupq_n_f32(sum_inversed)),
231                     };
232 
233                     if(is_qasymm8_signed)
234                     {
235                         const auto offset_vec = wrapper::vdup_n(128.f, wrapper::traits::vector_128_tag{});
236                         mul.val[0]            = wrapper::vsub(mul.val[0], offset_vec);
237                         mul.val[1]            = wrapper::vsub(mul.val[1], offset_vec);
238                         mul.val[2]            = wrapper::vsub(mul.val[2], offset_vec);
239                         mul.val[3]            = wrapper::vsub(mul.val[3], offset_vec);
240                     }
241 
242                     normalized_value = convert_float_to_int<float32x4x4_t, int_vec_type>(mul);
243                 }
244                 wrapper::vstore(out_ptr + x, normalized_value);
245             }
246             /* Run remaining elements */
247             for(; x < input_width; ++x)
248             {
249                 if(is_log)
250                 {
251                     out_ptr[x] = utils::cast::saturate_cast<T>(tmp_ptr[x] - sum);
252                 }
253                 else
254                 {
255                     out_ptr[x] = utils::cast::saturate_cast<T>((tmp_ptr[x] * sum_inversed) - (is_qasymm8_signed ? 128.f : 0));
256                 }
257             }
258         }
259     },
260     in_it, max_it, out_it);
261 }
262 
263 template void neon_softmax_logits_1d_quantized<qasymm8_signed_t>(const ITensor *in, const ITensor *max, void *const tmp,
264                                                                  ITensor *out, float beta, bool is_log, const Window &window);
265 template void neon_softmax_logits_1d_quantized<qasymm8_t>(const ITensor *in, const ITensor *max, void *const tmp,
266                                                           ITensor *out, float beta, bool is_log, const Window &window);
267 template <typename T>
neon_softmax_logits_1d_float(const ITensor * in,const ITensor * max,void * const tmp,ITensor * out,const float beta,bool is_log,const Window & window)268 void neon_softmax_logits_1d_float(const ITensor *in, const ITensor *max, void *const tmp,
269                                   ITensor *out, const float beta, bool is_log, const Window &window)
270 {
271     const int start_x     = in->info()->valid_region().anchor.x();
272     const int input_width = in->info()->valid_region().shape.x();
273 
274     Iterator in_it(in, window);
275     Iterator max_it(max, window);
276     Iterator out_it(out, window);
277 
278     /** SIMD vector tag type. */
279     using ExactTagType = typename wrapper::traits::neon_bitvector_tag_t<T, wrapper::traits::BitWidth::W128>;
280 
281     constexpr int vec_size   = 16 / sizeof(T);
282     const int     sum_stages = log2(vec_size / 2);
283 
284     execute_window_loop(window, [&](const Coordinates &)
285     {
286         /* Get pointers */
287         const auto in_ptr  = reinterpret_cast<const T *>(in_it.ptr()) + start_x;
288         const auto out_ptr = reinterpret_cast<T *>(out_it.ptr()) + start_x;
289         const auto tmp_ptr = reinterpret_cast<T *>(tmp);
290 
291         T sum{};
292         T sum_inversed{};
293 
294         /* Compute exponentials and sum */
295         {
296             /* Get max value */
297             const auto max_val = *reinterpret_cast<const T *>(max_it.ptr());
298             const auto vec_max = wrapper::vdup_n(max_val, ExactTagType{});
299 
300             /* Init sum to zero */
301             auto vec_sum = wrapper::vdup_n(static_cast<T>(0), ExactTagType{});
302 
303             /* Loop over row and compute exponentials and sum */
304             int x = 0;
305             for(; x <= (input_width - vec_size); x += vec_size)
306             {
307                 auto vec_elements = wrapper::vloadq(in_ptr + x);
308                 vec_elements      = wrapper::vsub(vec_elements, vec_max);
309                 if(is_log)
310                 {
311                     vec_elements = wrapper::vmul(vec_elements, wrapper::vdup_n(static_cast<T>(beta), ExactTagType{}));
312                     vec_sum      = wrapper::vadd(vec_sum, wrapper::vexpq(vec_elements));
313                 }
314                 else
315                 {
316                     vec_elements = wrapper::vexpq(wrapper::vmul(vec_elements, wrapper::vdup_n(static_cast<T>(beta), ExactTagType{})));
317                     vec_sum      = wrapper::vadd(vec_sum, vec_elements);
318                 }
319                 wrapper::vstore(tmp_ptr + x, vec_elements);
320             }
321 
322             /* Reduce sum */
323             auto sum_res = wrapper::vpadd(wrapper::vgethigh(vec_sum), wrapper::vgetlow(vec_sum));
324             for(int i = 0; i < sum_stages; ++i)
325             {
326                 sum_res = wrapper::vpadd(sum_res, sum_res);
327             }
328             sum = wrapper::vgetlane(sum_res, 0);
329 
330             /* Run remaining elements */
331             for(; x < input_width; ++x)
332             {
333                 T element{};
334 
335                 if(is_log)
336                 {
337                     element = (in_ptr[x] - max_val) * beta;
338                     sum += std::exp(element);
339                 }
340                 else
341                 {
342                     element = std::exp((in_ptr[x] - max_val) * beta);
343                     sum += element;
344                 }
345                 tmp_ptr[x] = element;
346             }
347 
348             if(!is_log)
349             {
350                 sum_inversed = T(1) / sum;
351             }
352             else
353             {
354                 sum = static_cast<T>(std::log(sum));
355             }
356         }
357 
358         /* Normalize exponentials */
359         {
360             /* Loop over row and compute softmax */
361             int x = 0;
362             for(; x <= (input_width - vec_size); x += vec_size)
363             {
364                 auto vec_in           = wrapper::vloadq(tmp_ptr + x);
365                 auto normalized_value = wrapper::vdup_n(static_cast<T>(0), ExactTagType{});
366                 if(is_log)
367                 {
368                     normalized_value = wrapper::vsub(vec_in, wrapper::vdup_n(static_cast<T>(sum), ExactTagType{}));
369                 }
370                 else
371                 {
372                     normalized_value = wrapper::vmul(vec_in, wrapper::vdup_n(static_cast<T>(sum_inversed), ExactTagType{}));
373                 }
374                 wrapper::vstore(out_ptr + x, normalized_value);
375             }
376             /* Run remaining elements */
377             for(; x < input_width; ++x)
378             {
379                 if(is_log)
380                 {
381                     out_ptr[x] = tmp_ptr[x] - sum;
382                 }
383                 else
384                 {
385                     out_ptr[x] = tmp_ptr[x] * sum_inversed;
386                 }
387             }
388         }
389     },
390     in_it, max_it, out_it);
391 }
392 #if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
393 template void neon_softmax_logits_1d_float<float16_t>(const ITensor *in, const ITensor *max, void *const tmp,
394                                                       ITensor *out, const float beta, bool is_log, const Window &window);
395 #endif //defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
396 template void neon_softmax_logits_1d_float<float>(const ITensor *in, const ITensor *max, void *const tmp,
397                                                   ITensor *out, const float beta, bool is_log, const Window &window);
398 } // namespace cpu
399 } // namespace arm_compute
400