1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/platform/errors.h"
17 #define EIGEN_USE_THREADS
18
19 // See docs in ../ops/fft_ops.cc.
20
21 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
22 #include "tensorflow/core/framework/op.h"
23 #include "tensorflow/core/framework/op_kernel.h"
24 #include "tensorflow/core/framework/tensor.h"
25 #include "tensorflow/core/framework/tensor_shape.h"
26 #include "tensorflow/core/framework/types.h"
27 #include "tensorflow/core/lib/core/errors.h"
28 #include "tensorflow/core/platform/logging.h"
29 #include "tensorflow/core/platform/types.h"
30 #include "tensorflow/core/util/env_var.h"
31 #include "tensorflow/core/util/work_sharder.h"
32
33 #if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \
34 (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM)
35 #include "tensorflow/core/platform/stream_executor.h"
36 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
37
38 namespace tensorflow {
39
40 class FFTBase : public OpKernel {
41 public:
FFTBase(OpKernelConstruction * ctx)42 explicit FFTBase(OpKernelConstruction* ctx) : OpKernel(ctx) {}
43
Compute(OpKernelContext * ctx)44 void Compute(OpKernelContext* ctx) override {
45 const Tensor& in = ctx->input(0);
46 const TensorShape& input_shape = in.shape();
47 const int fft_rank = Rank();
48 OP_REQUIRES(
49 ctx, input_shape.dims() >= fft_rank,
50 errors::InvalidArgument("Input must have rank of at least ", fft_rank,
51 " but got: ", input_shape.DebugString()));
52
53 Tensor* out;
54 TensorShape output_shape = input_shape;
55 uint64 fft_shape[3] = {0, 0, 0};
56
57 // In R2C or C2R mode, we use a second input to specify the FFT length
58 // instead of inferring it from the input shape.
59 if (IsReal()) {
60 const Tensor& fft_length = ctx->input(1);
61 OP_REQUIRES(ctx,
62 fft_length.shape().dims() == 1 &&
63 fft_length.shape().dim_size(0) == fft_rank,
64 errors::InvalidArgument("fft_length must have shape [",
65 fft_rank, "]"));
66
67 auto fft_length_as_vec = fft_length.vec<int32>();
68 for (int i = 0; i < fft_rank; ++i) {
69 fft_shape[i] = fft_length_as_vec(i);
70 // Each input dimension must have length of at least fft_shape[i]. For
71 // IRFFTs, the inner-most input dimension must have length of at least
72 // fft_shape[i] / 2 + 1.
73 bool inner_most = (i == fft_rank - 1);
74 uint64 min_input_dim_length =
75 !IsForward() && inner_most ? fft_shape[i] / 2 + 1 : fft_shape[i];
76 auto input_index = input_shape.dims() - fft_rank + i;
77 OP_REQUIRES(
78 ctx,
79 // We pass through empty tensors, so special case them here.
80 input_shape.dim_size(input_index) == 0 ||
81 input_shape.dim_size(input_index) >= min_input_dim_length,
82 errors::InvalidArgument(
83 "Input dimension ", input_index,
84 " must have length of at least ", min_input_dim_length,
85 " but got: ", input_shape.dim_size(input_index)));
86 uint64 dim = IsForward() && inner_most && fft_shape[i] != 0
87 ? fft_shape[i] / 2 + 1
88 : fft_shape[i];
89 output_shape.set_dim(output_shape.dims() - fft_rank + i, dim);
90 }
91 } else {
92 for (int i = 0; i < fft_rank; ++i) {
93 fft_shape[i] =
94 output_shape.dim_size(output_shape.dims() - fft_rank + i);
95 }
96 }
97
98 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, output_shape, &out));
99
100 if (IsReal()) {
101 if (IsForward()) {
102 OP_REQUIRES(
103 ctx,
104 (in.dtype() == DT_FLOAT && out->dtype() == DT_COMPLEX64) ||
105 (in.dtype() == DT_DOUBLE && out->dtype() == DT_COMPLEX128),
106 errors::InvalidArgument("Wrong types for forward real FFT: in=",
107 in.dtype(), " out=", out->dtype()));
108 } else {
109 OP_REQUIRES(
110 ctx,
111 (in.dtype() == DT_COMPLEX64 && out->dtype() == DT_FLOAT) ||
112 (in.dtype() == DT_COMPLEX128 && out->dtype() == DT_DOUBLE),
113 errors::InvalidArgument("Wrong types for backward real FFT: in=",
114 in.dtype(), " out=", out->dtype()));
115 }
116 } else {
117 OP_REQUIRES(
118 ctx,
119 (in.dtype() == DT_COMPLEX64 && out->dtype() == DT_COMPLEX64) ||
120 (in.dtype() == DT_COMPLEX128 && out->dtype() == DT_COMPLEX128),
121 errors::InvalidArgument("Wrong types for FFT: in=", in.dtype(),
122 " out=", out->dtype()));
123 }
124
125 if (input_shape.num_elements() == 0) {
126 DCHECK_EQ(0, output_shape.num_elements());
127 return;
128 }
129
130 DoFFT(ctx, in, fft_shape, out);
131 }
132
133 protected:
134 virtual int Rank() const = 0;
135 virtual bool IsForward() const = 0;
136 virtual bool IsReal() const = 0;
137
138 // The function that actually computes the FFT.
139 virtual void DoFFT(OpKernelContext* ctx, const Tensor& in, uint64* fft_shape,
140 Tensor* out) = 0;
141 };
142
143 typedef Eigen::ThreadPoolDevice CPUDevice;
144
145 template <bool Forward, bool _Real, int FFTRank>
146 class FFTCPU : public FFTBase {
147 public:
148 using FFTBase::FFTBase;
149
150 protected:
Rank() const151 int Rank() const override { return FFTRank; }
IsForward() const152 bool IsForward() const override { return Forward; }
IsReal() const153 bool IsReal() const override { return _Real; }
154
DoFFT(OpKernelContext * ctx,const Tensor & in,uint64 * fft_shape,Tensor * out)155 void DoFFT(OpKernelContext* ctx, const Tensor& in, uint64* fft_shape,
156 Tensor* out) override {
157 // Create the axes (which are always trailing).
158 const auto axes = Eigen::ArrayXi::LinSpaced(FFTRank, 1, FFTRank);
159 auto device = ctx->eigen_device<CPUDevice>();
160
161 const bool is_complex128 =
162 in.dtype() == DT_COMPLEX128 || out->dtype() == DT_COMPLEX128;
163
164 if (!IsReal()) {
165 // Compute the FFT using Eigen.
166 constexpr auto direction =
167 Forward ? Eigen::FFT_FORWARD : Eigen::FFT_REVERSE;
168 if (is_complex128) {
169 DCHECK_EQ(in.dtype(), DT_COMPLEX128);
170 DCHECK_EQ(out->dtype(), DT_COMPLEX128);
171 auto input = Tensor(in).flat_inner_dims<complex128, FFTRank + 1>();
172 auto output = out->flat_inner_dims<complex128, FFTRank + 1>();
173 output.device(device) =
174 input.template fft<Eigen::BothParts, direction>(axes);
175 } else {
176 DCHECK_EQ(in.dtype(), DT_COMPLEX64);
177 DCHECK_EQ(out->dtype(), DT_COMPLEX64);
178 auto input = Tensor(in).flat_inner_dims<complex64, FFTRank + 1>();
179 auto output = out->flat_inner_dims<complex64, FFTRank + 1>();
180 output.device(device) =
181 input.template fft<Eigen::BothParts, direction>(axes);
182 }
183 } else {
184 if (IsForward()) {
185 if (is_complex128) {
186 DCHECK_EQ(in.dtype(), DT_DOUBLE);
187 DCHECK_EQ(out->dtype(), DT_COMPLEX128);
188 DoRealForwardFFT<double, complex128>(ctx, fft_shape, in, out);
189 } else {
190 DCHECK_EQ(in.dtype(), DT_FLOAT);
191 DCHECK_EQ(out->dtype(), DT_COMPLEX64);
192 DoRealForwardFFT<float, complex64>(ctx, fft_shape, in, out);
193 }
194 } else {
195 if (is_complex128) {
196 DCHECK_EQ(in.dtype(), DT_COMPLEX128);
197 DCHECK_EQ(out->dtype(), DT_DOUBLE);
198 DoRealBackwardFFT<complex128, double>(ctx, fft_shape, in, out);
199 } else {
200 DCHECK_EQ(in.dtype(), DT_COMPLEX64);
201 DCHECK_EQ(out->dtype(), DT_FLOAT);
202 DoRealBackwardFFT<complex64, float>(ctx, fft_shape, in, out);
203 }
204 }
205 }
206 }
207
208 template <typename RealT, typename ComplexT>
DoRealForwardFFT(OpKernelContext * ctx,uint64 * fft_shape,const Tensor & in,Tensor * out)209 void DoRealForwardFFT(OpKernelContext* ctx, uint64* fft_shape,
210 const Tensor& in, Tensor* out) {
211 // Create the axes (which are always trailing).
212 const auto axes = Eigen::ArrayXi::LinSpaced(FFTRank, 1, FFTRank);
213 auto device = ctx->eigen_device<CPUDevice>();
214 auto input = Tensor(in).flat_inner_dims<RealT, FFTRank + 1>();
215 const auto input_dims = input.dimensions();
216
217 // Slice input to fft_shape on its inner-most dimensions.
218 Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> input_slice_sizes;
219 input_slice_sizes[0] = input_dims[0];
220 TensorShape temp_shape{input_dims[0]};
221 for (int i = 1; i <= FFTRank; ++i) {
222 input_slice_sizes[i] = fft_shape[i - 1];
223 temp_shape.AddDim(fft_shape[i - 1]);
224 }
225 OP_REQUIRES(ctx, temp_shape.num_elements() > 0,
226 errors::InvalidArgument("Obtained a FFT shape of 0 elements: ",
227 temp_shape.DebugString()));
228
229 auto output = out->flat_inner_dims<ComplexT, FFTRank + 1>();
230 const Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> zero_start_indices;
231
232 // Compute the full FFT using a temporary tensor.
233 Tensor temp;
234 OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<ComplexT>::v(),
235 temp_shape, &temp));
236 auto full_fft = temp.flat_inner_dims<ComplexT, FFTRank + 1>();
237 full_fft.device(device) =
238 input.slice(zero_start_indices, input_slice_sizes)
239 .template fft<Eigen::BothParts, Eigen::FFT_FORWARD>(axes);
240
241 // Slice away the negative frequency components.
242 output.device(device) =
243 full_fft.slice(zero_start_indices, output.dimensions());
244 }
245
246 template <typename ComplexT, typename RealT>
DoRealBackwardFFT(OpKernelContext * ctx,uint64 * fft_shape,const Tensor & in,Tensor * out)247 void DoRealBackwardFFT(OpKernelContext* ctx, uint64* fft_shape,
248 const Tensor& in, Tensor* out) {
249 auto device = ctx->eigen_device<CPUDevice>();
250 // Reconstruct the full FFT and take the inverse.
251 auto input = Tensor(in).flat_inner_dims<ComplexT, FFTRank + 1>();
252 auto output = out->flat_inner_dims<RealT, FFTRank + 1>();
253 const auto input_dims = input.dimensions();
254
255 // Calculate the shape of the temporary tensor for the full FFT and the
256 // region we will slice from input given fft_shape. We slice input to
257 // fft_shape on its inner-most dimensions, except the last (which we
258 // slice to fft_shape[-1] / 2 + 1).
259 Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> input_slice_sizes;
260 input_slice_sizes[0] = input_dims[0];
261 TensorShape full_fft_shape;
262 full_fft_shape.AddDim(input_dims[0]);
263 for (auto i = 1; i <= FFTRank; i++) {
264 input_slice_sizes[i] =
265 i == FFTRank ? fft_shape[i - 1] / 2 + 1 : fft_shape[i - 1];
266 full_fft_shape.AddDim(fft_shape[i - 1]);
267 }
268 OP_REQUIRES(ctx, full_fft_shape.num_elements() > 0,
269 errors::InvalidArgument("Obtained a FFT shape of 0 elements: ",
270 full_fft_shape.DebugString()));
271
272 Tensor temp;
273 OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<ComplexT>::v(),
274 full_fft_shape, &temp));
275 auto full_fft = temp.flat_inner_dims<ComplexT, FFTRank + 1>();
276
277 // Calculate the starting point and range of the source of
278 // negative frequency part.
279 auto neg_sizes = input_slice_sizes;
280 neg_sizes[FFTRank] = fft_shape[FFTRank - 1] - input_slice_sizes[FFTRank];
281 Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> neg_target_indices;
282 neg_target_indices[FFTRank] = input_slice_sizes[FFTRank];
283
284 const Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> start_indices;
285 Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> neg_start_indices;
286 neg_start_indices[FFTRank] = 1;
287
288 full_fft.slice(start_indices, input_slice_sizes).device(device) =
289 input.slice(start_indices, input_slice_sizes);
290
291 // First, conduct IFFTs on outer dimensions. We save computation (and
292 // avoid touching uninitialized memory) by slicing full_fft to the
293 // subregion we wrote input to.
294 if (FFTRank > 1) {
295 const auto outer_axes =
296 Eigen::ArrayXi::LinSpaced(FFTRank - 1, 1, FFTRank - 1);
297 full_fft.slice(start_indices, input_slice_sizes).device(device) =
298 full_fft.slice(start_indices, input_slice_sizes)
299 .template fft<Eigen::BothParts, Eigen::FFT_REVERSE>(outer_axes);
300 }
301
302 // Reconstruct the full FFT by appending reversed and conjugated
303 // spectrum as the negative frequency part.
304 Eigen::array<bool, FFTRank + 1> reverse_last_axis;
305 for (auto i = 0; i <= FFTRank; i++) {
306 reverse_last_axis[i] = i == FFTRank;
307 }
308
309 if (neg_sizes[FFTRank] != 0) {
310 full_fft.slice(neg_target_indices, neg_sizes).device(device) =
311 full_fft.slice(neg_start_indices, neg_sizes)
312 .reverse(reverse_last_axis)
313 .conjugate();
314 }
315
316 auto inner_axis = Eigen::array<int, 1>{FFTRank};
317 output.device(device) =
318 full_fft.template fft<Eigen::RealPart, Eigen::FFT_REVERSE>(inner_axis);
319 }
320 };
321
322 REGISTER_KERNEL_BUILDER(Name("FFT").Device(DEVICE_CPU), FFTCPU<true, false, 1>);
323 REGISTER_KERNEL_BUILDER(Name("IFFT").Device(DEVICE_CPU),
324 FFTCPU<false, false, 1>);
325 REGISTER_KERNEL_BUILDER(Name("FFT2D").Device(DEVICE_CPU),
326 FFTCPU<true, false, 2>);
327 REGISTER_KERNEL_BUILDER(Name("IFFT2D").Device(DEVICE_CPU),
328 FFTCPU<false, false, 2>);
329 REGISTER_KERNEL_BUILDER(Name("FFT3D").Device(DEVICE_CPU),
330 FFTCPU<true, false, 3>);
331 REGISTER_KERNEL_BUILDER(Name("IFFT3D").Device(DEVICE_CPU),
332 FFTCPU<false, false, 3>);
333
334 REGISTER_KERNEL_BUILDER(Name("RFFT").Device(DEVICE_CPU), FFTCPU<true, true, 1>);
335 REGISTER_KERNEL_BUILDER(Name("IRFFT").Device(DEVICE_CPU),
336 FFTCPU<false, true, 1>);
337 REGISTER_KERNEL_BUILDER(Name("RFFT2D").Device(DEVICE_CPU),
338 FFTCPU<true, true, 2>);
339 REGISTER_KERNEL_BUILDER(Name("IRFFT2D").Device(DEVICE_CPU),
340 FFTCPU<false, true, 2>);
341 REGISTER_KERNEL_BUILDER(Name("RFFT3D").Device(DEVICE_CPU),
342 FFTCPU<true, true, 3>);
343 REGISTER_KERNEL_BUILDER(Name("IRFFT3D").Device(DEVICE_CPU),
344 FFTCPU<false, true, 3>);
345
346 #if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \
347 (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM)
348
349 namespace {
350 template <typename T>
AsDeviceMemory(const T * cuda_memory)351 se::DeviceMemory<T> AsDeviceMemory(const T* cuda_memory) {
352 se::DeviceMemoryBase wrapped(const_cast<T*>(cuda_memory));
353 se::DeviceMemory<T> typed(wrapped);
354 return typed;
355 }
356
357 template <typename T>
AsDeviceMemory(const T * cuda_memory,uint64 size)358 se::DeviceMemory<T> AsDeviceMemory(const T* cuda_memory, uint64 size) {
359 se::DeviceMemoryBase wrapped(const_cast<T*>(cuda_memory), size * sizeof(T));
360 se::DeviceMemory<T> typed(wrapped);
361 return typed;
362 }
363
364 // A class to provide scratch-space allocator for Stream-Executor Cufft
365 // callback. Tensorflow is responsible for releasing the temporary buffers after
366 // the kernel finishes.
367 // TODO(yangzihao): Refactor redundant code in subclasses of ScratchAllocator
368 // into base class.
369 class CufftScratchAllocator : public se::ScratchAllocator {
370 public:
~CufftScratchAllocator()371 ~CufftScratchAllocator() override {}
CufftScratchAllocator(int64_t memory_limit,OpKernelContext * context)372 CufftScratchAllocator(int64_t memory_limit, OpKernelContext* context)
373 : memory_limit_(memory_limit), total_byte_size_(0), context_(context) {}
GetMemoryLimitInBytes()374 int64 GetMemoryLimitInBytes() override { return memory_limit_; }
AllocateBytes(int64_t byte_size)375 se::port::StatusOr<se::DeviceMemory<uint8>> AllocateBytes(
376 int64_t byte_size) override {
377 Tensor temporary_memory;
378 if (byte_size > memory_limit_) {
379 return se::port::StatusOr<se::DeviceMemory<uint8>>();
380 }
381 AllocationAttributes allocation_attr;
382 allocation_attr.retry_on_failure = false;
383 Status allocation_status(context_->allocate_temp(
384 DT_UINT8, TensorShape({byte_size}), &temporary_memory,
385 AllocatorAttributes(), allocation_attr));
386 if (!allocation_status.ok()) {
387 return se::port::StatusOr<se::DeviceMemory<uint8>>();
388 }
389 // Hold the reference of the allocated tensors until the end of the
390 // allocator.
391 allocated_tensors_.push_back(temporary_memory);
392 total_byte_size_ += byte_size;
393 return se::port::StatusOr<se::DeviceMemory<uint8>>(
394 AsDeviceMemory(temporary_memory.flat<uint8>().data(),
395 temporary_memory.flat<uint8>().size()));
396 }
TotalByteSize()397 int64 TotalByteSize() { return total_byte_size_; }
398
399 private:
400 int64 memory_limit_;
401 int64 total_byte_size_;
402 OpKernelContext* context_;
403 std::vector<Tensor> allocated_tensors_;
404 };
405
406 } // end namespace
407
GetCufftWorkspaceLimit(const string & envvar_in_mb,int64_t default_value_in_bytes)408 int64 GetCufftWorkspaceLimit(const string& envvar_in_mb,
409 int64_t default_value_in_bytes) {
410 const char* workspace_limit_in_mb_str = getenv(envvar_in_mb.c_str());
411 if (workspace_limit_in_mb_str != nullptr &&
412 strcmp(workspace_limit_in_mb_str, "") != 0) {
413 int64_t scratch_limit_in_mb = -1;
414 Status status = ReadInt64FromEnvVar(envvar_in_mb, default_value_in_bytes,
415 &scratch_limit_in_mb);
416 if (!status.ok()) {
417 LOG(WARNING) << "Invalid value for env-var " << envvar_in_mb << ": "
418 << workspace_limit_in_mb_str;
419 } else {
420 return scratch_limit_in_mb * (1 << 20);
421 }
422 }
423 return default_value_in_bytes;
424 }
425
426 class FFTGPUBase : public FFTBase {
427 public:
428 using FFTBase::FFTBase;
429
430 protected:
431 static int64_t CufftScratchSize;
DoFFT(OpKernelContext * ctx,const Tensor & in,uint64 * fft_shape,Tensor * out)432 void DoFFT(OpKernelContext* ctx, const Tensor& in, uint64* fft_shape,
433 Tensor* out) override {
434 auto* stream = ctx->op_device_context()->stream();
435 OP_REQUIRES(ctx, stream, errors::Internal("No GPU stream available."));
436
437 const TensorShape& input_shape = in.shape();
438 const TensorShape& output_shape = out->shape();
439
440 const int fft_rank = Rank();
441 int batch_size = 1;
442 for (int i = 0; i < input_shape.dims() - fft_rank; ++i) {
443 batch_size *= input_shape.dim_size(i);
444 }
445 uint64 input_embed[3];
446 const uint64 input_stride = 1;
447 uint64 input_distance = 1;
448 uint64 output_embed[3];
449 const uint64 output_stride = 1;
450 uint64 output_distance = 1;
451
452 for (int i = 0; i < fft_rank; ++i) {
453 auto dim_offset = input_shape.dims() - fft_rank + i;
454 input_embed[i] = input_shape.dim_size(dim_offset);
455 input_distance *= input_shape.dim_size(dim_offset);
456 output_embed[i] = output_shape.dim_size(dim_offset);
457 output_distance *= output_shape.dim_size(dim_offset);
458 }
459
460 constexpr bool kInPlaceFft = false;
461 const bool is_complex128 =
462 in.dtype() == DT_COMPLEX128 || out->dtype() == DT_COMPLEX128;
463
464 const auto kFftType =
465 IsReal()
466 ? (IsForward()
467 ? (is_complex128 ? se::fft::Type::kD2Z : se::fft::Type::kR2C)
468 : (is_complex128 ? se::fft::Type::kZ2D
469 : se::fft::Type::kC2R))
470 : (IsForward() ? (is_complex128 ? se::fft::Type::kZ2ZForward
471 : se::fft::Type::kC2CForward)
472 : (is_complex128 ? se::fft::Type::kZ2ZInverse
473 : se::fft::Type::kC2CInverse));
474
475 CufftScratchAllocator scratch_allocator(CufftScratchSize, ctx);
476 auto plan =
477 stream->parent()->AsFft()->CreateBatchedPlanWithScratchAllocator(
478 stream, fft_rank, fft_shape, input_embed, input_stride,
479 input_distance, output_embed, output_stride, output_distance,
480 kFftType, kInPlaceFft, batch_size, &scratch_allocator);
481
482 if (IsReal()) {
483 if (IsForward()) {
484 if (is_complex128) {
485 DCHECK_EQ(in.dtype(), DT_DOUBLE);
486 DCHECK_EQ(out->dtype(), DT_COMPLEX128);
487 DoFFTInternal<double, complex128>(ctx, stream, plan.get(), kFftType,
488 output_distance, in, out);
489 } else {
490 DCHECK_EQ(in.dtype(), DT_FLOAT);
491 DCHECK_EQ(out->dtype(), DT_COMPLEX64);
492 DoFFTInternal<float, complex64>(ctx, stream, plan.get(), kFftType,
493 output_distance, in, out);
494 }
495 } else {
496 if (is_complex128) {
497 DCHECK_EQ(in.dtype(), DT_COMPLEX128);
498 DCHECK_EQ(out->dtype(), DT_DOUBLE);
499 DoFFTInternal<complex128, double>(ctx, stream, plan.get(), kFftType,
500 output_distance, in, out);
501 } else {
502 DCHECK_EQ(in.dtype(), DT_COMPLEX64);
503 DCHECK_EQ(out->dtype(), DT_FLOAT);
504 DoFFTInternal<complex64, float>(ctx, stream, plan.get(), kFftType,
505 output_distance, in, out);
506 }
507 }
508 } else {
509 if (is_complex128) {
510 DCHECK_EQ(in.dtype(), DT_COMPLEX128);
511 DCHECK_EQ(out->dtype(), DT_COMPLEX128);
512 DoFFTInternal<complex128, complex128>(ctx, stream, plan.get(), kFftType,
513 output_distance, in, out);
514 } else {
515 DCHECK_EQ(in.dtype(), DT_COMPLEX64);
516 DCHECK_EQ(out->dtype(), DT_COMPLEX64);
517 DoFFTInternal<complex64, complex64>(ctx, stream, plan.get(), kFftType,
518 output_distance, in, out);
519 }
520 }
521 }
522
523 private:
524 template <typename T>
525 struct RealTypeFromComplexType {
526 typedef T RealT;
527 };
528
529 template <typename T>
530 struct RealTypeFromComplexType<std::complex<T>> {
531 typedef T RealT;
532 };
533
534 template <typename InT, typename OutT>
DoFFTInternal(OpKernelContext * ctx,se::Stream * stream,se::fft::Plan * plan,const se::fft::Type fft_type,const uint64 output_distance,const Tensor & in,Tensor * out)535 void DoFFTInternal(OpKernelContext* ctx, se::Stream* stream,
536 se::fft::Plan* plan, const se::fft::Type fft_type,
537 const uint64 output_distance, const Tensor& in,
538 Tensor* out) {
539 const TensorShape& input_shape = in.shape();
540 const TensorShape& output_shape = out->shape();
541 auto src =
542 AsDeviceMemory<InT>(in.flat<InT>().data(), input_shape.num_elements());
543 auto dst = AsDeviceMemory<OutT>(out->flat<OutT>().data(),
544 output_shape.num_elements());
545 OP_REQUIRES(
546 ctx, stream->ThenFft(plan, src, &dst).ok(),
547 errors::Internal("fft failed : type=", static_cast<int>(fft_type),
548 " in.shape=", input_shape.DebugString()));
549 if (!IsForward()) {
550 typedef typename RealTypeFromComplexType<OutT>::RealT RealT;
551 RealT alpha = 1.0 / output_distance;
552 OP_REQUIRES(
553 ctx,
554 stream->ThenBlasScal(output_shape.num_elements(), alpha, &dst, 1)
555 .ok(),
556 errors::Internal("BlasScal failed : in.shape=",
557 input_shape.DebugString()));
558 }
559 }
560 };
561
562 int64_t FFTGPUBase::CufftScratchSize = GetCufftWorkspaceLimit(
563 // default value is in bytes despite the name of the environment variable
564 "TF_CUFFT_WORKSPACE_LIMIT_IN_MB", 1LL << 32 // 4GB
565 );
566
567 template <bool Forward, bool _Real, int FFTRank>
568 class FFTGPU : public FFTGPUBase {
569 public:
570 static_assert(FFTRank >= 1 && FFTRank <= 3,
571 "Only 1D, 2D and 3D FFTs supported.");
FFTGPU(OpKernelConstruction * ctx)572 explicit FFTGPU(OpKernelConstruction* ctx) : FFTGPUBase(ctx) {}
573
574 protected:
Rank() const575 int Rank() const override { return FFTRank; }
IsForward() const576 bool IsForward() const override { return Forward; }
IsReal() const577 bool IsReal() const override { return _Real; }
578 };
579
580 // Register GPU kernels with priority 1 so that if a custom FFT CPU kernel is
581 // registered with priority 1 (to override the default Eigen CPU kernel), the
582 // CPU kernel does not outrank the GPU kernel.
583 REGISTER_KERNEL_BUILDER(Name("FFT").Device(DEVICE_GPU).Priority(1),
584 FFTGPU<true, false, 1>);
585 REGISTER_KERNEL_BUILDER(Name("IFFT").Device(DEVICE_GPU).Priority(1),
586 FFTGPU<false, false, 1>);
587 REGISTER_KERNEL_BUILDER(Name("FFT2D").Device(DEVICE_GPU).Priority(1),
588 FFTGPU<true, false, 2>);
589 REGISTER_KERNEL_BUILDER(Name("IFFT2D").Device(DEVICE_GPU).Priority(1),
590 FFTGPU<false, false, 2>);
591 REGISTER_KERNEL_BUILDER(Name("FFT3D").Device(DEVICE_GPU).Priority(1),
592 FFTGPU<true, false, 3>);
593 REGISTER_KERNEL_BUILDER(Name("IFFT3D").Device(DEVICE_GPU).Priority(1),
594 FFTGPU<false, false, 3>);
595
596 REGISTER_KERNEL_BUILDER(
597 Name("RFFT").Device(DEVICE_GPU).HostMemory("fft_length").Priority(1),
598 FFTGPU<true, true, 1>);
599 REGISTER_KERNEL_BUILDER(
600 Name("IRFFT").Device(DEVICE_GPU).HostMemory("fft_length").Priority(1),
601 FFTGPU<false, true, 1>);
602 REGISTER_KERNEL_BUILDER(
603 Name("RFFT2D").Device(DEVICE_GPU).HostMemory("fft_length").Priority(1),
604 FFTGPU<true, true, 2>);
605 REGISTER_KERNEL_BUILDER(
606 Name("IRFFT2D").Device(DEVICE_GPU).HostMemory("fft_length").Priority(1),
607 FFTGPU<false, true, 2>);
608 REGISTER_KERNEL_BUILDER(
609 Name("RFFT3D").Device(DEVICE_GPU).HostMemory("fft_length").Priority(1),
610 FFTGPU<true, true, 3>);
611 REGISTER_KERNEL_BUILDER(
612 Name("IRFFT3D").Device(DEVICE_GPU).HostMemory("fft_length").Priority(1),
613 FFTGPU<false, true, 3>);
614
615 // Deprecated kernels.
616 REGISTER_KERNEL_BUILDER(Name("BatchFFT").Device(DEVICE_GPU).Priority(1),
617 FFTGPU<true, false, 1>);
618 REGISTER_KERNEL_BUILDER(Name("BatchIFFT").Device(DEVICE_GPU).Priority(1),
619 FFTGPU<false, false, 1>);
620 REGISTER_KERNEL_BUILDER(Name("BatchFFT2D").Device(DEVICE_GPU).Priority(1),
621 FFTGPU<true, false, 2>);
622 REGISTER_KERNEL_BUILDER(Name("BatchIFFT2D").Device(DEVICE_GPU).Priority(1),
623 FFTGPU<false, false, 2>);
624 REGISTER_KERNEL_BUILDER(Name("BatchFFT3D").Device(DEVICE_GPU).Priority(1),
625 FFTGPU<true, false, 3>);
626 REGISTER_KERNEL_BUILDER(Name("BatchIFFT3D").Device(DEVICE_GPU).Priority(1),
627 FFTGPU<false, false, 3>);
628 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
629
630 } // end namespace tensorflow
631