1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_FFT_IMPL_H_
16 #define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_FFT_IMPL_H_
17
18 #include <array>
19
20 #include "third_party/eigen3/Eigen/Core"
21 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
22 #include "tensorflow/core/framework/numeric_types.h"
23 #include "tensorflow/core/platform/types.h"
24
25 // 'tensorflow' namespace is used so that int64 and other types don't require
26 // qualification.
27 namespace tensorflow {
28 namespace xla {
29
30 enum class FftType : int32 {
31 FFT = 0, // Forward FFT; complex in, complex out.
32 IFFT = 1, // Inverse FFT; complex in, complex out.
33 RFFT = 2, // Forward real FFT; real in, fft_length / 2 + 1 complex out
34 IRFFT = 3, // Inverse real FFT; fft_length / 2 + 1 complex in,
35 // fft_length real out
36 };
37 static constexpr int kFftTypeArraySize = 4;
38
39 namespace internal {
40
41 // Computes either a forward or reverse complex-to-complex FFT.
42 template <bool Forward, int FFTRank, typename EigenDevice, typename Complex>
EigenFftC2C(const EigenDevice & device,Complex * out,Complex * operand,int64 input_batch,int64 fft_length0,int64 fft_length1,int64 fft_length2)43 void EigenFftC2C(const EigenDevice& device, Complex* out, Complex* operand,
44 int64 input_batch, int64 fft_length0, int64 fft_length1,
45 int64 fft_length2) {
46 // Create the axes (which are always trailing).
47 const auto axes = Eigen::ArrayXi::LinSpaced(FFTRank, 1, FFTRank);
48 constexpr auto direction = Forward ? Eigen::FFT_FORWARD : Eigen::FFT_REVERSE;
49
50 const std::array<int64, 3> fft_shape = {
51 {fft_length0, fft_length1, fft_length2}};
52
53 Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> dims;
54 dims[0] = input_batch;
55 for (int i = 0; i < FFTRank; i++) {
56 dims[i + 1] = fft_shape[i];
57 }
58 const Eigen::TensorMap<Eigen::Tensor<Complex, FFTRank + 1, Eigen::RowMajor>,
59 Eigen::Aligned>
60 input(operand, dims);
61 Eigen::TensorMap<Eigen::Tensor<Complex, FFTRank + 1, Eigen::RowMajor>,
62 Eigen::Aligned>
63 output(out, dims);
64 output.device(device) = input.template fft<Eigen::BothParts, direction>(axes);
65 }
66
67 // Computes a forward real->complex FFT, slicing out redundant negative
68 // frequencies from the innermost dimension.
69 template <int FFTRank, typename EigenDevice, typename Real, typename Complex>
EigenFftR2C(const EigenDevice & device,Complex * out,Real * operand,int64 input_batch,int64 fft_length0,int64 fft_length1,int64 fft_length2)70 void EigenFftR2C(const EigenDevice& device, Complex* out, Real* operand,
71 int64 input_batch, int64 fft_length0, int64 fft_length1,
72 int64 fft_length2) {
73 const std::array<int64, 3> fft_shape = {
74 {fft_length0, fft_length1, fft_length2}};
75
76 Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> in_dims;
77 in_dims[0] = input_batch;
78 Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> out_dims;
79 out_dims[0] = input_batch;
80 for (int i = 0; i < FFTRank; i++) {
81 in_dims[i + 1] = fft_shape[i];
82 out_dims[i + 1] = i == FFTRank - 1 ? fft_shape[i] / 2 + 1 : fft_shape[i];
83 }
84 const Eigen::TensorMap<Eigen::Tensor<Real, FFTRank + 1, Eigen::RowMajor>,
85 Eigen::Aligned>
86 input(operand, in_dims);
87 Eigen::TensorMap<Eigen::Tensor<Complex, FFTRank + 1, Eigen::RowMajor>,
88 Eigen::Aligned>
89 output(out, out_dims);
90
91 // Create the axes (which are always trailing).
92 const auto axes = Eigen::ArrayXi::LinSpaced(FFTRank, 1, FFTRank);
93
94 // Compute the full FFT using a temporary tensor.
95 Eigen::Tensor<Complex, FFTRank + 1, Eigen::RowMajor> full_fft(in_dims);
96
97 const Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> zero_start_indices;
98 full_fft.device(device) =
99 input.template fft<Eigen::BothParts, Eigen::FFT_FORWARD>(axes);
100
101 // Slice away the negative frequency components.
102 output.device(device) = full_fft.slice(zero_start_indices, out_dims);
103 }
104
105 // Computes a reverse complex->real FFT, reconstructing redundant negative
106 // frequencies using reverse conjugate on innermost dimension after doing IFFT
107 // on outer dimensions.
108 template <int FFTRank, typename EigenDevice, typename Complex, typename Real>
EigenFftC2R(const EigenDevice & device,Real * out,Complex * operand,int64 input_batch,int64 fft_length0,int64 fft_length1,int64 fft_length2)109 void EigenFftC2R(const EigenDevice& device, Real* out, Complex* operand,
110 int64 input_batch, int64 fft_length0, int64 fft_length1,
111 int64 fft_length2) {
112 const std::array<int64, 3> fft_shape = {
113 {fft_length0, fft_length1, fft_length2}};
114
115 Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> in_dims;
116 in_dims[0] = input_batch;
117 Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> out_dims;
118 out_dims[0] = input_batch;
119 for (int i = 0; i < FFTRank; i++) {
120 in_dims[i + 1] = i == FFTRank - 1 ? fft_shape[i] / 2 + 1 : fft_shape[i];
121 out_dims[i + 1] = fft_shape[i];
122 }
123 const Eigen::TensorMap<Eigen::Tensor<Complex, FFTRank + 1, Eigen::RowMajor>,
124 Eigen::Aligned>
125 input(operand, in_dims);
126 Eigen::TensorMap<Eigen::Tensor<Real, FFTRank + 1, Eigen::RowMajor>,
127 Eigen::Aligned>
128 output(out, out_dims);
129
130 // Calculate the shape of the temporary tensor for the full FFT and the
131 // region we will slice from input given fft_shape. We slice input to
132 // fft_shape on its inner-most dimensions, except the last (which we
133 // slice to fft_shape[-1] / 2 + 1).
134 Eigen::Tensor<Complex, FFTRank + 1, Eigen::RowMajor> full_fft(out_dims);
135
136 // Calculate the starting point and range of the source of
137 // negative frequency part.
138 auto neg_sizes = in_dims;
139 neg_sizes[FFTRank] = fft_shape[FFTRank - 1] - in_dims[FFTRank];
140 Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> neg_target_indices;
141 neg_target_indices[FFTRank] = in_dims[FFTRank];
142
143 const Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> zero_start_indices;
144 Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> neg_start_indices;
145 neg_start_indices[FFTRank] = 1;
146
147 full_fft.slice(zero_start_indices, in_dims).device(device) = input;
148
149 // First, conduct IFFTs on outer dimensions. We save computation (and
150 // avoid touching uninitialized memory) by slicing full_fft to the
151 // subregion we wrote input to.
152 if (FFTRank > 1) {
153 const auto outer_axes =
154 Eigen::ArrayXi::LinSpaced(FFTRank - 1, 1, FFTRank - 1);
155 full_fft.slice(zero_start_indices, in_dims).device(device) =
156 full_fft.slice(zero_start_indices, in_dims)
157 .template fft<Eigen::BothParts, Eigen::FFT_REVERSE>(outer_axes);
158 }
159
160 // Reconstruct the full FFT by appending reversed and conjugated
161 // spectrum as the negative frequency part.
162 Eigen::array<bool, FFTRank + 1> reverse_last_axis;
163 for (auto i = 0; i <= FFTRank; i++) {
164 reverse_last_axis[i] = i == FFTRank;
165 }
166
167 if (neg_sizes[FFTRank] != 0) {
168 full_fft.slice(neg_target_indices, neg_sizes).device(device) =
169 full_fft.slice(neg_start_indices, neg_sizes)
170 .reverse(reverse_last_axis)
171 .conjugate();
172 }
173
174 auto inner_axis = Eigen::array<int, 1>{FFTRank};
175 output.device(device) =
176 full_fft.template fft<Eigen::RealPart, Eigen::FFT_REVERSE>(inner_axis);
177 }
178
179 template <int FFTRank, typename EigenDevice>
EigenFftWithRank(const EigenDevice & device,void * out,void * operand,FftType fft_type,bool double_precision,int64 input_batch,int64 fft_length0,int64 fft_length1,int64 fft_length2)180 void EigenFftWithRank(const EigenDevice& device, void* out, void* operand,
181 FftType fft_type, bool double_precision,
182 int64 input_batch, int64 fft_length0, int64 fft_length1,
183 int64 fft_length2) {
184 switch (fft_type) {
185 case FftType::FFT:
186 if (double_precision) {
187 EigenFftC2C<true, FFTRank, EigenDevice, complex128>(
188 device, static_cast<complex128*>(out),
189 static_cast<complex128*>(operand), input_batch, fft_length0,
190 fft_length1, fft_length2);
191 } else {
192 EigenFftC2C<true, FFTRank, EigenDevice, complex64>(
193 device, static_cast<complex64*>(out),
194 static_cast<complex64*>(operand), input_batch, fft_length0,
195 fft_length1, fft_length2);
196 }
197 break;
198 case FftType::IFFT:
199 if (double_precision) {
200 EigenFftC2C<false, FFTRank, EigenDevice, complex128>(
201 device, static_cast<complex128*>(out),
202 static_cast<complex128*>(operand), input_batch, fft_length0,
203 fft_length1, fft_length2);
204 } else {
205 EigenFftC2C<false, FFTRank, EigenDevice, complex64>(
206 device, static_cast<complex64*>(out),
207 static_cast<complex64*>(operand), input_batch, fft_length0,
208 fft_length1, fft_length2);
209 }
210 break;
211 case FftType::RFFT:
212 if (double_precision) {
213 EigenFftR2C<FFTRank, EigenDevice, double, complex128>(
214 device, static_cast<complex128*>(out),
215 static_cast<double*>(operand), input_batch, fft_length0,
216 fft_length1, fft_length2);
217 } else {
218 EigenFftR2C<FFTRank, EigenDevice, float, complex64>(
219 device, static_cast<complex64*>(out), static_cast<float*>(operand),
220 input_batch, fft_length0, fft_length1, fft_length2);
221 }
222 break;
223 case FftType::IRFFT:
224 if (double_precision) {
225 EigenFftC2R<FFTRank, EigenDevice, complex128, double>(
226 device, static_cast<double*>(out),
227 static_cast<complex128*>(operand), input_batch, fft_length0,
228 fft_length1, fft_length2);
229 } else {
230 EigenFftC2R<FFTRank, EigenDevice, complex64, float>(
231 device, static_cast<float*>(out), static_cast<complex64*>(operand),
232 input_batch, fft_length0, fft_length1, fft_length2);
233 }
234 break;
235 default:
236 // Unsupported FFT type
237 abort();
238 }
239 }
240
241 } // namespace internal
242
243 template <typename EigenDevice>
EigenFftImpl(const EigenDevice & device,void * out,void * operand,FftType fft_type,bool double_precision,int32 fft_rank,int64 input_batch,int64 fft_length0,int64 fft_length1,int64 fft_length2)244 void EigenFftImpl(const EigenDevice& device, void* out, void* operand,
245 FftType fft_type, bool double_precision, int32 fft_rank,
246 int64 input_batch, int64 fft_length0, int64 fft_length1,
247 int64 fft_length2) {
248 switch (fft_rank) {
249 case 1:
250 internal::EigenFftWithRank<1, EigenDevice>(device, out, operand, fft_type,
251 double_precision, input_batch,
252 fft_length0, 0, 0);
253 break;
254 case 2:
255 internal::EigenFftWithRank<2, EigenDevice>(device, out, operand, fft_type,
256 double_precision, input_batch,
257 fft_length0, fft_length1, 0);
258 break;
259 case 3:
260 internal::EigenFftWithRank<3, EigenDevice>(
261 device, out, operand, fft_type, double_precision, input_batch,
262 fft_length0, fft_length1, fft_length2);
263 break;
264 default:
265 // Unsupported FFT rank
266 abort();
267 }
268 }
269
270 } // namespace xla
271 } // namespace tensorflow
272
273 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_FFT_IMPL_H_
274