• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1/*
2 * Copyright (c) 2019 Arm Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "helpers.h"
25
26/** Calculates and applies the twiddle factor to a given input.
27 *
28 * @param[in]     phi   The angle.
29 * @param[in,out] input The input on which the factor should be applied.
30 */
31#define TWIDDLE_FACTOR_MULTIPLICATION(phi, input)  \
32    {                                              \
33        float2 w, tmp;                             \
34        w.x   = native_cos(phi);                   \
35        w.y   = native_sin(phi);                   \
36        tmp.x = (w.x * input.x) - (w.y * input.y); \
37        tmp.y = (w.x * input.y) + (w.y * input.x); \
38        input = tmp;                               \
39    }
40
41/** Computes radix-2 butterfly unit.
42 *
43 * @param[in,out] c0 Complex input 0.
44 * @param[in,out] c1 Complex input 1.
45 */
46#define DFT_2(c0, c1) \
47    {                 \
48        float2 v0;    \
49        v0 = c0;      \
50        c0 = v0 + c1; \
51        c1 = v0 - c1; \
52    }
53
54// radix-3 butterfly unit factors
55#define SQRT3DIV2 0.86602540378443f
56
57/** Computes radix-3 butterfly unit.
58 *
59 * @param[in,out] c0 Complex input 0.
60 * @param[in,out] c1 Complex input 1.
61 * @param[in,out] c2 Complex input 2.
62 */
63#define DFT_3(c0, c1, c2)                                  \
64    {                                                      \
65        float2 v0 = c1 + c2;                               \
66        float2 v1 = c1 - c2;                               \
67        c1.x      = c0.x - 0.5f * v0.x + v1.y * SQRT3DIV2; \
68        c1.y      = c0.y - 0.5f * v0.y - v1.x * SQRT3DIV2; \
69        c2.x      = c0.x - 0.5f * v0.x - v1.y * SQRT3DIV2; \
70        c2.y      = c0.y - 0.5f * v0.y + v1.x * SQRT3DIV2; \
71        c0        = c0 + v0;                               \
72    }
73
74/**Computes radix-4 butterfly unit.
75 *
76 * @param[in,out] c0 Complex input 0.
77 * @param[in,out] c1 Complex input 1.
78 * @param[in,out] c2 Complex input 2.
79 * @param[in,out] c3 Complex input 3.
80 */
81#define DFT_4(c0, c1, c2, c3)  \
82    {                          \
83        float2 v0, v1, v2, v3; \
84        v0   = c0 + c2;        \
85        v1   = c1 + c3;        \
86        v2   = c0 - c2;        \
87        v3.x = c1.y - c3.y;    \
88        v3.y = c3.x - c1.x;    \
89        c0   = v0 + v1;        \
90        c2   = v0 - v1;        \
91        c1   = v2 + v3;        \
92        c3   = v2 - v3;        \
93    }
94
95// radix-5 butterfly unit factors
96#define W5_A 0.30901699437494f
97#define W5_B 0.95105651629515f
98#define W5_C 0.80901699437494f
99#define W5_D 0.58778525229247f
100
101/** Computes radix-5 butterfly unit.
102 *
103 * @param[in,out] c0 Complex input 0.
104 * @param[in,out] c1 Complex input 1.
105 * @param[in,out] c2 Complex input 2.
106 * @param[in,out] c3 Complex input 3.
107 * @param[in,out] c4 Complex input 4.
108 */
109#define DFT_5(c0, c1, c2, c3, c4)                 \
110    {                                             \
111        float2 v0, v1, v2, v3, v4;                \
112        v0 = c0;                                  \
113        v1 = W5_A * (c1 + c4) - W5_C * (c2 + c3); \
114        v2 = W5_C * (c1 + c4) - W5_A * (c2 + c3); \
115        v3 = W5_D * (c1 - c4) - W5_B * (c2 - c3); \
116        v4 = W5_B * (c1 - c4) + W5_D * (c2 - c3); \
117        c0 = v0 + c1 + c2 + c3 + c4;              \
118        c1 = v0 + v1 + (float2)(v4.y, -v4.x);     \
119        c2 = v0 - v2 + (float2)(v3.y, -v3.x);     \
120        c3 = v0 - v2 + (float2)(-v3.y, v3.x);     \
121        c4 = v0 + v1 + (float2)(-v4.y, v4.x);     \
122    }
123
124// radix-7 butterfly unit factors
125#define W7_A 0.62348980185873f
126#define W7_B 0.78183148246802f
127#define W7_C 0.22252093395631f
128#define W7_D 0.97492791218182f
129#define W7_E 0.90096886790241f
130#define W7_F 0.43388373911755f
131
132/** Computes radix-7 butterfly unit.
133 *
134 * @param[in,out] c0 Complex input 0.
135 * @param[in,out] c1 Complex input 1.
136 * @param[in,out] c2 Complex input 2.
137 * @param[in,out] c3 Complex input 3.
138 * @param[in,out] c4 Complex input 4.
139 * @param[in,out] c5 Complex input 5.
140 * @param[in,out] c6 Complex input 6.
141 */
142#define DFT_7(c0, c1, c2, c3, c4, c5, c6)                            \
143    {                                                                \
144        float2 v0, v1, v2, v3, v4, v5, v6;                           \
145        v0 = c0;                                                     \
146        v1 = W7_A * (c1 + c6) - W7_C * (c2 + c5) - W7_E * (c3 + c4); \
147        v2 = W7_C * (c1 + c6) + W7_E * (c2 + c5) - W7_A * (c3 + c4); \
148        v3 = W7_E * (c1 + c6) - W7_A * (c2 + c5) + W7_C * (c3 + c4); \
149        v4 = W7_B * (c1 - c6) + W7_D * (c2 - c5) + W7_F * (c3 - c4); \
150        v5 = W7_D * (c1 - c6) - W7_F * (c2 - c5) - W7_B * (c3 - c4); \
151        v6 = W7_F * (c1 - c6) - W7_B * (c2 - c5) + W7_D * (c3 - c4); \
152        c0 = v0 + c1 + c2 + c3 + c4 + c5 + c6;                       \
153        c1 = v0 + v1 + (float2)(v4.y, -v4.x);                        \
154        c2 = v0 - v2 + (float2)(v5.y, -v5.x);                        \
155        c3 = v0 - v3 + (float2)(v6.y, -v6.x);                        \
156        c4 = v0 - v3 + (float2)(-v6.y, v6.x);                        \
157        c5 = v0 - v2 + (float2)(-v5.y, v5.x);                        \
158        c6 = v0 + v1 + (float2)(-v4.y, v4.x);                        \
159    }
160
161/** Computes radix-8 butterfly unit.
162 *
163 * @param[in,out] c0 Complex input 0.
164 * @param[in,out] c1 Complex input 1.
165 * @param[in,out] c2 Complex input 2.
166 * @param[in,out] c3 Complex input 3.
167 * @param[in,out] c4 Complex input 4.
168 * @param[in,out] c5 Complex input 5.
169 * @param[in,out] c6 Complex input 6.
170 * @param[in,out] c7 Complex input 7.
171 */
172#define DFT_8(c0, c1, c2, c3, c4, c5, c6, c7)  \
173    {                                          \
174        float2 v0, v1, v2, v3, v4, v5, v6, v7; \
175        float2 s0, s1, s2, s3, s4, s5, s6, s7; \
176        float2 t0, t1, t2;                     \
177        v0   = c0 + c4;                        \
178        v1   = c1 + c5;                        \
179        v2   = c2 + c6;                        \
180        v3   = c3 + c7;                        \
181        v4   = c0 - c4;                        \
182        v5   = c1 - c5;                        \
183        v6   = c2 - c6;                        \
184        v7   = c3 - c7;                        \
185        s0   = v0 + v2;                        \
186        s1   = v1 + v3;                        \
187        s2   = v0 - v2;                        \
188        s3   = v1 - v3;                        \
189        s4.x = v4.x - v6.y;                    \
190        s4.y = v4.y + v6.x;                    \
191        s5.x = v5.x - v7.y;                    \
192        s5.y = v5.y + v7.x;                    \
193        s6.x = v4.x + v6.y;                    \
194        s6.y = v4.y - v6.x;                    \
195        s7.x = v5.x + v7.y;                    \
196        s7.y = v5.y - v7.x;                    \
197        t0.x = -s3.y;                          \
198        t0.y = s3.x;                           \
199        t1.x = M_SQRT1_2_F * (s5.x - s5.y);    \
200        t1.y = M_SQRT1_2_F * (s5.x + s5.y);    \
201        t2.x = -M_SQRT1_2_F * (s7.x + s7.y);   \
202        t2.y = M_SQRT1_2_F * (s7.x - s7.y);    \
203        c0   = s0 + s1;                        \
204        c1   = s6 - t2;                        \
205        c2   = s2 - t0;                        \
206        c3   = s4 - t1;                        \
207        c4   = s0 - s1;                        \
208        c5   = s6 + t2;                        \
209        c6   = s2 + t0;                        \
210        c7   = s4 + t1;                        \
211    }
212
213/** Computes the first stage of a radix-2 DFT on axis 0.
214 *
215 * @note In order to perform the FFT function "in-place", the pre-processor -DIN_PLACE must be passed at compile time
216 *
217 * @param[in,out] input_ptr                            Pointer to the source tensor. Supported data types: F32
218 * @param[in,out] input_stride_x                       Stride of the source tensor in X dimension (in bytes)
219 * @param[in,out] input_step_x                         input_stride_x * number of elements along X processed per workitem(in bytes)
220 * @param[in,out] input_stride_y                       Stride of the source tensor in Y dimension (in bytes)
221 * @param[in,out] input_step_y                         input_stride_y * number of elements along Y processed per workitem(in bytes)
222 * @param[in,out] input_stride_z                       Stride of the source tensor in Z dimension (in bytes)
223 * @param[in,out] input_step_z                         input_stride_z * number of elements along Z processed per workitem(in bytes)
224 * @param[in,out] input_offset_first_element_in_bytes  The offset of the first element in the source tensor
225 * @param[out]    output_ptr                           (Optional) Pointer to the destination image. Supported data types: same as @p input_ptr
226 * @param[in]     output_stride_x                      (Optional) Stride of the destination image in X dimension (in bytes)
227 * @param[in]     output_step_x                        (Optional) output_stride_x * number of elements along X processed per workitem(in bytes)
228 * @param[in]     output_stride_y                      (Optional) Stride of the destination image in Y dimension (in bytes)
229 * @param[in]     output_step_y                        (Optional) output_stride_y * number of elements along Y processed per workitem(in bytes)
230 * @param[in]     output_stride_z                      (Optional) Stride of the source tensor in Z dimension (in bytes)
231 * @param[in]     output_step_z                        (Optional) output_stride_z * number of elements along Z processed per workitem(in bytes)
232 * @param[in]     output_offset_first_element_in_bytes (Optional) The offset of the first element in the destination image
233 */
234kernel void fft_radix_2_first_stage_axis_0(
235    TENSOR3D_DECLARATION(input)
236#ifndef IN_PLACE
237    ,
238    TENSOR3D_DECLARATION(output)
239#endif /* not IN_PLACE */
240)
241{
242    // Get tensor pointers
243    Tensor3D input = CONVERT_TO_TENSOR3D_STRUCT(input);
244#ifdef IN_PLACE
245    Tensor3D output = input;
246#else  /* IN_PLACE */
247    Tensor3D output = CONVERT_TO_TENSOR3D_STRUCT(output);
248#endif /* IN_PLACE */
249
250    // Load two complex input values
251    float4 data = vload4(0, (__global float *)input.ptr);
252
253    // Compute DFT N = 2
254    DFT_2(data.s01, data.s23);
255
256    // Store two complex output values
257    vstore4(data, 0, (__global float *)output.ptr);
258}
259
260/** Computes the first stage of a radix-2 DFT on axis 1.
261 *
262 * @note In order to perform the FFT function "in-place", the pre-processor -DIN_PLACE must be passed at compile time
263 *
264 * @param[in,out] input_ptr                            Pointer to the source tensor. Supported data types: F32
265 * @param[in,out] input_stride_x                       Stride of the source tensor in X dimension (in bytes)
266 * @param[in,out] input_step_x                         input_stride_x * number of elements along X processed per workitem(in bytes)
267 * @param[in,out] input_stride_y                       Stride of the source tensor in Y dimension (in bytes)
268 * @param[in,out] input_step_y                         input_stride_y * number of elements along Y processed per workitem(in bytes)
269 * @param[in,out] input_stride_z                       Stride of the source tensor in Z dimension (in bytes)
270 * @param[in,out] input_step_z                         input_stride_z * number of elements along Z processed per workitem(in bytes)
271 * @param[in,out] input_offset_first_element_in_bytes  The offset of the first element in the source tensor
272 * @param[out]    output_ptr                           (Optional) Pointer to the destination image. Supported data types: same as @p input_ptr
273 * @param[in]     output_stride_x                      (Optional) Stride of the destination image in X dimension (in bytes)
274 * @param[in]     output_step_x                        (Optional) output_stride_x * number of elements along X processed per workitem(in bytes)
275 * @param[in]     output_stride_y                      (Optional) Stride of the destination image in Y dimension (in bytes)
276 * @param[in]     output_step_y                        (Optional) output_stride_y * number of elements along Y processed per workitem(in bytes)
277 * @param[in]     output_stride_z                      (Optional) Stride of the source tensor in Z dimension (in bytes)
278 * @param[in]     output_step_z                        (Optional) output_stride_z * number of elements along Z processed per workitem(in bytes)
279 * @param[in]     output_offset_first_element_in_bytes (Optional) The offset of the first element in the destination image
280 */
281kernel void fft_radix_2_first_stage_axis_1(
282    TENSOR3D_DECLARATION(input)
283#ifndef IN_PLACE
284    ,
285    TENSOR3D_DECLARATION(output)
286#endif /* not IN_PLACE */
287)
288{
289    // Get tensor pointers
290    Tensor3D input = CONVERT_TO_TENSOR3D_STRUCT(input);
291#ifdef IN_PLACE
292    Tensor3D output = input;
293#else  /* IN_PLACE */
294    Tensor3D output = CONVERT_TO_TENSOR3D_STRUCT(output);
295#endif /* IN_PLACE */
296
297    // Load two complex input values
298    float2 data1 = vload2(0, (__global float *)input.ptr);
299    float2 data2 = vload2(0, (__global float *)tensor3D_offset(&input, 0, 1, 0));
300
301    // Compute DFT N = 2
302    DFT_2(data1, data2);
303
304    // Store two complex output values
305    vstore2(data1, 0, (__global float *)output.ptr);
306    vstore2(data2, 0, (__global float *)tensor3D_offset(&output, 0, 1, 0));
307}
308
309/** Computes the first stage of a radix-3 DFT on axis 0.
310 *
311 * @note In order to perform the FFT function "in-place", the pre-processor -DIN_PLACE must be passed at compile time
312 *
313 * @param[in,out] input_ptr                            Pointer to the source tensor. Supported data types: F32
314 * @param[in,out] input_stride_x                       Stride of the source tensor in X dimension (in bytes)
315 * @param[in,out] input_step_x                         input_stride_x * number of elements along X processed per workitem(in bytes)
316 * @param[in,out] input_stride_y                       Stride of the source tensor in Y dimension (in bytes)
317 * @param[in,out] input_step_y                         input_stride_y * number of elements along Y processed per workitem(in bytes)
318 * @param[in,out] input_stride_z                       Stride of the source tensor in Z dimension (in bytes)
319 * @param[in,out] input_step_z                         input_stride_z * number of elements along Z processed per workitem(in bytes)
320 * @param[in,out] input_offset_first_element_in_bytes  The offset of the first element in the source tensor
321 * @param[out]    output_ptr                           (Optional) Pointer to the destination image. Supported data types: same as @p input_ptr
322 * @param[in]     output_stride_x                      (Optional) Stride of the destination image in X dimension (in bytes)
323 * @param[in]     output_step_x                        (Optional) output_stride_x * number of elements along X processed per workitem(in bytes)
324 * @param[in]     output_stride_y                      (Optional) Stride of the destination image in Y dimension (in bytes)
325 * @param[in]     output_step_y                        (Optional) output_stride_y * number of elements along Y processed per workitem(in bytes)
326 * @param[in]     output_stride_z                      (Optional) Stride of the source tensor in Z dimension (in bytes)
327 * @param[in]     output_step_z                        (Optional) output_stride_z * number of elements along Z processed per workitem(in bytes)
328 * @param[in]     output_offset_first_element_in_bytes (Optional) The offset of the first element in the destination image
329 */
330kernel void fft_radix_3_first_stage_axis_0(
331    TENSOR3D_DECLARATION(input)
332#ifndef IN_PLACE
333    ,
334    TENSOR3D_DECLARATION(output)
335#endif /* not IN_PLACE */
336)
337{
338    // Get tensor pointers
339    Tensor3D input = CONVERT_TO_TENSOR3D_STRUCT(input);
340#ifdef IN_PLACE
341    Tensor3D output = input;
342#else  /* IN_PLACE */
343    Tensor3D output = CONVERT_TO_TENSOR3D_STRUCT(output);
344#endif /* IN_PLACE */
345
346    // Load three complex input values
347    float4 data0 = vload4(0, (__global float *)input.ptr);
348    float2 data1 = vload2(0, (__global float *)tensor3D_offset(&input, 2, 0, 0));
349
350    // Compute DFT N = 3
351    DFT_3(data0.s01, data0.s23, data1.s01);
352
353    // Store three complex output values
354    vstore4(data0, 0, (__global float *)output.ptr);
355    vstore2(data1, 0, (__global float *)tensor3D_offset(&output, 2, 0, 0));
356}
357
358/** Computes the first stage of a radix-3 DFT on axis 1.
359 *
360 * @note In order to perform the FFT function "in-place", the pre-processor -DIN_PLACE must be passed at compile time
361 *
362 * @param[in,out] input_ptr                            Pointer to the source tensor. Supported data types: F32
363 * @param[in,out] input_stride_x                       Stride of the source tensor in X dimension (in bytes)
364 * @param[in,out] input_step_x                         input_stride_x * number of elements along X processed per workitem(in bytes)
365 * @param[in,out] input_stride_y                       Stride of the source tensor in Y dimension (in bytes)
366 * @param[in,out] input_step_y                         input_stride_y * number of elements along Y processed per workitem(in bytes)
367 * @param[in,out] input_stride_z                       Stride of the source tensor in Z dimension (in bytes)
368 * @param[in,out] input_step_z                         input_stride_z * number of elements along Z processed per workitem(in bytes)
369 * @param[in,out] input_offset_first_element_in_bytes  The offset of the first element in the source tensor
370 * @param[out]    output_ptr                           (Optional) Pointer to the destination image. Supported data types: same as @p input_ptr
371 * @param[in]     output_stride_x                      (Optional) Stride of the destination image in X dimension (in bytes)
372 * @param[in]     output_step_x                        (Optional) output_stride_x * number of elements along X processed per workitem(in bytes)
373 * @param[in]     output_stride_y                      (Optional) Stride of the destination image in Y dimension (in bytes)
374 * @param[in]     output_step_y                        (Optional) output_stride_y * number of elements along Y processed per workitem(in bytes)
375 * @param[in]     output_stride_z                      (Optional) Stride of the source tensor in Z dimension (in bytes)
376 * @param[in]     output_step_z                        (Optional) output_stride_z * number of elements along Z processed per workitem(in bytes)
377 * @param[in]     output_offset_first_element_in_bytes (Optional) The offset of the first element in the destination image
378 */
379kernel void fft_radix_3_first_stage_axis_1(
380    TENSOR3D_DECLARATION(input)
381#ifndef IN_PLACE
382    ,
383    TENSOR3D_DECLARATION(output)
384#endif /* not IN_PLACE */
385)
386{
387    // Get tensor pointers
388    Tensor3D input = CONVERT_TO_TENSOR3D_STRUCT(input);
389#ifdef IN_PLACE
390    Tensor3D output = input;
391#else  /* IN_PLACE */
392    Tensor3D output = CONVERT_TO_TENSOR3D_STRUCT(output);
393#endif /* IN_PLACE */
394
395    // Load three complex input values
396    float2 data0 = vload2(0, (__global float *)input.ptr);
397    float2 data1 = vload2(0, (__global float *)tensor3D_offset(&input, 0, 1, 0));
398    float2 data2 = vload2(0, (__global float *)tensor3D_offset(&input, 0, 2, 0));
399
400    // Compute DFT N = 3
401    DFT_3(data0, data1, data2);
402
403    // Store three complex output values
404    vstore2(data0, 0, (__global float *)output.ptr);
405    vstore2(data1, 0, (__global float *)tensor3D_offset(&output, 0, 1, 0));
406    vstore2(data2, 0, (__global float *)tensor3D_offset(&output, 0, 2, 0));
407}
408
409/** Computes the first stage of a radix-4 DFT on axis 0.
410 *
411 * @note In order to perform the FFT function "in-place", the pre-processor -DIN_PLACE must be passed at compile time
412 *
413 * @param[in,out] input_ptr                            Pointer to the source tensor. Supported data types: F32
414 * @param[in,out] input_stride_x                       Stride of the source tensor in X dimension (in bytes)
415 * @param[in,out] input_step_x                         input_stride_x * number of elements along X processed per workitem(in bytes)
416 * @param[in,out] input_stride_y                       Stride of the source tensor in Y dimension (in bytes)
417 * @param[in,out] input_step_y                         input_stride_y * number of elements along Y processed per workitem(in bytes)
418 * @param[in,out] input_stride_z                       Stride of the source tensor in Z dimension (in bytes)
419 * @param[in,out] input_step_z                         input_stride_z * number of elements along Z processed per workitem(in bytes)
420 * @param[in,out] input_offset_first_element_in_bytes  The offset of the first element in the source tensor
421 * @param[out]    output_ptr                           (Optional) Pointer to the destination image. Supported data types: same as @p input_ptr
422 * @param[in]     output_stride_x                      (Optional) Stride of the destination image in X dimension (in bytes)
423 * @param[in]     output_step_x                        (Optional) output_stride_x * number of elements along X processed per workitem(in bytes)
424 * @param[in]     output_stride_y                      (Optional) Stride of the destination image in Y dimension (in bytes)
425 * @param[in]     output_step_y                        (Optional) output_stride_y * number of elements along Y processed per workitem(in bytes)
426 * @param[in]     output_stride_z                      (Optional) Stride of the source tensor in Z dimension (in bytes)
427 * @param[in]     output_step_z                        (Optional) output_stride_z * number of elements along Z processed per workitem(in bytes)
428 * @param[in]     output_offset_first_element_in_bytes (Optional) The offset of the first element in the destination image
429 */
430kernel void fft_radix_4_first_stage_axis_0(
431    TENSOR3D_DECLARATION(input)
432#ifndef IN_PLACE
433    ,
434    TENSOR3D_DECLARATION(output)
435#endif /* not IN_PLACE */
436)
437{
438    // Get tensor pointers
439    Tensor3D input = CONVERT_TO_TENSOR3D_STRUCT(input);
440#ifdef IN_PLACE
441    Tensor3D output = input;
442#else  /* IN_PLACE */
443    Tensor3D output = CONVERT_TO_TENSOR3D_STRUCT(output);
444#endif /* IN_PLACE */
445
446    // Load four complex input values
447    float8 data = vload8(0, (__global float *)input.ptr);
448
449    // Compute DFT N = 4
450    DFT_4(data.s01, data.s23, data.s45, data.s67);
451
452    // Store four complex output values
453    vstore8(data, 0, (__global float *)output.ptr);
454}
455
456/** Computes the first stage of a radix-4 DFT on axis 1.
457 *
458 * @note In order to perform the FFT function "in-place", the pre-processor -DIN_PLACE must be passed at compile time
459 *
460 * @param[in,out] input_ptr                            Pointer to the source tensor. Supported data types: F32
461 * @param[in,out] input_stride_x                       Stride of the source tensor in X dimension (in bytes)
462 * @param[in,out] input_step_x                         input_stride_x * number of elements along X processed per workitem(in bytes)
463 * @param[in,out] input_stride_y                       Stride of the source tensor in Y dimension (in bytes)
464 * @param[in,out] input_step_y                         input_stride_y * number of elements along Y processed per workitem(in bytes)
465 * @param[in,out] input_stride_z                       Stride of the source tensor in Z dimension (in bytes)
466 * @param[in,out] input_step_z                         input_stride_z * number of elements along Z processed per workitem(in bytes)
467 * @param[in,out] input_offset_first_element_in_bytes  The offset of the first element in the source tensor
468 * @param[out]    output_ptr                           (Optional) Pointer to the destination image. Supported data types: same as @p input_ptr
469 * @param[in]     output_stride_x                      (Optional) Stride of the destination image in X dimension (in bytes)
470 * @param[in]     output_step_x                        (Optional) output_stride_x * number of elements along X processed per workitem(in bytes)
471 * @param[in]     output_stride_y                      (Optional) Stride of the destination image in Y dimension (in bytes)
472 * @param[in]     output_step_y                        (Optional) output_stride_y * number of elements along Y processed per workitem(in bytes)
473 * @param[in]     output_stride_z                      (Optional) Stride of the source tensor in Z dimension (in bytes)
474 * @param[in]     output_step_z                        (Optional) output_stride_z * number of elements along Z processed per workitem(in bytes)
475 * @param[in]     output_offset_first_element_in_bytes (Optional) The offset of the first element in the destination image
476 */
477kernel void fft_radix_4_first_stage_axis_1(
478    TENSOR3D_DECLARATION(input)
479#ifndef IN_PLACE
480    ,
481    TENSOR3D_DECLARATION(output)
482#endif /* not IN_PLACE */
483)
484{
485    // Get tensor pointers
486    Tensor3D input = CONVERT_TO_TENSOR3D_STRUCT(input);
487#ifdef IN_PLACE
488    Tensor3D output = input;
489#else  /* IN_PLACE */
490    Tensor3D output = CONVERT_TO_TENSOR3D_STRUCT(output);
491#endif /* IN_PLACE */
492
493    // Load four complex input values
494    float2 data0 = vload2(0, (__global float *)input.ptr);
495    float2 data1 = vload2(0, (__global float *)tensor3D_offset(&input, 0, 1, 0));
496    float2 data2 = vload2(0, (__global float *)tensor3D_offset(&input, 0, 2, 0));
497    float2 data3 = vload2(0, (__global float *)tensor3D_offset(&input, 0, 3, 0));
498
499    // Compute DFT N = 4
500    DFT_4(data0, data1, data2, data3);
501
502    // Store four complex output values
503    vstore2(data0, 0, (__global float *)output.ptr);
504    vstore2(data1, 0, (__global float *)tensor3D_offset(&output, 0, 1, 0));
505    vstore2(data2, 0, (__global float *)tensor3D_offset(&output, 0, 2, 0));
506    vstore2(data3, 0, (__global float *)tensor3D_offset(&output, 0, 3, 0));
507}
508
509/** Computes the first stage of a radix-5 DFT on axis 0.
510 *
511 * @note In order to perform the FFT function "in-place", the pre-processor -DIN_PLACE must be passed at compile time
512 *
513 * @param[in,out] input_ptr                            Pointer to the source tensor. Supported data types: F32
514 * @param[in,out] input_stride_x                       Stride of the source tensor in X dimension (in bytes)
515 * @param[in,out] input_step_x                         input_stride_x * number of elements along X processed per workitem(in bytes)
516 * @param[in,out] input_stride_y                       Stride of the source tensor in Y dimension (in bytes)
517 * @param[in,out] input_step_y                         input_stride_y * number of elements along Y processed per workitem(in bytes)
518 * @param[in,out] input_stride_z                       Stride of the source tensor in Z dimension (in bytes)
519 * @param[in,out] input_step_z                         input_stride_z * number of elements along Z processed per workitem(in bytes)
520 * @param[in,out] input_offset_first_element_in_bytes  The offset of the first element in the source tensor
521 * @param[out]    output_ptr                           (Optional) Pointer to the destination image. Supported data types: same as @p input_ptr
522 * @param[in]     output_stride_x                      (Optional) Stride of the destination image in X dimension (in bytes)
523 * @param[in]     output_step_x                        (Optional) output_stride_x * number of elements along X processed per workitem(in bytes)
524 * @param[in]     output_stride_y                      (Optional) Stride of the destination image in Y dimension (in bytes)
525 * @param[in]     output_step_y                        (Optional) output_stride_y * number of elements along Y processed per workitem(in bytes)
526 * @param[in]     output_stride_z                      (Optional) Stride of the source tensor in Z dimension (in bytes)
527 * @param[in]     output_step_z                        (Optional) output_stride_z * number of elements along Z processed per workitem(in bytes)
528 * @param[in]     output_offset_first_element_in_bytes (Optional) The offset of the first element in the destination image
529 */
530kernel void fft_radix_5_first_stage_axis_0(
531    TENSOR3D_DECLARATION(input)
532#ifndef IN_PLACE
533    ,
534    TENSOR3D_DECLARATION(output)
535#endif /* not IN_PLACE */
536)
537{
538    // Get tensor pointers
539    Tensor3D input = CONVERT_TO_TENSOR3D_STRUCT(input);
540#ifdef IN_PLACE
541    Tensor3D output = input;
542#else  /* IN_PLACE */
543    Tensor3D output = CONVERT_TO_TENSOR3D_STRUCT(output);
544#endif /* IN_PLACE */
545
546    // Load five complex input values
547    float8 data0 = vload8(0, (__global float *)input.ptr);
548    float2 data1 = vload2(0, (__global float *)tensor3D_offset(&input, 4, 0, 0));
549
550    // Compute DFT N = 5
551    DFT_5(data0.s01, data0.s23, data0.s45, data0.s67, data1.s01);
552
553    // Store five complex output values
554    vstore8(data0, 0, (__global float *)output.ptr);
555    vstore2(data1, 0, (__global float *)tensor3D_offset(&output, 4, 0, 0));
556}
557
558/** Computes the first stage of a radix-5 DFT on axis 1.
559 *
560 * @note In order to perform the FFT function "in-place", the pre-processor -DIN_PLACE must be passed at compile time
561 *
562 * @param[in,out] input_ptr                            Pointer to the source tensor. Supported data types: F32
563 * @param[in,out] input_stride_x                       Stride of the source tensor in X dimension (in bytes)
564 * @param[in,out] input_step_x                         input_stride_x * number of elements along X processed per workitem(in bytes)
565 * @param[in,out] input_stride_y                       Stride of the source tensor in Y dimension (in bytes)
566 * @param[in,out] input_step_y                         input_stride_y * number of elements along Y processed per workitem(in bytes)
567 * @param[in,out] input_stride_z                       Stride of the source tensor in Z dimension (in bytes)
568 * @param[in,out] input_step_z                         input_stride_z * number of elements along Z processed per workitem(in bytes)
569 * @param[in,out] input_offset_first_element_in_bytes  The offset of the first element in the source tensor
570 * @param[out]    output_ptr                           (Optional) Pointer to the destination image. Supported data types: same as @p input_ptr
571 * @param[in]     output_stride_x                      (Optional) Stride of the destination image in X dimension (in bytes)
572 * @param[in]     output_step_x                        (Optional) output_stride_x * number of elements along X processed per workitem(in bytes)
573 * @param[in]     output_stride_y                      (Optional) Stride of the destination image in Y dimension (in bytes)
574 * @param[in]     output_step_y                        (Optional) output_stride_y * number of elements along Y processed per workitem(in bytes)
575 * @param[in]     output_stride_z                      (Optional) Stride of the source tensor in Z dimension (in bytes)
576 * @param[in]     output_step_z                        (Optional) output_stride_z * number of elements along Z processed per workitem(in bytes)
577 * @param[in]     output_offset_first_element_in_bytes (Optional) The offset of the first element in the destination image
578 */
579kernel void fft_radix_5_first_stage_axis_1(
580    TENSOR3D_DECLARATION(input)
581#ifndef IN_PLACE
582    ,
583    TENSOR3D_DECLARATION(output)
584#endif /* not IN_PLACE */
585)
586{
587    // Get tensor pointers
588    Tensor3D input = CONVERT_TO_TENSOR3D_STRUCT(input);
589#ifdef IN_PLACE
590    Tensor3D output = input;
591#else  /* IN_PLACE */
592    Tensor3D output = CONVERT_TO_TENSOR3D_STRUCT(output);
593#endif /* IN_PLACE */
594
595    // Load five complex input values
596    float2 data0 = vload2(0, (__global float *)input.ptr);
597    float2 data1 = vload2(0, (__global float *)tensor3D_offset(&input, 0, 1, 0));
598    float2 data2 = vload2(0, (__global float *)tensor3D_offset(&input, 0, 2, 0));
599    float2 data3 = vload2(0, (__global float *)tensor3D_offset(&input, 0, 3, 0));
600    float2 data4 = vload2(0, (__global float *)tensor3D_offset(&input, 0, 4, 0));
601
602    // Compute DFT N = 5
603    DFT_5(data0, data1, data2, data3, data4);
604
605    // Store five complex output values
606    vstore2(data0, 0, (__global float *)output.ptr);
607    vstore2(data1, 0, (__global float *)tensor3D_offset(&output, 0, 1, 0));
608    vstore2(data2, 0, (__global float *)tensor3D_offset(&output, 0, 2, 0));
609    vstore2(data3, 0, (__global float *)tensor3D_offset(&output, 0, 3, 0));
610    vstore2(data4, 0, (__global float *)tensor3D_offset(&output, 0, 4, 0));
611}
612
613/** Computes the first stage of a radix-7 DFT on axis 0.
614 *
615 * @note In order to perform the FFT function "in-place", the pre-processor -DIN_PLACE must be passed at compile time
616 *
617 * @param[in,out] input_ptr                            Pointer to the source tensor. Supported data types: F32
618 * @param[in,out] input_stride_x                       Stride of the source tensor in X dimension (in bytes)
619 * @param[in,out] input_step_x                         input_stride_x * number of elements along X processed per workitem(in bytes)
620 * @param[in,out] input_stride_y                       Stride of the source tensor in Y dimension (in bytes)
621 * @param[in,out] input_step_y                         input_stride_y * number of elements along Y processed per workitem(in bytes)
622 * @param[in,out] input_stride_z                       Stride of the source tensor in Z dimension (in bytes)
623 * @param[in,out] input_step_z                         input_stride_z * number of elements along Z processed per workitem(in bytes)
624 * @param[in,out] input_offset_first_element_in_bytes  The offset of the first element in the source tensor
625 * @param[out]    output_ptr                           (Optional) Pointer to the destination image. Supported data types: same as @p input_ptr
626 * @param[in]     output_stride_x                      (Optional) Stride of the destination image in X dimension (in bytes)
627 * @param[in]     output_step_x                        (Optional) output_stride_x * number of elements along X processed per workitem(in bytes)
628 * @param[in]     output_stride_y                      (Optional) Stride of the destination image in Y dimension (in bytes)
629 * @param[in]     output_step_y                        (Optional) output_stride_y * number of elements along Y processed per workitem(in bytes)
630 * @param[in]     output_stride_z                      (Optional) Stride of the source tensor in Z dimension (in bytes)
631 * @param[in]     output_step_z                        (Optional) output_stride_z * number of elements along Z processed per workitem(in bytes)
632 * @param[in]     output_offset_first_element_in_bytes (Optional) The offset of the first element in the destination image
633 */
634kernel void fft_radix_7_first_stage_axis_0(
635    TENSOR3D_DECLARATION(input)
636#ifndef IN_PLACE
637    ,
638    TENSOR3D_DECLARATION(output)
639#endif /* not IN_PLACE */
640)
641{
642    // Get tensor pointers
643    Tensor3D input = CONVERT_TO_TENSOR3D_STRUCT(input);
644#ifdef IN_PLACE
645    Tensor3D output = input;
646#else  /* IN_PLACE */
647    Tensor3D output = CONVERT_TO_TENSOR3D_STRUCT(output);
648#endif /* IN_PLACE */
649
650    // Load seven complex input values
651    float8 data0 = vload8(0, (__global float *)input.ptr);
652    float4 data1 = vload4(0, (__global float *)tensor3D_offset(&input, 4, 0, 0));
653    float2 data2 = vload2(0, (__global float *)tensor3D_offset(&input, 6, 0, 0));
654
655    // Compute DFT N = 7
656    DFT_7(data0.s01, data0.s23, data0.s45, data0.s67, data1.s01, data1.s23, data2.s01);
657
658    // Store seven complex output values
659    vstore8(data0, 0, (__global float *)output.ptr);
660    vstore4(data1, 0, (__global float *)tensor3D_offset(&output, 4, 0, 0));
661    vstore2(data2, 0, (__global float *)tensor3D_offset(&output, 6, 0, 0));
662}
663
664/** Computes the first stage of a radix-7 DFT on axis 1.
665 *
666 * @note In order to perform the FFT function "in-place", the pre-processor -DIN_PLACE must be passed at compile time
667 *
668 * @param[in,out] input_ptr                            Pointer to the source tensor. Supported data types: F32
669 * @param[in,out] input_stride_x                       Stride of the source tensor in X dimension (in bytes)
670 * @param[in,out] input_step_x                         input_stride_x * number of elements along X processed per workitem(in bytes)
671 * @param[in,out] input_stride_y                       Stride of the source tensor in Y dimension (in bytes)
672 * @param[in,out] input_step_y                         input_stride_y * number of elements along Y processed per workitem(in bytes)
673 * @param[in,out] input_stride_z                       Stride of the source tensor in Z dimension (in bytes)
674 * @param[in,out] input_step_z                         input_stride_z * number of elements along Z processed per workitem(in bytes)
675 * @param[in,out] input_offset_first_element_in_bytes  The offset of the first element in the source tensor
676 * @param[out]    output_ptr                           (Optional) Pointer to the destination image. Supported data types: same as @p input_ptr
677 * @param[in]     output_stride_x                      (Optional) Stride of the destination image in X dimension (in bytes)
678 * @param[in]     output_step_x                        (Optional) output_stride_x * number of elements along X processed per workitem(in bytes)
679 * @param[in]     output_stride_y                      (Optional) Stride of the destination image in Y dimension (in bytes)
680 * @param[in]     output_step_y                        (Optional) output_stride_y * number of elements along Y processed per workitem(in bytes)
681 * @param[in]     output_stride_z                      (Optional) Stride of the source tensor in Z dimension (in bytes)
682 * @param[in]     output_step_z                        (Optional) output_stride_z * number of elements along Z processed per workitem(in bytes)
683 * @param[in]     output_offset_first_element_in_bytes (Optional) The offset of the first element in the destination image
684 */
685kernel void fft_radix_7_first_stage_axis_1(
686    TENSOR3D_DECLARATION(input)
687#ifndef IN_PLACE
688    ,
689    TENSOR3D_DECLARATION(output)
690#endif /* not IN_PLACE */
691)
692{
693    // Get tensor pointers
694    Tensor3D input = CONVERT_TO_TENSOR3D_STRUCT(input);
695#ifdef IN_PLACE
696    Tensor3D output = input;
697#else  /* IN_PLACE */
698    Tensor3D output = CONVERT_TO_TENSOR3D_STRUCT(output);
699#endif /* IN_PLACE */
700
701    // Load seven complex input values
702    float2 data0 = vload2(0, (__global float *)input.ptr);
703    float2 data1 = vload2(0, (__global float *)tensor3D_offset(&input, 0, 1, 0));
704    float2 data2 = vload2(0, (__global float *)tensor3D_offset(&input, 0, 2, 0));
705    float2 data3 = vload2(0, (__global float *)tensor3D_offset(&input, 0, 3, 0));
706    float2 data4 = vload2(0, (__global float *)tensor3D_offset(&input, 0, 4, 0));
707    float2 data5 = vload2(0, (__global float *)tensor3D_offset(&input, 0, 5, 0));
708    float2 data6 = vload2(0, (__global float *)tensor3D_offset(&input, 0, 6, 0));
709
710    // Compute DFT N = 7
711    DFT_7(data0, data1, data2, data3, data4, data5, data6);
712
713    // Store seven complex output values
714    vstore2(data0, 0, (__global float *)output.ptr);
715    vstore2(data1, 0, (__global float *)tensor3D_offset(&output, 0, 1, 0));
716    vstore2(data2, 0, (__global float *)tensor3D_offset(&output, 0, 2, 0));
717    vstore2(data3, 0, (__global float *)tensor3D_offset(&output, 0, 3, 0));
718    vstore2(data4, 0, (__global float *)tensor3D_offset(&output, 0, 4, 0));
719    vstore2(data5, 0, (__global float *)tensor3D_offset(&output, 0, 5, 0));
720    vstore2(data6, 0, (__global float *)tensor3D_offset(&output, 0, 6, 0));
721}
722
723/** Computes the first stage of a radix-8 DFT on axis 0.
724 *
725 * @note In order to perform the FFT function "in-place", the pre-processor -DIN_PLACE must be passed at compile time
726 *
727 * @param[in,out] input_ptr                            Pointer to the source tensor. Supported data types: F32
728 * @param[in,out] input_stride_x                       Stride of the source tensor in X dimension (in bytes)
729 * @param[in,out] input_step_x                         input_stride_x * number of elements along X processed per workitem(in bytes)
730 * @param[in,out] input_stride_y                       Stride of the source tensor in Y dimension (in bytes)
731 * @param[in,out] input_step_y                         input_stride_y * number of elements along Y processed per workitem(in bytes)
732 * @param[in,out] input_stride_z                       Stride of the source tensor in Z dimension (in bytes)
733 * @param[in,out] input_step_z                         input_stride_z * number of elements along Z processed per workitem(in bytes)
734 * @param[in,out] input_offset_first_element_in_bytes  The offset of the first element in the source tensor
735 * @param[out]    output_ptr                           (Optional) Pointer to the destination image. Supported data types: same as @p input_ptr
736 * @param[in]     output_stride_x                      (Optional) Stride of the destination image in X dimension (in bytes)
737 * @param[in]     output_step_x                        (Optional) output_stride_x * number of elements along X processed per workitem(in bytes)
738 * @param[in]     output_stride_y                      (Optional) Stride of the destination image in Y dimension (in bytes)
739 * @param[in]     output_step_y                        (Optional) output_stride_y * number of elements along Y processed per workitem(in bytes)
740 * @param[in]     output_stride_z                      (Optional) Stride of the source tensor in Z dimension (in bytes)
741 * @param[in]     output_step_z                        (Optional) output_stride_z * number of elements along Z processed per workitem(in bytes)
742 * @param[in]     output_offset_first_element_in_bytes (Optional) The offset of the first element in the destination image
743 */
744kernel void fft_radix_8_first_stage_axis_0(
745    TENSOR3D_DECLARATION(input)
746#ifndef IN_PLACE
747    ,
748    TENSOR3D_DECLARATION(output)
749#endif /* not IN_PLACE */
750)
751{
752    // Get tensor pointers
753    Tensor3D input = CONVERT_TO_TENSOR3D_STRUCT(input);
754#ifdef IN_PLACE
755    Tensor3D output = input;
756#else  /* IN_PLACE */
757    Tensor3D output = CONVERT_TO_TENSOR3D_STRUCT(output);
758#endif /* IN_PLACE */
759
760    // Load eight complex input values
761    float16 data = vload16(0, (__global float *)input.ptr);
762
763    // Compute DFT N = 8
764    DFT_8(data.s01, data.s23, data.s45, data.s67, data.s89, data.sAB, data.sCD, data.sEF);
765
766    // Store eight complex output values
767    vstore16(data, 0, (__global float *)output.ptr);
768}
769
770/** Computes the first stage of a radix-8 DFT on axis 1.
771 *
772 * @note In order to perform the FFT function "in-place", the pre-processor -DIN_PLACE must be passed at compile time
773 *
774 * @param[in,out] input_ptr                            Pointer to the source tensor. Supported data types: F32
775 * @param[in,out] input_stride_x                       Stride of the source tensor in X dimension (in bytes)
776 * @param[in,out] input_step_x                         input_stride_x * number of elements along X processed per workitem(in bytes)
777 * @param[in,out] input_stride_y                       Stride of the source tensor in Y dimension (in bytes)
778 * @param[in,out] input_step_y                         input_stride_y * number of elements along Y processed per workitem(in bytes)
779 * @param[in,out] input_stride_z                       Stride of the source tensor in Z dimension (in bytes)
780 * @param[in,out] input_step_z                         input_stride_z * number of elements along Z processed per workitem(in bytes)
781 * @param[in,out] input_offset_first_element_in_bytes  The offset of the first element in the source tensor
782 * @param[out]    output_ptr                           (Optional) Pointer to the destination image. Supported data types: same as @p input_ptr
783 * @param[in]     output_stride_x                      (Optional) Stride of the destination image in X dimension (in bytes)
784 * @param[in]     output_step_x                        (Optional) output_stride_x * number of elements along X processed per workitem(in bytes)
785 * @param[in]     output_stride_y                      (Optional) Stride of the destination image in Y dimension (in bytes)
786 * @param[in]     output_step_y                        (Optional) output_stride_y * number of elements along Y processed per workitem(in bytes)
787 * @param[in]     output_stride_z                      (Optional) Stride of the source tensor in Z dimension (in bytes)
788 * @param[in]     output_step_z                        (Optional) output_stride_z * number of elements along Z processed per workitem(in bytes)
789 * @param[in]     output_offset_first_element_in_bytes (Optional) The offset of the first element in the destination image
790 */
791kernel void fft_radix_8_first_stage_axis_1(
792    TENSOR3D_DECLARATION(input)
793#ifndef IN_PLACE
794    ,
795    TENSOR3D_DECLARATION(output)
796#endif /* not IN_PLACE */
797)
798{
799    // Get tensor pointers
800    Tensor3D input = CONVERT_TO_TENSOR3D_STRUCT(input);
801#ifdef IN_PLACE
802    Tensor3D output = input;
803#else  /* IN_PLACE */
804    Tensor3D output = CONVERT_TO_TENSOR3D_STRUCT(output);
805#endif /* IN_PLACE */
806
807    // Load eight complex input values
808    float2 data0 = vload2(0, (__global float *)input.ptr);
809    float2 data1 = vload2(0, (__global float *)tensor3D_offset(&input, 0, 1, 0));
810    float2 data2 = vload2(0, (__global float *)tensor3D_offset(&input, 0, 2, 0));
811    float2 data3 = vload2(0, (__global float *)tensor3D_offset(&input, 0, 3, 0));
812    float2 data4 = vload2(0, (__global float *)tensor3D_offset(&input, 0, 4, 0));
813    float2 data5 = vload2(0, (__global float *)tensor3D_offset(&input, 0, 5, 0));
814    float2 data6 = vload2(0, (__global float *)tensor3D_offset(&input, 0, 6, 0));
815    float2 data7 = vload2(0, (__global float *)tensor3D_offset(&input, 0, 7, 0));
816
817    // Compute DFT N = 8
818    DFT_8(data0, data1, data2, data3, data4, data5, data6, data7);
819
820    // Store eight complex output values
821    vstore2(data0, 0, (__global float *)output.ptr);
822    vstore2(data1, 0, (__global float *)tensor3D_offset(&output, 0, 1, 0));
823    vstore2(data2, 0, (__global float *)tensor3D_offset(&output, 0, 2, 0));
824    vstore2(data3, 0, (__global float *)tensor3D_offset(&output, 0, 3, 0));
825    vstore2(data4, 0, (__global float *)tensor3D_offset(&output, 0, 4, 0));
826    vstore2(data5, 0, (__global float *)tensor3D_offset(&output, 0, 5, 0));
827    vstore2(data6, 0, (__global float *)tensor3D_offset(&output, 0, 6, 0));
828    vstore2(data7, 0, (__global float *)tensor3D_offset(&output, 0, 7, 0));
829}
830
831/** Computes a stage of a radix-2 FFT on axis 0.
832 *
833 * @note In order to perform the FFT function "in-place", the pre-processor -DIN_PLACE must be passed at compile time
834 *
835 * @param[in,out] input_ptr                            Pointer to the source tensor. Supported data types: F32
836 * @param[in,out] input_stride_x                       Stride of the source tensor in X dimension (in bytes)
837 * @param[in,out] input_step_x                         input_stride_x * number of elements along X processed per workitem(in bytes)
838 * @param[in,out] input_stride_y                       Stride of the source tensor in Y dimension (in bytes)
839 * @param[in,out] input_step_y                         input_stride_y * number of elements along Y processed per workitem(in bytes)
840 * @param[in,out] input_stride_z                       Stride of the source tensor in Z dimension (in bytes)
841 * @param[in,out] input_step_z                         input_stride_z * number of elements along Z processed per workitem(in bytes)
842 * @param[in,out] input_offset_first_element_in_bytes  The offset of the first element in the source tensor
843 * @param[out]    output_ptr                           (Optional) Pointer to the destination image. Supported data types: same as @p input_ptr
844 * @param[in]     output_stride_x                      (Optional) Stride of the destination image in X dimension (in bytes)
845 * @param[in]     output_step_x                        (Optional) output_stride_x * number of elements along X processed per workitem(in bytes)
846 * @param[in]     output_stride_y                      (Optional) Stride of the destination image in Y dimension (in bytes)
847 * @param[in]     output_step_y                        (Optional) output_stride_y * number of elements along Y processed per workitem(in bytes)
848 * @param[in]     output_stride_z                      (Optional) Stride of the source tensor in Z dimension (in bytes)
849 * @param[in]     output_step_z                        (Optional) output_stride_z * number of elements along Z processed per workitem(in bytes)
850 * @param[in]     output_offset_first_element_in_bytes (Optional) The offset of the first element in the destination image
851 * @param[in]     Nx                                   The butterfly span. Products of radix order of previous radix's stage
852 * @param[in]     Ni                                   Nx * Ny.
853 * @param[in]     exp_const                            Exponent constant
854 */
855kernel void fft_radix_2_axis_0(
856    TENSOR3D_DECLARATION(input)
857#ifndef IN_PLACE
858    ,
859    TENSOR3D_DECLARATION(output)
860#endif /* not IN_PLACE */
861    ,
862    uint Nx, uint Ni, float exp_const)
863{
864    // Each work-item computes a single radix-2
865    uint kx = get_global_id(0);
866
867    // Compute nx
868    uint nx = kx % Nx;
869
870    // Compute n index
871    uint n = nx + (kx / Nx) * Ni;
872
873    // Get tensor pointers
874    Tensor3D input = CONVERT_TO_TENSOR3D_STRUCT_NO_STEP(input);
875    input.ptr += n * input.stride_x + get_global_id(1) * input.stride_y + get_global_id(2) * input.stride_z;
876#ifdef IN_PLACE
877    Tensor3D output = input;
878#else  /* IN_PLACE */
879    Tensor3D output = CONVERT_TO_TENSOR3D_STRUCT_NO_STEP(output);
880    output.ptr += n * output.stride_x + get_global_id(1) * output.stride_y + get_global_id(2) * output.stride_z;
881#endif /* IN_PLACE */
882
883    // Load two complex input values
884    float2 c0 = vload2(0, (__global float *)input.ptr);
885    float2 c1 = vload2(0, (__global float *)tensor3D_offset(&input, Nx, 0, 0));
886
887    // Compute phi
888    float phi = (float)nx * exp_const;
889
890    // Multiply by twiddle factor
891    TWIDDLE_FACTOR_MULTIPLICATION(phi, c1);
892
893    // Compute DFT N = 2
894    DFT_2(c0, c1);
895
896    // Store two complex output values
897    vstore2(c0, 0, (__global float *)output.ptr);
898    vstore2(c1, 0, (__global float *)tensor3D_offset(&output, Nx, 0, 0));
899}
900
901/** Computes a stage of a radix-2 FFT on axis 1.
902 *
903 * @note In order to perform the FFT function "in-place", the pre-processor -DIN_PLACE must be passed at compile time
904 *
905 * @param[in,out] input_ptr                            Pointer to the source tensor. Supported data types: F32
906 * @param[in,out] input_stride_x                       Stride of the source tensor in X dimension (in bytes)
907 * @param[in,out] input_step_x                         input_stride_x * number of elements along X processed per workitem(in bytes)
908 * @param[in,out] input_stride_y                       Stride of the source tensor in Y dimension (in bytes)
909 * @param[in,out] input_step_y                         input_stride_y * number of elements along Y processed per workitem(in bytes)
910 * @param[in,out] input_stride_z                       Stride of the source tensor in Z dimension (in bytes)
911 * @param[in,out] input_step_z                         input_stride_z * number of elements along Z processed per workitem(in bytes)
912 * @param[in,out] input_offset_first_element_in_bytes  The offset of the first element in the source tensor
913 * @param[out]    output_ptr                           (Optional) Pointer to the destination image. Supported data types: same as @p input_ptr
914 * @param[in]     output_stride_x                      (Optional) Stride of the destination image in X dimension (in bytes)
915 * @param[in]     output_step_x                        (Optional) output_stride_x * number of elements along X processed per workitem(in bytes)
916 * @param[in]     output_stride_y                      (Optional) Stride of the destination image in Y dimension (in bytes)
917 * @param[in]     output_step_y                        (Optional) output_stride_y * number of elements along Y processed per workitem(in bytes)
918 * @param[in]     output_stride_z                      (Optional) Stride of the source tensor in Z dimension (in bytes)
919 * @param[in]     output_step_z                        (Optional) output_stride_z * number of elements along Z processed per workitem(in bytes)
920 * @param[in]     output_offset_first_element_in_bytes (Optional) The offset of the first element in the destination image
921 * @param[in]     Nx                                   The butterfly span. Products of radix order of previous radix's stage
922 * @param[in]     Ni                                   Nx * Ny.
923 * @param[in]     exp_const                            Exponent constant
924 */
925kernel void fft_radix_2_axis_1(
926    TENSOR3D_DECLARATION(input)
927#ifndef IN_PLACE
928    ,
929    TENSOR3D_DECLARATION(output)
930#endif /* not IN_PLACE */
931    ,
932    uint Nx, uint Ni, float exp_const)
933{
934    // Each work-item computes a single radix-2
935    uint kx = get_global_id(1);
936
937    // Compute nx
938    uint nx = kx % Nx;
939
940    // Compute n index
941    uint n = nx + (kx / Nx) * Ni;
942
943    // Get tensor pointers
944    Tensor3D input = CONVERT_TO_TENSOR3D_STRUCT_NO_STEP(input);
945    input.ptr += get_global_id(0) * input.stride_x + n * input.stride_y + get_global_id(2) * input.stride_z;
946#ifdef IN_PLACE
947    Tensor3D output = input;
948#else  /* IN_PLACE */
949    Tensor3D output = CONVERT_TO_TENSOR3D_STRUCT_NO_STEP(output);
950    output.ptr += get_global_id(0) * output.stride_x + n * output.stride_y + get_global_id(2) * output.stride_z;
951#endif /* IN_PLACE */
952
953    // Load two complex input values
954    float2 c0 = vload2(0, (__global float *)input.ptr);
955    float2 c1 = vload2(0, (__global float *)tensor3D_offset(&input, 0, Nx, 0));
956
957    // Compute phi
958    float phi = (float)nx * exp_const;
959
960    // Multiply by twiddle factor
961    TWIDDLE_FACTOR_MULTIPLICATION(phi, c1);
962
963    // Compute DFT N = 2
964    DFT_2(c0, c1);
965
966    // Store two complex output values
967    vstore2(c0, 0, (__global float *)output.ptr);
968    vstore2(c1, 0, (__global float *)tensor3D_offset(&output, 0, Nx, 0));
969}
970
971/** Computes a stage of a radix-3 FFT on axis 0.
972 *
973 * @note In order to perform the FFT function "in-place", the pre-processor -DIN_PLACE must be passed at compile time
974 *
975 * @param[in,out] input_ptr                            Pointer to the source tensor. Supported data types: F32
976 * @param[in,out] input_stride_x                       Stride of the source tensor in X dimension (in bytes)
977 * @param[in,out] input_step_x                         input_stride_x * number of elements along X processed per workitem(in bytes)
978 * @param[in,out] input_stride_y                       Stride of the source tensor in Y dimension (in bytes)
979 * @param[in,out] input_step_y                         input_stride_y * number of elements along Y processed per workitem(in bytes)
980 * @param[in,out] input_stride_z                       Stride of the source tensor in Z dimension (in bytes)
981 * @param[in,out] input_step_z                         input_stride_z * number of elements along Z processed per workitem(in bytes)
982 * @param[in,out] input_offset_first_element_in_bytes  The offset of the first element in the source tensor
983 * @param[out]    output_ptr                           (Optional) Pointer to the destination image. Supported data types: same as @p input_ptr
984 * @param[in]     output_stride_x                      (Optional) Stride of the destination image in X dimension (in bytes)
985 * @param[in]     output_step_x                        (Optional) output_stride_x * number of elements along X processed per workitem(in bytes)
986 * @param[in]     output_stride_y                      (Optional) Stride of the destination image in Y dimension (in bytes)
987 * @param[in]     output_step_y                        (Optional) output_stride_y * number of elements along Y processed per workitem(in bytes)
988 * @param[in]     output_stride_z                      (Optional) Stride of the source tensor in Z dimension (in bytes)
989 * @param[in]     output_step_z                        (Optional) output_stride_z * number of elements along Z processed per workitem(in bytes)
990 * @param[in]     output_offset_first_element_in_bytes (Optional) The offset of the first element in the destination image
991 * @param[in]     Nx                                   The butterfly span. Products of radix order of previous radix's stage
992 * @param[in]     Ni                                   Nx * Ny.
993 * @param[in]     exp_const                            Exponent constant
994 */
995kernel void fft_radix_3_axis_0(
996    TENSOR3D_DECLARATION(input)
997#ifndef IN_PLACE
998    ,
999    TENSOR3D_DECLARATION(output)
1000#endif /* not IN_PLACE */
1001    ,
1002    uint Nx, uint Ni, float exp_const)
1003{
1004    // Each work-item computes a single radix-3
1005    uint kx = get_global_id(0);
1006
1007    // Compute nx
1008    uint nx = kx % Nx;
1009
1010    // Compute n index
1011    uint n = nx + (kx / Nx) * Ni;
1012
1013    // Get tensor pointers
1014    Tensor3D input = CONVERT_TO_TENSOR3D_STRUCT_NO_STEP(input);
1015    input.ptr += n * input.stride_x + get_global_id(1) * input.stride_y + get_global_id(2) * input.stride_z;
1016#ifdef IN_PLACE
1017    Tensor3D output = input;
1018#else  /* IN_PLACE */
1019    Tensor3D output = CONVERT_TO_TENSOR3D_STRUCT_NO_STEP(output);
1020    output.ptr += n * output.stride_x + get_global_id(1) * output.stride_y + get_global_id(2) * output.stride_z;
1021#endif /* IN_PLACE */
1022
1023    // Load three complex input values
1024    float2 c0 = vload2(0, (__global float *)input.ptr);
1025    float2 c1 = vload2(0, (__global float *)tensor3D_offset(&input, Nx, 0, 0));
1026    float2 c2 = vload2(0, (__global float *)tensor3D_offset(&input, 2 * Nx, 0, 0));
1027
1028    // Compute phi
1029    float phi = (float)nx * exp_const;
1030
1031    // Multiply by twiddle factor
1032    TWIDDLE_FACTOR_MULTIPLICATION(phi, c1);
1033    TWIDDLE_FACTOR_MULTIPLICATION(2 * phi, c2);
1034
1035    // Compute DFT N = 3
1036    DFT_3(c0, c1, c2);
1037
1038    // Store three complex output values
1039    vstore2(c0, 0, (__global float *)output.ptr);
1040    vstore2(c1, 0, (__global float *)tensor3D_offset(&output, Nx, 0, 0));
1041    vstore2(c2, 0, (__global float *)tensor3D_offset(&output, 2 * Nx, 0, 0));
1042}
1043
1044/** Computes a stage of a radix-3 FFT on axis 1.
1045 *
1046 * @note In order to perform the FFT function "in-place", the pre-processor -DIN_PLACE must be passed at compile time
1047 *
1048 * @param[in,out] input_ptr                            Pointer to the source tensor. Supported data types: F32
1049 * @param[in,out] input_stride_x                       Stride of the source tensor in X dimension (in bytes)
1050 * @param[in,out] input_step_x                         input_stride_x * number of elements along X processed per workitem(in bytes)
1051 * @param[in,out] input_stride_y                       Stride of the source tensor in Y dimension (in bytes)
1052 * @param[in,out] input_step_y                         input_stride_y * number of elements along Y processed per workitem(in bytes)
1053 * @param[in,out] input_stride_z                       Stride of the source tensor in Z dimension (in bytes)
1054 * @param[in,out] input_step_z                         input_stride_z * number of elements along Z processed per workitem(in bytes)
1055 * @param[in,out] input_offset_first_element_in_bytes  The offset of the first element in the source tensor
1056 * @param[out]    output_ptr                           (Optional) Pointer to the destination image. Supported data types: same as @p input_ptr
1057 * @param[in]     output_stride_x                      (Optional) Stride of the destination image in X dimension (in bytes)
1058 * @param[in]     output_step_x                        (Optional) output_stride_x * number of elements along X processed per workitem(in bytes)
1059 * @param[in]     output_stride_y                      (Optional) Stride of the destination image in Y dimension (in bytes)
1060 * @param[in]     output_step_y                        (Optional) output_stride_y * number of elements along Y processed per workitem(in bytes)
1061 * @param[in]     output_stride_z                      (Optional) Stride of the source tensor in Z dimension (in bytes)
1062 * @param[in]     output_step_z                        (Optional) output_stride_z * number of elements along Z processed per workitem(in bytes)
1063 * @param[in]     output_offset_first_element_in_bytes (Optional) The offset of the first element in the destination image
1064 * @param[in]     Nx                                   The butterfly span. Products of radix order of previous radix's stage
1065 * @param[in]     Ni                                   Nx * Ny.
1066 * @param[in]     exp_const                            Exponent constant
1067 */
1068kernel void fft_radix_3_axis_1(
1069    TENSOR3D_DECLARATION(input)
1070#ifndef IN_PLACE
1071    ,
1072    TENSOR3D_DECLARATION(output)
1073#endif /* not IN_PLACE */
1074    ,
1075    uint Nx, uint Ni, float exp_const)
1076{
1077    // Each work-item computes a single radix-3
1078    uint kx = get_global_id(1);
1079
1080    // Compute nx
1081    uint nx = kx % Nx;
1082
1083    // Compute n index
1084    uint n = nx + (kx / Nx) * Ni;
1085
1086    // Get tensor pointers
1087    Tensor3D input = CONVERT_TO_TENSOR3D_STRUCT_NO_STEP(input);
1088    input.ptr += get_global_id(0) * input.stride_x + n * input.stride_y + get_global_id(2) * input.stride_z;
1089#ifdef IN_PLACE
1090    Tensor3D output = input;
1091#else  /* IN_PLACE */
1092    Tensor3D output = CONVERT_TO_TENSOR3D_STRUCT_NO_STEP(output);
1093    output.ptr += get_global_id(0) * output.stride_x + n * output.stride_y + get_global_id(2) * output.stride_z;
1094#endif /* IN_PLACE */
1095
1096    // Load three complex input values
1097    float2 c0 = vload2(0, (__global float *)input.ptr);
1098    float2 c1 = vload2(0, (__global float *)tensor3D_offset(&input, 0, Nx, 0));
1099    float2 c2 = vload2(0, (__global float *)tensor3D_offset(&input, 0, 2 * Nx, 0));
1100
1101    // Compute phi
1102    float phi = (float)nx * exp_const;
1103
1104    // Multiply by twiddle factor
1105    TWIDDLE_FACTOR_MULTIPLICATION(phi, c1);
1106    TWIDDLE_FACTOR_MULTIPLICATION(2 * phi, c2);
1107
1108    // Compute DFT N = 3
1109    DFT_3(c0, c1, c2);
1110
1111    // Store three complex output values
1112    vstore2(c0, 0, (__global float *)output.ptr);
1113    vstore2(c1, 0, (__global float *)tensor3D_offset(&output, 0, Nx, 0));
1114    vstore2(c2, 0, (__global float *)tensor3D_offset(&output, 0, 2 * Nx, 0));
1115}
1116
1117/** Computes a stage of a radix-4 FFT on axis 0.
1118 *
1119 * @note In order to perform the FFT function "in-place", the pre-processor -DIN_PLACE must be passed at compile time
1120 *
1121 * @param[in,out] input_ptr                            Pointer to the source tensor. Supported data types: F32
1122 * @param[in,out] input_stride_x                       Stride of the source tensor in X dimension (in bytes)
1123 * @param[in,out] input_step_x                         input_stride_x * number of elements along X processed per workitem(in bytes)
1124 * @param[in,out] input_stride_y                       Stride of the source tensor in Y dimension (in bytes)
1125 * @param[in,out] input_step_y                         input_stride_y * number of elements along Y processed per workitem(in bytes)
1126 * @param[in,out] input_stride_z                       Stride of the source tensor in Z dimension (in bytes)
1127 * @param[in,out] input_step_z                         input_stride_z * number of elements along Z processed per workitem(in bytes)
1128 * @param[in,out] input_offset_first_element_in_bytes  The offset of the first element in the source tensor
1129 * @param[out]    output_ptr                           (Optional) Pointer to the destination image. Supported data types: same as @p input_ptr
1130 * @param[in]     output_stride_x                      (Optional) Stride of the destination image in X dimension (in bytes)
1131 * @param[in]     output_step_x                        (Optional) output_stride_x * number of elements along X processed per workitem(in bytes)
1132 * @param[in]     output_stride_y                      (Optional) Stride of the destination image in Y dimension (in bytes)
1133 * @param[in]     output_step_y                        (Optional) output_stride_y * number of elements along Y processed per workitem(in bytes)
1134 * @param[in]     output_stride_z                      (Optional) Stride of the source tensor in Z dimension (in bytes)
1135 * @param[in]     output_step_z                        (Optional) output_stride_z * number of elements along Z processed per workitem(in bytes)
1136 * @param[in]     output_offset_first_element_in_bytes (Optional) The offset of the first element in the destination image
1137 * @param[in]     Nx                                   The butterfly span. Products of radix order of previous radix's stage
1138 * @param[in]     Ni                                   Nx * Ny.
1139 * @param[in]     exp_const                            Exponent constant
1140 */
1141kernel void fft_radix_4_axis_0(
1142    TENSOR3D_DECLARATION(input)
1143#ifndef IN_PLACE
1144    ,
1145    TENSOR3D_DECLARATION(output)
1146#endif /* not IN_PLACE */
1147    ,
1148    uint Nx, uint Ni, float exp_const)
1149{
1150    // Each work-item computes a single radix-4
1151    uint kx = get_global_id(0);
1152
1153    // Compute nx
1154    uint nx = kx % Nx;
1155
1156    // Compute n index
1157    uint n = nx + (kx / Nx) * Ni;
1158
1159    // Get tensor pointers
1160    Tensor3D input = CONVERT_TO_TENSOR3D_STRUCT_NO_STEP(input);
1161    input.ptr += n * input.stride_x + get_global_id(1) * input.stride_y + get_global_id(2) * input.stride_z;
1162#ifdef IN_PLACE
1163    Tensor3D output = input;
1164#else  /* IN_PLACE */
1165    Tensor3D output = CONVERT_TO_TENSOR3D_STRUCT_NO_STEP(output);
1166    output.ptr += n * output.stride_x + get_global_id(1) * output.stride_y + get_global_id(2) * output.stride_z;
1167#endif /* IN_PLACE */
1168
1169    // Load four complex input values
1170    float2 c0 = vload2(0, (__global float *)input.ptr);
1171    float2 c1 = vload2(0, (__global float *)tensor3D_offset(&input, Nx, 0, 0));
1172    float2 c2 = vload2(0, (__global float *)tensor3D_offset(&input, 2 * Nx, 0, 0));
1173    float2 c3 = vload2(0, (__global float *)tensor3D_offset(&input, 3 * Nx, 0, 0));
1174
1175    // Compute phi
1176    float phi = (float)nx * exp_const;
1177
1178    // Multiply by twiddle factor
1179    TWIDDLE_FACTOR_MULTIPLICATION(phi, c1);
1180    TWIDDLE_FACTOR_MULTIPLICATION(2 * phi, c2);
1181    TWIDDLE_FACTOR_MULTIPLICATION(3 * phi, c3);
1182
1183    // Compute DFT N = 4
1184    DFT_4(c0, c1, c2, c3);
1185
1186    // Store four complex output values
1187    vstore2(c0, 0, (__global float *)output.ptr);
1188    vstore2(c1, 0, (__global float *)tensor3D_offset(&output, Nx, 0, 0));
1189    vstore2(c2, 0, (__global float *)tensor3D_offset(&output, 2 * Nx, 0, 0));
1190    vstore2(c3, 0, (__global float *)tensor3D_offset(&output, 3 * Nx, 0, 0));
1191}
1192
1193/** Computes a stage of a radix-4 FFT on axis 1.
1194 *
1195 * @note In order to perform the FFT function "in-place", the pre-processor -DIN_PLACE must be passed at compile time
1196 *
1197 * @param[in,out] input_ptr                            Pointer to the source tensor. Supported data types: F32
1198 * @param[in,out] input_stride_x                       Stride of the source tensor in X dimension (in bytes)
1199 * @param[in,out] input_step_x                         input_stride_x * number of elements along X processed per workitem(in bytes)
1200 * @param[in,out] input_stride_y                       Stride of the source tensor in Y dimension (in bytes)
1201 * @param[in,out] input_step_y                         input_stride_y * number of elements along Y processed per workitem(in bytes)
1202 * @param[in,out] input_stride_z                       Stride of the source tensor in Z dimension (in bytes)
1203 * @param[in,out] input_step_z                         input_stride_z * number of elements along Z processed per workitem(in bytes)
1204 * @param[in,out] input_offset_first_element_in_bytes  The offset of the first element in the source tensor
1205 * @param[out]    output_ptr                           (Optional) Pointer to the destination image. Supported data types: same as @p input_ptr
1206 * @param[in]     output_stride_x                      (Optional) Stride of the destination image in X dimension (in bytes)
1207 * @param[in]     output_step_x                        (Optional) output_stride_x * number of elements along X processed per workitem(in bytes)
1208 * @param[in]     output_stride_y                      (Optional) Stride of the destination image in Y dimension (in bytes)
1209 * @param[in]     output_step_y                        (Optional) output_stride_y * number of elements along Y processed per workitem(in bytes)
1210 * @param[in]     output_stride_z                      (Optional) Stride of the source tensor in Z dimension (in bytes)
1211 * @param[in]     output_step_z                        (Optional) output_stride_z * number of elements along Z processed per workitem(in bytes)
1212 * @param[in]     output_offset_first_element_in_bytes (Optional) The offset of the first element in the destination image
1213 * @param[in]     Nx                                   The butterfly span. Products of radix order of previous radix's stage
1214 * @param[in]     Ni                                   Nx * Ny.
1215 * @param[in]     exp_const                            Exponent constant
1216 */
1217kernel void fft_radix_4_axis_1(
1218    TENSOR3D_DECLARATION(input)
1219#ifndef IN_PLACE
1220    ,
1221    TENSOR3D_DECLARATION(output)
1222#endif /* not IN_PLACE */
1223    ,
1224    uint Nx, uint Ni, float exp_const)
1225{
1226    // Each work-item computes a single radix-4
1227    uint kx = get_global_id(1);
1228
1229    // Compute nx
1230    uint nx = kx % Nx;
1231
1232    // Compute n index
1233    uint n = nx + (kx / Nx) * Ni;
1234
1235    // Get tensor pointers
1236    Tensor3D input = CONVERT_TO_TENSOR3D_STRUCT_NO_STEP(input);
1237    input.ptr += get_global_id(0) * input.stride_x + n * input.stride_y + get_global_id(2) * input.stride_z;
1238#ifdef IN_PLACE
1239    Tensor3D output = input;
1240#else  /* IN_PLACE */
1241    Tensor3D output = CONVERT_TO_TENSOR3D_STRUCT_NO_STEP(output);
1242    output.ptr += get_global_id(0) * output.stride_x + n * output.stride_y + get_global_id(2) * output.stride_z;
1243#endif /* IN_PLACE */
1244
1245    // Load four complex input values
1246    float2 c0 = vload2(0, (__global float *)input.ptr);
1247    float2 c1 = vload2(0, (__global float *)tensor3D_offset(&input, 0, Nx, 0));
1248    float2 c2 = vload2(0, (__global float *)tensor3D_offset(&input, 0, 2 * Nx, 0));
1249    float2 c3 = vload2(0, (__global float *)tensor3D_offset(&input, 0, 3 * Nx, 0));
1250
1251    // Compute phi
1252    float phi = (float)nx * exp_const;
1253
1254    // Multiply by twiddle factor
1255    TWIDDLE_FACTOR_MULTIPLICATION(phi, c1);
1256    TWIDDLE_FACTOR_MULTIPLICATION(2 * phi, c2);
1257    TWIDDLE_FACTOR_MULTIPLICATION(3 * phi, c3);
1258
1259    // Compute DFT N = 4
1260    DFT_4(c0, c1, c2, c3);
1261
1262    // Store four complex output values
1263    vstore2(c0, 0, (__global float *)output.ptr);
1264    vstore2(c1, 0, (__global float *)tensor3D_offset(&output, 0, Nx, 0));
1265    vstore2(c2, 0, (__global float *)tensor3D_offset(&output, 0, 2 * Nx, 0));
1266    vstore2(c3, 0, (__global float *)tensor3D_offset(&output, 0, 3 * Nx, 0));
1267}
1268
1269/** Computes a stage of a radix-5 FFT on axis 0.
1270 *
1271 * @note In order to perform the FFT function "in-place", the pre-processor -DIN_PLACE must be passed at compile time
1272 *
1273 * @param[in,out] input_ptr                            Pointer to the source tensor. Supported data types: F32
1274 * @param[in,out] input_stride_x                       Stride of the source tensor in X dimension (in bytes)
1275 * @param[in,out] input_step_x                         input_stride_x * number of elements along X processed per workitem(in bytes)
1276 * @param[in,out] input_stride_y                       Stride of the source tensor in Y dimension (in bytes)
1277 * @param[in,out] input_step_y                         input_stride_y * number of elements along Y processed per workitem(in bytes)
1278 * @param[in,out] input_stride_z                       Stride of the source tensor in Z dimension (in bytes)
1279 * @param[in,out] input_step_z                         input_stride_z * number of elements along Z processed per workitem(in bytes)
1280 * @param[in,out] input_offset_first_element_in_bytes  The offset of the first element in the source tensor
1281 * @param[out]    output_ptr                           (Optional) Pointer to the destination image. Supported data types: same as @p input_ptr
1282 * @param[in]     output_stride_x                      (Optional) Stride of the destination image in X dimension (in bytes)
1283 * @param[in]     output_step_x                        (Optional) output_stride_x * number of elements along X processed per workitem(in bytes)
1284 * @param[in]     output_stride_y                      (Optional) Stride of the destination image in Y dimension (in bytes)
1285 * @param[in]     output_step_y                        (Optional) output_stride_y * number of elements along Y processed per workitem(in bytes)
1286 * @param[in]     output_stride_z                      (Optional) Stride of the source tensor in Z dimension (in bytes)
1287 * @param[in]     output_step_z                        (Optional) output_stride_z * number of elements along Z processed per workitem(in bytes)
1288 * @param[in]     output_offset_first_element_in_bytes (Optional) The offset of the first element in the destination image
1289 * @param[in]     Nx                                   The butterfly span. Products of radix order of previous radix's stage
1290 * @param[in]     Ni                                   Nx * Ny.
1291 * @param[in]     exp_const                            Exponent constant
1292 */
1293kernel void fft_radix_5_axis_0(
1294    TENSOR3D_DECLARATION(input)
1295#ifndef IN_PLACE
1296    ,
1297    TENSOR3D_DECLARATION(output)
1298#endif /* not IN_PLACE */
1299    ,
1300    uint Nx, uint Ni, float exp_const)
1301{
1302    // Each work-item computes a single radix-5
1303    uint kx = get_global_id(0);
1304
1305    // Compute nx
1306    uint nx = kx % Nx;
1307
1308    // Compute n index
1309    uint n = nx + (kx / Nx) * Ni;
1310
1311    // Get tensor pointers
1312    Tensor3D input = CONVERT_TO_TENSOR3D_STRUCT_NO_STEP(input);
1313    input.ptr += n * input.stride_x + get_global_id(1) * input.stride_y + get_global_id(2) * input.stride_z;
1314#ifdef IN_PLACE
1315    Tensor3D output = input;
1316#else  /* IN_PLACE */
1317    Tensor3D output = CONVERT_TO_TENSOR3D_STRUCT_NO_STEP(output);
1318    output.ptr += n * output.stride_x + get_global_id(1) * output.stride_y + get_global_id(2) * output.stride_z;
1319#endif /* IN_PLACE */
1320
1321    // Load five complex input values
1322    float2 c0 = vload2(0, (__global float *)input.ptr);
1323    float2 c1 = vload2(0, (__global float *)tensor3D_offset(&input, Nx, 0, 0));
1324    float2 c2 = vload2(0, (__global float *)tensor3D_offset(&input, 2 * Nx, 0, 0));
1325    float2 c3 = vload2(0, (__global float *)tensor3D_offset(&input, 3 * Nx, 0, 0));
1326    float2 c4 = vload2(0, (__global float *)tensor3D_offset(&input, 4 * Nx, 0, 0));
1327
1328    // Compute phi
1329    float phi = (float)nx * exp_const;
1330
1331    // Multiply by twiddle factor
1332    TWIDDLE_FACTOR_MULTIPLICATION(phi, c1);
1333    TWIDDLE_FACTOR_MULTIPLICATION(2 * phi, c2);
1334    TWIDDLE_FACTOR_MULTIPLICATION(3 * phi, c3);
1335    TWIDDLE_FACTOR_MULTIPLICATION(4 * phi, c4);
1336
1337    // Compute DFT N = 5
1338    DFT_5(c0, c1, c2, c3, c4);
1339
1340    // Store five complex output values
1341    vstore2(c0, 0, (__global float *)output.ptr);
1342    vstore2(c1, 0, (__global float *)tensor3D_offset(&output, Nx, 0, 0));
1343    vstore2(c2, 0, (__global float *)tensor3D_offset(&output, 2 * Nx, 0, 0));
1344    vstore2(c3, 0, (__global float *)tensor3D_offset(&output, 3 * Nx, 0, 0));
1345    vstore2(c4, 0, (__global float *)tensor3D_offset(&output, 4 * Nx, 0, 0));
1346}
1347
1348/** Computes a stage of a radix-5 FFT on axis 1.
1349 *
1350 * @note In order to perform the FFT function "in-place", the pre-processor -DIN_PLACE must be passed at compile time
1351 *
1352 * @param[in,out] input_ptr                            Pointer to the source tensor. Supported data types: F32
1353 * @param[in,out] input_stride_x                       Stride of the source tensor in X dimension (in bytes)
1354 * @param[in,out] input_step_x                         input_stride_x * number of elements along X processed per workitem(in bytes)
1355 * @param[in,out] input_stride_y                       Stride of the source tensor in Y dimension (in bytes)
1356 * @param[in,out] input_step_y                         input_stride_y * number of elements along Y processed per workitem(in bytes)
1357 * @param[in,out] input_stride_z                       Stride of the source tensor in Z dimension (in bytes)
1358 * @param[in,out] input_step_z                         input_stride_z * number of elements along Z processed per workitem(in bytes)
1359 * @param[in,out] input_offset_first_element_in_bytes  The offset of the first element in the source tensor
1360 * @param[out]    output_ptr                           (Optional) Pointer to the destination image. Supported data types: same as @p input_ptr
1361 * @param[in]     output_stride_x                      (Optional) Stride of the destination image in X dimension (in bytes)
1362 * @param[in]     output_step_x                        (Optional) output_stride_x * number of elements along X processed per workitem(in bytes)
1363 * @param[in]     output_stride_y                      (Optional) Stride of the destination image in Y dimension (in bytes)
1364 * @param[in]     output_step_y                        (Optional) output_stride_y * number of elements along Y processed per workitem(in bytes)
1365 * @param[in]     output_stride_z                      (Optional) Stride of the source tensor in Z dimension (in bytes)
1366 * @param[in]     output_step_z                        (Optional) output_stride_z * number of elements along Z processed per workitem(in bytes)
1367 * @param[in]     output_offset_first_element_in_bytes (Optional) The offset of the first element in the destination image
1368 * @param[in]     Nx                                   The butterfly span. Products of radix order of previous radix's stage
1369 * @param[in]     Ni                                   Nx * Ny.
1370 * @param[in]     exp_const                            Exponent constant
1371 */
1372kernel void fft_radix_5_axis_1(
1373    TENSOR3D_DECLARATION(input)
1374#ifndef IN_PLACE
1375    ,
1376    TENSOR3D_DECLARATION(output)
1377#endif /* not IN_PLACE */
1378    ,
1379    uint Nx, uint Ni, float exp_const)
1380{
1381    // Each work-item computes a single radix-5
1382    uint kx = get_global_id(1);
1383
1384    // Compute nx
1385    uint nx = kx % Nx;
1386
1387    // Compute n index
1388    uint n = nx + (kx / Nx) * Ni;
1389
1390    // Get tensor pointers
1391    Tensor3D input = CONVERT_TO_TENSOR3D_STRUCT_NO_STEP(input);
1392    input.ptr += get_global_id(0) * input.stride_x + n * input.stride_y + get_global_id(2) * input.stride_z;
1393#ifdef IN_PLACE
1394    Tensor3D output = input;
1395#else  /* IN_PLACE */
1396    Tensor3D output = CONVERT_TO_TENSOR3D_STRUCT_NO_STEP(output);
1397    output.ptr += get_global_id(0) * output.stride_x + n * output.stride_y + get_global_id(2) * output.stride_z;
1398#endif /* IN_PLACE */
1399
1400    // Load five complex input values
1401    float2 c0 = vload2(0, (__global float *)input.ptr);
1402    float2 c1 = vload2(0, (__global float *)tensor3D_offset(&input, 0, Nx, 0));
1403    float2 c2 = vload2(0, (__global float *)tensor3D_offset(&input, 0, 2 * Nx, 0));
1404    float2 c3 = vload2(0, (__global float *)tensor3D_offset(&input, 0, 3 * Nx, 0));
1405    float2 c4 = vload2(0, (__global float *)tensor3D_offset(&input, 0, 4 * Nx, 0));
1406
1407    // Compute phi
1408    float phi = (float)nx * exp_const;
1409
1410    // Multiply by twiddle factor
1411    TWIDDLE_FACTOR_MULTIPLICATION(phi, c1);
1412    TWIDDLE_FACTOR_MULTIPLICATION(2 * phi, c2);
1413    TWIDDLE_FACTOR_MULTIPLICATION(3 * phi, c3);
1414    TWIDDLE_FACTOR_MULTIPLICATION(4 * phi, c4);
1415
1416    // Compute DFT N = 5
1417    DFT_5(c0, c1, c2, c3, c4);
1418
1419    // Store five complex output values
1420    vstore2(c0, 0, (__global float *)output.ptr);
1421    vstore2(c1, 0, (__global float *)tensor3D_offset(&output, 0, Nx, 0));
1422    vstore2(c2, 0, (__global float *)tensor3D_offset(&output, 0, 2 * Nx, 0));
1423    vstore2(c3, 0, (__global float *)tensor3D_offset(&output, 0, 3 * Nx, 0));
1424    vstore2(c4, 0, (__global float *)tensor3D_offset(&output, 0, 4 * Nx, 0));
1425}
1426
1427/** Computes a stage of a radix-7 FFT on axis 0.
1428 *
1429 * @note In order to perform the FFT function "in-place", the pre-processor -DIN_PLACE must be passed at compile time
1430 *
1431 * @param[in,out] input_ptr                            Pointer to the source tensor. Supported data types: F32
1432 * @param[in,out] input_stride_x                       Stride of the source tensor in X dimension (in bytes)
1433 * @param[in,out] input_step_x                         input_stride_x * number of elements along X processed per workitem(in bytes)
1434 * @param[in,out] input_stride_y                       Stride of the source tensor in Y dimension (in bytes)
1435 * @param[in,out] input_step_y                         input_stride_y * number of elements along Y processed per workitem(in bytes)
1436 * @param[in,out] input_stride_z                       Stride of the source tensor in Z dimension (in bytes)
1437 * @param[in,out] input_step_z                         input_stride_z * number of elements along Z processed per workitem(in bytes)
1438 * @param[in,out] input_offset_first_element_in_bytes  The offset of the first element in the source tensor
1439 * @param[out]    output_ptr                           (Optional) Pointer to the destination image. Supported data types: same as @p input_ptr
1440 * @param[in]     output_stride_x                      (Optional) Stride of the destination image in X dimension (in bytes)
1441 * @param[in]     output_step_x                        (Optional) output_stride_x * number of elements along X processed per workitem(in bytes)
1442 * @param[in]     output_stride_y                      (Optional) Stride of the destination image in Y dimension (in bytes)
1443 * @param[in]     output_step_y                        (Optional) output_stride_y * number of elements along Y processed per workitem(in bytes)
1444 * @param[in]     output_stride_z                      (Optional) Stride of the source tensor in Z dimension (in bytes)
1445 * @param[in]     output_step_z                        (Optional) output_stride_z * number of elements along Z processed per workitem(in bytes)
1446 * @param[in]     output_offset_first_element_in_bytes (Optional) The offset of the first element in the destination image
1447 * @param[in]     Nx                                   The butterfly span. Products of radix order of previous radix's stage
1448 * @param[in]     Ni                                   Nx * Ny.
1449 * @param[in]     exp_const                            Exponent constant
1450 */
1451kernel void fft_radix_7_axis_0(
1452    TENSOR3D_DECLARATION(input)
1453#ifndef IN_PLACE
1454    ,
1455    TENSOR3D_DECLARATION(output)
1456#endif /* not IN_PLACE */
1457    ,
1458    uint Nx, uint Ni, float exp_const)
1459{
1460    // Each work-item computes a single radix-7
1461    uint kx = get_global_id(0);
1462
1463    // Compute nx
1464    uint nx = kx % Nx;
1465
1466    // Compute n index
1467    uint n = nx + (kx / Nx) * Ni;
1468
1469    // Get tensor pointers
1470    Tensor3D input = CONVERT_TO_TENSOR3D_STRUCT_NO_STEP(input);
1471    input.ptr += n * input.stride_x + get_global_id(1) * input.stride_y + get_global_id(2) * input.stride_z;
1472#ifdef IN_PLACE
1473    Tensor3D output = input;
1474#else  /* IN_PLACE */
1475    Tensor3D output = CONVERT_TO_TENSOR3D_STRUCT_NO_STEP(output);
1476    output.ptr += n * output.stride_x + get_global_id(1) * output.stride_y + get_global_id(2) * output.stride_z;
1477#endif /* IN_PLACE */
1478
1479    // Load seven complex input values
1480    float2 c0 = vload2(0, (__global float *)input.ptr);
1481    float2 c1 = vload2(0, (__global float *)tensor3D_offset(&input, Nx, 0, 0));
1482    float2 c2 = vload2(0, (__global float *)tensor3D_offset(&input, 2 * Nx, 0, 0));
1483    float2 c3 = vload2(0, (__global float *)tensor3D_offset(&input, 3 * Nx, 0, 0));
1484    float2 c4 = vload2(0, (__global float *)tensor3D_offset(&input, 4 * Nx, 0, 0));
1485    float2 c5 = vload2(0, (__global float *)tensor3D_offset(&input, 5 * Nx, 0, 0));
1486    float2 c6 = vload2(0, (__global float *)tensor3D_offset(&input, 6 * Nx, 0, 0));
1487
1488    // Compute phi
1489    float phi = (float)nx * exp_const;
1490
1491    // Multiply by twiddle factor
1492    TWIDDLE_FACTOR_MULTIPLICATION(phi, c1);
1493    TWIDDLE_FACTOR_MULTIPLICATION(2 * phi, c2);
1494    TWIDDLE_FACTOR_MULTIPLICATION(3 * phi, c3);
1495    TWIDDLE_FACTOR_MULTIPLICATION(4 * phi, c4);
1496    TWIDDLE_FACTOR_MULTIPLICATION(5 * phi, c5);
1497    TWIDDLE_FACTOR_MULTIPLICATION(6 * phi, c6);
1498
1499    // Compute DFT N = 7
1500    DFT_7(c0, c1, c2, c3, c4, c5, c6);
1501
1502    // Store seven complex output values
1503    vstore2(c0, 0, (__global float *)output.ptr);
1504    vstore2(c1, 0, (__global float *)tensor3D_offset(&output, Nx, 0, 0));
1505    vstore2(c2, 0, (__global float *)tensor3D_offset(&output, 2 * Nx, 0, 0));
1506    vstore2(c3, 0, (__global float *)tensor3D_offset(&output, 3 * Nx, 0, 0));
1507    vstore2(c4, 0, (__global float *)tensor3D_offset(&output, 4 * Nx, 0, 0));
1508    vstore2(c5, 0, (__global float *)tensor3D_offset(&output, 5 * Nx, 0, 0));
1509    vstore2(c6, 0, (__global float *)tensor3D_offset(&output, 6 * Nx, 0, 0));
1510}
1511
1512/** Computes a stage of a radix-7 FFT on axis 1.
1513 *
1514 * @note In order to perform the FFT function "in-place", the pre-processor -DIN_PLACE must be passed at compile time
1515 *
1516 * @param[in,out] input_ptr                            Pointer to the source tensor. Supported data types: F32
1517 * @param[in,out] input_stride_x                       Stride of the source tensor in X dimension (in bytes)
1518 * @param[in,out] input_step_x                         input_stride_x * number of elements along X processed per workitem(in bytes)
1519 * @param[in,out] input_stride_y                       Stride of the source tensor in Y dimension (in bytes)
1520 * @param[in,out] input_step_y                         input_stride_y * number of elements along Y processed per workitem(in bytes)
1521 * @param[in,out] input_stride_z                       Stride of the source tensor in Z dimension (in bytes)
1522 * @param[in,out] input_step_z                         input_stride_z * number of elements along Z processed per workitem(in bytes)
1523 * @param[in,out] input_offset_first_element_in_bytes  The offset of the first element in the source tensor
1524 * @param[out]    output_ptr                           (Optional) Pointer to the destination image. Supported data types: same as @p input_ptr
1525 * @param[in]     output_stride_x                      (Optional) Stride of the destination image in X dimension (in bytes)
1526 * @param[in]     output_step_x                        (Optional) output_stride_x * number of elements along X processed per workitem(in bytes)
1527 * @param[in]     output_stride_y                      (Optional) Stride of the destination image in Y dimension (in bytes)
1528 * @param[in]     output_step_y                        (Optional) output_stride_y * number of elements along Y processed per workitem(in bytes)
1529 * @param[in]     output_stride_z                      (Optional) Stride of the source tensor in Z dimension (in bytes)
1530 * @param[in]     output_step_z                        (Optional) output_stride_z * number of elements along Z processed per workitem(in bytes)
1531 * @param[in]     output_offset_first_element_in_bytes (Optional) The offset of the first element in the destination image
1532 * @param[in]     Nx                                   The butterfly span. Products of radix order of previous radix's stage
1533 * @param[in]     Ni                                   Nx * Ny.
1534 * @param[in]     exp_const                            Exponent constant
1535 */
1536kernel void fft_radix_7_axis_1(
1537    TENSOR3D_DECLARATION(input)
1538#ifndef IN_PLACE
1539    ,
1540    TENSOR3D_DECLARATION(output)
1541#endif /* not IN_PLACE */
1542    ,
1543    uint Nx, uint Ni, float exp_const)
1544{
1545    // Each work-item computes a single radix-7
1546    uint kx = get_global_id(1);
1547
1548    // Compute nx
1549    uint nx = kx % Nx;
1550
1551    // Compute n index
1552    uint n = nx + (kx / Nx) * Ni;
1553
1554    // Get tensor pointers
1555    Tensor3D input = CONVERT_TO_TENSOR3D_STRUCT_NO_STEP(input);
1556    input.ptr += get_global_id(0) * input.stride_x + n * input.stride_y + get_global_id(2) * input.stride_z;
1557#ifdef IN_PLACE
1558    Tensor3D output = input;
1559#else  /* IN_PLACE */
1560    Tensor3D output = CONVERT_TO_TENSOR3D_STRUCT_NO_STEP(output);
1561    output.ptr += get_global_id(0) * output.stride_x + n * output.stride_y + get_global_id(2) * output.stride_z;
1562#endif /* IN_PLACE */
1563
1564    // Load seven complex input values
1565    float2 c0 = vload2(0, (__global float *)input.ptr);
1566    float2 c1 = vload2(0, (__global float *)tensor3D_offset(&input, 0, Nx, 0));
1567    float2 c2 = vload2(0, (__global float *)tensor3D_offset(&input, 0, 2 * Nx, 0));
1568    float2 c3 = vload2(0, (__global float *)tensor3D_offset(&input, 0, 3 * Nx, 0));
1569    float2 c4 = vload2(0, (__global float *)tensor3D_offset(&input, 0, 4 * Nx, 0));
1570    float2 c5 = vload2(0, (__global float *)tensor3D_offset(&input, 0, 5 * Nx, 0));
1571    float2 c6 = vload2(0, (__global float *)tensor3D_offset(&input, 0, 6 * Nx, 0));
1572
1573    // Compute phi
1574    float phi = (float)nx * exp_const;
1575
1576    // Multiply by twiddle factor
1577    TWIDDLE_FACTOR_MULTIPLICATION(phi, c1);
1578    TWIDDLE_FACTOR_MULTIPLICATION(2 * phi, c2);
1579    TWIDDLE_FACTOR_MULTIPLICATION(3 * phi, c3);
1580    TWIDDLE_FACTOR_MULTIPLICATION(4 * phi, c4);
1581    TWIDDLE_FACTOR_MULTIPLICATION(5 * phi, c5);
1582    TWIDDLE_FACTOR_MULTIPLICATION(6 * phi, c6);
1583
1584    // Compute DFT N = 7
1585    DFT_7(c0, c1, c2, c3, c4, c5, c6);
1586
1587    // Store seven complex output values
1588    vstore2(c0, 0, (__global float *)output.ptr);
1589    vstore2(c1, 0, (__global float *)tensor3D_offset(&output, 0, Nx, 0));
1590    vstore2(c2, 0, (__global float *)tensor3D_offset(&output, 0, 2 * Nx, 0));
1591    vstore2(c3, 0, (__global float *)tensor3D_offset(&output, 0, 3 * Nx, 0));
1592    vstore2(c4, 0, (__global float *)tensor3D_offset(&output, 0, 4 * Nx, 0));
1593    vstore2(c5, 0, (__global float *)tensor3D_offset(&output, 0, 5 * Nx, 0));
1594    vstore2(c6, 0, (__global float *)tensor3D_offset(&output, 0, 6 * Nx, 0));
1595}
1596
1597/** Computes a stage of a radix-8 FFT on axis 0.
1598 *
1599 * @note In order to perform the FFT function "in-place", the pre-processor -DIN_PLACE must be passed at compile time
1600 *
1601 * @param[in,out] input_ptr                            Pointer to the source tensor. Supported data types: F32
1602 * @param[in,out] input_stride_x                       Stride of the source tensor in X dimension (in bytes)
1603 * @param[in,out] input_step_x                         input_stride_x * number of elements along X processed per workitem(in bytes)
1604 * @param[in,out] input_stride_y                       Stride of the source tensor in Y dimension (in bytes)
1605 * @param[in,out] input_step_y                         input_stride_y * number of elements along Y processed per workitem(in bytes)
1606 * @param[in,out] input_stride_z                       Stride of the source tensor in Z dimension (in bytes)
1607 * @param[in,out] input_step_z                         input_stride_z * number of elements along Z processed per workitem(in bytes)
1608 * @param[in,out] input_offset_first_element_in_bytes  The offset of the first element in the source tensor
1609 * @param[out]    output_ptr                           (Optional) Pointer to the destination image. Supported data types: same as @p input_ptr
1610 * @param[in]     output_stride_x                      (Optional) Stride of the destination image in X dimension (in bytes)
1611 * @param[in]     output_step_x                        (Optional) output_stride_x * number of elements along X processed per workitem(in bytes)
1612 * @param[in]     output_stride_y                      (Optional) Stride of the destination image in Y dimension (in bytes)
1613 * @param[in]     output_step_y                        (Optional) output_stride_y * number of elements along Y processed per workitem(in bytes)
1614 * @param[in]     output_stride_z                      (Optional) Stride of the source tensor in Z dimension (in bytes)
1615 * @param[in]     output_step_z                        (Optional) output_stride_z * number of elements along Z processed per workitem(in bytes)
1616 * @param[in]     output_offset_first_element_in_bytes (Optional) The offset of the first element in the destination image
1617 * @param[in]     Nx                                   The butterfly span. Products of radix order of previous radix's stage
1618 * @param[in]     Ni                                   Nx * Ny.
1619 * @param[in]     exp_const                            Exponent constant
1620 */
1621kernel void fft_radix_8_axis_0(
1622    TENSOR3D_DECLARATION(input)
1623#ifndef IN_PLACE
1624    ,
1625    TENSOR3D_DECLARATION(output)
1626#endif /* not IN_PLACE */
1627    ,
1628    uint Nx, uint Ni, float exp_const)
1629{
1630    // Each work-item computes a single radix-8
1631    uint kx = get_global_id(0);
1632
1633    // Compute nx
1634    uint nx = kx % Nx;
1635
1636    // Compute n index
1637    uint n = nx + (kx / Nx) * Ni;
1638
1639    // Get tensor pointers
1640    Tensor3D input = CONVERT_TO_TENSOR3D_STRUCT_NO_STEP(input);
1641    input.ptr += n * input.stride_x + get_global_id(1) * input.stride_y + get_global_id(2) * input.stride_z;
1642#ifdef IN_PLACE
1643    Tensor3D output = input;
1644#else  /* IN_PLACE */
1645    Tensor3D output = CONVERT_TO_TENSOR3D_STRUCT_NO_STEP(output);
1646    output.ptr += n * output.stride_x + get_global_id(1) * output.stride_y + get_global_id(2) * output.stride_z;
1647#endif /* IN_PLACE */
1648
1649    // Load eight complex input values
1650    float2 c0 = vload2(0, (__global float *)input.ptr);
1651    float2 c1 = vload2(0, (__global float *)tensor3D_offset(&input, Nx, 0, 0));
1652    float2 c2 = vload2(0, (__global float *)tensor3D_offset(&input, 2 * Nx, 0, 0));
1653    float2 c3 = vload2(0, (__global float *)tensor3D_offset(&input, 3 * Nx, 0, 0));
1654    float2 c4 = vload2(0, (__global float *)tensor3D_offset(&input, 4 * Nx, 0, 0));
1655    float2 c5 = vload2(0, (__global float *)tensor3D_offset(&input, 5 * Nx, 0, 0));
1656    float2 c6 = vload2(0, (__global float *)tensor3D_offset(&input, 6 * Nx, 0, 0));
1657    float2 c7 = vload2(0, (__global float *)tensor3D_offset(&input, 7 * Nx, 0, 0));
1658
1659    // Compute phi
1660    float phi = (float)nx * exp_const;
1661
1662    // Multiply by twiddle factor
1663    TWIDDLE_FACTOR_MULTIPLICATION(phi, c1);
1664    TWIDDLE_FACTOR_MULTIPLICATION(2 * phi, c2);
1665    TWIDDLE_FACTOR_MULTIPLICATION(3 * phi, c3);
1666    TWIDDLE_FACTOR_MULTIPLICATION(4 * phi, c4);
1667    TWIDDLE_FACTOR_MULTIPLICATION(5 * phi, c5);
1668    TWIDDLE_FACTOR_MULTIPLICATION(6 * phi, c6);
1669    TWIDDLE_FACTOR_MULTIPLICATION(7 * phi, c7);
1670
1671    // Compute DFT N = 8
1672    DFT_8(c0, c1, c2, c3, c4, c5, c6, c7);
1673
1674    // Store eight complex output values
1675    vstore2(c0, 0, (__global float *)output.ptr);
1676    vstore2(c1, 0, (__global float *)tensor3D_offset(&output, Nx, 0, 0));
1677    vstore2(c2, 0, (__global float *)tensor3D_offset(&output, 2 * Nx, 0, 0));
1678    vstore2(c3, 0, (__global float *)tensor3D_offset(&output, 3 * Nx, 0, 0));
1679    vstore2(c4, 0, (__global float *)tensor3D_offset(&output, 4 * Nx, 0, 0));
1680    vstore2(c5, 0, (__global float *)tensor3D_offset(&output, 5 * Nx, 0, 0));
1681    vstore2(c6, 0, (__global float *)tensor3D_offset(&output, 6 * Nx, 0, 0));
1682    vstore2(c7, 0, (__global float *)tensor3D_offset(&output, 7 * Nx, 0, 0));
1683}
1684
1685/** Computes a stage of a radix-8 FFT on axis 1.
1686 *
1687 * @note In order to perform the FFT function "in-place", the pre-processor -DIN_PLACE must be passed at compile time
1688 *
1689 * @param[in,out] input_ptr                            Pointer to the source tensor. Supported data types: F32
1690 * @param[in,out] input_stride_x                       Stride of the source tensor in X dimension (in bytes)
1691 * @param[in,out] input_step_x                         input_stride_x * number of elements along X processed per workitem(in bytes)
1692 * @param[in,out] input_stride_y                       Stride of the source tensor in Y dimension (in bytes)
1693 * @param[in,out] input_step_y                         input_stride_y * number of elements along Y processed per workitem(in bytes)
1694 * @param[in,out] input_stride_z                       Stride of the source tensor in Z dimension (in bytes)
1695 * @param[in,out] input_step_z                         input_stride_z * number of elements along Z processed per workitem(in bytes)
1696 * @param[in,out] input_offset_first_element_in_bytes  The offset of the first element in the source tensor
1697 * @param[out]    output_ptr                           (Optional) Pointer to the destination image. Supported data types: same as @p input_ptr
1698 * @param[in]     output_stride_x                      (Optional) Stride of the destination image in X dimension (in bytes)
1699 * @param[in]     output_step_x                        (Optional) output_stride_x * number of elements along X processed per workitem(in bytes)
1700 * @param[in]     output_stride_y                      (Optional) Stride of the destination image in Y dimension (in bytes)
1701 * @param[in]     output_step_y                        (Optional) output_stride_y * number of elements along Y processed per workitem(in bytes)
1702 * @param[in]     output_stride_z                      (Optional) Stride of the source tensor in Z dimension (in bytes)
1703 * @param[in]     output_step_z                        (Optional) output_stride_z * number of elements along Z processed per workitem(in bytes)
1704 * @param[in]     output_offset_first_element_in_bytes (Optional) The offset of the first element in the destination image
1705 * @param[in]     Nx                                   The butterfly span. Products of radix order of previous radix's stage
1706 * @param[in]     Ni                                   Nx * Ny.
1707 * @param[in]     exp_const                            Exponent constant
1708 */
1709kernel void fft_radix_8_axis_1(
1710    TENSOR3D_DECLARATION(input)
1711#ifndef IN_PLACE
1712    ,
1713    TENSOR3D_DECLARATION(output)
1714#endif /* not IN_PLACE */
1715    ,
1716    uint Nx, uint Ni, float exp_const)
1717{
1718    // Each work-item computes a single radix-8
1719    uint kx = get_global_id(1);
1720
1721    // Compute nx
1722    uint nx = kx % Nx;
1723
1724    // Compute n index
1725    uint n = nx + (kx / Nx) * Ni;
1726
1727    // Get tensor pointers
1728    Tensor3D input = CONVERT_TO_TENSOR3D_STRUCT_NO_STEP(input);
1729    input.ptr += get_global_id(0) * input.stride_x + n * input.stride_y + get_global_id(2) * input.stride_z;
1730#ifdef IN_PLACE
1731    Tensor3D output = input;
1732#else  /* IN_PLACE */
1733    Tensor3D output = CONVERT_TO_TENSOR3D_STRUCT_NO_STEP(output);
1734    output.ptr += get_global_id(0) * output.stride_x + n * output.stride_y + get_global_id(2) * output.stride_z;
1735#endif /* IN_PLACE */
1736
1737    // Load eight complex input values
1738    float2 c0 = vload2(0, (__global float *)input.ptr);
1739    float2 c1 = vload2(0, (__global float *)tensor3D_offset(&input, 0, Nx, 0));
1740    float2 c2 = vload2(0, (__global float *)tensor3D_offset(&input, 0, 2 * Nx, 0));
1741    float2 c3 = vload2(0, (__global float *)tensor3D_offset(&input, 0, 3 * Nx, 0));
1742    float2 c4 = vload2(0, (__global float *)tensor3D_offset(&input, 0, 4 * Nx, 0));
1743    float2 c5 = vload2(0, (__global float *)tensor3D_offset(&input, 0, 5 * Nx, 0));
1744    float2 c6 = vload2(0, (__global float *)tensor3D_offset(&input, 0, 6 * Nx, 0));
1745    float2 c7 = vload2(0, (__global float *)tensor3D_offset(&input, 0, 7 * Nx, 0));
1746
1747    // Compute phi
1748    float phi = (float)nx * exp_const;
1749
1750    // Multiply by twiddle factor
1751    TWIDDLE_FACTOR_MULTIPLICATION(phi, c1);
1752    TWIDDLE_FACTOR_MULTIPLICATION(2 * phi, c2);
1753    TWIDDLE_FACTOR_MULTIPLICATION(3 * phi, c3);
1754    TWIDDLE_FACTOR_MULTIPLICATION(4 * phi, c4);
1755    TWIDDLE_FACTOR_MULTIPLICATION(5 * phi, c5);
1756    TWIDDLE_FACTOR_MULTIPLICATION(6 * phi, c6);
1757    TWIDDLE_FACTOR_MULTIPLICATION(7 * phi, c7);
1758
1759    // Compute DFT N = 8
1760    DFT_8(c0, c1, c2, c3, c4, c5, c6, c7);
1761
1762    // Store eight complex output values
1763    vstore2(c0, 0, (__global float *)output.ptr);
1764    vstore2(c1, 0, (__global float *)tensor3D_offset(&output, 0, Nx, 0));
1765    vstore2(c2, 0, (__global float *)tensor3D_offset(&output, 0, 2 * Nx, 0));
1766    vstore2(c3, 0, (__global float *)tensor3D_offset(&output, 0, 3 * Nx, 0));
1767    vstore2(c4, 0, (__global float *)tensor3D_offset(&output, 0, 4 * Nx, 0));
1768    vstore2(c5, 0, (__global float *)tensor3D_offset(&output, 0, 5 * Nx, 0));
1769    vstore2(c6, 0, (__global float *)tensor3D_offset(&output, 0, 6 * Nx, 0));
1770    vstore2(c7, 0, (__global float *)tensor3D_offset(&output, 0, 7 * Nx, 0));
1771}