1 /*
2 * Copyright (c) 2019-2021 Arm Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24 #include "src/core/NEON/kernels/NEFFTRadixStageKernel.h"
25
26 #include "arm_compute/core/ITensor.h"
27 #include "arm_compute/core/TensorInfo.h"
28 #include "arm_compute/core/Types.h"
29 #include "arm_compute/core/Utils.h"
30 #include "arm_compute/core/Window.h"
31 #include "src/core/NEON/wrapper/traits.h"
32 #include "src/core/NEON/wrapper/wrapper.h"
33 #include "src/core/helpers/AutoConfiguration.h"
34 #include "src/core/helpers/WindowHelpers.h"
35 #include "support/ToolchainSupport.h"
36
37 #include <arm_neon.h>
38 #include <cmath>
39 #include <complex>
40 #include <map>
41
42 namespace arm_compute
43 {
44 namespace
45 {
46 // PI constant (from cmath)
47 constexpr float kPi = float(M_PI);
48
49 // Constant used in the fft_3 kernel
50 constexpr float kSqrt3Div2 = 0.866025403784438;
51
52 // Constants used in the fft_5 kernel
53 constexpr float kW5_0 = 0.30901699437494f;
54 constexpr float kW5_1 = 0.95105651629515f;
55 constexpr float kW5_2 = 0.80901699437494f;
56 constexpr float kW5_3 = 0.58778525229247f;
57
58 // Constants used in the fft_7 kernel
59 constexpr float kW7_0 = 0.62348980185873f;
60 constexpr float kW7_1 = 0.78183148246802f;
61 constexpr float kW7_2 = 0.22252093395631f;
62 constexpr float kW7_3 = 0.97492791218182f;
63 constexpr float kW7_4 = 0.90096886790241f;
64 constexpr float kW7_5 = 0.43388373911755f;
65
66 // Constant used in the fft_8 kernel
67 constexpr float kSqrt2Div2 = 0.707106781186548;
68
c_mul_neon(float32x2_t a,float32x2_t b)69 float32x2_t c_mul_neon(float32x2_t a, float32x2_t b)
70 {
71 using ExactTagType = typename wrapper::traits::neon_vector<float, 2>::tag_type;
72
73 const float32x2_t mask = { -1.0, 1.0 };
74 const float32x2_t tmp0 = wrapper::vdup_n(wrapper::vgetlane(a, 0), ExactTagType{});
75 const float32x2_t tmp1 = wrapper::vdup_n(wrapper::vgetlane(a, 1), ExactTagType{});
76
77 float32x2_t res = wrapper::vmul(tmp0, b);
78
79 b = wrapper::vrev64(b);
80 b = wrapper::vmul(b, mask);
81 res = wrapper::vmla(res, tmp1, b);
82
83 return res;
84 }
85
c_mul_neon_img(float32x2_t a,float img_constant)86 float32x2_t c_mul_neon_img(float32x2_t a, float img_constant)
87 {
88 const float a_r = wrapper::vgetlane(a, 0);
89 const float a_i = wrapper::vgetlane(a, 1);
90
91 const auto out = wrapper::vmul(float32x2_t{ -a_i, a_r }, float32x2_t{ img_constant, img_constant });
92 return out;
93 }
94
reduce_sum_5(float32x2_t a,float32x2_t b,float32x2_t c,float32x2_t d,float32x2_t e)95 float32x2_t reduce_sum_5(float32x2_t a, float32x2_t b, float32x2_t c, float32x2_t d, float32x2_t e)
96 {
97 const auto t0 = wrapper::vadd(a, b);
98 const auto t1 = wrapper::vadd(c, d);
99 const auto t2 = wrapper::vadd(t0, t1);
100 return wrapper::vadd(t2, e);
101 }
102
reduce_sum_7(float32x2_t x1,float32x2_t x2,float32x2_t x3,float32x2_t x4,float32x2_t x5,float32x2_t x6,float32x2_t x7)103 float32x2_t reduce_sum_7(float32x2_t x1, float32x2_t x2, float32x2_t x3, float32x2_t x4, float32x2_t x5, float32x2_t x6, float32x2_t x7)
104 {
105 const auto t0 = wrapper::vadd(x1, x2);
106 const auto t1 = wrapper::vadd(x3, x4);
107 const auto t2 = wrapper::vadd(x5, x6);
108 const auto t00 = wrapper::vadd(t0, t1);
109 const auto t01 = wrapper::vadd(t2, x7);
110
111 return wrapper::vadd(t00, t01);
112 }
113
reduce_sum_8(float32x2_t x1,float32x2_t x2,float32x2_t x3,float32x2_t x4,float32x2_t x5,float32x2_t x6,float32x2_t x7,float32x2_t x8)114 float32x2_t reduce_sum_8(float32x2_t x1, float32x2_t x2, float32x2_t x3, float32x2_t x4, float32x2_t x5, float32x2_t x6, float32x2_t x7, float32x2_t x8)
115 {
116 const auto t0 = wrapper::vadd(x1, x2);
117 const auto t1 = wrapper::vadd(x3, x4);
118 const auto t2 = wrapper::vadd(x5, x6);
119 const auto t3 = wrapper::vadd(x7, x8);
120 const auto t00 = wrapper::vadd(t0, t1);
121 const auto t01 = wrapper::vadd(t2, t3);
122
123 return wrapper::vadd(t00, t01);
124 }
125
fft_2(float32x2_t & x,float32x2_t & y,float32x2_t & w)126 void fft_2(float32x2_t &x, float32x2_t &y, float32x2_t &w)
127 {
128 float32x2_t a = x;
129 float32x2_t b = c_mul_neon(w, y);
130
131 x = wrapper::vadd(a, b);
132 y = wrapper::vsub(a, b);
133 }
134
fft_3(float32x2_t & x,float32x2_t & y,float32x2_t & z,const float32x2_t & w,const float32x2_t & w2)135 void fft_3(float32x2_t &x, float32x2_t &y, float32x2_t &z, const float32x2_t &w, const float32x2_t &w2)
136 {
137 float32x2_t a = x;
138 float32x2_t b = c_mul_neon(w, y);
139 float32x2_t c = c_mul_neon(w2, z);
140
141 x = wrapper::vadd(a, b);
142 x = wrapper::vadd(x, c);
143
144 const auto v1 = wrapper::vmul(float32x2_t{ 0.5f, 0.5 }, wrapper::vadd(b, c));
145 const auto v2 = c_mul_neon(float32x2_t{ 0.f, -kSqrt3Div2 }, wrapper::vsub(b, c));
146
147 y = z = wrapper::vsub(a, v1);
148 y = wrapper::vadd(y, v2);
149 z = wrapper::vsub(z, v2);
150 }
151
fft_4(float32x2_t & x1,float32x2_t & x2,float32x2_t & x3,float32x2_t & x4,const float32x2_t & w,const float32x2_t & w2,const float32x2_t & w3)152 void fft_4(float32x2_t &x1, float32x2_t &x2, float32x2_t &x3, float32x2_t &x4, const float32x2_t &w, const float32x2_t &w2, const float32x2_t &w3)
153 {
154 float32x2_t a = x1;
155 float32x2_t b = c_mul_neon(w, x2);
156 float32x2_t c = c_mul_neon(w2, x3);
157 float32x2_t d = c_mul_neon(w3, x4);
158
159 const auto x11 = wrapper::vadd(a, b);
160 const auto x12 = wrapper::vadd(c, d);
161 x1 = wrapper::vadd(x11, x12);
162
163 const auto x21 = wrapper::vadd(a, c_mul_neon_img(b, -1));
164 const auto x22 = wrapper::vadd(wrapper::vneg(c), c_mul_neon_img(d, 1.f));
165 x2 = wrapper::vadd(x21, x22);
166
167 const auto x31 = wrapper::vadd(a, wrapper::vneg(b));
168 const auto x32 = wrapper::vadd(c, wrapper::vneg(d));
169 x3 = wrapper::vadd(x31, x32);
170
171 const auto x41 = wrapper::vadd(a, c_mul_neon_img(b, 1));
172 const auto x42 = wrapper::vadd(wrapper::vneg(c), c_mul_neon_img(d, -1));
173 x4 = wrapper::vadd(x41, x42);
174 }
175
fft_5(float32x2_t & x1,float32x2_t & x2,float32x2_t & x3,float32x2_t & x4,float32x2_t & x5,const float32x2_t & w,const float32x2_t & w2,const float32x2_t & w3,const float32x2_t & w4)176 void fft_5(float32x2_t &x1, float32x2_t &x2, float32x2_t &x3, float32x2_t &x4, float32x2_t &x5, const float32x2_t &w, const float32x2_t &w2, const float32x2_t &w3, const float32x2_t &w4)
177 {
178 const auto a = x1;
179 const auto b = c_mul_neon(w, x2);
180 const auto c = c_mul_neon(w2, x3);
181 const auto d = c_mul_neon(w3, x4);
182 const auto e = c_mul_neon(w4, x5);
183
184 const auto b0 = c_mul_neon(float32x2_t{ kW5_0, -kW5_1 }, b);
185 const auto b1 = c_mul_neon(float32x2_t{ -kW5_2, -kW5_3 }, b);
186 const auto b2 = c_mul_neon(float32x2_t{ -kW5_2, kW5_3 }, b);
187 const auto b3 = c_mul_neon(float32x2_t{ kW5_0, kW5_1 }, b);
188
189 const auto c0 = c_mul_neon(float32x2_t{ -kW5_2, -kW5_3 }, c);
190 const auto c1 = c_mul_neon(float32x2_t{ kW5_0, kW5_1 }, c);
191 const auto c2 = c_mul_neon(float32x2_t{ kW5_0, -kW5_1 }, c);
192 const auto c3 = c_mul_neon(float32x2_t{ -kW5_2, kW5_3 }, c);
193
194 const auto d0 = c_mul_neon(float32x2_t{ -kW5_2, kW5_3 }, d);
195 const auto d1 = c_mul_neon(float32x2_t{ kW5_0, -kW5_1 }, d);
196 const auto d2 = c_mul_neon(float32x2_t{ kW5_0, kW5_1 }, d);
197 const auto d3 = c_mul_neon(float32x2_t{ -kW5_2, -kW5_3 }, d);
198
199 const auto e0 = c_mul_neon(float32x2_t{ kW5_0, kW5_1 }, e);
200 const auto e1 = c_mul_neon(float32x2_t{ -kW5_2, kW5_3 }, e);
201 const auto e2 = c_mul_neon(float32x2_t{ -kW5_2, -kW5_3 }, e);
202 const auto e3 = c_mul_neon(float32x2_t{ kW5_0, -kW5_1 }, e);
203
204 x1 = reduce_sum_5(a, b, c, d, e);
205 x2 = reduce_sum_5(a, b0, c0, d0, e0);
206 x3 = reduce_sum_5(a, b1, c1, d1, e1);
207 x4 = reduce_sum_5(a, b2, c2, d2, e2);
208 x5 = reduce_sum_5(a, b3, c3, d3, e3);
209 }
210
fft_7(float32x2_t & x1,float32x2_t & x2,float32x2_t & x3,float32x2_t & x4,float32x2_t & x5,float32x2_t & x6,float32x2_t & x7,const float32x2_t & w,const float32x2_t & w2,const float32x2_t & w3,const float32x2_t & w4,const float32x2_t & w5,const float32x2_t & w6)211 void fft_7(float32x2_t &x1, float32x2_t &x2, float32x2_t &x3, float32x2_t &x4, float32x2_t &x5, float32x2_t &x6, float32x2_t &x7, const float32x2_t &w, const float32x2_t &w2, const float32x2_t &w3,
212 const float32x2_t &w4,
213 const float32x2_t &w5, const float32x2_t &w6)
214 {
215 const auto a = x1;
216 const auto b = c_mul_neon(w, x2);
217 const auto c = c_mul_neon(w2, x3);
218 const auto d = c_mul_neon(w3, x4);
219 const auto e = c_mul_neon(w4, x5);
220 const auto f = c_mul_neon(w5, x6);
221 const auto g = c_mul_neon(w6, x7);
222
223 const auto b0 = c_mul_neon(float32x2_t{ kW7_0, -kW7_1 }, b);
224 const auto b1 = c_mul_neon(float32x2_t{ -kW7_2, -kW7_3 }, b);
225 const auto b2 = c_mul_neon(float32x2_t{ -kW7_4, -kW7_5 }, b);
226 const auto b3 = c_mul_neon(float32x2_t{ -kW7_4, kW7_5 }, b);
227 const auto b4 = c_mul_neon(float32x2_t{ -kW7_2, kW7_3 }, b);
228 const auto b5 = c_mul_neon(float32x2_t{ kW7_0, kW7_1 }, b);
229
230 const auto c0 = c_mul_neon(float32x2_t{ -kW7_2, -kW7_3 }, c);
231 const auto c1 = c_mul_neon(float32x2_t{ -kW7_4, kW7_5 }, c);
232 const auto c2 = c_mul_neon(float32x2_t{ kW7_0, kW7_1 }, c);
233 const auto c3 = c_mul_neon(float32x2_t{ kW7_0, -kW7_1 }, c);
234 const auto c4 = c_mul_neon(float32x2_t{ -kW7_4, -kW7_5 }, c);
235 const auto c5 = c_mul_neon(float32x2_t{ -kW7_2, kW7_3 }, c);
236
237 const auto d0 = c_mul_neon(float32x2_t{ -kW7_4, -kW7_5 }, d);
238 const auto d1 = c_mul_neon(float32x2_t{ kW7_0, kW7_1 }, d);
239 const auto d2 = c_mul_neon(float32x2_t{ -kW7_2, -kW7_3 }, d);
240 const auto d3 = c_mul_neon(float32x2_t{ -kW7_2, +kW7_3 }, d);
241 const auto d4 = c_mul_neon(float32x2_t{ kW7_0, -kW7_1 }, d);
242 const auto d5 = c_mul_neon(float32x2_t{ -kW7_4, kW7_5 }, d);
243
244 const auto e0 = c_mul_neon(float32x2_t{ -kW7_4, kW7_5 }, e);
245 const auto e1 = c_mul_neon(float32x2_t{ kW7_0, -kW7_1 }, e);
246 const auto e2 = c_mul_neon(float32x2_t{ -kW7_2, kW7_3 }, e);
247 const auto e3 = c_mul_neon(float32x2_t{ -kW7_2, -kW7_3 }, e);
248 const auto e4 = c_mul_neon(float32x2_t{ kW7_0, kW7_1 }, e);
249 const auto e5 = c_mul_neon(float32x2_t{ -kW7_4, -kW7_5 }, e);
250
251 const auto f0 = c_mul_neon(float32x2_t{ -kW7_2, kW7_3 }, f);
252 const auto f1 = c_mul_neon(float32x2_t{ -kW7_4, -kW7_5 }, f);
253 const auto f2 = c_mul_neon(float32x2_t{ kW7_0, -kW7_1 }, f);
254 const auto f3 = c_mul_neon(float32x2_t{ kW7_0, kW7_1 }, f);
255 const auto f4 = c_mul_neon(float32x2_t{ -kW7_4, kW7_5 }, f);
256 const auto f5 = c_mul_neon(float32x2_t{ -kW7_2, -kW7_3 }, f);
257
258 const auto g0 = c_mul_neon(float32x2_t{ kW7_0, kW7_1 }, g);
259 const auto g1 = c_mul_neon(float32x2_t{ -kW7_2, kW7_3 }, g);
260 const auto g2 = c_mul_neon(float32x2_t{ -kW7_4, kW7_5 }, g);
261 const auto g3 = c_mul_neon(float32x2_t{ -kW7_4, -kW7_5 }, g);
262 const auto g4 = c_mul_neon(float32x2_t{ -kW7_2, -kW7_3 }, g);
263 const auto g5 = c_mul_neon(float32x2_t{ kW7_0, -kW7_1 }, g);
264
265 x1 = reduce_sum_7(a, b, c, d, e, f, g);
266 x2 = reduce_sum_7(a, b0, c0, d0, e0, f0, g0);
267 x3 = reduce_sum_7(a, b1, c1, d1, e1, f1, g1);
268 x4 = reduce_sum_7(a, b2, c2, d2, e2, f2, g2);
269 x5 = reduce_sum_7(a, b3, c3, d3, e3, f3, g3);
270 x6 = reduce_sum_7(a, b4, c4, d4, e4, f4, g4);
271 x7 = reduce_sum_7(a, b5, c5, d5, e5, f5, g5);
272 }
273
fft_8(float32x2_t & x1,float32x2_t & x2,float32x2_t & x3,float32x2_t & x4,float32x2_t & x5,float32x2_t & x6,float32x2_t & x7,float32x2_t & x8,const float32x2_t & w,const float32x2_t & w2,const float32x2_t & w3,const float32x2_t & w4,const float32x2_t & w5,const float32x2_t & w6,const float32x2_t & w7)274 void fft_8(float32x2_t &x1, float32x2_t &x2, float32x2_t &x3, float32x2_t &x4, float32x2_t &x5, float32x2_t &x6, float32x2_t &x7, float32x2_t &x8, const float32x2_t &w, const float32x2_t &w2,
275 const float32x2_t &w3,
276 const float32x2_t &w4, const float32x2_t &w5, const float32x2_t &w6,
277 const float32x2_t &w7)
278 {
279 const auto a = x1;
280 const auto b = c_mul_neon(w, x2);
281 const auto c = c_mul_neon(w2, x3);
282 const auto d = c_mul_neon(w3, x4);
283 const auto e = c_mul_neon(w4, x5);
284 const auto f = c_mul_neon(w5, x6);
285 const auto g = c_mul_neon(w6, x7);
286 const auto h = c_mul_neon(w7, x8);
287
288 const auto b0 = c_mul_neon(float32x2_t{ kSqrt2Div2, -kSqrt2Div2 }, b);
289 const auto b1 = c_mul_neon(float32x2_t{ 0, -1 }, b);
290 const auto b2 = c_mul_neon(float32x2_t{ -kSqrt2Div2, -kSqrt2Div2 }, b);
291 const auto b3 = c_mul_neon(float32x2_t{ -1, 0 }, b);
292 const auto b4 = c_mul_neon(float32x2_t{ -kSqrt2Div2, kSqrt2Div2 }, b);
293 const auto b5 = c_mul_neon(float32x2_t{ 0, 1 }, b);
294 const auto b6 = c_mul_neon(float32x2_t{ kSqrt2Div2, kSqrt2Div2 }, b);
295
296 const auto c0 = c_mul_neon(float32x2_t{ 0, -1 }, c);
297 const auto c1 = c_mul_neon(float32x2_t{ -1, 0 }, c);
298 const auto c2 = c_mul_neon(float32x2_t{ 0, 1 }, c);
299 const auto c3 = c_mul_neon(float32x2_t{ 1, 0 }, c);
300 const auto c4 = c_mul_neon(float32x2_t{ 0, -1 }, c);
301 const auto c5 = c_mul_neon(float32x2_t{ -1, 0 }, c);
302 const auto c6 = c_mul_neon(float32x2_t{ 0, 1 }, c);
303
304 const auto d0 = c_mul_neon(float32x2_t{ -kSqrt2Div2, -kSqrt2Div2 }, d);
305 const auto d1 = c_mul_neon(float32x2_t{ 0, 1 }, d);
306 const auto d2 = c_mul_neon(float32x2_t{ kSqrt2Div2, -kSqrt2Div2 }, d);
307 const auto d3 = c_mul_neon(float32x2_t{ -1, 0 }, d);
308 const auto d4 = c_mul_neon(float32x2_t{ kSqrt2Div2, kSqrt2Div2 }, d);
309 const auto d5 = c_mul_neon(float32x2_t{ 0, -1 }, d);
310 const auto d6 = c_mul_neon(float32x2_t{ -kSqrt2Div2, kSqrt2Div2 }, d);
311
312 const auto e0 = c_mul_neon(float32x2_t{ -1, 0 }, e);
313 const auto e1 = c_mul_neon(float32x2_t{ 1, 0 }, e);
314 const auto e2 = c_mul_neon(float32x2_t{ -1, 0 }, e);
315 const auto e3 = c_mul_neon(float32x2_t{ 1, 0 }, e);
316 const auto e4 = c_mul_neon(float32x2_t{ -1, 0 }, e);
317 const auto e5 = c_mul_neon(float32x2_t{ 1, 0 }, e);
318 const auto e6 = c_mul_neon(float32x2_t{ -1, 0 }, e);
319
320 const auto f0 = c_mul_neon(float32x2_t{ -kSqrt2Div2, kSqrt2Div2 }, f);
321 const auto f1 = c_mul_neon(float32x2_t{ 0, -1 }, f);
322 const auto f2 = c_mul_neon(float32x2_t{ kSqrt2Div2, kSqrt2Div2 }, f);
323 const auto f3 = c_mul_neon(float32x2_t{ -1, 0 }, f);
324 const auto f4 = c_mul_neon(float32x2_t{ kSqrt2Div2, -kSqrt2Div2 }, f);
325 const auto f5 = c_mul_neon(float32x2_t{ 0, 1 }, f);
326 const auto f6 = c_mul_neon(float32x2_t{ -kSqrt2Div2, -kSqrt2Div2 }, f);
327
328 const auto g0 = c_mul_neon(float32x2_t{ 0, 1 }, g);
329 const auto g1 = c_mul_neon(float32x2_t{ -1, 0 }, g);
330 const auto g2 = c_mul_neon(float32x2_t{ 0, -1 }, g);
331 const auto g3 = c_mul_neon(float32x2_t{ 1, 0 }, g);
332 const auto g4 = c_mul_neon(float32x2_t{ 0, 1 }, g);
333 const auto g5 = c_mul_neon(float32x2_t{ -1, 0 }, g);
334 const auto g6 = c_mul_neon(float32x2_t{ 0, -1 }, g);
335
336 const auto h0 = c_mul_neon(float32x2_t{ kSqrt2Div2, kSqrt2Div2 }, h);
337 const auto h1 = c_mul_neon(float32x2_t{ 0, 1 }, h);
338 const auto h2 = c_mul_neon(float32x2_t{ -kSqrt2Div2, kSqrt2Div2 }, h);
339 const auto h3 = c_mul_neon(float32x2_t{ -1, 0 }, h);
340 const auto h4 = c_mul_neon(float32x2_t{ -kSqrt2Div2, -kSqrt2Div2 }, h);
341 const auto h5 = c_mul_neon(float32x2_t{ 0, -1 }, h);
342 const auto h6 = c_mul_neon(float32x2_t{ kSqrt2Div2, -kSqrt2Div2 }, h);
343
344 x1 = reduce_sum_8(a, b, c, d, e, f, g, h);
345 x2 = reduce_sum_8(a, b0, c0, d0, e0, f0, g0, h0);
346 x3 = reduce_sum_8(a, b1, c1, d1, e1, f1, g1, h1);
347 x4 = reduce_sum_8(a, b2, c2, d2, e2, f2, g2, h2);
348 x5 = reduce_sum_8(a, b3, c3, d3, e3, f3, g3, h3);
349 x6 = reduce_sum_8(a, b4, c4, d4, e4, f4, g4, h4);
350 x7 = reduce_sum_8(a, b5, c5, d5, e5, f5, g5, h5);
351 x8 = reduce_sum_8(a, b6, c6, d6, e6, f6, g6, h6);
352 }
353
354 template <bool first_stage>
fft_radix_2_axes_0(float * out,float * in,unsigned int Nx,unsigned int NxRadix,const float32x2_t & w_m,unsigned int N)355 void fft_radix_2_axes_0(float *out, float *in, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int N)
356 {
357 float32x2_t w{ 1.0f, 0.0f };
358 for(unsigned int j = 0; j < Nx; j++)
359 {
360 for(unsigned int k = 2 * j; k < 2 * N; k += 2 * NxRadix)
361 {
362 auto a = float32x2_t{ 0, 0 };
363 auto b = float32x2_t{ 0, 0 };
364
365 // Load inputs
366 if(first_stage)
367 {
368 const auto ab = wrapper::vloadq(in + k);
369 a = wrapper::vgetlow(ab);
370 b = wrapper::vgethigh(ab);
371 }
372 else
373 {
374 a = wrapper::vload(in + k);
375 b = wrapper::vload(in + k + 2 * Nx);
376 }
377
378 // Base-case prime transform
379 fft_2(a, b, w);
380
381 // Write outputs
382 if(first_stage)
383 {
384 wrapper::vstore(out + k, wrapper::vcombine(a, b));
385 }
386 else
387 {
388 wrapper::vstore(out + k, a);
389 wrapper::vstore(out + k + 2 * Nx, b);
390 }
391 }
392
393 w = c_mul_neon(w, w_m);
394 }
395 }
396
fft_radix_2_axes_1(float * out,float * in,unsigned int Nx,unsigned int NxRadix,const float32x2_t & w_m,unsigned int N,unsigned int M,unsigned int in_pad_x,unsigned int out_pad_x)397 void fft_radix_2_axes_1(float *out, float *in, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int N, unsigned int M, unsigned int in_pad_x, unsigned int out_pad_x)
398 {
399 float32x2_t w{ 1.0f, 0.0f };
400 for(unsigned int j = 0; j < Nx; j++)
401 {
402 for(unsigned int k = 2 * j; k < 2 * M; k += 2 * NxRadix)
403 {
404 // Load inputs
405 float32x2_t a = wrapper::vload(in + (N + in_pad_x) * k);
406 float32x2_t b = wrapper::vload(in + (N + in_pad_x) * (k + 2 * Nx));
407
408 // Base-case prime transform
409 fft_2(a, b, w);
410
411 // Write outputs
412 wrapper::vstore(out + (N + out_pad_x) * k, a);
413 wrapper::vstore(out + (N + out_pad_x) * (k + 2 * Nx), b);
414 }
415
416 w = c_mul_neon(w, w_m);
417 }
418 }
419
420 template <bool first_stage>
fft_radix_3_axes_0(float * out,float * in,unsigned int Nx,unsigned int NxRadix,const float32x2_t & w_m,unsigned int N)421 void fft_radix_3_axes_0(float *out, float *in, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int N)
422 {
423 float32x2_t w{ 1.0f, 0.0f };
424 for(unsigned int j = 0; j < Nx; j++)
425 {
426 const auto w2 = c_mul_neon(w, w);
427
428 for(unsigned int k = 2 * j; k < 2 * N; k += 2 * NxRadix)
429 {
430 // Load inputs
431 float32x2_t a = { 0, 0 };
432 float32x2_t b = { 0, 0 };
433 float32x2_t c = { 0, 0 };
434 if(first_stage)
435 {
436 const auto ab = wrapper::vloadq(in + k);
437 a = wrapper::vgetlow(ab);
438 b = wrapper::vgethigh(ab);
439 }
440 else
441 {
442 a = wrapper::vload(in + k);
443 b = wrapper::vload(in + k + 2 * Nx);
444 }
445 c = wrapper::vload(in + k + 4 * Nx);
446
447 // Base-case prime transform
448 fft_3(a, b, c, w, w2);
449
450 if(first_stage)
451 {
452 wrapper::vstore(out + k, wrapper::vcombine(a, b));
453 }
454 else
455 {
456 wrapper::vstore(out + k, a);
457 wrapper::vstore(out + k + 2 * Nx, b);
458 }
459 wrapper::vstore(out + k + 4 * Nx, c);
460 }
461 w = c_mul_neon(w, w_m);
462 }
463 }
464
fft_radix_3_axes_1(float * out,float * in,unsigned int Nx,unsigned int NxRadix,const float32x2_t & w_m,unsigned int N,unsigned int M,unsigned int in_pad_x,unsigned int out_pad_x)465 void fft_radix_3_axes_1(float *out, float *in, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int N, unsigned int M, unsigned int in_pad_x, unsigned int out_pad_x)
466 {
467 float32x2_t w{ 1.0f, 0.0f };
468 for(unsigned int j = 0; j < Nx; j++)
469 {
470 const auto w2 = c_mul_neon(w, w);
471
472 for(unsigned int k = 2 * j; k < 2 * M; k += 2 * NxRadix)
473 {
474 // Load inputs
475 float32x2_t a = wrapper::vload(in + (N + in_pad_x) * k);
476 float32x2_t b = wrapper::vload(in + (N + in_pad_x) * (k + 2 * Nx));
477 float32x2_t c = wrapper::vload(in + (N + in_pad_x) * (k + 4 * Nx));
478
479 // Base-case prime transform
480 fft_3(a, b, c, w, w2);
481
482 // Store the output
483 wrapper::vstore(out + (N + out_pad_x) * k, a);
484 wrapper::vstore(out + (N + out_pad_x) * (k + 2 * Nx), b);
485 wrapper::vstore(out + (N + out_pad_x) * (k + 4 * Nx), c);
486 }
487 w = c_mul_neon(w, w_m);
488 }
489 }
490
491 template <bool first_stage>
fft_radix_4_axes_0(float * out,float * in,unsigned int Nx,unsigned int NxRadix,const float32x2_t & w_m,unsigned int N)492 void fft_radix_4_axes_0(float *out, float *in, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int N)
493 {
494 float32x2_t w{ 1.0f, 0.0f };
495 for(unsigned int j = 0; j < Nx; j++)
496 {
497 const auto w2 = c_mul_neon(w, w);
498 const auto w3 = c_mul_neon(w2, w);
499
500 for(unsigned int k = 2 * j; k < 2 * N; k += 2 * NxRadix)
501 {
502 float32x2_t a = { 0, 0 };
503 float32x2_t b = { 0, 0 };
504 float32x2_t c = { 0, 0 };
505 float32x2_t d = { 0, 0 };
506 if(first_stage)
507 {
508 const auto ab = wrapper::vloadq(in + k);
509 const auto cd = wrapper::vloadq(in + k + 4 * Nx);
510 a = wrapper::vgetlow(ab);
511 b = wrapper::vgethigh(ab);
512 c = wrapper::vgetlow(cd);
513 d = wrapper::vgethigh(cd);
514 }
515 else
516 {
517 // Load inputs
518 a = wrapper::vload(in + k);
519 b = wrapper::vload(in + k + 2 * Nx);
520 c = wrapper::vload(in + k + 4 * Nx);
521 d = wrapper::vload(in + k + 6 * Nx);
522 }
523
524 // Base-case prime transform
525 fft_4(a, b, c, d, w, w2, w3);
526
527 if(first_stage)
528 {
529 wrapper::vstore(out + k, wrapper::vcombine(a, b));
530 wrapper::vstore(out + k + 4 * Nx, wrapper::vcombine(c, d));
531 }
532 else
533 {
534 wrapper::vstore(out + k, a);
535 wrapper::vstore(out + k + 2 * Nx, b);
536 wrapper::vstore(out + k + 4 * Nx, c);
537 wrapper::vstore(out + k + 6 * Nx, d);
538 }
539 }
540
541 w = c_mul_neon(w, w_m);
542 }
543 }
544
fft_radix_4_axes_1(float * out,float * in,unsigned int Nx,unsigned int NxRadix,const float32x2_t & w_m,unsigned int N,unsigned int M,unsigned int in_pad_x,unsigned int out_pad_x)545 void fft_radix_4_axes_1(float *out, float *in, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int N, unsigned int M, unsigned int in_pad_x, unsigned int out_pad_x)
546 {
547 float32x2_t w{ 1.0f, 0.0f };
548 for(unsigned int j = 0; j < Nx; j++)
549 {
550 const auto w2 = c_mul_neon(w, w);
551 const auto w3 = c_mul_neon(w2, w);
552
553 for(unsigned int k = 2 * j; k < 2 * M; k += 2 * NxRadix)
554 {
555 // Load inputs
556 float32x2_t a = wrapper::vload(in + (N + in_pad_x) * k);
557 float32x2_t b = wrapper::vload(in + (N + in_pad_x) * (k + 2 * Nx));
558 float32x2_t c = wrapper::vload(in + (N + in_pad_x) * (k + 4 * Nx));
559 float32x2_t d = wrapper::vload(in + (N + in_pad_x) * (k + 6 * Nx));
560
561 // Base-case prime transform
562 fft_4(a, b, c, d, w, w2, w3);
563
564 wrapper::vstore(out + (N + out_pad_x) * k, a);
565 wrapper::vstore(out + (N + out_pad_x) * (k + 2 * Nx), b);
566 wrapper::vstore(out + (N + out_pad_x) * (k + 4 * Nx), c);
567 wrapper::vstore(out + (N + out_pad_x) * (k + 6 * Nx), d);
568 }
569
570 w = c_mul_neon(w, w_m);
571 }
572 }
573
574 template <bool first_stage>
fft_radix_5_axes_0(float * out,float * in,unsigned int Nx,unsigned int NxRadix,const float32x2_t & w_m,unsigned int N)575 void fft_radix_5_axes_0(float *out, float *in, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int N)
576 {
577 float32x2_t w{ 1.0f, 0.0f };
578 for(unsigned int j = 0; j < Nx; j++)
579 {
580 const float32x2_t w2 = c_mul_neon(w, w);
581 const float32x2_t w3 = c_mul_neon(w2, w);
582 const float32x2_t w4 = c_mul_neon(w3, w);
583
584 for(unsigned int k = 2 * j; k < 2 * N; k += 2 * NxRadix)
585 {
586 float32x2_t a = { 0, 0 };
587 float32x2_t b = { 0, 0 };
588 float32x2_t c = { 0, 0 };
589 float32x2_t d = { 0, 0 };
590 float32x2_t e = { 0, 0 };
591
592 // Load inputs
593 if(first_stage)
594 {
595 const auto ab = wrapper::vloadq(in + k);
596 const auto cd = wrapper::vloadq(in + k + 4 * Nx);
597
598 a = wrapper::vgetlow(ab);
599 b = wrapper::vgethigh(ab);
600 c = wrapper::vgetlow(cd);
601 d = wrapper::vgethigh(cd);
602 }
603 else
604 {
605 a = wrapper::vload(in + k);
606 b = wrapper::vload(in + k + 2 * Nx);
607 c = wrapper::vload(in + k + 4 * Nx);
608 d = wrapper::vload(in + k + 6 * Nx);
609 }
610 e = wrapper::vload(in + k + 8 * Nx);
611
612 // Base-case prime transform
613 fft_5(a, b, c, d, e, w, w2, w3, w4);
614
615 // Store outputs
616 if(first_stage)
617 {
618 wrapper::vstore(out + k, wrapper::vcombine(a, b));
619 wrapper::vstore(out + k + 4 * Nx, wrapper::vcombine(c, d));
620 }
621 else
622 {
623 wrapper::vstore(out + k, a);
624 wrapper::vstore(out + k + 2 * Nx, b);
625 wrapper::vstore(out + k + 4 * Nx, c);
626 wrapper::vstore(out + k + 6 * Nx, d);
627 }
628 wrapper::vstore(out + k + 8 * Nx, e);
629 }
630
631 w = c_mul_neon(w, w_m);
632 }
633 }
634
fft_radix_5_axes_1(float * out,float * in,unsigned int Nx,unsigned int NxRadix,const float32x2_t & w_m,unsigned int N,unsigned int M,unsigned int in_pad_x,unsigned int out_pad_x)635 void fft_radix_5_axes_1(float *out, float *in, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int N, unsigned int M, unsigned int in_pad_x, unsigned int out_pad_x)
636 {
637 float32x2_t w{ 1.0f, 0.0f };
638 for(unsigned int j = 0; j < Nx; j++)
639 {
640 const float32x2_t w2 = c_mul_neon(w, w);
641 const float32x2_t w3 = c_mul_neon(w2, w);
642 const float32x2_t w4 = c_mul_neon(w3, w);
643
644 for(unsigned int k = 2 * j; k < 2 * M; k += 2 * NxRadix)
645 {
646 // Load inputs
647 float32x2_t a = wrapper::vload(in + (N + in_pad_x) * k);
648 float32x2_t b = wrapper::vload(in + (N + in_pad_x) * (k + 2 * Nx));
649 float32x2_t c = wrapper::vload(in + (N + in_pad_x) * (k + 4 * Nx));
650 float32x2_t d = wrapper::vload(in + (N + in_pad_x) * (k + 6 * Nx));
651 float32x2_t e = wrapper::vload(in + (N + in_pad_x) * (k + 8 * Nx));
652
653 // Base-case prime transform
654 fft_5(a, b, c, d, e, w, w2, w3, w4);
655
656 // Store outputs
657 wrapper::vstore(out + (N + out_pad_x) * k, a);
658 wrapper::vstore(out + (N + out_pad_x) * (k + 2 * Nx), b);
659 wrapper::vstore(out + (N + out_pad_x) * (k + 4 * Nx), c);
660 wrapper::vstore(out + (N + out_pad_x) * (k + 6 * Nx), d);
661 wrapper::vstore(out + (N + out_pad_x) * (k + 8 * Nx), e);
662 }
663
664 w = c_mul_neon(w, w_m);
665 }
666 }
667
668 template <bool first_stage>
fft_radix_7_axes_0(float * out,float * in,unsigned int Nx,unsigned int NxRadix,const float32x2_t & w_m,unsigned int N)669 void fft_radix_7_axes_0(float *out, float *in, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int N)
670 {
671 float32x2_t w{ 1.0f, 0.0f };
672 for(unsigned int j = 0; j < Nx; j++)
673 {
674 const float32x2_t w2 = c_mul_neon(w, w);
675 const float32x2_t w3 = c_mul_neon(w2, w);
676 const float32x2_t w4 = c_mul_neon(w3, w);
677 const float32x2_t w5 = c_mul_neon(w4, w);
678 const float32x2_t w6 = c_mul_neon(w5, w);
679
680 for(unsigned int k = 2 * j; k < 2 * N; k += 2 * NxRadix)
681 {
682 float32x2_t a = { 0, 0 };
683 float32x2_t b = { 0, 0 };
684 float32x2_t c = { 0, 0 };
685 float32x2_t d = { 0, 0 };
686 float32x2_t e = { 0, 0 };
687 float32x2_t f = { 0, 0 };
688 float32x2_t g = { 0, 0 };
689
690 // Load inputs
691 if(first_stage)
692 {
693 const auto ab = wrapper::vloadq(in + k);
694 const auto cd = wrapper::vloadq(in + k + 4 * Nx);
695 const auto ef = wrapper::vloadq(in + k + 8 * Nx);
696
697 a = wrapper::vgetlow(ab);
698 b = wrapper::vgethigh(ab);
699 c = wrapper::vgetlow(cd);
700 d = wrapper::vgethigh(cd);
701 e = wrapper::vgetlow(ef);
702 f = wrapper::vgethigh(ef);
703 }
704 else
705 {
706 a = wrapper::vload(in + k);
707 b = wrapper::vload(in + k + 2 * Nx);
708 c = wrapper::vload(in + k + 4 * Nx);
709 d = wrapper::vload(in + k + 6 * Nx);
710 e = wrapper::vload(in + k + 8 * Nx);
711 f = wrapper::vload(in + k + 10 * Nx);
712 }
713 g = wrapper::vload(in + k + 12 * Nx);
714
715 // Base-case prime transform
716 fft_7(a, b, c, d, e, f, g, w, w2, w3, w4, w5, w6);
717
718 if(first_stage)
719 {
720 wrapper::vstore(out + k, wrapper::vcombine(a, b));
721 wrapper::vstore(out + k + 4 * Nx, wrapper::vcombine(c, d));
722 wrapper::vstore(out + k + 8 * Nx, wrapper::vcombine(e, f));
723 }
724 else
725 {
726 wrapper::vstore(out + k, a);
727 wrapper::vstore(out + k + 2 * Nx, b);
728 wrapper::vstore(out + k + 4 * Nx, c);
729 wrapper::vstore(out + k + 6 * Nx, d);
730 wrapper::vstore(out + k + 8 * Nx, e);
731 wrapper::vstore(out + k + 10 * Nx, f);
732 }
733 wrapper::vstore(out + k + 12 * Nx, g);
734 }
735
736 w = c_mul_neon(w, w_m);
737 }
738 }
739
fft_radix_7_axes_1(float * out,float * in,unsigned int Nx,unsigned int NxRadix,const float32x2_t & w_m,unsigned int N,unsigned int M,unsigned int in_pad_x,unsigned int out_pad_x)740 void fft_radix_7_axes_1(float *out, float *in, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int N, unsigned int M, unsigned int in_pad_x, unsigned int out_pad_x)
741 {
742 float32x2_t w{ 1.0f, 0.0f };
743 for(unsigned int j = 0; j < Nx; j++)
744 {
745 const float32x2_t w2 = c_mul_neon(w, w);
746 const float32x2_t w3 = c_mul_neon(w2, w);
747 const float32x2_t w4 = c_mul_neon(w3, w);
748 const float32x2_t w5 = c_mul_neon(w4, w);
749 const float32x2_t w6 = c_mul_neon(w5, w);
750
751 for(unsigned int k = 2 * j; k < 2 * M; k += 2 * NxRadix)
752 {
753 // Load inputs
754 float32x2_t a = wrapper::vload(in + (N + in_pad_x) * k);
755 float32x2_t b = wrapper::vload(in + (N + in_pad_x) * (k + 2 * Nx));
756 float32x2_t c = wrapper::vload(in + (N + in_pad_x) * (k + 4 * Nx));
757 float32x2_t d = wrapper::vload(in + (N + in_pad_x) * (k + 6 * Nx));
758 float32x2_t e = wrapper::vload(in + (N + in_pad_x) * (k + 8 * Nx));
759 float32x2_t f = wrapper::vload(in + (N + in_pad_x) * (k + 10 * Nx));
760 float32x2_t g = wrapper::vload(in + (N + in_pad_x) * (k + 12 * Nx));
761
762 // Base-case prime transform
763 fft_7(a, b, c, d, e, f, g, w, w2, w3, w4, w5, w6);
764
765 // Store outputs
766 wrapper::vstore(out + (N + out_pad_x) * k, a);
767 wrapper::vstore(out + (N + out_pad_x) * (k + 2 * Nx), b);
768 wrapper::vstore(out + (N + out_pad_x) * (k + 4 * Nx), c);
769 wrapper::vstore(out + (N + out_pad_x) * (k + 6 * Nx), d);
770 wrapper::vstore(out + (N + out_pad_x) * (k + 8 * Nx), e);
771 wrapper::vstore(out + (N + out_pad_x) * (k + 10 * Nx), f);
772 wrapper::vstore(out + (N + out_pad_x) * (k + 12 * Nx), g);
773 }
774
775 w = c_mul_neon(w, w_m);
776 }
777 }
778
779 template <bool first_stage>
fft_radix_8_axes_0(float * out,float * in,unsigned int Nx,unsigned int NxRadix,const float32x2_t & w_m,unsigned int N)780 void fft_radix_8_axes_0(float *out, float *in, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int N)
781 {
782 float32x2_t w{ 1.0f, 0.0f };
783 for(unsigned int j = 0; j < Nx; j++)
784 {
785 const float32x2_t w2 = c_mul_neon(w, w);
786 const float32x2_t w3 = c_mul_neon(w2, w);
787 const float32x2_t w4 = c_mul_neon(w3, w);
788 const float32x2_t w5 = c_mul_neon(w4, w);
789 const float32x2_t w6 = c_mul_neon(w5, w);
790 const float32x2_t w7 = c_mul_neon(w6, w);
791
792 for(unsigned int k = 2 * j; k < 2 * N; k += 2 * NxRadix)
793 {
794 // Load inputs
795 float32x2_t a = { 0, 0 };
796 float32x2_t b = { 0, 0 };
797 float32x2_t c = { 0, 0 };
798 float32x2_t d = { 0, 0 };
799 float32x2_t e = { 0, 0 };
800 float32x2_t f = { 0, 0 };
801 float32x2_t g = { 0, 0 };
802 float32x2_t h = { 0, 0 };
803
804 // Base-case prime transform
805 if(first_stage)
806 {
807 const auto ab = wrapper::vloadq(in + k);
808 const auto cd = wrapper::vloadq(in + k + 4 * Nx);
809 const auto ef = wrapper::vloadq(in + k + 8 * Nx);
810 const auto gh = wrapper::vloadq(in + k + 12 * Nx);
811
812 a = wrapper::vgetlow(ab);
813 b = wrapper::vgethigh(ab);
814 c = wrapper::vgetlow(cd);
815 d = wrapper::vgethigh(cd);
816 e = wrapper::vgetlow(ef);
817 f = wrapper::vgethigh(ef);
818 g = wrapper::vgetlow(gh);
819 h = wrapper::vgethigh(gh);
820 }
821 else
822 {
823 a = wrapper::vload(in + k);
824 b = wrapper::vload(in + k + 2 * Nx);
825 c = wrapper::vload(in + k + 4 * Nx);
826 d = wrapper::vload(in + k + 6 * Nx);
827 e = wrapper::vload(in + k + 8 * Nx);
828 f = wrapper::vload(in + k + 10 * Nx);
829 g = wrapper::vload(in + k + 12 * Nx);
830 h = wrapper::vload(in + k + 14 * Nx);
831 }
832
833 // Apply twiddle factors
834 fft_8(a, b, c, d, e, f, g, h, w, w2, w3, w4, w5, w6, w7);
835
836 // Store outputs
837 if(first_stage)
838 {
839 wrapper::vstore(out + k, wrapper::vcombine(a, b));
840 wrapper::vstore(out + k + 4 * Nx, wrapper::vcombine(c, d));
841 wrapper::vstore(out + k + 8 * Nx, wrapper::vcombine(e, f));
842 wrapper::vstore(out + k + 12 * Nx, wrapper::vcombine(g, h));
843 }
844 else
845 {
846 wrapper::vstore(out + k, a);
847 wrapper::vstore(out + k + 2 * Nx, b);
848 wrapper::vstore(out + k + 4 * Nx, c);
849 wrapper::vstore(out + k + 6 * Nx, d);
850 wrapper::vstore(out + k + 8 * Nx, e);
851 wrapper::vstore(out + k + 10 * Nx, f);
852 wrapper::vstore(out + k + 12 * Nx, g);
853 wrapper::vstore(out + k + 14 * Nx, h);
854 }
855 }
856
857 w = c_mul_neon(w, w_m);
858 }
859 }
860
fft_radix_8_axes_1(float * out,float * in,unsigned int Nx,unsigned int NxRadix,const float32x2_t & w_m,unsigned int N,unsigned int M,unsigned int in_pad_x,unsigned int out_pad_x)861 void fft_radix_8_axes_1(float *out, float *in, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int N, unsigned int M, unsigned int in_pad_x, unsigned int out_pad_x)
862 {
863 float32x2_t w{ 1.0f, 0.0f };
864 for(unsigned int j = 0; j < Nx; j++)
865 {
866 const float32x2_t w2 = c_mul_neon(w, w);
867 const float32x2_t w3 = c_mul_neon(w2, w);
868 const float32x2_t w4 = c_mul_neon(w3, w);
869 const float32x2_t w5 = c_mul_neon(w4, w);
870 const float32x2_t w6 = c_mul_neon(w5, w);
871 const float32x2_t w7 = c_mul_neon(w6, w);
872
873 for(unsigned int k = 2 * j; k < 2 * M; k += 2 * NxRadix)
874 {
875 // Load inputs
876 float32x2_t a = wrapper::vload(in + (N + in_pad_x) * k);
877 float32x2_t b = wrapper::vload(in + (N + in_pad_x) * (k + 2 * Nx));
878 float32x2_t c = wrapper::vload(in + (N + in_pad_x) * (k + 4 * Nx));
879 float32x2_t d = wrapper::vload(in + (N + in_pad_x) * (k + 6 * Nx));
880 float32x2_t e = wrapper::vload(in + (N + in_pad_x) * (k + 8 * Nx));
881 float32x2_t f = wrapper::vload(in + (N + in_pad_x) * (k + 10 * Nx));
882 float32x2_t g = wrapper::vload(in + (N + in_pad_x) * (k + 12 * Nx));
883 float32x2_t h = wrapper::vload(in + (N + in_pad_x) * (k + 14 * Nx));
884
885 // Base-case prime transform
886 fft_8(a, b, c, d, e, f, g, h, w, w2, w3, w4, w5, w6, w7);
887
888 // Store outputs
889 wrapper::vstore(out + (N + out_pad_x) * k, a);
890 wrapper::vstore(out + (N + out_pad_x) * (k + 2 * Nx), b);
891 wrapper::vstore(out + (N + out_pad_x) * (k + 4 * Nx), c);
892 wrapper::vstore(out + (N + out_pad_x) * (k + 6 * Nx), d);
893 wrapper::vstore(out + (N + out_pad_x) * (k + 8 * Nx), e);
894 wrapper::vstore(out + (N + out_pad_x) * (k + 10 * Nx), f);
895 wrapper::vstore(out + (N + out_pad_x) * (k + 12 * Nx), g);
896 wrapper::vstore(out + (N + out_pad_x) * (k + 14 * Nx), h);
897 }
898
899 w = c_mul_neon(w, w_m);
900 }
901 }
902
validate_arguments(const ITensorInfo * input,const ITensorInfo * output,const FFTRadixStageKernelInfo & config)903 Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, const FFTRadixStageKernelInfo &config)
904 {
905 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 2, DataType::F32);
906 ARM_COMPUTE_RETURN_ERROR_ON(config.axis > 1);
907 ARM_COMPUTE_RETURN_ERROR_ON(NEFFTRadixStageKernel::supported_radix().count(config.radix) == 0);
908 ARM_COMPUTE_UNUSED(config);
909
910 // Checks performed when output is configured
911 if((output != nullptr) && (output->total_size() != 0))
912 {
913 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input, output);
914 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
915 }
916
917 return Status{};
918 }
919
validate_and_configure_window(ITensorInfo * input,ITensorInfo * output,const FFTRadixStageKernelInfo & config)920 std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input, ITensorInfo *output, const FFTRadixStageKernelInfo &config)
921 {
922 ARM_COMPUTE_UNUSED(config);
923
924 if(output != nullptr)
925 {
926 auto_init_if_empty(*output, *input);
927 }
928
929 Window win = calculate_max_window(*input, Steps());
930
931 return std::make_pair(Status{}, win);
932 }
933 } // namespace
934
NEFFTRadixStageKernel()935 NEFFTRadixStageKernel::NEFFTRadixStageKernel()
936 : _input(nullptr), _output(nullptr), _Nx(0), _axis(0), _radix(0), _func_0(), _func_1()
937 {
938 }
939
set_radix_stage_axis0(const FFTRadixStageKernelInfo & config)940 void NEFFTRadixStageKernel::set_radix_stage_axis0(const FFTRadixStageKernelInfo &config)
941 {
942 // FFT table axis 0: [radix, first_stage]
943 static std::map<unsigned int, std::map<bool, FFTFunctionPointerAxis0>> fft_table_axis0;
944
945 if(fft_table_axis0.empty())
946 {
947 fft_table_axis0[2][false] = &fft_radix_2_axes_0<false>;
948 fft_table_axis0[3][false] = &fft_radix_3_axes_0<false>;
949 fft_table_axis0[4][false] = &fft_radix_4_axes_0<false>;
950 fft_table_axis0[5][false] = &fft_radix_5_axes_0<false>;
951 fft_table_axis0[7][false] = &fft_radix_7_axes_0<false>;
952 fft_table_axis0[8][false] = &fft_radix_8_axes_0<false>;
953
954 fft_table_axis0[2][true] = &fft_radix_2_axes_0<true>;
955 fft_table_axis0[3][true] = &fft_radix_3_axes_0<true>;
956 fft_table_axis0[4][true] = &fft_radix_4_axes_0<true>;
957 fft_table_axis0[5][true] = &fft_radix_5_axes_0<true>;
958 fft_table_axis0[7][true] = &fft_radix_7_axes_0<true>;
959 fft_table_axis0[8][true] = &fft_radix_8_axes_0<true>;
960 }
961
962 _func_0 = fft_table_axis0[config.radix][config.is_first_stage];
963 }
964
set_radix_stage_axis1(const FFTRadixStageKernelInfo & config)965 void NEFFTRadixStageKernel::set_radix_stage_axis1(const FFTRadixStageKernelInfo &config)
966 {
967 // FFT table axis 1: [radix, first_stage]
968 static std::map<unsigned int, FFTFunctionPointerAxis1> fft_table_axis1;
969
970 if(fft_table_axis1.empty())
971 {
972 fft_table_axis1[2] = &fft_radix_2_axes_1;
973 fft_table_axis1[3] = &fft_radix_3_axes_1;
974 fft_table_axis1[4] = &fft_radix_4_axes_1;
975 fft_table_axis1[5] = &fft_radix_5_axes_1;
976 fft_table_axis1[7] = &fft_radix_7_axes_1;
977 fft_table_axis1[8] = &fft_radix_8_axes_1;
978 }
979
980 _func_1 = fft_table_axis1[config.radix];
981 }
982
configure(ITensor * input,ITensor * output,const FFTRadixStageKernelInfo & config)983 void NEFFTRadixStageKernel::configure(ITensor *input, ITensor *output, const FFTRadixStageKernelInfo &config)
984 {
985 ARM_COMPUTE_ERROR_ON_NULLPTR(input);
986
987 // Output auto inizialitation if not yet initialized
988 if(output != nullptr)
989 {
990 auto_init_if_empty(*output->info(), *input->info()->clone());
991 }
992
993 ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), (output != nullptr) ? output->info() : nullptr, config));
994
995 _input = input;
996 _output = (output == nullptr) ? input : output;
997 _Nx = config.Nx;
998 _axis = config.axis;
999 _radix = config.radix;
1000
1001 switch(config.axis)
1002 {
1003 case 0:
1004 set_radix_stage_axis0(config);
1005 break;
1006 case 1:
1007 set_radix_stage_axis1(config);
1008 break;
1009 default:
1010 ARM_COMPUTE_ERROR("Axis not supported");
1011 break;
1012 }
1013
1014 // Configure kernel window
1015 auto win_config = validate_and_configure_window(input->info(), (output != nullptr) ? output->info() : nullptr, config);
1016 ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
1017 INEKernel::configure(win_config.second);
1018 }
1019
validate(const ITensorInfo * input,const ITensorInfo * output,const FFTRadixStageKernelInfo & config)1020 Status NEFFTRadixStageKernel::validate(const ITensorInfo *input, const ITensorInfo *output, const FFTRadixStageKernelInfo &config)
1021 {
1022 const bool run_in_place = (output == nullptr) || (output == input);
1023 ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, output, config));
1024 ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(input->clone().get(),
1025 (run_in_place) ? nullptr : output->clone().get(),
1026 config)
1027 .first);
1028
1029 return Status{};
1030 }
1031
supported_radix()1032 std::set<unsigned int> NEFFTRadixStageKernel::supported_radix()
1033 {
1034 return std::set<unsigned int> { 2, 3, 4, 5, 7, 8 };
1035 }
1036
run(const Window & window,const ThreadInfo & info)1037 void NEFFTRadixStageKernel::run(const Window &window, const ThreadInfo &info)
1038 {
1039 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
1040 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
1041 ARM_COMPUTE_UNUSED(info);
1042
1043 Window input_window = window;
1044 input_window.set(_axis, 0);
1045
1046 Iterator in(_input, input_window);
1047 Iterator out(_output, input_window);
1048
1049 // Precompute FFT constants
1050 const unsigned int NxRadix = _radix * _Nx;
1051 const float alpha = 2.0f * kPi / float(NxRadix);
1052 const float32x2_t w_m{ cosf(alpha), -sinf(alpha) };
1053
1054 if(_axis == 0)
1055 {
1056 const unsigned int N = _input->info()->dimension(0);
1057 execute_window_loop(input_window, [&](const Coordinates &)
1058 {
1059 _func_0(reinterpret_cast<float *>(out.ptr()), reinterpret_cast<float *>(in.ptr()), _Nx, NxRadix, w_m, N);
1060 },
1061 in, out);
1062 }
1063 else
1064 {
1065 const unsigned int N = _input->info()->dimension(0);
1066 const unsigned int M = _input->info()->dimension(1);
1067 execute_window_loop(input_window, [&](const Coordinates &)
1068 {
1069 _func_1(reinterpret_cast<float *>(out.ptr()), reinterpret_cast<float *>(in.ptr()), _Nx, NxRadix, w_m, N, M,
1070 _input->info()->padding().right + _input->info()->padding().left,
1071 _output->info()->padding().right + _output->info()->padding().left);
1072 },
1073 in, out);
1074 }
1075
1076 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
1077 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
1078 }
1079 } // namespace arm_compute
1080