• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_FFT_IMPL_H_
16 #define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_FFT_IMPL_H_
17 
18 #include <array>
19 
20 #include "third_party/eigen3/Eigen/Core"
21 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
22 #include "tensorflow/compiler/xla/types.h"
23 
24 namespace xla {
25 
26 namespace internal {
27 
28 enum class FftType : int32_t {
29   FFT = 0,    // Forward FFT; complex in, complex out.
30   IFFT = 1,   // Inverse FFT; complex in, complex out.
31   RFFT = 2,   // Forward real FFT; real in, fft_length / 2 + 1 complex out
32   IRFFT = 3,  // Inverse real FFT; fft_length / 2 + 1 complex in,
33               //                   fft_length real out
34 };
FftTypeArraySize()35 inline constexpr int FftTypeArraySize() { return 4; }
36 
37 // Computes either a forward or reverse complex-to-complex FFT.
38 template <bool Forward, int FFTRank, typename EigenDevice, typename Complex>
EigenFftC2C(const EigenDevice & device,Complex * out,Complex * operand,int64_t input_batch,int64_t fft_length0,int64_t fft_length1,int64_t fft_length2)39 void EigenFftC2C(const EigenDevice& device, Complex* out, Complex* operand,
40                  int64_t input_batch, int64_t fft_length0, int64_t fft_length1,
41                  int64_t fft_length2) {
42   // Create the axes (which are always trailing).
43   const auto axes = Eigen::ArrayXi::LinSpaced(FFTRank, 1, FFTRank);
44   constexpr auto direction = Forward ? Eigen::FFT_FORWARD : Eigen::FFT_REVERSE;
45 
46   const std::array<int64_t, 3> fft_shape = {
47       {fft_length0, fft_length1, fft_length2}};
48 
49   Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> dims;
50   dims[0] = input_batch;
51   for (int i = 0; i < FFTRank; i++) {
52     dims[i + 1] = fft_shape[i];
53   }
54   const Eigen::TensorMap<Eigen::Tensor<Complex, FFTRank + 1, Eigen::RowMajor>,
55                          Eigen::Aligned>
56       input(operand, dims);
57   Eigen::TensorMap<Eigen::Tensor<Complex, FFTRank + 1, Eigen::RowMajor>,
58                    Eigen::Aligned>
59       output(out, dims);
60   output.device(device) = input.template fft<Eigen::BothParts, direction>(axes);
61 }
62 
63 // Computes a forward real->complex FFT, slicing out redundant negative
64 // frequencies from the innermost dimension.
65 template <int FFTRank, typename EigenDevice, typename Real, typename Complex>
EigenFftR2C(const EigenDevice & device,Complex * out,Real * operand,int64_t input_batch,int64_t fft_length0,int64_t fft_length1,int64_t fft_length2)66 void EigenFftR2C(const EigenDevice& device, Complex* out, Real* operand,
67                  int64_t input_batch, int64_t fft_length0, int64_t fft_length1,
68                  int64_t fft_length2) {
69   const std::array<int64_t, 3> fft_shape = {
70       {fft_length0, fft_length1, fft_length2}};
71 
72   Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> in_dims;
73   in_dims[0] = input_batch;
74   Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> out_dims;
75   out_dims[0] = input_batch;
76   for (int i = 0; i < FFTRank; i++) {
77     in_dims[i + 1] = fft_shape[i];
78     out_dims[i + 1] = i == FFTRank - 1 ? fft_shape[i] / 2 + 1 : fft_shape[i];
79   }
80   const Eigen::TensorMap<Eigen::Tensor<Real, FFTRank + 1, Eigen::RowMajor>,
81                          Eigen::Aligned>
82       input(operand, in_dims);
83   Eigen::TensorMap<Eigen::Tensor<Complex, FFTRank + 1, Eigen::RowMajor>,
84                    Eigen::Aligned>
85       output(out, out_dims);
86 
87   // Create the axes (which are always trailing).
88   const auto axes = Eigen::ArrayXi::LinSpaced(FFTRank, 1, FFTRank);
89 
90   // Compute the full FFT using a temporary tensor.
91   Eigen::Tensor<Complex, FFTRank + 1, Eigen::RowMajor> full_fft(in_dims);
92 
93   const Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> zero_start_indices;
94   full_fft.device(device) =
95       input.template fft<Eigen::BothParts, Eigen::FFT_FORWARD>(axes);
96 
97   // Slice away the negative frequency components.
98   output.device(device) = full_fft.slice(zero_start_indices, out_dims);
99 }
100 
101 // Computes a reverse complex->real FFT, reconstructing redundant negative
102 // frequencies using reverse conjugate on innermost dimension after doing IFFT
103 // on outer dimensions.
104 template <int FFTRank, typename EigenDevice, typename Complex, typename Real>
EigenFftC2R(const EigenDevice & device,Real * out,Complex * operand,int64_t input_batch,int64_t fft_length0,int64_t fft_length1,int64_t fft_length2)105 void EigenFftC2R(const EigenDevice& device, Real* out, Complex* operand,
106                  int64_t input_batch, int64_t fft_length0, int64_t fft_length1,
107                  int64_t fft_length2) {
108   const std::array<int64_t, 3> fft_shape = {
109       {fft_length0, fft_length1, fft_length2}};
110 
111   Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> in_dims;
112   in_dims[0] = input_batch;
113   Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> out_dims;
114   out_dims[0] = input_batch;
115   for (int i = 0; i < FFTRank; i++) {
116     in_dims[i + 1] = i == FFTRank - 1 ? fft_shape[i] / 2 + 1 : fft_shape[i];
117     out_dims[i + 1] = fft_shape[i];
118   }
119   const Eigen::TensorMap<Eigen::Tensor<Complex, FFTRank + 1, Eigen::RowMajor>,
120                          Eigen::Aligned>
121       input(operand, in_dims);
122   Eigen::TensorMap<Eigen::Tensor<Real, FFTRank + 1, Eigen::RowMajor>,
123                    Eigen::Aligned>
124       output(out, out_dims);
125 
126   // Calculate the shape of the temporary tensor for the full FFT and the
127   // region we will slice from input given fft_shape. We slice input to
128   // fft_shape on its inner-most dimensions, except the last (which we
129   // slice to fft_shape[-1] / 2 + 1).
130   Eigen::Tensor<Complex, FFTRank + 1, Eigen::RowMajor> full_fft(out_dims);
131 
132   // Calculate the starting point and range of the source of
133   // negative frequency part.
134   auto neg_sizes = in_dims;
135   neg_sizes[FFTRank] = fft_shape[FFTRank - 1] - in_dims[FFTRank];
136   Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> neg_target_indices;
137   neg_target_indices[FFTRank] = in_dims[FFTRank];
138 
139   const Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> zero_start_indices;
140   Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> neg_start_indices;
141   neg_start_indices[FFTRank] = 1;
142 
143   full_fft.slice(zero_start_indices, in_dims).device(device) = input;
144 
145   // First, conduct IFFTs on outer dimensions. We save computation (and
146   // avoid touching uninitialized memory) by slicing full_fft to the
147   // subregion we wrote input to.
148   if (FFTRank > 1) {
149     const auto outer_axes =
150         Eigen::ArrayXi::LinSpaced(FFTRank - 1, 1, FFTRank - 1);
151     full_fft.slice(zero_start_indices, in_dims).device(device) =
152         full_fft.slice(zero_start_indices, in_dims)
153             .template fft<Eigen::BothParts, Eigen::FFT_REVERSE>(outer_axes);
154   }
155 
156   // Reconstruct the full FFT by appending reversed and conjugated
157   // spectrum as the negative frequency part.
158   Eigen::array<bool, FFTRank + 1> reverse_last_axis;
159   for (auto i = 0; i <= FFTRank; i++) {
160     reverse_last_axis[i] = i == FFTRank;
161   }
162 
163   if (neg_sizes[FFTRank] != 0) {
164     full_fft.slice(neg_target_indices, neg_sizes).device(device) =
165         full_fft.slice(neg_start_indices, neg_sizes)
166             .reverse(reverse_last_axis)
167             .conjugate();
168   }
169 
170   auto inner_axis = Eigen::array<int, 1>{FFTRank};
171   output.device(device) =
172       full_fft.template fft<Eigen::RealPart, Eigen::FFT_REVERSE>(inner_axis);
173 }
174 
175 template <int FFTRank, typename EigenDevice>
EigenFftWithRank(const EigenDevice & device,void * out,void * operand,FftType fft_type,bool double_precision,int64_t input_batch,int64_t fft_length0,int64_t fft_length1,int64_t fft_length2)176 void EigenFftWithRank(const EigenDevice& device, void* out, void* operand,
177                       FftType fft_type, bool double_precision,
178                       int64_t input_batch, int64_t fft_length0,
179                       int64_t fft_length1, int64_t fft_length2) {
180   switch (fft_type) {
181     case FftType::FFT:
182       if (double_precision) {
183         EigenFftC2C<true, FFTRank, EigenDevice, complex128>(
184             device, static_cast<complex128*>(out),
185             static_cast<complex128*>(operand), input_batch, fft_length0,
186             fft_length1, fft_length2);
187       } else {
188         EigenFftC2C<true, FFTRank, EigenDevice, complex64>(
189             device, static_cast<complex64*>(out),
190             static_cast<complex64*>(operand), input_batch, fft_length0,
191             fft_length1, fft_length2);
192       }
193       break;
194     case FftType::IFFT:
195       if (double_precision) {
196         EigenFftC2C<false, FFTRank, EigenDevice, complex128>(
197             device, static_cast<complex128*>(out),
198             static_cast<complex128*>(operand), input_batch, fft_length0,
199             fft_length1, fft_length2);
200       } else {
201         EigenFftC2C<false, FFTRank, EigenDevice, complex64>(
202             device, static_cast<complex64*>(out),
203             static_cast<complex64*>(operand), input_batch, fft_length0,
204             fft_length1, fft_length2);
205       }
206       break;
207     case FftType::RFFT:
208       if (double_precision) {
209         EigenFftR2C<FFTRank, EigenDevice, double, complex128>(
210             device, static_cast<complex128*>(out),
211             static_cast<double*>(operand), input_batch, fft_length0,
212             fft_length1, fft_length2);
213       } else {
214         EigenFftR2C<FFTRank, EigenDevice, float, complex64>(
215             device, static_cast<complex64*>(out), static_cast<float*>(operand),
216             input_batch, fft_length0, fft_length1, fft_length2);
217       }
218       break;
219     case FftType::IRFFT:
220       if (double_precision) {
221         EigenFftC2R<FFTRank, EigenDevice, complex128, double>(
222             device, static_cast<double*>(out),
223             static_cast<complex128*>(operand), input_batch, fft_length0,
224             fft_length1, fft_length2);
225       } else {
226         EigenFftC2R<FFTRank, EigenDevice, complex64, float>(
227             device, static_cast<float*>(out), static_cast<complex64*>(operand),
228             input_batch, fft_length0, fft_length1, fft_length2);
229       }
230       break;
231     default:
232       // Unsupported FFT type
233       abort();
234   }
235 }
236 
237 }  // namespace internal
238 
239 template <typename EigenDevice>
EigenFftImpl(const EigenDevice & device,void * out,void * operand,internal::FftType fft_type,bool double_precision,int32_t fft_rank,int64_t input_batch,int64_t fft_length0,int64_t fft_length1,int64_t fft_length2)240 void EigenFftImpl(const EigenDevice& device, void* out, void* operand,
241                   internal::FftType fft_type, bool double_precision,
242                   int32_t fft_rank, int64_t input_batch, int64_t fft_length0,
243                   int64_t fft_length1, int64_t fft_length2) {
244   switch (fft_rank) {
245     case 1:
246       internal::EigenFftWithRank<1, EigenDevice>(device, out, operand, fft_type,
247                                                  double_precision, input_batch,
248                                                  fft_length0, 0, 0);
249       break;
250     case 2:
251       internal::EigenFftWithRank<2, EigenDevice>(device, out, operand, fft_type,
252                                                  double_precision, input_batch,
253                                                  fft_length0, fft_length1, 0);
254       break;
255     case 3:
256       internal::EigenFftWithRank<3, EigenDevice>(
257           device, out, operand, fft_type, double_precision, input_batch,
258           fft_length0, fft_length1, fft_length2);
259       break;
260     default:
261       // Unsupported FFT rank
262       abort();
263   }
264 }
265 
266 }  // namespace xla
267 
268 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_FFT_IMPL_H_
269