• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_FFT_IMPL_H_
16 #define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_FFT_IMPL_H_
17 
18 #include <array>
19 
20 #include "third_party/eigen3/Eigen/Core"
21 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
22 #include "tensorflow/core/framework/numeric_types.h"
23 #include "tensorflow/core/platform/types.h"
24 
25 // 'tensorflow' namespace is used so that int64 and other types don't require
26 // qualification.
27 namespace tensorflow {
28 namespace xla {
29 
30 enum class FftType : int32 {
31   FFT = 0,    // Forward FFT; complex in, complex out.
32   IFFT = 1,   // Inverse FFT; complex in, complex out.
33   RFFT = 2,   // Forward real FFT; real in, fft_length / 2 + 1 complex out
34   IRFFT = 3,  // Inverse real FFT; fft_length / 2 + 1 complex in,
35               //                   fft_length real out
36 };
37 static constexpr int kFftTypeArraySize = 4;
38 
39 namespace internal {
40 
41 // Computes either a forward or reverse complex-to-complex FFT.
42 template <bool Forward, int FFTRank, typename EigenDevice, typename Complex>
EigenFftC2C(const EigenDevice & device,Complex * out,Complex * operand,int64 input_batch,int64 fft_length0,int64 fft_length1,int64 fft_length2)43 void EigenFftC2C(const EigenDevice& device, Complex* out, Complex* operand,
44                  int64 input_batch, int64 fft_length0, int64 fft_length1,
45                  int64 fft_length2) {
46   // Create the axes (which are always trailing).
47   const auto axes = Eigen::ArrayXi::LinSpaced(FFTRank, 1, FFTRank);
48   constexpr auto direction = Forward ? Eigen::FFT_FORWARD : Eigen::FFT_REVERSE;
49 
50   const std::array<int64, 3> fft_shape = {
51       {fft_length0, fft_length1, fft_length2}};
52 
53   Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> dims;
54   dims[0] = input_batch;
55   for (int i = 0; i < FFTRank; i++) {
56     dims[i + 1] = fft_shape[i];
57   }
58   const Eigen::TensorMap<Eigen::Tensor<Complex, FFTRank + 1, Eigen::RowMajor>,
59                          Eigen::Aligned>
60       input(operand, dims);
61   Eigen::TensorMap<Eigen::Tensor<Complex, FFTRank + 1, Eigen::RowMajor>,
62                    Eigen::Aligned>
63       output(out, dims);
64   output.device(device) = input.template fft<Eigen::BothParts, direction>(axes);
65 }
66 
67 // Computes a forward real->complex FFT, slicing out redundant negative
68 // frequencies from the innermost dimension.
69 template <int FFTRank, typename EigenDevice, typename Real, typename Complex>
EigenFftR2C(const EigenDevice & device,Complex * out,Real * operand,int64 input_batch,int64 fft_length0,int64 fft_length1,int64 fft_length2)70 void EigenFftR2C(const EigenDevice& device, Complex* out, Real* operand,
71                  int64 input_batch, int64 fft_length0, int64 fft_length1,
72                  int64 fft_length2) {
73   const std::array<int64, 3> fft_shape = {
74       {fft_length0, fft_length1, fft_length2}};
75 
76   Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> in_dims;
77   in_dims[0] = input_batch;
78   Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> out_dims;
79   out_dims[0] = input_batch;
80   for (int i = 0; i < FFTRank; i++) {
81     in_dims[i + 1] = fft_shape[i];
82     out_dims[i + 1] = i == FFTRank - 1 ? fft_shape[i] / 2 + 1 : fft_shape[i];
83   }
84   const Eigen::TensorMap<Eigen::Tensor<Real, FFTRank + 1, Eigen::RowMajor>,
85                          Eigen::Aligned>
86       input(operand, in_dims);
87   Eigen::TensorMap<Eigen::Tensor<Complex, FFTRank + 1, Eigen::RowMajor>,
88                    Eigen::Aligned>
89       output(out, out_dims);
90 
91   // Create the axes (which are always trailing).
92   const auto axes = Eigen::ArrayXi::LinSpaced(FFTRank, 1, FFTRank);
93 
94   // Compute the full FFT using a temporary tensor.
95   Eigen::Tensor<Complex, FFTRank + 1, Eigen::RowMajor> full_fft(in_dims);
96 
97   const Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> zero_start_indices;
98   full_fft.device(device) =
99       input.template fft<Eigen::BothParts, Eigen::FFT_FORWARD>(axes);
100 
101   // Slice away the negative frequency components.
102   output.device(device) = full_fft.slice(zero_start_indices, out_dims);
103 }
104 
105 // Computes a reverse complex->real FFT, reconstructing redundant negative
106 // frequencies using reverse conjugate on innermost dimension after doing IFFT
107 // on outer dimensions.
108 template <int FFTRank, typename EigenDevice, typename Complex, typename Real>
EigenFftC2R(const EigenDevice & device,Real * out,Complex * operand,int64 input_batch,int64 fft_length0,int64 fft_length1,int64 fft_length2)109 void EigenFftC2R(const EigenDevice& device, Real* out, Complex* operand,
110                  int64 input_batch, int64 fft_length0, int64 fft_length1,
111                  int64 fft_length2) {
112   const std::array<int64, 3> fft_shape = {
113       {fft_length0, fft_length1, fft_length2}};
114 
115   Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> in_dims;
116   in_dims[0] = input_batch;
117   Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> out_dims;
118   out_dims[0] = input_batch;
119   for (int i = 0; i < FFTRank; i++) {
120     in_dims[i + 1] = i == FFTRank - 1 ? fft_shape[i] / 2 + 1 : fft_shape[i];
121     out_dims[i + 1] = fft_shape[i];
122   }
123   const Eigen::TensorMap<Eigen::Tensor<Complex, FFTRank + 1, Eigen::RowMajor>,
124                          Eigen::Aligned>
125       input(operand, in_dims);
126   Eigen::TensorMap<Eigen::Tensor<Real, FFTRank + 1, Eigen::RowMajor>,
127                    Eigen::Aligned>
128       output(out, out_dims);
129 
130   // Calculate the shape of the temporary tensor for the full FFT and the
131   // region we will slice from input given fft_shape. We slice input to
132   // fft_shape on its inner-most dimensions, except the last (which we
133   // slice to fft_shape[-1] / 2 + 1).
134   Eigen::Tensor<Complex, FFTRank + 1, Eigen::RowMajor> full_fft(out_dims);
135 
136   // Calculate the starting point and range of the source of
137   // negative frequency part.
138   auto neg_sizes = in_dims;
139   neg_sizes[FFTRank] = fft_shape[FFTRank - 1] - in_dims[FFTRank];
140   Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> neg_target_indices;
141   neg_target_indices[FFTRank] = in_dims[FFTRank];
142 
143   const Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> zero_start_indices;
144   Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> neg_start_indices;
145   neg_start_indices[FFTRank] = 1;
146 
147   full_fft.slice(zero_start_indices, in_dims).device(device) = input;
148 
149   // First, conduct IFFTs on outer dimensions. We save computation (and
150   // avoid touching uninitialized memory) by slicing full_fft to the
151   // subregion we wrote input to.
152   if (FFTRank > 1) {
153     const auto outer_axes =
154         Eigen::ArrayXi::LinSpaced(FFTRank - 1, 1, FFTRank - 1);
155     full_fft.slice(zero_start_indices, in_dims).device(device) =
156         full_fft.slice(zero_start_indices, in_dims)
157             .template fft<Eigen::BothParts, Eigen::FFT_REVERSE>(outer_axes);
158   }
159 
160   // Reconstruct the full FFT by appending reversed and conjugated
161   // spectrum as the negative frequency part.
162   Eigen::array<bool, FFTRank + 1> reverse_last_axis;
163   for (auto i = 0; i <= FFTRank; i++) {
164     reverse_last_axis[i] = i == FFTRank;
165   }
166 
167   if (neg_sizes[FFTRank] != 0) {
168     full_fft.slice(neg_target_indices, neg_sizes).device(device) =
169         full_fft.slice(neg_start_indices, neg_sizes)
170             .reverse(reverse_last_axis)
171             .conjugate();
172   }
173 
174   auto inner_axis = Eigen::array<int, 1>{FFTRank};
175   output.device(device) =
176       full_fft.template fft<Eigen::RealPart, Eigen::FFT_REVERSE>(inner_axis);
177 }
178 
179 template <int FFTRank, typename EigenDevice>
EigenFftWithRank(const EigenDevice & device,void * out,void * operand,FftType fft_type,bool double_precision,int64 input_batch,int64 fft_length0,int64 fft_length1,int64 fft_length2)180 void EigenFftWithRank(const EigenDevice& device, void* out, void* operand,
181                       FftType fft_type, bool double_precision,
182                       int64 input_batch, int64 fft_length0, int64 fft_length1,
183                       int64 fft_length2) {
184   switch (fft_type) {
185     case FftType::FFT:
186       if (double_precision) {
187         EigenFftC2C<true, FFTRank, EigenDevice, complex128>(
188             device, static_cast<complex128*>(out),
189             static_cast<complex128*>(operand), input_batch, fft_length0,
190             fft_length1, fft_length2);
191       } else {
192         EigenFftC2C<true, FFTRank, EigenDevice, complex64>(
193             device, static_cast<complex64*>(out),
194             static_cast<complex64*>(operand), input_batch, fft_length0,
195             fft_length1, fft_length2);
196       }
197       break;
198     case FftType::IFFT:
199       if (double_precision) {
200         EigenFftC2C<false, FFTRank, EigenDevice, complex128>(
201             device, static_cast<complex128*>(out),
202             static_cast<complex128*>(operand), input_batch, fft_length0,
203             fft_length1, fft_length2);
204       } else {
205         EigenFftC2C<false, FFTRank, EigenDevice, complex64>(
206             device, static_cast<complex64*>(out),
207             static_cast<complex64*>(operand), input_batch, fft_length0,
208             fft_length1, fft_length2);
209       }
210       break;
211     case FftType::RFFT:
212       if (double_precision) {
213         EigenFftR2C<FFTRank, EigenDevice, double, complex128>(
214             device, static_cast<complex128*>(out),
215             static_cast<double*>(operand), input_batch, fft_length0,
216             fft_length1, fft_length2);
217       } else {
218         EigenFftR2C<FFTRank, EigenDevice, float, complex64>(
219             device, static_cast<complex64*>(out), static_cast<float*>(operand),
220             input_batch, fft_length0, fft_length1, fft_length2);
221       }
222       break;
223     case FftType::IRFFT:
224       if (double_precision) {
225         EigenFftC2R<FFTRank, EigenDevice, complex128, double>(
226             device, static_cast<double*>(out),
227             static_cast<complex128*>(operand), input_batch, fft_length0,
228             fft_length1, fft_length2);
229       } else {
230         EigenFftC2R<FFTRank, EigenDevice, complex64, float>(
231             device, static_cast<float*>(out), static_cast<complex64*>(operand),
232             input_batch, fft_length0, fft_length1, fft_length2);
233       }
234       break;
235     default:
236       // Unsupported FFT type
237       abort();
238   }
239 }
240 
241 }  // namespace internal
242 
243 template <typename EigenDevice>
EigenFftImpl(const EigenDevice & device,void * out,void * operand,FftType fft_type,bool double_precision,int32 fft_rank,int64 input_batch,int64 fft_length0,int64 fft_length1,int64 fft_length2)244 void EigenFftImpl(const EigenDevice& device, void* out, void* operand,
245                   FftType fft_type, bool double_precision, int32 fft_rank,
246                   int64 input_batch, int64 fft_length0, int64 fft_length1,
247                   int64 fft_length2) {
248   switch (fft_rank) {
249     case 1:
250       internal::EigenFftWithRank<1, EigenDevice>(device, out, operand, fft_type,
251                                                  double_precision, input_batch,
252                                                  fft_length0, 0, 0);
253       break;
254     case 2:
255       internal::EigenFftWithRank<2, EigenDevice>(device, out, operand, fft_type,
256                                                  double_precision, input_batch,
257                                                  fft_length0, fft_length1, 0);
258       break;
259     case 3:
260       internal::EigenFftWithRank<3, EigenDevice>(
261           device, out, operand, fft_type, double_precision, input_batch,
262           fft_length0, fft_length1, fft_length2);
263       break;
264     default:
265       // Unsupported FFT rank
266       abort();
267   }
268 }
269 
270 }  // namespace xla
271 }  // namespace tensorflow
272 
273 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_FFT_IMPL_H_
274