1 /*
2 * Copyright (c) 2021-2022 Arm Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24
25 #include "src/cpu/kernels/softmax/generic/sve/impl.h"
26 #include "src/core/NEON/wrapper/intrinsics/intrinsics.h"
27
28 namespace arm_compute
29 {
30 namespace cpu
31 {
32 template <typename ScalarType>
sve_logits_1d_max(const ITensor * in,ITensor * out,const Window & window)33 void sve_logits_1d_max(const ITensor *in, ITensor *out, const Window &window)
34 {
35 const auto all_true_pg = wrapper::svptrue<ScalarType>();
36 const auto window_start_x = static_cast<int>(window.x().start());
37 const auto window_end_x = static_cast<int>(window.x().end());
38
39 Window win{ window };
40 win.set(Window::DimX, Window::Dimension(0, 1, 1));
41 Iterator input(in, win);
42 Iterator output(out, win);
43
44 execute_window_loop(win, [&](const Coordinates &)
45 {
46 // Get pointers
47 const auto in_ptr = reinterpret_cast<const ScalarType *>(input.ptr());
48 const auto out_ptr = reinterpret_cast<ScalarType *>(output.ptr());
49
50 // Init max value
51 auto vec_max = wrapper::svdup_n(support::cpp11::lowest<ScalarType>());
52
53 int x = window_start_x;
54 svbool_t pg = wrapper::svwhilelt<ScalarType>(x, window_end_x);
55 do
56 {
57 const auto current_value = svld1(pg, in_ptr + x);
58 vec_max = svmax_m(pg, vec_max, current_value);
59
60 x += wrapper::svcnt<ScalarType>();
61 pg = wrapper::svwhilelt<ScalarType>(x, window_end_x);
62 }
63 while(svptest_any(all_true_pg, pg));
64
65 auto max_val = svmaxv(all_true_pg, vec_max);
66
67 *out_ptr = max_val;
68 },
69 input, output);
70 }
71
72 template <typename ScalarType>
sve_softmax_logits_1d_float(const ITensor * in,const ITensor * max,void * const tmp,ITensor * out,const float beta,bool is_log,const Window & window)73 void sve_softmax_logits_1d_float(const ITensor *in, const ITensor *max, void *const tmp,
74 ITensor *out, const float beta, bool is_log, const Window &window)
75 {
76 const int start_x = in->info()->valid_region().anchor.x();
77 const int input_width = in->info()->valid_region().shape.x();
78
79 Iterator in_it(in, window);
80 Iterator max_it(max, window);
81 Iterator out_it(out, window);
82
83 const auto all_true_pg = wrapper::svptrue<ScalarType>();
84
85 execute_window_loop(window, [&](const Coordinates &)
86 {
87 /* Get pointers */
88 const auto in_ptr = reinterpret_cast<const ScalarType *>(in_it.ptr()) + start_x;
89 const auto out_ptr = reinterpret_cast<ScalarType *>(out_it.ptr()) + start_x;
90 const auto tmp_ptr = reinterpret_cast<ScalarType *>(tmp);
91
92 ScalarType sum{ 0 };
93
94 /* Compute exponentials and sum */
95 {
96 /* Get max value */
97 const auto max_val = *reinterpret_cast<const ScalarType *>(max_it.ptr());
98 const auto vec_max = wrapper::svdup_n(max_val);
99 const auto vec_beta = wrapper::svdup_n(static_cast<ScalarType>(beta));
100
101 /* Init sum to zero */
102 auto vec_sum = wrapper::svdup_n(static_cast<ScalarType>(0));
103
104 /* Loop over row and compute exponentials and sum */
105 int x = 0;
106 svbool_t pg = wrapper::svwhilelt<ScalarType>(x, input_width);
107 do
108 {
109 auto vec_elements = svld1(pg, in_ptr + x);
110 vec_elements = svmul_z(pg, svsub_z(pg, vec_elements, vec_max), vec_beta);
111 if(!is_log)
112 {
113 vec_elements = wrapper::svexp_z(pg, vec_elements);
114 vec_sum = svadd_m(pg, vec_sum, vec_elements);
115 }
116 svst1(pg, tmp_ptr + x, vec_elements);
117
118 if(is_log)
119 {
120 vec_sum = svadd_m(pg, vec_sum, wrapper::svexp_z(pg, vec_elements));
121 }
122
123 x += wrapper::svcnt<ScalarType>();
124 pg = wrapper::svwhilelt<ScalarType>(x, input_width);
125 }
126 while(svptest_any(all_true_pg, pg));
127
128 /* Reduce sum */
129 sum = svaddv(all_true_pg, vec_sum);
130
131 if(is_log)
132 {
133 sum = static_cast<ScalarType>(std::log(sum));
134 }
135 else
136 {
137 sum = ScalarType(1) / sum;
138 }
139 }
140
141 /* Normalize exponentials */
142 {
143 /* Loop over row and compute softmax */
144 int x = 0;
145 svbool_t pg = wrapper::svwhilelt<ScalarType>(x, input_width);
146 do
147 {
148 auto vec_in = svld1(pg, tmp_ptr + x);
149 auto normalized_value = wrapper::svdup_n(static_cast<ScalarType>(0));
150 if(is_log)
151 {
152 normalized_value = svsub_z(pg, vec_in, wrapper::svdup_n(static_cast<ScalarType>(sum)));
153 }
154 else
155 {
156 normalized_value = svmul_z(pg, vec_in, wrapper::svdup_n(static_cast<ScalarType>(sum)));
157 }
158 svst1(pg, out_ptr + x, normalized_value);
159
160 x += wrapper::svcnt<ScalarType>();
161 pg = wrapper::svwhilelt<ScalarType>(x, input_width);
162 }
163 while(svptest_any(all_true_pg, pg));
164 }
165 },
166 in_it, max_it, out_it);
167 }
168
169 template void sve_logits_1d_max<float>(const ITensor *in, ITensor *out, const Window &window);
170 template void sve_logits_1d_max<float16_t>(const ITensor *in, ITensor *out, const Window &window);
171 template void sve_logits_1d_max<qasymm8_t>(const ITensor *in, ITensor *out, const Window &window);
172 template void sve_logits_1d_max<qasymm8_signed_t>(const ITensor *in, ITensor *out, const Window &window);
173
174 template void sve_softmax_logits_1d_float<float>(const ITensor *in, const ITensor *max, void *const tmp,
175 ITensor *out, const float beta, bool is_log, const Window &window);
176 template void sve_softmax_logits_1d_float<float16_t>(const ITensor *in, const ITensor *max, void *const tmp,
177 ITensor *out, const float beta, bool is_log, const Window &window);
178 } // namespace cpu
179 } // namespace arm_compute
180