1 /*
2  * Copyright (c) 2019-2021 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #include "src/core/NEON/kernels/NEFFTRadixStageKernel.h"
25 
26 #include "arm_compute/core/ITensor.h"
27 #include "arm_compute/core/TensorInfo.h"
28 #include "arm_compute/core/Types.h"
29 #include "arm_compute/core/Utils.h"
30 #include "arm_compute/core/Window.h"
31 #include "src/core/NEON/wrapper/traits.h"
32 #include "src/core/NEON/wrapper/wrapper.h"
33 #include "src/core/helpers/AutoConfiguration.h"
34 #include "src/core/helpers/WindowHelpers.h"
35 #include "support/ToolchainSupport.h"
36 
37 #include <arm_neon.h>
38 #include <cmath>
39 #include <complex>
40 #include <map>
41 
42 namespace arm_compute
43 {
44 namespace
45 {
46 // PI constant (from cmath)
47 constexpr float kPi = float(M_PI);
48 
49 // Constant used in the fft_3 kernel
50 constexpr float kSqrt3Div2 = 0.866025403784438;
51 
52 // Constants used in the fft_5 kernel
53 constexpr float kW5_0 = 0.30901699437494f;
54 constexpr float kW5_1 = 0.95105651629515f;
55 constexpr float kW5_2 = 0.80901699437494f;
56 constexpr float kW5_3 = 0.58778525229247f;
57 
58 // Constants used in the fft_7 kernel
59 constexpr float kW7_0 = 0.62348980185873f;
60 constexpr float kW7_1 = 0.78183148246802f;
61 constexpr float kW7_2 = 0.22252093395631f;
62 constexpr float kW7_3 = 0.97492791218182f;
63 constexpr float kW7_4 = 0.90096886790241f;
64 constexpr float kW7_5 = 0.43388373911755f;
65 
66 // Constant used in the fft_8 kernel
67 constexpr float kSqrt2Div2 = 0.707106781186548;
68 
c_mul_neon(float32x2_t a,float32x2_t b)69 float32x2_t c_mul_neon(float32x2_t a, float32x2_t b)
70 {
71     using ExactTagType = typename wrapper::traits::neon_vector<float, 2>::tag_type;
72 
73     const float32x2_t mask = { -1.0, 1.0 };
74     const float32x2_t tmp0 = wrapper::vdup_n(wrapper::vgetlane(a, 0), ExactTagType{});
75     const float32x2_t tmp1 = wrapper::vdup_n(wrapper::vgetlane(a, 1), ExactTagType{});
76 
77     float32x2_t res = wrapper::vmul(tmp0, b);
78 
79     b   = wrapper::vrev64(b);
80     b   = wrapper::vmul(b, mask);
81     res = wrapper::vmla(res, tmp1, b);
82 
83     return res;
84 }
85 
c_mul_neon_img(float32x2_t a,float img_constant)86 float32x2_t c_mul_neon_img(float32x2_t a, float img_constant)
87 {
88     const float a_r = wrapper::vgetlane(a, 0);
89     const float a_i = wrapper::vgetlane(a, 1);
90 
91     const auto out = wrapper::vmul(float32x2_t{ -a_i, a_r }, float32x2_t{ img_constant, img_constant });
92     return out;
93 }
94 
reduce_sum_5(float32x2_t a,float32x2_t b,float32x2_t c,float32x2_t d,float32x2_t e)95 float32x2_t reduce_sum_5(float32x2_t a, float32x2_t b, float32x2_t c, float32x2_t d, float32x2_t e)
96 {
97     const auto t0 = wrapper::vadd(a, b);
98     const auto t1 = wrapper::vadd(c, d);
99     const auto t2 = wrapper::vadd(t0, t1);
100     return wrapper::vadd(t2, e);
101 }
102 
reduce_sum_7(float32x2_t x1,float32x2_t x2,float32x2_t x3,float32x2_t x4,float32x2_t x5,float32x2_t x6,float32x2_t x7)103 float32x2_t reduce_sum_7(float32x2_t x1, float32x2_t x2, float32x2_t x3, float32x2_t x4, float32x2_t x5, float32x2_t x6, float32x2_t x7)
104 {
105     const auto t0  = wrapper::vadd(x1, x2);
106     const auto t1  = wrapper::vadd(x3, x4);
107     const auto t2  = wrapper::vadd(x5, x6);
108     const auto t00 = wrapper::vadd(t0, t1);
109     const auto t01 = wrapper::vadd(t2, x7);
110 
111     return wrapper::vadd(t00, t01);
112 }
113 
reduce_sum_8(float32x2_t x1,float32x2_t x2,float32x2_t x3,float32x2_t x4,float32x2_t x5,float32x2_t x6,float32x2_t x7,float32x2_t x8)114 float32x2_t reduce_sum_8(float32x2_t x1, float32x2_t x2, float32x2_t x3, float32x2_t x4, float32x2_t x5, float32x2_t x6, float32x2_t x7, float32x2_t x8)
115 {
116     const auto t0  = wrapper::vadd(x1, x2);
117     const auto t1  = wrapper::vadd(x3, x4);
118     const auto t2  = wrapper::vadd(x5, x6);
119     const auto t3  = wrapper::vadd(x7, x8);
120     const auto t00 = wrapper::vadd(t0, t1);
121     const auto t01 = wrapper::vadd(t2, t3);
122 
123     return wrapper::vadd(t00, t01);
124 }
125 
fft_2(float32x2_t & x,float32x2_t & y,float32x2_t & w)126 void fft_2(float32x2_t &x, float32x2_t &y, float32x2_t &w)
127 {
128     float32x2_t a = x;
129     float32x2_t b = c_mul_neon(w, y);
130 
131     x = wrapper::vadd(a, b);
132     y = wrapper::vsub(a, b);
133 }
134 
fft_3(float32x2_t & x,float32x2_t & y,float32x2_t & z,const float32x2_t & w,const float32x2_t & w2)135 void fft_3(float32x2_t &x, float32x2_t &y, float32x2_t &z, const float32x2_t &w, const float32x2_t &w2)
136 {
137     float32x2_t a = x;
138     float32x2_t b = c_mul_neon(w, y);
139     float32x2_t c = c_mul_neon(w2, z);
140 
141     x = wrapper::vadd(a, b);
142     x = wrapper::vadd(x, c);
143 
144     const auto v1 = wrapper::vmul(float32x2_t{ 0.5f, 0.5 }, wrapper::vadd(b, c));
145     const auto v2 = c_mul_neon(float32x2_t{ 0.f, -kSqrt3Div2 }, wrapper::vsub(b, c));
146 
147     y = z = wrapper::vsub(a, v1);
148     y     = wrapper::vadd(y, v2);
149     z     = wrapper::vsub(z, v2);
150 }
151 
fft_4(float32x2_t & x1,float32x2_t & x2,float32x2_t & x3,float32x2_t & x4,const float32x2_t & w,const float32x2_t & w2,const float32x2_t & w3)152 void fft_4(float32x2_t &x1, float32x2_t &x2, float32x2_t &x3, float32x2_t &x4, const float32x2_t &w, const float32x2_t &w2, const float32x2_t &w3)
153 {
154     float32x2_t a = x1;
155     float32x2_t b = c_mul_neon(w, x2);
156     float32x2_t c = c_mul_neon(w2, x3);
157     float32x2_t d = c_mul_neon(w3, x4);
158 
159     const auto x11 = wrapper::vadd(a, b);
160     const auto x12 = wrapper::vadd(c, d);
161     x1             = wrapper::vadd(x11, x12);
162 
163     const auto x21 = wrapper::vadd(a, c_mul_neon_img(b, -1));
164     const auto x22 = wrapper::vadd(wrapper::vneg(c), c_mul_neon_img(d, 1.f));
165     x2             = wrapper::vadd(x21, x22);
166 
167     const auto x31 = wrapper::vadd(a, wrapper::vneg(b));
168     const auto x32 = wrapper::vadd(c, wrapper::vneg(d));
169     x3             = wrapper::vadd(x31, x32);
170 
171     const auto x41 = wrapper::vadd(a, c_mul_neon_img(b, 1));
172     const auto x42 = wrapper::vadd(wrapper::vneg(c), c_mul_neon_img(d, -1));
173     x4             = wrapper::vadd(x41, x42);
174 }
175 
fft_5(float32x2_t & x1,float32x2_t & x2,float32x2_t & x3,float32x2_t & x4,float32x2_t & x5,const float32x2_t & w,const float32x2_t & w2,const float32x2_t & w3,const float32x2_t & w4)176 void fft_5(float32x2_t &x1, float32x2_t &x2, float32x2_t &x3, float32x2_t &x4, float32x2_t &x5, const float32x2_t &w, const float32x2_t &w2, const float32x2_t &w3, const float32x2_t &w4)
177 {
178     const auto a = x1;
179     const auto b = c_mul_neon(w, x2);
180     const auto c = c_mul_neon(w2, x3);
181     const auto d = c_mul_neon(w3, x4);
182     const auto e = c_mul_neon(w4, x5);
183 
184     const auto b0 = c_mul_neon(float32x2_t{ kW5_0, -kW5_1 }, b);
185     const auto b1 = c_mul_neon(float32x2_t{ -kW5_2, -kW5_3 }, b);
186     const auto b2 = c_mul_neon(float32x2_t{ -kW5_2, kW5_3 }, b);
187     const auto b3 = c_mul_neon(float32x2_t{ kW5_0, kW5_1 }, b);
188 
189     const auto c0 = c_mul_neon(float32x2_t{ -kW5_2, -kW5_3 }, c);
190     const auto c1 = c_mul_neon(float32x2_t{ kW5_0, kW5_1 }, c);
191     const auto c2 = c_mul_neon(float32x2_t{ kW5_0, -kW5_1 }, c);
192     const auto c3 = c_mul_neon(float32x2_t{ -kW5_2, kW5_3 }, c);
193 
194     const auto d0 = c_mul_neon(float32x2_t{ -kW5_2, kW5_3 }, d);
195     const auto d1 = c_mul_neon(float32x2_t{ kW5_0, -kW5_1 }, d);
196     const auto d2 = c_mul_neon(float32x2_t{ kW5_0, kW5_1 }, d);
197     const auto d3 = c_mul_neon(float32x2_t{ -kW5_2, -kW5_3 }, d);
198 
199     const auto e0 = c_mul_neon(float32x2_t{ kW5_0, kW5_1 }, e);
200     const auto e1 = c_mul_neon(float32x2_t{ -kW5_2, kW5_3 }, e);
201     const auto e2 = c_mul_neon(float32x2_t{ -kW5_2, -kW5_3 }, e);
202     const auto e3 = c_mul_neon(float32x2_t{ kW5_0, -kW5_1 }, e);
203 
204     x1 = reduce_sum_5(a, b, c, d, e);
205     x2 = reduce_sum_5(a, b0, c0, d0, e0);
206     x3 = reduce_sum_5(a, b1, c1, d1, e1);
207     x4 = reduce_sum_5(a, b2, c2, d2, e2);
208     x5 = reduce_sum_5(a, b3, c3, d3, e3);
209 }
210 
fft_7(float32x2_t & x1,float32x2_t & x2,float32x2_t & x3,float32x2_t & x4,float32x2_t & x5,float32x2_t & x6,float32x2_t & x7,const float32x2_t & w,const float32x2_t & w2,const float32x2_t & w3,const float32x2_t & w4,const float32x2_t & w5,const float32x2_t & w6)211 void fft_7(float32x2_t &x1, float32x2_t &x2, float32x2_t &x3, float32x2_t &x4, float32x2_t &x5, float32x2_t &x6, float32x2_t &x7, const float32x2_t &w, const float32x2_t &w2, const float32x2_t &w3,
212            const float32x2_t &w4,
213            const float32x2_t &w5, const float32x2_t &w6)
214 {
215     const auto a = x1;
216     const auto b = c_mul_neon(w, x2);
217     const auto c = c_mul_neon(w2, x3);
218     const auto d = c_mul_neon(w3, x4);
219     const auto e = c_mul_neon(w4, x5);
220     const auto f = c_mul_neon(w5, x6);
221     const auto g = c_mul_neon(w6, x7);
222 
223     const auto b0 = c_mul_neon(float32x2_t{ kW7_0, -kW7_1 }, b);
224     const auto b1 = c_mul_neon(float32x2_t{ -kW7_2, -kW7_3 }, b);
225     const auto b2 = c_mul_neon(float32x2_t{ -kW7_4, -kW7_5 }, b);
226     const auto b3 = c_mul_neon(float32x2_t{ -kW7_4, kW7_5 }, b);
227     const auto b4 = c_mul_neon(float32x2_t{ -kW7_2, kW7_3 }, b);
228     const auto b5 = c_mul_neon(float32x2_t{ kW7_0, kW7_1 }, b);
229 
230     const auto c0 = c_mul_neon(float32x2_t{ -kW7_2, -kW7_3 }, c);
231     const auto c1 = c_mul_neon(float32x2_t{ -kW7_4, kW7_5 }, c);
232     const auto c2 = c_mul_neon(float32x2_t{ kW7_0, kW7_1 }, c);
233     const auto c3 = c_mul_neon(float32x2_t{ kW7_0, -kW7_1 }, c);
234     const auto c4 = c_mul_neon(float32x2_t{ -kW7_4, -kW7_5 }, c);
235     const auto c5 = c_mul_neon(float32x2_t{ -kW7_2, kW7_3 }, c);
236 
237     const auto d0 = c_mul_neon(float32x2_t{ -kW7_4, -kW7_5 }, d);
238     const auto d1 = c_mul_neon(float32x2_t{ kW7_0, kW7_1 }, d);
239     const auto d2 = c_mul_neon(float32x2_t{ -kW7_2, -kW7_3 }, d);
240     const auto d3 = c_mul_neon(float32x2_t{ -kW7_2, +kW7_3 }, d);
241     const auto d4 = c_mul_neon(float32x2_t{ kW7_0, -kW7_1 }, d);
242     const auto d5 = c_mul_neon(float32x2_t{ -kW7_4, kW7_5 }, d);
243 
244     const auto e0 = c_mul_neon(float32x2_t{ -kW7_4, kW7_5 }, e);
245     const auto e1 = c_mul_neon(float32x2_t{ kW7_0, -kW7_1 }, e);
246     const auto e2 = c_mul_neon(float32x2_t{ -kW7_2, kW7_3 }, e);
247     const auto e3 = c_mul_neon(float32x2_t{ -kW7_2, -kW7_3 }, e);
248     const auto e4 = c_mul_neon(float32x2_t{ kW7_0, kW7_1 }, e);
249     const auto e5 = c_mul_neon(float32x2_t{ -kW7_4, -kW7_5 }, e);
250 
251     const auto f0 = c_mul_neon(float32x2_t{ -kW7_2, kW7_3 }, f);
252     const auto f1 = c_mul_neon(float32x2_t{ -kW7_4, -kW7_5 }, f);
253     const auto f2 = c_mul_neon(float32x2_t{ kW7_0, -kW7_1 }, f);
254     const auto f3 = c_mul_neon(float32x2_t{ kW7_0, kW7_1 }, f);
255     const auto f4 = c_mul_neon(float32x2_t{ -kW7_4, kW7_5 }, f);
256     const auto f5 = c_mul_neon(float32x2_t{ -kW7_2, -kW7_3 }, f);
257 
258     const auto g0 = c_mul_neon(float32x2_t{ kW7_0, kW7_1 }, g);
259     const auto g1 = c_mul_neon(float32x2_t{ -kW7_2, kW7_3 }, g);
260     const auto g2 = c_mul_neon(float32x2_t{ -kW7_4, kW7_5 }, g);
261     const auto g3 = c_mul_neon(float32x2_t{ -kW7_4, -kW7_5 }, g);
262     const auto g4 = c_mul_neon(float32x2_t{ -kW7_2, -kW7_3 }, g);
263     const auto g5 = c_mul_neon(float32x2_t{ kW7_0, -kW7_1 }, g);
264 
265     x1 = reduce_sum_7(a, b, c, d, e, f, g);
266     x2 = reduce_sum_7(a, b0, c0, d0, e0, f0, g0);
267     x3 = reduce_sum_7(a, b1, c1, d1, e1, f1, g1);
268     x4 = reduce_sum_7(a, b2, c2, d2, e2, f2, g2);
269     x5 = reduce_sum_7(a, b3, c3, d3, e3, f3, g3);
270     x6 = reduce_sum_7(a, b4, c4, d4, e4, f4, g4);
271     x7 = reduce_sum_7(a, b5, c5, d5, e5, f5, g5);
272 }
273 
fft_8(float32x2_t & x1,float32x2_t & x2,float32x2_t & x3,float32x2_t & x4,float32x2_t & x5,float32x2_t & x6,float32x2_t & x7,float32x2_t & x8,const float32x2_t & w,const float32x2_t & w2,const float32x2_t & w3,const float32x2_t & w4,const float32x2_t & w5,const float32x2_t & w6,const float32x2_t & w7)274 void fft_8(float32x2_t &x1, float32x2_t &x2, float32x2_t &x3, float32x2_t &x4, float32x2_t &x5, float32x2_t &x6, float32x2_t &x7, float32x2_t &x8, const float32x2_t &w, const float32x2_t &w2,
275            const float32x2_t &w3,
276            const float32x2_t &w4, const float32x2_t &w5, const float32x2_t &w6,
277            const float32x2_t &w7)
278 {
279     const auto a = x1;
280     const auto b = c_mul_neon(w, x2);
281     const auto c = c_mul_neon(w2, x3);
282     const auto d = c_mul_neon(w3, x4);
283     const auto e = c_mul_neon(w4, x5);
284     const auto f = c_mul_neon(w5, x6);
285     const auto g = c_mul_neon(w6, x7);
286     const auto h = c_mul_neon(w7, x8);
287 
288     const auto b0 = c_mul_neon(float32x2_t{ kSqrt2Div2, -kSqrt2Div2 }, b);
289     const auto b1 = c_mul_neon(float32x2_t{ 0, -1 }, b);
290     const auto b2 = c_mul_neon(float32x2_t{ -kSqrt2Div2, -kSqrt2Div2 }, b);
291     const auto b3 = c_mul_neon(float32x2_t{ -1, 0 }, b);
292     const auto b4 = c_mul_neon(float32x2_t{ -kSqrt2Div2, kSqrt2Div2 }, b);
293     const auto b5 = c_mul_neon(float32x2_t{ 0, 1 }, b);
294     const auto b6 = c_mul_neon(float32x2_t{ kSqrt2Div2, kSqrt2Div2 }, b);
295 
296     const auto c0 = c_mul_neon(float32x2_t{ 0, -1 }, c);
297     const auto c1 = c_mul_neon(float32x2_t{ -1, 0 }, c);
298     const auto c2 = c_mul_neon(float32x2_t{ 0, 1 }, c);
299     const auto c3 = c_mul_neon(float32x2_t{ 1, 0 }, c);
300     const auto c4 = c_mul_neon(float32x2_t{ 0, -1 }, c);
301     const auto c5 = c_mul_neon(float32x2_t{ -1, 0 }, c);
302     const auto c6 = c_mul_neon(float32x2_t{ 0, 1 }, c);
303 
304     const auto d0 = c_mul_neon(float32x2_t{ -kSqrt2Div2, -kSqrt2Div2 }, d);
305     const auto d1 = c_mul_neon(float32x2_t{ 0, 1 }, d);
306     const auto d2 = c_mul_neon(float32x2_t{ kSqrt2Div2, -kSqrt2Div2 }, d);
307     const auto d3 = c_mul_neon(float32x2_t{ -1, 0 }, d);
308     const auto d4 = c_mul_neon(float32x2_t{ kSqrt2Div2, kSqrt2Div2 }, d);
309     const auto d5 = c_mul_neon(float32x2_t{ 0, -1 }, d);
310     const auto d6 = c_mul_neon(float32x2_t{ -kSqrt2Div2, kSqrt2Div2 }, d);
311 
312     const auto e0 = c_mul_neon(float32x2_t{ -1, 0 }, e);
313     const auto e1 = c_mul_neon(float32x2_t{ 1, 0 }, e);
314     const auto e2 = c_mul_neon(float32x2_t{ -1, 0 }, e);
315     const auto e3 = c_mul_neon(float32x2_t{ 1, 0 }, e);
316     const auto e4 = c_mul_neon(float32x2_t{ -1, 0 }, e);
317     const auto e5 = c_mul_neon(float32x2_t{ 1, 0 }, e);
318     const auto e6 = c_mul_neon(float32x2_t{ -1, 0 }, e);
319 
320     const auto f0 = c_mul_neon(float32x2_t{ -kSqrt2Div2, kSqrt2Div2 }, f);
321     const auto f1 = c_mul_neon(float32x2_t{ 0, -1 }, f);
322     const auto f2 = c_mul_neon(float32x2_t{ kSqrt2Div2, kSqrt2Div2 }, f);
323     const auto f3 = c_mul_neon(float32x2_t{ -1, 0 }, f);
324     const auto f4 = c_mul_neon(float32x2_t{ kSqrt2Div2, -kSqrt2Div2 }, f);
325     const auto f5 = c_mul_neon(float32x2_t{ 0, 1 }, f);
326     const auto f6 = c_mul_neon(float32x2_t{ -kSqrt2Div2, -kSqrt2Div2 }, f);
327 
328     const auto g0 = c_mul_neon(float32x2_t{ 0, 1 }, g);
329     const auto g1 = c_mul_neon(float32x2_t{ -1, 0 }, g);
330     const auto g2 = c_mul_neon(float32x2_t{ 0, -1 }, g);
331     const auto g3 = c_mul_neon(float32x2_t{ 1, 0 }, g);
332     const auto g4 = c_mul_neon(float32x2_t{ 0, 1 }, g);
333     const auto g5 = c_mul_neon(float32x2_t{ -1, 0 }, g);
334     const auto g6 = c_mul_neon(float32x2_t{ 0, -1 }, g);
335 
336     const auto h0 = c_mul_neon(float32x2_t{ kSqrt2Div2, kSqrt2Div2 }, h);
337     const auto h1 = c_mul_neon(float32x2_t{ 0, 1 }, h);
338     const auto h2 = c_mul_neon(float32x2_t{ -kSqrt2Div2, kSqrt2Div2 }, h);
339     const auto h3 = c_mul_neon(float32x2_t{ -1, 0 }, h);
340     const auto h4 = c_mul_neon(float32x2_t{ -kSqrt2Div2, -kSqrt2Div2 }, h);
341     const auto h5 = c_mul_neon(float32x2_t{ 0, -1 }, h);
342     const auto h6 = c_mul_neon(float32x2_t{ kSqrt2Div2, -kSqrt2Div2 }, h);
343 
344     x1 = reduce_sum_8(a, b, c, d, e, f, g, h);
345     x2 = reduce_sum_8(a, b0, c0, d0, e0, f0, g0, h0);
346     x3 = reduce_sum_8(a, b1, c1, d1, e1, f1, g1, h1);
347     x4 = reduce_sum_8(a, b2, c2, d2, e2, f2, g2, h2);
348     x5 = reduce_sum_8(a, b3, c3, d3, e3, f3, g3, h3);
349     x6 = reduce_sum_8(a, b4, c4, d4, e4, f4, g4, h4);
350     x7 = reduce_sum_8(a, b5, c5, d5, e5, f5, g5, h5);
351     x8 = reduce_sum_8(a, b6, c6, d6, e6, f6, g6, h6);
352 }
353 
354 template <bool first_stage>
fft_radix_2_axes_0(float * out,float * in,unsigned int Nx,unsigned int NxRadix,const float32x2_t & w_m,unsigned int N)355 void fft_radix_2_axes_0(float *out, float *in, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int N)
356 {
357     float32x2_t w{ 1.0f, 0.0f };
358     for(unsigned int j = 0; j < Nx; j++)
359     {
360         for(unsigned int k = 2 * j; k < 2 * N; k += 2 * NxRadix)
361         {
362             auto a = float32x2_t{ 0, 0 };
363             auto b = float32x2_t{ 0, 0 };
364 
365             // Load inputs
366             if(first_stage)
367             {
368                 const auto ab = wrapper::vloadq(in + k);
369                 a             = wrapper::vgetlow(ab);
370                 b             = wrapper::vgethigh(ab);
371             }
372             else
373             {
374                 a = wrapper::vload(in + k);
375                 b = wrapper::vload(in + k + 2 * Nx);
376             }
377 
378             // Base-case prime transform
379             fft_2(a, b, w);
380 
381             // Write outputs
382             if(first_stage)
383             {
384                 wrapper::vstore(out + k, wrapper::vcombine(a, b));
385             }
386             else
387             {
388                 wrapper::vstore(out + k, a);
389                 wrapper::vstore(out + k + 2 * Nx, b);
390             }
391         }
392 
393         w = c_mul_neon(w, w_m);
394     }
395 }
396 
fft_radix_2_axes_1(float * out,float * in,unsigned int Nx,unsigned int NxRadix,const float32x2_t & w_m,unsigned int N,unsigned int M,unsigned int in_pad_x,unsigned int out_pad_x)397 void fft_radix_2_axes_1(float *out, float *in, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int N, unsigned int M, unsigned int in_pad_x, unsigned int out_pad_x)
398 {
399     float32x2_t w{ 1.0f, 0.0f };
400     for(unsigned int j = 0; j < Nx; j++)
401     {
402         for(unsigned int k = 2 * j; k < 2 * M; k += 2 * NxRadix)
403         {
404             // Load inputs
405             float32x2_t a = wrapper::vload(in + (N + in_pad_x) * k);
406             float32x2_t b = wrapper::vload(in + (N + in_pad_x) * (k + 2 * Nx));
407 
408             // Base-case prime transform
409             fft_2(a, b, w);
410 
411             // Write outputs
412             wrapper::vstore(out + (N + out_pad_x) * k, a);
413             wrapper::vstore(out + (N + out_pad_x) * (k + 2 * Nx), b);
414         }
415 
416         w = c_mul_neon(w, w_m);
417     }
418 }
419 
420 template <bool first_stage>
fft_radix_3_axes_0(float * out,float * in,unsigned int Nx,unsigned int NxRadix,const float32x2_t & w_m,unsigned int N)421 void fft_radix_3_axes_0(float *out, float *in, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int N)
422 {
423     float32x2_t w{ 1.0f, 0.0f };
424     for(unsigned int j = 0; j < Nx; j++)
425     {
426         const auto w2 = c_mul_neon(w, w);
427 
428         for(unsigned int k = 2 * j; k < 2 * N; k += 2 * NxRadix)
429         {
430             // Load inputs
431             float32x2_t a = { 0, 0 };
432             float32x2_t b = { 0, 0 };
433             float32x2_t c = { 0, 0 };
434             if(first_stage)
435             {
436                 const auto ab = wrapper::vloadq(in + k);
437                 a             = wrapper::vgetlow(ab);
438                 b             = wrapper::vgethigh(ab);
439             }
440             else
441             {
442                 a = wrapper::vload(in + k);
443                 b = wrapper::vload(in + k + 2 * Nx);
444             }
445             c = wrapper::vload(in + k + 4 * Nx);
446 
447             // Base-case prime transform
448             fft_3(a, b, c, w, w2);
449 
450             if(first_stage)
451             {
452                 wrapper::vstore(out + k, wrapper::vcombine(a, b));
453             }
454             else
455             {
456                 wrapper::vstore(out + k, a);
457                 wrapper::vstore(out + k + 2 * Nx, b);
458             }
459             wrapper::vstore(out + k + 4 * Nx, c);
460         }
461         w = c_mul_neon(w, w_m);
462     }
463 }
464 
fft_radix_3_axes_1(float * out,float * in,unsigned int Nx,unsigned int NxRadix,const float32x2_t & w_m,unsigned int N,unsigned int M,unsigned int in_pad_x,unsigned int out_pad_x)465 void fft_radix_3_axes_1(float *out, float *in, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int N, unsigned int M, unsigned int in_pad_x, unsigned int out_pad_x)
466 {
467     float32x2_t w{ 1.0f, 0.0f };
468     for(unsigned int j = 0; j < Nx; j++)
469     {
470         const auto w2 = c_mul_neon(w, w);
471 
472         for(unsigned int k = 2 * j; k < 2 * M; k += 2 * NxRadix)
473         {
474             // Load inputs
475             float32x2_t a = wrapper::vload(in + (N + in_pad_x) * k);
476             float32x2_t b = wrapper::vload(in + (N + in_pad_x) * (k + 2 * Nx));
477             float32x2_t c = wrapper::vload(in + (N + in_pad_x) * (k + 4 * Nx));
478 
479             // Base-case prime transform
480             fft_3(a, b, c, w, w2);
481 
482             // Store the output
483             wrapper::vstore(out + (N + out_pad_x) * k, a);
484             wrapper::vstore(out + (N + out_pad_x) * (k + 2 * Nx), b);
485             wrapper::vstore(out + (N + out_pad_x) * (k + 4 * Nx), c);
486         }
487         w = c_mul_neon(w, w_m);
488     }
489 }
490 
491 template <bool first_stage>
fft_radix_4_axes_0(float * out,float * in,unsigned int Nx,unsigned int NxRadix,const float32x2_t & w_m,unsigned int N)492 void fft_radix_4_axes_0(float *out, float *in, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int N)
493 {
494     float32x2_t w{ 1.0f, 0.0f };
495     for(unsigned int j = 0; j < Nx; j++)
496     {
497         const auto w2 = c_mul_neon(w, w);
498         const auto w3 = c_mul_neon(w2, w);
499 
500         for(unsigned int k = 2 * j; k < 2 * N; k += 2 * NxRadix)
501         {
502             float32x2_t a = { 0, 0 };
503             float32x2_t b = { 0, 0 };
504             float32x2_t c = { 0, 0 };
505             float32x2_t d = { 0, 0 };
506             if(first_stage)
507             {
508                 const auto ab = wrapper::vloadq(in + k);
509                 const auto cd = wrapper::vloadq(in + k + 4 * Nx);
510                 a             = wrapper::vgetlow(ab);
511                 b             = wrapper::vgethigh(ab);
512                 c             = wrapper::vgetlow(cd);
513                 d             = wrapper::vgethigh(cd);
514             }
515             else
516             {
517                 // Load inputs
518                 a = wrapper::vload(in + k);
519                 b = wrapper::vload(in + k + 2 * Nx);
520                 c = wrapper::vload(in + k + 4 * Nx);
521                 d = wrapper::vload(in + k + 6 * Nx);
522             }
523 
524             // Base-case prime transform
525             fft_4(a, b, c, d, w, w2, w3);
526 
527             if(first_stage)
528             {
529                 wrapper::vstore(out + k, wrapper::vcombine(a, b));
530                 wrapper::vstore(out + k + 4 * Nx, wrapper::vcombine(c, d));
531             }
532             else
533             {
534                 wrapper::vstore(out + k, a);
535                 wrapper::vstore(out + k + 2 * Nx, b);
536                 wrapper::vstore(out + k + 4 * Nx, c);
537                 wrapper::vstore(out + k + 6 * Nx, d);
538             }
539         }
540 
541         w = c_mul_neon(w, w_m);
542     }
543 }
544 
fft_radix_4_axes_1(float * out,float * in,unsigned int Nx,unsigned int NxRadix,const float32x2_t & w_m,unsigned int N,unsigned int M,unsigned int in_pad_x,unsigned int out_pad_x)545 void fft_radix_4_axes_1(float *out, float *in, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int N, unsigned int M, unsigned int in_pad_x, unsigned int out_pad_x)
546 {
547     float32x2_t w{ 1.0f, 0.0f };
548     for(unsigned int j = 0; j < Nx; j++)
549     {
550         const auto w2 = c_mul_neon(w, w);
551         const auto w3 = c_mul_neon(w2, w);
552 
553         for(unsigned int k = 2 * j; k < 2 * M; k += 2 * NxRadix)
554         {
555             // Load inputs
556             float32x2_t a = wrapper::vload(in + (N + in_pad_x) * k);
557             float32x2_t b = wrapper::vload(in + (N + in_pad_x) * (k + 2 * Nx));
558             float32x2_t c = wrapper::vload(in + (N + in_pad_x) * (k + 4 * Nx));
559             float32x2_t d = wrapper::vload(in + (N + in_pad_x) * (k + 6 * Nx));
560 
561             // Base-case prime transform
562             fft_4(a, b, c, d, w, w2, w3);
563 
564             wrapper::vstore(out + (N + out_pad_x) * k, a);
565             wrapper::vstore(out + (N + out_pad_x) * (k + 2 * Nx), b);
566             wrapper::vstore(out + (N + out_pad_x) * (k + 4 * Nx), c);
567             wrapper::vstore(out + (N + out_pad_x) * (k + 6 * Nx), d);
568         }
569 
570         w = c_mul_neon(w, w_m);
571     }
572 }
573 
574 template <bool first_stage>
fft_radix_5_axes_0(float * out,float * in,unsigned int Nx,unsigned int NxRadix,const float32x2_t & w_m,unsigned int N)575 void fft_radix_5_axes_0(float *out, float *in, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int N)
576 {
577     float32x2_t w{ 1.0f, 0.0f };
578     for(unsigned int j = 0; j < Nx; j++)
579     {
580         const float32x2_t w2 = c_mul_neon(w, w);
581         const float32x2_t w3 = c_mul_neon(w2, w);
582         const float32x2_t w4 = c_mul_neon(w3, w);
583 
584         for(unsigned int k = 2 * j; k < 2 * N; k += 2 * NxRadix)
585         {
586             float32x2_t a = { 0, 0 };
587             float32x2_t b = { 0, 0 };
588             float32x2_t c = { 0, 0 };
589             float32x2_t d = { 0, 0 };
590             float32x2_t e = { 0, 0 };
591 
592             // Load inputs
593             if(first_stage)
594             {
595                 const auto ab = wrapper::vloadq(in + k);
596                 const auto cd = wrapper::vloadq(in + k + 4 * Nx);
597 
598                 a = wrapper::vgetlow(ab);
599                 b = wrapper::vgethigh(ab);
600                 c = wrapper::vgetlow(cd);
601                 d = wrapper::vgethigh(cd);
602             }
603             else
604             {
605                 a = wrapper::vload(in + k);
606                 b = wrapper::vload(in + k + 2 * Nx);
607                 c = wrapper::vload(in + k + 4 * Nx);
608                 d = wrapper::vload(in + k + 6 * Nx);
609             }
610             e = wrapper::vload(in + k + 8 * Nx);
611 
612             // Base-case prime transform
613             fft_5(a, b, c, d, e, w, w2, w3, w4);
614 
615             // Store outputs
616             if(first_stage)
617             {
618                 wrapper::vstore(out + k, wrapper::vcombine(a, b));
619                 wrapper::vstore(out + k + 4 * Nx, wrapper::vcombine(c, d));
620             }
621             else
622             {
623                 wrapper::vstore(out + k, a);
624                 wrapper::vstore(out + k + 2 * Nx, b);
625                 wrapper::vstore(out + k + 4 * Nx, c);
626                 wrapper::vstore(out + k + 6 * Nx, d);
627             }
628             wrapper::vstore(out + k + 8 * Nx, e);
629         }
630 
631         w = c_mul_neon(w, w_m);
632     }
633 }
634 
fft_radix_5_axes_1(float * out,float * in,unsigned int Nx,unsigned int NxRadix,const float32x2_t & w_m,unsigned int N,unsigned int M,unsigned int in_pad_x,unsigned int out_pad_x)635 void fft_radix_5_axes_1(float *out, float *in, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int N, unsigned int M, unsigned int in_pad_x, unsigned int out_pad_x)
636 {
637     float32x2_t w{ 1.0f, 0.0f };
638     for(unsigned int j = 0; j < Nx; j++)
639     {
640         const float32x2_t w2 = c_mul_neon(w, w);
641         const float32x2_t w3 = c_mul_neon(w2, w);
642         const float32x2_t w4 = c_mul_neon(w3, w);
643 
644         for(unsigned int k = 2 * j; k < 2 * M; k += 2 * NxRadix)
645         {
646             // Load inputs
647             float32x2_t a = wrapper::vload(in + (N + in_pad_x) * k);
648             float32x2_t b = wrapper::vload(in + (N + in_pad_x) * (k + 2 * Nx));
649             float32x2_t c = wrapper::vload(in + (N + in_pad_x) * (k + 4 * Nx));
650             float32x2_t d = wrapper::vload(in + (N + in_pad_x) * (k + 6 * Nx));
651             float32x2_t e = wrapper::vload(in + (N + in_pad_x) * (k + 8 * Nx));
652 
653             // Base-case prime transform
654             fft_5(a, b, c, d, e, w, w2, w3, w4);
655 
656             // Store outputs
657             wrapper::vstore(out + (N + out_pad_x) * k, a);
658             wrapper::vstore(out + (N + out_pad_x) * (k + 2 * Nx), b);
659             wrapper::vstore(out + (N + out_pad_x) * (k + 4 * Nx), c);
660             wrapper::vstore(out + (N + out_pad_x) * (k + 6 * Nx), d);
661             wrapper::vstore(out + (N + out_pad_x) * (k + 8 * Nx), e);
662         }
663 
664         w = c_mul_neon(w, w_m);
665     }
666 }
667 
668 template <bool first_stage>
fft_radix_7_axes_0(float * out,float * in,unsigned int Nx,unsigned int NxRadix,const float32x2_t & w_m,unsigned int N)669 void fft_radix_7_axes_0(float *out, float *in, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int N)
670 {
671     float32x2_t w{ 1.0f, 0.0f };
672     for(unsigned int j = 0; j < Nx; j++)
673     {
674         const float32x2_t w2 = c_mul_neon(w, w);
675         const float32x2_t w3 = c_mul_neon(w2, w);
676         const float32x2_t w4 = c_mul_neon(w3, w);
677         const float32x2_t w5 = c_mul_neon(w4, w);
678         const float32x2_t w6 = c_mul_neon(w5, w);
679 
680         for(unsigned int k = 2 * j; k < 2 * N; k += 2 * NxRadix)
681         {
682             float32x2_t a = { 0, 0 };
683             float32x2_t b = { 0, 0 };
684             float32x2_t c = { 0, 0 };
685             float32x2_t d = { 0, 0 };
686             float32x2_t e = { 0, 0 };
687             float32x2_t f = { 0, 0 };
688             float32x2_t g = { 0, 0 };
689 
690             // Load inputs
691             if(first_stage)
692             {
693                 const auto ab = wrapper::vloadq(in + k);
694                 const auto cd = wrapper::vloadq(in + k + 4 * Nx);
695                 const auto ef = wrapper::vloadq(in + k + 8 * Nx);
696 
697                 a = wrapper::vgetlow(ab);
698                 b = wrapper::vgethigh(ab);
699                 c = wrapper::vgetlow(cd);
700                 d = wrapper::vgethigh(cd);
701                 e = wrapper::vgetlow(ef);
702                 f = wrapper::vgethigh(ef);
703             }
704             else
705             {
706                 a = wrapper::vload(in + k);
707                 b = wrapper::vload(in + k + 2 * Nx);
708                 c = wrapper::vload(in + k + 4 * Nx);
709                 d = wrapper::vload(in + k + 6 * Nx);
710                 e = wrapper::vload(in + k + 8 * Nx);
711                 f = wrapper::vload(in + k + 10 * Nx);
712             }
713             g = wrapper::vload(in + k + 12 * Nx);
714 
715             // Base-case prime transform
716             fft_7(a, b, c, d, e, f, g, w, w2, w3, w4, w5, w6);
717 
718             if(first_stage)
719             {
720                 wrapper::vstore(out + k, wrapper::vcombine(a, b));
721                 wrapper::vstore(out + k + 4 * Nx, wrapper::vcombine(c, d));
722                 wrapper::vstore(out + k + 8 * Nx, wrapper::vcombine(e, f));
723             }
724             else
725             {
726                 wrapper::vstore(out + k, a);
727                 wrapper::vstore(out + k + 2 * Nx, b);
728                 wrapper::vstore(out + k + 4 * Nx, c);
729                 wrapper::vstore(out + k + 6 * Nx, d);
730                 wrapper::vstore(out + k + 8 * Nx, e);
731                 wrapper::vstore(out + k + 10 * Nx, f);
732             }
733             wrapper::vstore(out + k + 12 * Nx, g);
734         }
735 
736         w = c_mul_neon(w, w_m);
737     }
738 }
739 
fft_radix_7_axes_1(float * out,float * in,unsigned int Nx,unsigned int NxRadix,const float32x2_t & w_m,unsigned int N,unsigned int M,unsigned int in_pad_x,unsigned int out_pad_x)740 void fft_radix_7_axes_1(float *out, float *in, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int N, unsigned int M, unsigned int in_pad_x, unsigned int out_pad_x)
741 {
742     float32x2_t w{ 1.0f, 0.0f };
743     for(unsigned int j = 0; j < Nx; j++)
744     {
745         const float32x2_t w2 = c_mul_neon(w, w);
746         const float32x2_t w3 = c_mul_neon(w2, w);
747         const float32x2_t w4 = c_mul_neon(w3, w);
748         const float32x2_t w5 = c_mul_neon(w4, w);
749         const float32x2_t w6 = c_mul_neon(w5, w);
750 
751         for(unsigned int k = 2 * j; k < 2 * M; k += 2 * NxRadix)
752         {
753             // Load inputs
754             float32x2_t a = wrapper::vload(in + (N + in_pad_x) * k);
755             float32x2_t b = wrapper::vload(in + (N + in_pad_x) * (k + 2 * Nx));
756             float32x2_t c = wrapper::vload(in + (N + in_pad_x) * (k + 4 * Nx));
757             float32x2_t d = wrapper::vload(in + (N + in_pad_x) * (k + 6 * Nx));
758             float32x2_t e = wrapper::vload(in + (N + in_pad_x) * (k + 8 * Nx));
759             float32x2_t f = wrapper::vload(in + (N + in_pad_x) * (k + 10 * Nx));
760             float32x2_t g = wrapper::vload(in + (N + in_pad_x) * (k + 12 * Nx));
761 
762             // Base-case prime transform
763             fft_7(a, b, c, d, e, f, g, w, w2, w3, w4, w5, w6);
764 
765             // Store outputs
766             wrapper::vstore(out + (N + out_pad_x) * k, a);
767             wrapper::vstore(out + (N + out_pad_x) * (k + 2 * Nx), b);
768             wrapper::vstore(out + (N + out_pad_x) * (k + 4 * Nx), c);
769             wrapper::vstore(out + (N + out_pad_x) * (k + 6 * Nx), d);
770             wrapper::vstore(out + (N + out_pad_x) * (k + 8 * Nx), e);
771             wrapper::vstore(out + (N + out_pad_x) * (k + 10 * Nx), f);
772             wrapper::vstore(out + (N + out_pad_x) * (k + 12 * Nx), g);
773         }
774 
775         w = c_mul_neon(w, w_m);
776     }
777 }
778 
779 template <bool first_stage>
fft_radix_8_axes_0(float * out,float * in,unsigned int Nx,unsigned int NxRadix,const float32x2_t & w_m,unsigned int N)780 void fft_radix_8_axes_0(float *out, float *in, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int N)
781 {
782     float32x2_t w{ 1.0f, 0.0f };
783     for(unsigned int j = 0; j < Nx; j++)
784     {
785         const float32x2_t w2 = c_mul_neon(w, w);
786         const float32x2_t w3 = c_mul_neon(w2, w);
787         const float32x2_t w4 = c_mul_neon(w3, w);
788         const float32x2_t w5 = c_mul_neon(w4, w);
789         const float32x2_t w6 = c_mul_neon(w5, w);
790         const float32x2_t w7 = c_mul_neon(w6, w);
791 
792         for(unsigned int k = 2 * j; k < 2 * N; k += 2 * NxRadix)
793         {
794             // Load inputs
795             float32x2_t a = { 0, 0 };
796             float32x2_t b = { 0, 0 };
797             float32x2_t c = { 0, 0 };
798             float32x2_t d = { 0, 0 };
799             float32x2_t e = { 0, 0 };
800             float32x2_t f = { 0, 0 };
801             float32x2_t g = { 0, 0 };
802             float32x2_t h = { 0, 0 };
803 
804             // Base-case prime transform
805             if(first_stage)
806             {
807                 const auto ab = wrapper::vloadq(in + k);
808                 const auto cd = wrapper::vloadq(in + k + 4 * Nx);
809                 const auto ef = wrapper::vloadq(in + k + 8 * Nx);
810                 const auto gh = wrapper::vloadq(in + k + 12 * Nx);
811 
812                 a = wrapper::vgetlow(ab);
813                 b = wrapper::vgethigh(ab);
814                 c = wrapper::vgetlow(cd);
815                 d = wrapper::vgethigh(cd);
816                 e = wrapper::vgetlow(ef);
817                 f = wrapper::vgethigh(ef);
818                 g = wrapper::vgetlow(gh);
819                 h = wrapper::vgethigh(gh);
820             }
821             else
822             {
823                 a = wrapper::vload(in + k);
824                 b = wrapper::vload(in + k + 2 * Nx);
825                 c = wrapper::vload(in + k + 4 * Nx);
826                 d = wrapper::vload(in + k + 6 * Nx);
827                 e = wrapper::vload(in + k + 8 * Nx);
828                 f = wrapper::vload(in + k + 10 * Nx);
829                 g = wrapper::vload(in + k + 12 * Nx);
830                 h = wrapper::vload(in + k + 14 * Nx);
831             }
832 
833             // Apply twiddle factors
834             fft_8(a, b, c, d, e, f, g, h, w, w2, w3, w4, w5, w6, w7);
835 
836             // Store outputs
837             if(first_stage)
838             {
839                 wrapper::vstore(out + k, wrapper::vcombine(a, b));
840                 wrapper::vstore(out + k + 4 * Nx, wrapper::vcombine(c, d));
841                 wrapper::vstore(out + k + 8 * Nx, wrapper::vcombine(e, f));
842                 wrapper::vstore(out + k + 12 * Nx, wrapper::vcombine(g, h));
843             }
844             else
845             {
846                 wrapper::vstore(out + k, a);
847                 wrapper::vstore(out + k + 2 * Nx, b);
848                 wrapper::vstore(out + k + 4 * Nx, c);
849                 wrapper::vstore(out + k + 6 * Nx, d);
850                 wrapper::vstore(out + k + 8 * Nx, e);
851                 wrapper::vstore(out + k + 10 * Nx, f);
852                 wrapper::vstore(out + k + 12 * Nx, g);
853                 wrapper::vstore(out + k + 14 * Nx, h);
854             }
855         }
856 
857         w = c_mul_neon(w, w_m);
858     }
859 }
860 
fft_radix_8_axes_1(float * out,float * in,unsigned int Nx,unsigned int NxRadix,const float32x2_t & w_m,unsigned int N,unsigned int M,unsigned int in_pad_x,unsigned int out_pad_x)861 void fft_radix_8_axes_1(float *out, float *in, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int N, unsigned int M, unsigned int in_pad_x, unsigned int out_pad_x)
862 {
863     float32x2_t w{ 1.0f, 0.0f };
864     for(unsigned int j = 0; j < Nx; j++)
865     {
866         const float32x2_t w2 = c_mul_neon(w, w);
867         const float32x2_t w3 = c_mul_neon(w2, w);
868         const float32x2_t w4 = c_mul_neon(w3, w);
869         const float32x2_t w5 = c_mul_neon(w4, w);
870         const float32x2_t w6 = c_mul_neon(w5, w);
871         const float32x2_t w7 = c_mul_neon(w6, w);
872 
873         for(unsigned int k = 2 * j; k < 2 * M; k += 2 * NxRadix)
874         {
875             // Load inputs
876             float32x2_t a = wrapper::vload(in + (N + in_pad_x) * k);
877             float32x2_t b = wrapper::vload(in + (N + in_pad_x) * (k + 2 * Nx));
878             float32x2_t c = wrapper::vload(in + (N + in_pad_x) * (k + 4 * Nx));
879             float32x2_t d = wrapper::vload(in + (N + in_pad_x) * (k + 6 * Nx));
880             float32x2_t e = wrapper::vload(in + (N + in_pad_x) * (k + 8 * Nx));
881             float32x2_t f = wrapper::vload(in + (N + in_pad_x) * (k + 10 * Nx));
882             float32x2_t g = wrapper::vload(in + (N + in_pad_x) * (k + 12 * Nx));
883             float32x2_t h = wrapper::vload(in + (N + in_pad_x) * (k + 14 * Nx));
884 
885             // Base-case prime transform
886             fft_8(a, b, c, d, e, f, g, h, w, w2, w3, w4, w5, w6, w7);
887 
888             // Store outputs
889             wrapper::vstore(out + (N + out_pad_x) * k, a);
890             wrapper::vstore(out + (N + out_pad_x) * (k + 2 * Nx), b);
891             wrapper::vstore(out + (N + out_pad_x) * (k + 4 * Nx), c);
892             wrapper::vstore(out + (N + out_pad_x) * (k + 6 * Nx), d);
893             wrapper::vstore(out + (N + out_pad_x) * (k + 8 * Nx), e);
894             wrapper::vstore(out + (N + out_pad_x) * (k + 10 * Nx), f);
895             wrapper::vstore(out + (N + out_pad_x) * (k + 12 * Nx), g);
896             wrapper::vstore(out + (N + out_pad_x) * (k + 14 * Nx), h);
897         }
898 
899         w = c_mul_neon(w, w_m);
900     }
901 }
902 
validate_arguments(const ITensorInfo * input,const ITensorInfo * output,const FFTRadixStageKernelInfo & config)903 Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, const FFTRadixStageKernelInfo &config)
904 {
905     ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 2, DataType::F32);
906     ARM_COMPUTE_RETURN_ERROR_ON(config.axis > 1);
907     ARM_COMPUTE_RETURN_ERROR_ON(NEFFTRadixStageKernel::supported_radix().count(config.radix) == 0);
908     ARM_COMPUTE_UNUSED(config);
909 
910     // Checks performed when output is configured
911     if((output != nullptr) && (output->total_size() != 0))
912     {
913         ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input, output);
914         ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
915     }
916 
917     return Status{};
918 }
919 
validate_and_configure_window(ITensorInfo * input,ITensorInfo * output,const FFTRadixStageKernelInfo & config)920 std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input, ITensorInfo *output, const FFTRadixStageKernelInfo &config)
921 {
922     ARM_COMPUTE_UNUSED(config);
923 
924     if(output != nullptr)
925     {
926         auto_init_if_empty(*output, *input);
927     }
928 
929     Window win = calculate_max_window(*input, Steps());
930 
931     return std::make_pair(Status{}, win);
932 }
933 } // namespace
934 
NEFFTRadixStageKernel()935 NEFFTRadixStageKernel::NEFFTRadixStageKernel()
936     : _input(nullptr), _output(nullptr), _Nx(0), _axis(0), _radix(0), _func_0(), _func_1()
937 {
938 }
939 
set_radix_stage_axis0(const FFTRadixStageKernelInfo & config)940 void NEFFTRadixStageKernel::set_radix_stage_axis0(const FFTRadixStageKernelInfo &config)
941 {
942     // FFT table axis 0: [radix, first_stage]
943     static std::map<unsigned int, std::map<bool, FFTFunctionPointerAxis0>> fft_table_axis0;
944 
945     if(fft_table_axis0.empty())
946     {
947         fft_table_axis0[2][false] = &fft_radix_2_axes_0<false>;
948         fft_table_axis0[3][false] = &fft_radix_3_axes_0<false>;
949         fft_table_axis0[4][false] = &fft_radix_4_axes_0<false>;
950         fft_table_axis0[5][false] = &fft_radix_5_axes_0<false>;
951         fft_table_axis0[7][false] = &fft_radix_7_axes_0<false>;
952         fft_table_axis0[8][false] = &fft_radix_8_axes_0<false>;
953 
954         fft_table_axis0[2][true] = &fft_radix_2_axes_0<true>;
955         fft_table_axis0[3][true] = &fft_radix_3_axes_0<true>;
956         fft_table_axis0[4][true] = &fft_radix_4_axes_0<true>;
957         fft_table_axis0[5][true] = &fft_radix_5_axes_0<true>;
958         fft_table_axis0[7][true] = &fft_radix_7_axes_0<true>;
959         fft_table_axis0[8][true] = &fft_radix_8_axes_0<true>;
960     }
961 
962     _func_0 = fft_table_axis0[config.radix][config.is_first_stage];
963 }
964 
set_radix_stage_axis1(const FFTRadixStageKernelInfo & config)965 void NEFFTRadixStageKernel::set_radix_stage_axis1(const FFTRadixStageKernelInfo &config)
966 {
967     // FFT table axis 1: [radix, first_stage]
968     static std::map<unsigned int, FFTFunctionPointerAxis1> fft_table_axis1;
969 
970     if(fft_table_axis1.empty())
971     {
972         fft_table_axis1[2] = &fft_radix_2_axes_1;
973         fft_table_axis1[3] = &fft_radix_3_axes_1;
974         fft_table_axis1[4] = &fft_radix_4_axes_1;
975         fft_table_axis1[5] = &fft_radix_5_axes_1;
976         fft_table_axis1[7] = &fft_radix_7_axes_1;
977         fft_table_axis1[8] = &fft_radix_8_axes_1;
978     }
979 
980     _func_1 = fft_table_axis1[config.radix];
981 }
982 
configure(ITensor * input,ITensor * output,const FFTRadixStageKernelInfo & config)983 void NEFFTRadixStageKernel::configure(ITensor *input, ITensor *output, const FFTRadixStageKernelInfo &config)
984 {
985     ARM_COMPUTE_ERROR_ON_NULLPTR(input);
986 
987     // Output auto inizialitation if not yet initialized
988     if(output != nullptr)
989     {
990         auto_init_if_empty(*output->info(), *input->info()->clone());
991     }
992 
993     ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), (output != nullptr) ? output->info() : nullptr, config));
994 
995     _input  = input;
996     _output = (output == nullptr) ? input : output;
997     _Nx     = config.Nx;
998     _axis   = config.axis;
999     _radix  = config.radix;
1000 
1001     switch(config.axis)
1002     {
1003         case 0:
1004             set_radix_stage_axis0(config);
1005             break;
1006         case 1:
1007             set_radix_stage_axis1(config);
1008             break;
1009         default:
1010             ARM_COMPUTE_ERROR("Axis not supported");
1011             break;
1012     }
1013 
1014     // Configure kernel window
1015     auto win_config = validate_and_configure_window(input->info(), (output != nullptr) ? output->info() : nullptr, config);
1016     ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
1017     INEKernel::configure(win_config.second);
1018 }
1019 
validate(const ITensorInfo * input,const ITensorInfo * output,const FFTRadixStageKernelInfo & config)1020 Status NEFFTRadixStageKernel::validate(const ITensorInfo *input, const ITensorInfo *output, const FFTRadixStageKernelInfo &config)
1021 {
1022     const bool run_in_place = (output == nullptr) || (output == input);
1023     ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, output, config));
1024     ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(input->clone().get(),
1025                                                               (run_in_place) ? nullptr : output->clone().get(),
1026                                                               config)
1027                                 .first);
1028 
1029     return Status{};
1030 }
1031 
supported_radix()1032 std::set<unsigned int> NEFFTRadixStageKernel::supported_radix()
1033 {
1034     return std::set<unsigned int> { 2, 3, 4, 5, 7, 8 };
1035 }
1036 
run(const Window & window,const ThreadInfo & info)1037 void NEFFTRadixStageKernel::run(const Window &window, const ThreadInfo &info)
1038 {
1039     ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
1040     ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
1041     ARM_COMPUTE_UNUSED(info);
1042 
1043     Window input_window = window;
1044     input_window.set(_axis, 0);
1045 
1046     Iterator in(_input, input_window);
1047     Iterator out(_output, input_window);
1048 
1049     // Precompute FFT constants
1050     const unsigned int NxRadix = _radix * _Nx;
1051     const float        alpha   = 2.0f * kPi / float(NxRadix);
1052     const float32x2_t  w_m{ cosf(alpha), -sinf(alpha) };
1053 
1054     if(_axis == 0)
1055     {
1056         const unsigned int N = _input->info()->dimension(0);
1057         execute_window_loop(input_window, [&](const Coordinates &)
1058         {
1059             _func_0(reinterpret_cast<float *>(out.ptr()), reinterpret_cast<float *>(in.ptr()), _Nx, NxRadix, w_m, N);
1060         },
1061         in, out);
1062     }
1063     else
1064     {
1065         const unsigned int N = _input->info()->dimension(0);
1066         const unsigned int M = _input->info()->dimension(1);
1067         execute_window_loop(input_window, [&](const Coordinates &)
1068         {
1069             _func_1(reinterpret_cast<float *>(out.ptr()), reinterpret_cast<float *>(in.ptr()), _Nx, NxRadix, w_m, N, M,
1070                     _input->info()->padding().right + _input->info()->padding().left,
1071                     _output->info()->padding().right + _output->info()->padding().left);
1072         },
1073         in, out);
1074     }
1075 
1076     ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
1077     ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
1078 }
1079 } // namespace arm_compute
1080