1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_FFT_IMPL_H_
16 #define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_FFT_IMPL_H_
17
18 #include <array>
19
20 #include "third_party/eigen3/Eigen/Core"
21 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
22 #include "tensorflow/compiler/xla/types.h"
23
24 namespace xla {
25
26 namespace internal {
27
28 enum class FftType : int32_t {
29 FFT = 0, // Forward FFT; complex in, complex out.
30 IFFT = 1, // Inverse FFT; complex in, complex out.
31 RFFT = 2, // Forward real FFT; real in, fft_length / 2 + 1 complex out
32 IRFFT = 3, // Inverse real FFT; fft_length / 2 + 1 complex in,
33 // fft_length real out
34 };
FftTypeArraySize()35 inline constexpr int FftTypeArraySize() { return 4; }
36
37 // Computes either a forward or reverse complex-to-complex FFT.
38 template <bool Forward, int FFTRank, typename EigenDevice, typename Complex>
EigenFftC2C(const EigenDevice & device,Complex * out,Complex * operand,int64_t input_batch,int64_t fft_length0,int64_t fft_length1,int64_t fft_length2)39 void EigenFftC2C(const EigenDevice& device, Complex* out, Complex* operand,
40 int64_t input_batch, int64_t fft_length0, int64_t fft_length1,
41 int64_t fft_length2) {
42 // Create the axes (which are always trailing).
43 const auto axes = Eigen::ArrayXi::LinSpaced(FFTRank, 1, FFTRank);
44 constexpr auto direction = Forward ? Eigen::FFT_FORWARD : Eigen::FFT_REVERSE;
45
46 const std::array<int64_t, 3> fft_shape = {
47 {fft_length0, fft_length1, fft_length2}};
48
49 Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> dims;
50 dims[0] = input_batch;
51 for (int i = 0; i < FFTRank; i++) {
52 dims[i + 1] = fft_shape[i];
53 }
54 const Eigen::TensorMap<Eigen::Tensor<Complex, FFTRank + 1, Eigen::RowMajor>,
55 Eigen::Aligned>
56 input(operand, dims);
57 Eigen::TensorMap<Eigen::Tensor<Complex, FFTRank + 1, Eigen::RowMajor>,
58 Eigen::Aligned>
59 output(out, dims);
60 output.device(device) = input.template fft<Eigen::BothParts, direction>(axes);
61 }
62
63 // Computes a forward real->complex FFT, slicing out redundant negative
64 // frequencies from the innermost dimension.
65 template <int FFTRank, typename EigenDevice, typename Real, typename Complex>
EigenFftR2C(const EigenDevice & device,Complex * out,Real * operand,int64_t input_batch,int64_t fft_length0,int64_t fft_length1,int64_t fft_length2)66 void EigenFftR2C(const EigenDevice& device, Complex* out, Real* operand,
67 int64_t input_batch, int64_t fft_length0, int64_t fft_length1,
68 int64_t fft_length2) {
69 const std::array<int64_t, 3> fft_shape = {
70 {fft_length0, fft_length1, fft_length2}};
71
72 Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> in_dims;
73 in_dims[0] = input_batch;
74 Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> out_dims;
75 out_dims[0] = input_batch;
76 for (int i = 0; i < FFTRank; i++) {
77 in_dims[i + 1] = fft_shape[i];
78 out_dims[i + 1] = i == FFTRank - 1 ? fft_shape[i] / 2 + 1 : fft_shape[i];
79 }
80 const Eigen::TensorMap<Eigen::Tensor<Real, FFTRank + 1, Eigen::RowMajor>,
81 Eigen::Aligned>
82 input(operand, in_dims);
83 Eigen::TensorMap<Eigen::Tensor<Complex, FFTRank + 1, Eigen::RowMajor>,
84 Eigen::Aligned>
85 output(out, out_dims);
86
87 // Create the axes (which are always trailing).
88 const auto axes = Eigen::ArrayXi::LinSpaced(FFTRank, 1, FFTRank);
89
90 // Compute the full FFT using a temporary tensor.
91 Eigen::Tensor<Complex, FFTRank + 1, Eigen::RowMajor> full_fft(in_dims);
92
93 const Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> zero_start_indices;
94 full_fft.device(device) =
95 input.template fft<Eigen::BothParts, Eigen::FFT_FORWARD>(axes);
96
97 // Slice away the negative frequency components.
98 output.device(device) = full_fft.slice(zero_start_indices, out_dims);
99 }
100
101 // Computes a reverse complex->real FFT, reconstructing redundant negative
102 // frequencies using reverse conjugate on innermost dimension after doing IFFT
103 // on outer dimensions.
104 template <int FFTRank, typename EigenDevice, typename Complex, typename Real>
EigenFftC2R(const EigenDevice & device,Real * out,Complex * operand,int64_t input_batch,int64_t fft_length0,int64_t fft_length1,int64_t fft_length2)105 void EigenFftC2R(const EigenDevice& device, Real* out, Complex* operand,
106 int64_t input_batch, int64_t fft_length0, int64_t fft_length1,
107 int64_t fft_length2) {
108 const std::array<int64_t, 3> fft_shape = {
109 {fft_length0, fft_length1, fft_length2}};
110
111 Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> in_dims;
112 in_dims[0] = input_batch;
113 Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> out_dims;
114 out_dims[0] = input_batch;
115 for (int i = 0; i < FFTRank; i++) {
116 in_dims[i + 1] = i == FFTRank - 1 ? fft_shape[i] / 2 + 1 : fft_shape[i];
117 out_dims[i + 1] = fft_shape[i];
118 }
119 const Eigen::TensorMap<Eigen::Tensor<Complex, FFTRank + 1, Eigen::RowMajor>,
120 Eigen::Aligned>
121 input(operand, in_dims);
122 Eigen::TensorMap<Eigen::Tensor<Real, FFTRank + 1, Eigen::RowMajor>,
123 Eigen::Aligned>
124 output(out, out_dims);
125
126 // Calculate the shape of the temporary tensor for the full FFT and the
127 // region we will slice from input given fft_shape. We slice input to
128 // fft_shape on its inner-most dimensions, except the last (which we
129 // slice to fft_shape[-1] / 2 + 1).
130 Eigen::Tensor<Complex, FFTRank + 1, Eigen::RowMajor> full_fft(out_dims);
131
132 // Calculate the starting point and range of the source of
133 // negative frequency part.
134 auto neg_sizes = in_dims;
135 neg_sizes[FFTRank] = fft_shape[FFTRank - 1] - in_dims[FFTRank];
136 Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> neg_target_indices;
137 neg_target_indices[FFTRank] = in_dims[FFTRank];
138
139 const Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> zero_start_indices;
140 Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> neg_start_indices;
141 neg_start_indices[FFTRank] = 1;
142
143 full_fft.slice(zero_start_indices, in_dims).device(device) = input;
144
145 // First, conduct IFFTs on outer dimensions. We save computation (and
146 // avoid touching uninitialized memory) by slicing full_fft to the
147 // subregion we wrote input to.
148 if (FFTRank > 1) {
149 const auto outer_axes =
150 Eigen::ArrayXi::LinSpaced(FFTRank - 1, 1, FFTRank - 1);
151 full_fft.slice(zero_start_indices, in_dims).device(device) =
152 full_fft.slice(zero_start_indices, in_dims)
153 .template fft<Eigen::BothParts, Eigen::FFT_REVERSE>(outer_axes);
154 }
155
156 // Reconstruct the full FFT by appending reversed and conjugated
157 // spectrum as the negative frequency part.
158 Eigen::array<bool, FFTRank + 1> reverse_last_axis;
159 for (auto i = 0; i <= FFTRank; i++) {
160 reverse_last_axis[i] = i == FFTRank;
161 }
162
163 if (neg_sizes[FFTRank] != 0) {
164 full_fft.slice(neg_target_indices, neg_sizes).device(device) =
165 full_fft.slice(neg_start_indices, neg_sizes)
166 .reverse(reverse_last_axis)
167 .conjugate();
168 }
169
170 auto inner_axis = Eigen::array<int, 1>{FFTRank};
171 output.device(device) =
172 full_fft.template fft<Eigen::RealPart, Eigen::FFT_REVERSE>(inner_axis);
173 }
174
175 template <int FFTRank, typename EigenDevice>
EigenFftWithRank(const EigenDevice & device,void * out,void * operand,FftType fft_type,bool double_precision,int64_t input_batch,int64_t fft_length0,int64_t fft_length1,int64_t fft_length2)176 void EigenFftWithRank(const EigenDevice& device, void* out, void* operand,
177 FftType fft_type, bool double_precision,
178 int64_t input_batch, int64_t fft_length0,
179 int64_t fft_length1, int64_t fft_length2) {
180 switch (fft_type) {
181 case FftType::FFT:
182 if (double_precision) {
183 EigenFftC2C<true, FFTRank, EigenDevice, complex128>(
184 device, static_cast<complex128*>(out),
185 static_cast<complex128*>(operand), input_batch, fft_length0,
186 fft_length1, fft_length2);
187 } else {
188 EigenFftC2C<true, FFTRank, EigenDevice, complex64>(
189 device, static_cast<complex64*>(out),
190 static_cast<complex64*>(operand), input_batch, fft_length0,
191 fft_length1, fft_length2);
192 }
193 break;
194 case FftType::IFFT:
195 if (double_precision) {
196 EigenFftC2C<false, FFTRank, EigenDevice, complex128>(
197 device, static_cast<complex128*>(out),
198 static_cast<complex128*>(operand), input_batch, fft_length0,
199 fft_length1, fft_length2);
200 } else {
201 EigenFftC2C<false, FFTRank, EigenDevice, complex64>(
202 device, static_cast<complex64*>(out),
203 static_cast<complex64*>(operand), input_batch, fft_length0,
204 fft_length1, fft_length2);
205 }
206 break;
207 case FftType::RFFT:
208 if (double_precision) {
209 EigenFftR2C<FFTRank, EigenDevice, double, complex128>(
210 device, static_cast<complex128*>(out),
211 static_cast<double*>(operand), input_batch, fft_length0,
212 fft_length1, fft_length2);
213 } else {
214 EigenFftR2C<FFTRank, EigenDevice, float, complex64>(
215 device, static_cast<complex64*>(out), static_cast<float*>(operand),
216 input_batch, fft_length0, fft_length1, fft_length2);
217 }
218 break;
219 case FftType::IRFFT:
220 if (double_precision) {
221 EigenFftC2R<FFTRank, EigenDevice, complex128, double>(
222 device, static_cast<double*>(out),
223 static_cast<complex128*>(operand), input_batch, fft_length0,
224 fft_length1, fft_length2);
225 } else {
226 EigenFftC2R<FFTRank, EigenDevice, complex64, float>(
227 device, static_cast<float*>(out), static_cast<complex64*>(operand),
228 input_batch, fft_length0, fft_length1, fft_length2);
229 }
230 break;
231 default:
232 // Unsupported FFT type
233 abort();
234 }
235 }
236
237 } // namespace internal
238
239 template <typename EigenDevice>
EigenFftImpl(const EigenDevice & device,void * out,void * operand,internal::FftType fft_type,bool double_precision,int32_t fft_rank,int64_t input_batch,int64_t fft_length0,int64_t fft_length1,int64_t fft_length2)240 void EigenFftImpl(const EigenDevice& device, void* out, void* operand,
241 internal::FftType fft_type, bool double_precision,
242 int32_t fft_rank, int64_t input_batch, int64_t fft_length0,
243 int64_t fft_length1, int64_t fft_length2) {
244 switch (fft_rank) {
245 case 1:
246 internal::EigenFftWithRank<1, EigenDevice>(device, out, operand, fft_type,
247 double_precision, input_batch,
248 fft_length0, 0, 0);
249 break;
250 case 2:
251 internal::EigenFftWithRank<2, EigenDevice>(device, out, operand, fft_type,
252 double_precision, input_batch,
253 fft_length0, fft_length1, 0);
254 break;
255 case 3:
256 internal::EigenFftWithRank<3, EigenDevice>(
257 device, out, operand, fft_type, double_precision, input_batch,
258 fft_length0, fft_length1, fft_length2);
259 break;
260 default:
261 // Unsupported FFT rank
262 abort();
263 }
264 }
265
266 } // namespace xla
267
268 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_FFT_IMPL_H_
269