• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_FFT_IMPL_H_
16 #define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_FFT_IMPL_H_
17 
18 #include <array>
19 
20 #include "third_party/eigen3/Eigen/Core"
21 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
22 #include "tensorflow/compiler/xla/xla_data.pb.h"
23 #include "tensorflow/core/framework/numeric_types.h"
24 #include "tensorflow/core/platform/types.h"
25 
26 // 'tensorflow' namespace is used so that int64 and other types don't require
27 // qualification.
28 namespace tensorflow {
29 namespace xla {
30 
31 namespace internal {
32 
33 // Computes either a forward or reverse complex-to-complex FFT.
34 template <bool Forward, int FFTRank, typename EigenDevice>
EigenFftC2C(const EigenDevice & device,complex64 * out,complex64 * operand,int64 input_batch,int64 fft_length0,int64 fft_length1,int64 fft_length2)35 void EigenFftC2C(const EigenDevice& device, complex64* out, complex64* operand,
36                  int64 input_batch, int64 fft_length0, int64 fft_length1,
37                  int64 fft_length2) {
38   // Create the axes (which are always trailing).
39   const auto axes = Eigen::ArrayXi::LinSpaced(FFTRank, 1, FFTRank);
40   constexpr auto direction = Forward ? Eigen::FFT_FORWARD : Eigen::FFT_REVERSE;
41 
42   const std::array<int64, 3> fft_shape = {
43       {fft_length0, fft_length1, fft_length2}};
44 
45   Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> dims;
46   dims[0] = input_batch;
47   for (int i = 0; i < FFTRank; i++) {
48     dims[i + 1] = fft_shape[i];
49   }
50   const Eigen::TensorMap<Eigen::Tensor<complex64, FFTRank + 1, Eigen::RowMajor>,
51                          Eigen::Aligned>
52       input(operand, dims);
53   Eigen::TensorMap<Eigen::Tensor<complex64, FFTRank + 1, Eigen::RowMajor>,
54                    Eigen::Aligned>
55       output(out, dims);
56   output.device(device) = input.template fft<Eigen::BothParts, direction>(axes);
57 }
58 
59 // Computes a forward real->complex FFT, slicing out redundant negative
60 // frequencies from the innermost dimension.
61 template <int FFTRank, typename EigenDevice>
EigenFftR2C(const EigenDevice & device,complex64 * out,float * operand,int64 input_batch,int64 fft_length0,int64 fft_length1,int64 fft_length2)62 void EigenFftR2C(const EigenDevice& device, complex64* out, float* operand,
63                  int64 input_batch, int64 fft_length0, int64 fft_length1,
64                  int64 fft_length2) {
65   const std::array<int64, 3> fft_shape = {
66       {fft_length0, fft_length1, fft_length2}};
67 
68   Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> in_dims;
69   in_dims[0] = input_batch;
70   Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> out_dims;
71   out_dims[0] = input_batch;
72   for (int i = 0; i < FFTRank; i++) {
73     in_dims[i + 1] = fft_shape[i];
74     out_dims[i + 1] = i == FFTRank - 1 ? fft_shape[i] / 2 + 1 : fft_shape[i];
75   }
76   const Eigen::TensorMap<Eigen::Tensor<float, FFTRank + 1, Eigen::RowMajor>,
77                          Eigen::Aligned>
78       input(operand, in_dims);
79   Eigen::TensorMap<Eigen::Tensor<complex64, FFTRank + 1, Eigen::RowMajor>,
80                    Eigen::Aligned>
81       output(out, out_dims);
82 
83   // Create the axes (which are always trailing).
84   const auto axes = Eigen::ArrayXi::LinSpaced(FFTRank, 1, FFTRank);
85 
86   // Compute the full FFT using a temporary tensor.
87   Eigen::Tensor<complex64, FFTRank + 1, Eigen::RowMajor> full_fft(in_dims);
88 
89   const Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> zero_start_indices;
90   full_fft.device(device) =
91       input.template fft<Eigen::BothParts, Eigen::FFT_FORWARD>(axes);
92 
93   // Slice away the negative frequency components.
94   output.device(device) = full_fft.slice(zero_start_indices, out_dims);
95 }
96 
97 // Computes a reverse complex->real FFT, reconstructing redundant negative
98 // frequencies using reverse conjugate on innermost dimension after doing IFFT
99 // on outer dimensions.
100 template <int FFTRank, typename EigenDevice>
EigenFftC2R(const EigenDevice & device,float * out,complex64 * operand,int64 input_batch,int64 fft_length0,int64 fft_length1,int64 fft_length2)101 void EigenFftC2R(const EigenDevice& device, float* out, complex64* operand,
102                  int64 input_batch, int64 fft_length0, int64 fft_length1,
103                  int64 fft_length2) {
104   const std::array<int64, 3> fft_shape = {
105       {fft_length0, fft_length1, fft_length2}};
106 
107   Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> in_dims;
108   in_dims[0] = input_batch;
109   Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> out_dims;
110   out_dims[0] = input_batch;
111   for (int i = 0; i < FFTRank; i++) {
112     in_dims[i + 1] = i == FFTRank - 1 ? fft_shape[i] / 2 + 1 : fft_shape[i];
113     out_dims[i + 1] = fft_shape[i];
114   }
115   const Eigen::TensorMap<Eigen::Tensor<complex64, FFTRank + 1, Eigen::RowMajor>,
116                          Eigen::Aligned>
117       input(operand, in_dims);
118   Eigen::TensorMap<Eigen::Tensor<float, FFTRank + 1, Eigen::RowMajor>,
119                    Eigen::Aligned>
120       output(out, out_dims);
121 
122   // Calculate the shape of the temporary tensor for the full FFT and the
123   // region we will slice from input given fft_shape. We slice input to
124   // fft_shape on its inner-most dimensions, except the last (which we
125   // slice to fft_shape[-1] / 2 + 1).
126   Eigen::Tensor<complex64, FFTRank + 1, Eigen::RowMajor> full_fft(out_dims);
127 
128   // Calculate the starting point and range of the source of
129   // negative frequency part.
130   auto neg_sizes = in_dims;
131   neg_sizes[FFTRank] = fft_shape[FFTRank - 1] - in_dims[FFTRank];
132   Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> neg_target_indices;
133   neg_target_indices[FFTRank] = in_dims[FFTRank];
134 
135   const Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> zero_start_indices;
136   Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> neg_start_indices;
137   neg_start_indices[FFTRank] = 1;
138 
139   full_fft.slice(zero_start_indices, in_dims).device(device) = input;
140 
141   // First, conduct IFFTs on outer dimensions. We save computation (and
142   // avoid touching uninitialized memory) by slicing full_fft to the
143   // subregion we wrote input to.
144   if (FFTRank > 1) {
145     const auto outer_axes =
146         Eigen::ArrayXi::LinSpaced(FFTRank - 1, 1, FFTRank - 1);
147     full_fft.slice(zero_start_indices, in_dims).device(device) =
148         full_fft.slice(zero_start_indices, in_dims)
149             .template fft<Eigen::BothParts, Eigen::FFT_REVERSE>(outer_axes);
150   }
151 
152   // Reconstruct the full FFT by appending reversed and conjugated
153   // spectrum as the negative frequency part.
154   Eigen::array<bool, FFTRank + 1> reverse_last_axis;
155   for (auto i = 0; i <= FFTRank; i++) {
156     reverse_last_axis[i] = i == FFTRank;
157   }
158 
159   if (neg_sizes[FFTRank] != 0) {
160     full_fft.slice(neg_target_indices, neg_sizes).device(device) =
161         full_fft.slice(neg_start_indices, neg_sizes)
162             .reverse(reverse_last_axis)
163             .conjugate();
164   }
165 
166   auto inner_axis = Eigen::array<int, 1>{FFTRank};
167   output.device(device) =
168       full_fft.template fft<Eigen::RealPart, Eigen::FFT_REVERSE>(inner_axis);
169 }
170 
171 template <int FFTRank, typename EigenDevice>
EigenFftWithRank(const EigenDevice & device,void * out,void * operand,int32 fft_type,int64 input_batch,int64 fft_length0,int64 fft_length1,int64 fft_length2)172 void EigenFftWithRank(const EigenDevice& device, void* out, void* operand,
173                       int32 fft_type, int64 input_batch, int64 fft_length0,
174                       int64 fft_length1, int64 fft_length2) {
175   switch (fft_type) {
176     case ::xla::FftType::FFT:
177       EigenFftC2C<true, FFTRank, EigenDevice>(
178           device, static_cast<complex64*>(out),
179           static_cast<complex64*>(operand), input_batch, fft_length0,
180           fft_length1, fft_length2);
181       break;
182     case ::xla::FftType::IFFT:
183       EigenFftC2C<false, FFTRank, EigenDevice>(
184           device, static_cast<complex64*>(out),
185           static_cast<complex64*>(operand), input_batch, fft_length0,
186           fft_length1, fft_length2);
187       break;
188     case ::xla::FftType::RFFT:
189       EigenFftR2C<FFTRank, EigenDevice>(
190           device, static_cast<complex64*>(out), static_cast<float*>(operand),
191           input_batch, fft_length0, fft_length1, fft_length2);
192       break;
193     case ::xla::FftType::IRFFT:
194       EigenFftC2R<FFTRank, EigenDevice>(
195           device, static_cast<float*>(out), static_cast<complex64*>(operand),
196           input_batch, fft_length0, fft_length1, fft_length2);
197       break;
198     default:
199       // Unsupported FFT type
200       abort();
201   }
202 }
203 
204 }  // namespace internal
205 
206 template <typename EigenDevice>
EigenFftImpl(const EigenDevice & device,void * out,void * operand,int32 fft_type,int32 fft_rank,int64 input_batch,int64 fft_length0,int64 fft_length1,int64 fft_length2)207 void EigenFftImpl(const EigenDevice& device, void* out, void* operand,
208                   int32 fft_type, int32 fft_rank, int64 input_batch,
209                   int64 fft_length0, int64 fft_length1, int64 fft_length2) {
210   switch (fft_rank) {
211     case 1:
212       internal::EigenFftWithRank<1, EigenDevice>(
213           device, out, operand, fft_type, input_batch, fft_length0, 0, 0);
214       break;
215     case 2:
216       internal::EigenFftWithRank<2, EigenDevice>(device, out, operand, fft_type,
217                                                  input_batch, fft_length0,
218                                                  fft_length1, 0);
219       break;
220     case 3:
221       internal::EigenFftWithRank<3, EigenDevice>(device, out, operand, fft_type,
222                                                  input_batch, fft_length0,
223                                                  fft_length1, fft_length2);
224       break;
225     default:
226       // Unsupported FFT rank
227       abort();
228   }
229 }
230 
231 }  // namespace xla
232 }  // namespace tensorflow
233 
234 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_FFT_IMPL_H_
235