1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_FFT_IMPL_H_
16 #define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_FFT_IMPL_H_
17
18 #include <array>
19
20 #include "third_party/eigen3/Eigen/Core"
21 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
22 #include "tensorflow/compiler/xla/xla_data.pb.h"
23 #include "tensorflow/core/framework/numeric_types.h"
24 #include "tensorflow/core/platform/types.h"
25
26 // 'tensorflow' namespace is used so that int64 and other types don't require
27 // qualification.
28 namespace tensorflow {
29 namespace xla {
30
31 namespace internal {
32
33 // Computes either a forward or reverse complex-to-complex FFT.
34 template <bool Forward, int FFTRank, typename EigenDevice>
EigenFftC2C(const EigenDevice & device,complex64 * out,complex64 * operand,int64 input_batch,int64 fft_length0,int64 fft_length1,int64 fft_length2)35 void EigenFftC2C(const EigenDevice& device, complex64* out, complex64* operand,
36 int64 input_batch, int64 fft_length0, int64 fft_length1,
37 int64 fft_length2) {
38 // Create the axes (which are always trailing).
39 const auto axes = Eigen::ArrayXi::LinSpaced(FFTRank, 1, FFTRank);
40 constexpr auto direction = Forward ? Eigen::FFT_FORWARD : Eigen::FFT_REVERSE;
41
42 const std::array<int64, 3> fft_shape = {
43 {fft_length0, fft_length1, fft_length2}};
44
45 Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> dims;
46 dims[0] = input_batch;
47 for (int i = 0; i < FFTRank; i++) {
48 dims[i + 1] = fft_shape[i];
49 }
50 const Eigen::TensorMap<Eigen::Tensor<complex64, FFTRank + 1, Eigen::RowMajor>,
51 Eigen::Aligned>
52 input(operand, dims);
53 Eigen::TensorMap<Eigen::Tensor<complex64, FFTRank + 1, Eigen::RowMajor>,
54 Eigen::Aligned>
55 output(out, dims);
56 output.device(device) = input.template fft<Eigen::BothParts, direction>(axes);
57 }
58
59 // Computes a forward real->complex FFT, slicing out redundant negative
60 // frequencies from the innermost dimension.
61 template <int FFTRank, typename EigenDevice>
EigenFftR2C(const EigenDevice & device,complex64 * out,float * operand,int64 input_batch,int64 fft_length0,int64 fft_length1,int64 fft_length2)62 void EigenFftR2C(const EigenDevice& device, complex64* out, float* operand,
63 int64 input_batch, int64 fft_length0, int64 fft_length1,
64 int64 fft_length2) {
65 const std::array<int64, 3> fft_shape = {
66 {fft_length0, fft_length1, fft_length2}};
67
68 Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> in_dims;
69 in_dims[0] = input_batch;
70 Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> out_dims;
71 out_dims[0] = input_batch;
72 for (int i = 0; i < FFTRank; i++) {
73 in_dims[i + 1] = fft_shape[i];
74 out_dims[i + 1] = i == FFTRank - 1 ? fft_shape[i] / 2 + 1 : fft_shape[i];
75 }
76 const Eigen::TensorMap<Eigen::Tensor<float, FFTRank + 1, Eigen::RowMajor>,
77 Eigen::Aligned>
78 input(operand, in_dims);
79 Eigen::TensorMap<Eigen::Tensor<complex64, FFTRank + 1, Eigen::RowMajor>,
80 Eigen::Aligned>
81 output(out, out_dims);
82
83 // Create the axes (which are always trailing).
84 const auto axes = Eigen::ArrayXi::LinSpaced(FFTRank, 1, FFTRank);
85
86 // Compute the full FFT using a temporary tensor.
87 Eigen::Tensor<complex64, FFTRank + 1, Eigen::RowMajor> full_fft(in_dims);
88
89 const Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> zero_start_indices;
90 full_fft.device(device) =
91 input.template fft<Eigen::BothParts, Eigen::FFT_FORWARD>(axes);
92
93 // Slice away the negative frequency components.
94 output.device(device) = full_fft.slice(zero_start_indices, out_dims);
95 }
96
97 // Computes a reverse complex->real FFT, reconstructing redundant negative
98 // frequencies using reverse conjugate on innermost dimension after doing IFFT
99 // on outer dimensions.
100 template <int FFTRank, typename EigenDevice>
EigenFftC2R(const EigenDevice & device,float * out,complex64 * operand,int64 input_batch,int64 fft_length0,int64 fft_length1,int64 fft_length2)101 void EigenFftC2R(const EigenDevice& device, float* out, complex64* operand,
102 int64 input_batch, int64 fft_length0, int64 fft_length1,
103 int64 fft_length2) {
104 const std::array<int64, 3> fft_shape = {
105 {fft_length0, fft_length1, fft_length2}};
106
107 Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> in_dims;
108 in_dims[0] = input_batch;
109 Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> out_dims;
110 out_dims[0] = input_batch;
111 for (int i = 0; i < FFTRank; i++) {
112 in_dims[i + 1] = i == FFTRank - 1 ? fft_shape[i] / 2 + 1 : fft_shape[i];
113 out_dims[i + 1] = fft_shape[i];
114 }
115 const Eigen::TensorMap<Eigen::Tensor<complex64, FFTRank + 1, Eigen::RowMajor>,
116 Eigen::Aligned>
117 input(operand, in_dims);
118 Eigen::TensorMap<Eigen::Tensor<float, FFTRank + 1, Eigen::RowMajor>,
119 Eigen::Aligned>
120 output(out, out_dims);
121
122 // Calculate the shape of the temporary tensor for the full FFT and the
123 // region we will slice from input given fft_shape. We slice input to
124 // fft_shape on its inner-most dimensions, except the last (which we
125 // slice to fft_shape[-1] / 2 + 1).
126 Eigen::Tensor<complex64, FFTRank + 1, Eigen::RowMajor> full_fft(out_dims);
127
128 // Calculate the starting point and range of the source of
129 // negative frequency part.
130 auto neg_sizes = in_dims;
131 neg_sizes[FFTRank] = fft_shape[FFTRank - 1] - in_dims[FFTRank];
132 Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> neg_target_indices;
133 neg_target_indices[FFTRank] = in_dims[FFTRank];
134
135 const Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> zero_start_indices;
136 Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> neg_start_indices;
137 neg_start_indices[FFTRank] = 1;
138
139 full_fft.slice(zero_start_indices, in_dims).device(device) = input;
140
141 // First, conduct IFFTs on outer dimensions. We save computation (and
142 // avoid touching uninitialized memory) by slicing full_fft to the
143 // subregion we wrote input to.
144 if (FFTRank > 1) {
145 const auto outer_axes =
146 Eigen::ArrayXi::LinSpaced(FFTRank - 1, 1, FFTRank - 1);
147 full_fft.slice(zero_start_indices, in_dims).device(device) =
148 full_fft.slice(zero_start_indices, in_dims)
149 .template fft<Eigen::BothParts, Eigen::FFT_REVERSE>(outer_axes);
150 }
151
152 // Reconstruct the full FFT by appending reversed and conjugated
153 // spectrum as the negative frequency part.
154 Eigen::array<bool, FFTRank + 1> reverse_last_axis;
155 for (auto i = 0; i <= FFTRank; i++) {
156 reverse_last_axis[i] = i == FFTRank;
157 }
158
159 if (neg_sizes[FFTRank] != 0) {
160 full_fft.slice(neg_target_indices, neg_sizes).device(device) =
161 full_fft.slice(neg_start_indices, neg_sizes)
162 .reverse(reverse_last_axis)
163 .conjugate();
164 }
165
166 auto inner_axis = Eigen::array<int, 1>{FFTRank};
167 output.device(device) =
168 full_fft.template fft<Eigen::RealPart, Eigen::FFT_REVERSE>(inner_axis);
169 }
170
171 template <int FFTRank, typename EigenDevice>
EigenFftWithRank(const EigenDevice & device,void * out,void * operand,int32 fft_type,int64 input_batch,int64 fft_length0,int64 fft_length1,int64 fft_length2)172 void EigenFftWithRank(const EigenDevice& device, void* out, void* operand,
173 int32 fft_type, int64 input_batch, int64 fft_length0,
174 int64 fft_length1, int64 fft_length2) {
175 switch (fft_type) {
176 case ::xla::FftType::FFT:
177 EigenFftC2C<true, FFTRank, EigenDevice>(
178 device, static_cast<complex64*>(out),
179 static_cast<complex64*>(operand), input_batch, fft_length0,
180 fft_length1, fft_length2);
181 break;
182 case ::xla::FftType::IFFT:
183 EigenFftC2C<false, FFTRank, EigenDevice>(
184 device, static_cast<complex64*>(out),
185 static_cast<complex64*>(operand), input_batch, fft_length0,
186 fft_length1, fft_length2);
187 break;
188 case ::xla::FftType::RFFT:
189 EigenFftR2C<FFTRank, EigenDevice>(
190 device, static_cast<complex64*>(out), static_cast<float*>(operand),
191 input_batch, fft_length0, fft_length1, fft_length2);
192 break;
193 case ::xla::FftType::IRFFT:
194 EigenFftC2R<FFTRank, EigenDevice>(
195 device, static_cast<float*>(out), static_cast<complex64*>(operand),
196 input_batch, fft_length0, fft_length1, fft_length2);
197 break;
198 default:
199 // Unsupported FFT type
200 abort();
201 }
202 }
203
204 } // namespace internal
205
206 template <typename EigenDevice>
EigenFftImpl(const EigenDevice & device,void * out,void * operand,int32 fft_type,int32 fft_rank,int64 input_batch,int64 fft_length0,int64 fft_length1,int64 fft_length2)207 void EigenFftImpl(const EigenDevice& device, void* out, void* operand,
208 int32 fft_type, int32 fft_rank, int64 input_batch,
209 int64 fft_length0, int64 fft_length1, int64 fft_length2) {
210 switch (fft_rank) {
211 case 1:
212 internal::EigenFftWithRank<1, EigenDevice>(
213 device, out, operand, fft_type, input_batch, fft_length0, 0, 0);
214 break;
215 case 2:
216 internal::EigenFftWithRank<2, EigenDevice>(device, out, operand, fft_type,
217 input_batch, fft_length0,
218 fft_length1, 0);
219 break;
220 case 3:
221 internal::EigenFftWithRank<3, EigenDevice>(device, out, operand, fft_type,
222 input_batch, fft_length0,
223 fft_length1, fft_length2);
224 break;
225 default:
226 // Unsupported FFT rank
227 abort();
228 }
229 }
230
231 } // namespace xla
232 } // namespace tensorflow
233
234 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_FFT_IMPL_H_
235