• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/platform/errors.h"
17 #define EIGEN_USE_THREADS
18 
19 // See docs in ../ops/fft_ops.cc.
20 
21 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
22 #include "tensorflow/core/framework/op.h"
23 #include "tensorflow/core/framework/op_kernel.h"
24 #include "tensorflow/core/framework/tensor.h"
25 #include "tensorflow/core/framework/tensor_shape.h"
26 #include "tensorflow/core/framework/types.h"
27 #include "tensorflow/core/lib/core/errors.h"
28 #include "tensorflow/core/platform/logging.h"
29 #include "tensorflow/core/platform/types.h"
30 #include "tensorflow/core/util/env_var.h"
31 #include "tensorflow/core/util/work_sharder.h"
32 
33 #if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \
34     (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM)
35 #include "tensorflow/core/platform/stream_executor.h"
36 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
37 
38 namespace tensorflow {
39 
40 class FFTBase : public OpKernel {
41  public:
FFTBase(OpKernelConstruction * ctx)42   explicit FFTBase(OpKernelConstruction* ctx) : OpKernel(ctx) {}
43 
Compute(OpKernelContext * ctx)44   void Compute(OpKernelContext* ctx) override {
45     const Tensor& in = ctx->input(0);
46     const TensorShape& input_shape = in.shape();
47     const int fft_rank = Rank();
48     OP_REQUIRES(
49         ctx, input_shape.dims() >= fft_rank,
50         errors::InvalidArgument("Input must have rank of at least ", fft_rank,
51                                 " but got: ", input_shape.DebugString()));
52 
53     Tensor* out;
54     TensorShape output_shape = input_shape;
55     uint64 fft_shape[3] = {0, 0, 0};
56 
57     // In R2C or C2R mode, we use a second input to specify the FFT length
58     // instead of inferring it from the input shape.
59     if (IsReal()) {
60       const Tensor& fft_length = ctx->input(1);
61       OP_REQUIRES(ctx,
62                   fft_length.shape().dims() == 1 &&
63                       fft_length.shape().dim_size(0) == fft_rank,
64                   errors::InvalidArgument("fft_length must have shape [",
65                                           fft_rank, "]"));
66 
67       auto fft_length_as_vec = fft_length.vec<int32>();
68       for (int i = 0; i < fft_rank; ++i) {
69         OP_REQUIRES(ctx, fft_length_as_vec(i) >= 0,
70                     errors::InvalidArgument(
71                         "fft_length[", i,
72                         "] must >= 0, but got: ", fft_length_as_vec(i)));
73         fft_shape[i] = fft_length_as_vec(i);
74         // Each input dimension must have length of at least fft_shape[i]. For
75         // IRFFTs, the inner-most input dimension must have length of at least
76         // fft_shape[i] / 2 + 1.
77         bool inner_most = (i == fft_rank - 1);
78         uint64 min_input_dim_length =
79             !IsForward() && inner_most ? fft_shape[i] / 2 + 1 : fft_shape[i];
80         auto input_index = input_shape.dims() - fft_rank + i;
81         OP_REQUIRES(
82             ctx,
83             // We pass through empty tensors, so special case them here.
84             input_shape.dim_size(input_index) == 0 ||
85                 input_shape.dim_size(input_index) >= min_input_dim_length,
86             errors::InvalidArgument(
87                 "Input dimension ", input_index,
88                 " must have length of at least ", min_input_dim_length,
89                 " but got: ", input_shape.dim_size(input_index)));
90         uint64 dim = IsForward() && inner_most && fft_shape[i] != 0
91                          ? fft_shape[i] / 2 + 1
92                          : fft_shape[i];
93         output_shape.set_dim(output_shape.dims() - fft_rank + i, dim);
94       }
95     } else {
96       for (int i = 0; i < fft_rank; ++i) {
97         fft_shape[i] =
98             output_shape.dim_size(output_shape.dims() - fft_rank + i);
99       }
100     }
101 
102     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, output_shape, &out));
103 
104     if (IsReal()) {
105       if (IsForward()) {
106         OP_REQUIRES(
107             ctx,
108             (in.dtype() == DT_FLOAT && out->dtype() == DT_COMPLEX64) ||
109                 (in.dtype() == DT_DOUBLE && out->dtype() == DT_COMPLEX128),
110             errors::InvalidArgument("Wrong types for forward real FFT: in=",
111                                     in.dtype(), " out=", out->dtype()));
112       } else {
113         OP_REQUIRES(
114             ctx,
115             (in.dtype() == DT_COMPLEX64 && out->dtype() == DT_FLOAT) ||
116                 (in.dtype() == DT_COMPLEX128 && out->dtype() == DT_DOUBLE),
117             errors::InvalidArgument("Wrong types for backward real FFT: in=",
118                                     in.dtype(), " out=", out->dtype()));
119       }
120     } else {
121       OP_REQUIRES(
122           ctx,
123           (in.dtype() == DT_COMPLEX64 && out->dtype() == DT_COMPLEX64) ||
124               (in.dtype() == DT_COMPLEX128 && out->dtype() == DT_COMPLEX128),
125           errors::InvalidArgument("Wrong types for FFT: in=", in.dtype(),
126                                   " out=", out->dtype()));
127     }
128 
129     if (input_shape.num_elements() == 0) {
130       DCHECK_EQ(0, output_shape.num_elements());
131       return;
132     }
133 
134     DoFFT(ctx, in, fft_shape, out);
135   }
136 
137  protected:
138   virtual int Rank() const = 0;
139   virtual bool IsForward() const = 0;
140   virtual bool IsReal() const = 0;
141 
142   // The function that actually computes the FFT.
143   virtual void DoFFT(OpKernelContext* ctx, const Tensor& in, uint64* fft_shape,
144                      Tensor* out) = 0;
145 };
146 
147 typedef Eigen::ThreadPoolDevice CPUDevice;
148 
149 template <bool Forward, bool _Real, int FFTRank>
150 class FFTCPU : public FFTBase {
151  public:
152   using FFTBase::FFTBase;
153 
154  protected:
Rank() const155   int Rank() const override { return FFTRank; }
IsForward() const156   bool IsForward() const override { return Forward; }
IsReal() const157   bool IsReal() const override { return _Real; }
158 
DoFFT(OpKernelContext * ctx,const Tensor & in,uint64 * fft_shape,Tensor * out)159   void DoFFT(OpKernelContext* ctx, const Tensor& in, uint64* fft_shape,
160              Tensor* out) override {
161     // Create the axes (which are always trailing).
162     const auto axes = Eigen::ArrayXi::LinSpaced(FFTRank, 1, FFTRank);
163     auto device = ctx->eigen_device<CPUDevice>();
164 
165     const bool is_complex128 =
166         in.dtype() == DT_COMPLEX128 || out->dtype() == DT_COMPLEX128;
167 
168     if (!IsReal()) {
169       // Compute the FFT using Eigen.
170       constexpr auto direction =
171           Forward ? Eigen::FFT_FORWARD : Eigen::FFT_REVERSE;
172       if (is_complex128) {
173         DCHECK_EQ(in.dtype(), DT_COMPLEX128);
174         DCHECK_EQ(out->dtype(), DT_COMPLEX128);
175         auto input = Tensor(in).flat_inner_dims<complex128, FFTRank + 1>();
176         auto output = out->flat_inner_dims<complex128, FFTRank + 1>();
177         output.device(device) =
178             input.template fft<Eigen::BothParts, direction>(axes);
179       } else {
180         DCHECK_EQ(in.dtype(), DT_COMPLEX64);
181         DCHECK_EQ(out->dtype(), DT_COMPLEX64);
182         auto input = Tensor(in).flat_inner_dims<complex64, FFTRank + 1>();
183         auto output = out->flat_inner_dims<complex64, FFTRank + 1>();
184         output.device(device) =
185             input.template fft<Eigen::BothParts, direction>(axes);
186       }
187     } else {
188       if (IsForward()) {
189         if (is_complex128) {
190           DCHECK_EQ(in.dtype(), DT_DOUBLE);
191           DCHECK_EQ(out->dtype(), DT_COMPLEX128);
192           DoRealForwardFFT<double, complex128>(ctx, fft_shape, in, out);
193         } else {
194           DCHECK_EQ(in.dtype(), DT_FLOAT);
195           DCHECK_EQ(out->dtype(), DT_COMPLEX64);
196           DoRealForwardFFT<float, complex64>(ctx, fft_shape, in, out);
197         }
198       } else {
199         if (is_complex128) {
200           DCHECK_EQ(in.dtype(), DT_COMPLEX128);
201           DCHECK_EQ(out->dtype(), DT_DOUBLE);
202           DoRealBackwardFFT<complex128, double>(ctx, fft_shape, in, out);
203         } else {
204           DCHECK_EQ(in.dtype(), DT_COMPLEX64);
205           DCHECK_EQ(out->dtype(), DT_FLOAT);
206           DoRealBackwardFFT<complex64, float>(ctx, fft_shape, in, out);
207         }
208       }
209     }
210   }
211 
212   template <typename RealT, typename ComplexT>
DoRealForwardFFT(OpKernelContext * ctx,uint64 * fft_shape,const Tensor & in,Tensor * out)213   void DoRealForwardFFT(OpKernelContext* ctx, uint64* fft_shape,
214                         const Tensor& in, Tensor* out) {
215     // Create the axes (which are always trailing).
216     const auto axes = Eigen::ArrayXi::LinSpaced(FFTRank, 1, FFTRank);
217     auto device = ctx->eigen_device<CPUDevice>();
218     auto input = Tensor(in).flat_inner_dims<RealT, FFTRank + 1>();
219     const auto input_dims = input.dimensions();
220 
221     // Slice input to fft_shape on its inner-most dimensions.
222     Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> input_slice_sizes;
223     input_slice_sizes[0] = input_dims[0];
224     TensorShape temp_shape{input_dims[0]};
225     for (int i = 1; i <= FFTRank; ++i) {
226       input_slice_sizes[i] = fft_shape[i - 1];
227       temp_shape.AddDim(fft_shape[i - 1]);
228     }
229     OP_REQUIRES(ctx, temp_shape.num_elements() > 0,
230                 errors::InvalidArgument("Obtained a FFT shape of 0 elements: ",
231                                         temp_shape.DebugString()));
232 
233     auto output = out->flat_inner_dims<ComplexT, FFTRank + 1>();
234     const Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> zero_start_indices;
235 
236     // Compute the full FFT using a temporary tensor.
237     Tensor temp;
238     OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<ComplexT>::v(),
239                                            temp_shape, &temp));
240     auto full_fft = temp.flat_inner_dims<ComplexT, FFTRank + 1>();
241     full_fft.device(device) =
242         input.slice(zero_start_indices, input_slice_sizes)
243             .template fft<Eigen::BothParts, Eigen::FFT_FORWARD>(axes);
244 
245     // Slice away the negative frequency components.
246     output.device(device) =
247         full_fft.slice(zero_start_indices, output.dimensions());
248   }
249 
250   template <typename ComplexT, typename RealT>
DoRealBackwardFFT(OpKernelContext * ctx,uint64 * fft_shape,const Tensor & in,Tensor * out)251   void DoRealBackwardFFT(OpKernelContext* ctx, uint64* fft_shape,
252                          const Tensor& in, Tensor* out) {
253     auto device = ctx->eigen_device<CPUDevice>();
254     // Reconstruct the full FFT and take the inverse.
255     auto input = Tensor(in).flat_inner_dims<ComplexT, FFTRank + 1>();
256     auto output = out->flat_inner_dims<RealT, FFTRank + 1>();
257     const auto input_dims = input.dimensions();
258 
259     // Calculate the shape of the temporary tensor for the full FFT and the
260     // region we will slice from input given fft_shape. We slice input to
261     // fft_shape on its inner-most dimensions, except the last (which we
262     // slice to fft_shape[-1] / 2 + 1).
263     Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> input_slice_sizes;
264     input_slice_sizes[0] = input_dims[0];
265     TensorShape full_fft_shape;
266     full_fft_shape.AddDim(input_dims[0]);
267     for (auto i = 1; i <= FFTRank; i++) {
268       input_slice_sizes[i] =
269           i == FFTRank ? fft_shape[i - 1] / 2 + 1 : fft_shape[i - 1];
270       full_fft_shape.AddDim(fft_shape[i - 1]);
271     }
272     OP_REQUIRES(ctx, full_fft_shape.num_elements() > 0,
273                 errors::InvalidArgument("Obtained a FFT shape of 0 elements: ",
274                                         full_fft_shape.DebugString()));
275 
276     Tensor temp;
277     OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<ComplexT>::v(),
278                                            full_fft_shape, &temp));
279     auto full_fft = temp.flat_inner_dims<ComplexT, FFTRank + 1>();
280 
281     // Calculate the starting point and range of the source of
282     // negative frequency part.
283     auto neg_sizes = input_slice_sizes;
284     neg_sizes[FFTRank] = fft_shape[FFTRank - 1] - input_slice_sizes[FFTRank];
285     Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> neg_target_indices;
286     neg_target_indices[FFTRank] = input_slice_sizes[FFTRank];
287 
288     const Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> start_indices;
289     Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> neg_start_indices;
290     neg_start_indices[FFTRank] = 1;
291 
292     full_fft.slice(start_indices, input_slice_sizes).device(device) =
293         input.slice(start_indices, input_slice_sizes);
294 
295     // First, conduct IFFTs on outer dimensions. We save computation (and
296     // avoid touching uninitialized memory) by slicing full_fft to the
297     // subregion we wrote input to.
298     if (FFTRank > 1) {
299       const auto outer_axes =
300           Eigen::ArrayXi::LinSpaced(FFTRank - 1, 1, FFTRank - 1);
301       full_fft.slice(start_indices, input_slice_sizes).device(device) =
302           full_fft.slice(start_indices, input_slice_sizes)
303               .template fft<Eigen::BothParts, Eigen::FFT_REVERSE>(outer_axes);
304     }
305 
306     // Reconstruct the full FFT by appending reversed and conjugated
307     // spectrum as the negative frequency part.
308     Eigen::array<bool, FFTRank + 1> reverse_last_axis;
309     for (auto i = 0; i <= FFTRank; i++) {
310       reverse_last_axis[i] = i == FFTRank;
311     }
312 
313     if (neg_sizes[FFTRank] != 0) {
314       full_fft.slice(neg_target_indices, neg_sizes).device(device) =
315           full_fft.slice(neg_start_indices, neg_sizes)
316               .reverse(reverse_last_axis)
317               .conjugate();
318     }
319 
320     auto inner_axis = Eigen::array<int, 1>{FFTRank};
321     output.device(device) =
322         full_fft.template fft<Eigen::RealPart, Eigen::FFT_REVERSE>(inner_axis);
323   }
324 };
325 
326 REGISTER_KERNEL_BUILDER(Name("FFT").Device(DEVICE_CPU), FFTCPU<true, false, 1>);
327 REGISTER_KERNEL_BUILDER(Name("IFFT").Device(DEVICE_CPU),
328                         FFTCPU<false, false, 1>);
329 REGISTER_KERNEL_BUILDER(Name("FFT2D").Device(DEVICE_CPU),
330                         FFTCPU<true, false, 2>);
331 REGISTER_KERNEL_BUILDER(Name("IFFT2D").Device(DEVICE_CPU),
332                         FFTCPU<false, false, 2>);
333 REGISTER_KERNEL_BUILDER(Name("FFT3D").Device(DEVICE_CPU),
334                         FFTCPU<true, false, 3>);
335 REGISTER_KERNEL_BUILDER(Name("IFFT3D").Device(DEVICE_CPU),
336                         FFTCPU<false, false, 3>);
337 
338 REGISTER_KERNEL_BUILDER(Name("RFFT").Device(DEVICE_CPU), FFTCPU<true, true, 1>);
339 REGISTER_KERNEL_BUILDER(Name("IRFFT").Device(DEVICE_CPU),
340                         FFTCPU<false, true, 1>);
341 REGISTER_KERNEL_BUILDER(Name("RFFT2D").Device(DEVICE_CPU),
342                         FFTCPU<true, true, 2>);
343 REGISTER_KERNEL_BUILDER(Name("IRFFT2D").Device(DEVICE_CPU),
344                         FFTCPU<false, true, 2>);
345 REGISTER_KERNEL_BUILDER(Name("RFFT3D").Device(DEVICE_CPU),
346                         FFTCPU<true, true, 3>);
347 REGISTER_KERNEL_BUILDER(Name("IRFFT3D").Device(DEVICE_CPU),
348                         FFTCPU<false, true, 3>);
349 
350 #if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \
351     (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM)
352 
353 namespace {
354 template <typename T>
AsDeviceMemory(const T * cuda_memory)355 se::DeviceMemory<T> AsDeviceMemory(const T* cuda_memory) {
356   se::DeviceMemoryBase wrapped(const_cast<T*>(cuda_memory));
357   se::DeviceMemory<T> typed(wrapped);
358   return typed;
359 }
360 
361 template <typename T>
AsDeviceMemory(const T * cuda_memory,uint64 size)362 se::DeviceMemory<T> AsDeviceMemory(const T* cuda_memory, uint64 size) {
363   se::DeviceMemoryBase wrapped(const_cast<T*>(cuda_memory), size * sizeof(T));
364   se::DeviceMemory<T> typed(wrapped);
365   return typed;
366 }
367 
368 // A class to provide scratch-space allocator for Stream-Executor Cufft
369 // callback. Tensorflow is responsible for releasing the temporary buffers after
370 // the kernel finishes.
371 // TODO(yangzihao): Refactor redundant code in subclasses of ScratchAllocator
372 // into base class.
373 class CufftScratchAllocator : public se::ScratchAllocator {
374  public:
~CufftScratchAllocator()375   ~CufftScratchAllocator() override {}
CufftScratchAllocator(int64_t memory_limit,OpKernelContext * context)376   CufftScratchAllocator(int64_t memory_limit, OpKernelContext* context)
377       : memory_limit_(memory_limit), total_byte_size_(0), context_(context) {}
GetMemoryLimitInBytes()378   int64_t GetMemoryLimitInBytes() override { return memory_limit_; }
AllocateBytes(int64_t byte_size)379   se::port::StatusOr<se::DeviceMemory<uint8>> AllocateBytes(
380       int64_t byte_size) override {
381     Tensor temporary_memory;
382     if (byte_size > memory_limit_) {
383       return se::port::StatusOr<se::DeviceMemory<uint8>>();
384     }
385     AllocationAttributes allocation_attr;
386     allocation_attr.retry_on_failure = false;
387     Status allocation_status(context_->allocate_temp(
388         DT_UINT8, TensorShape({byte_size}), &temporary_memory,
389         AllocatorAttributes(), allocation_attr));
390     if (!allocation_status.ok()) {
391       return se::port::StatusOr<se::DeviceMemory<uint8>>();
392     }
393     // Hold the reference of the allocated tensors until the end of the
394     // allocator.
395     allocated_tensors_.push_back(temporary_memory);
396     total_byte_size_ += byte_size;
397     return se::port::StatusOr<se::DeviceMemory<uint8>>(
398         AsDeviceMemory(temporary_memory.flat<uint8>().data(),
399                        temporary_memory.flat<uint8>().size()));
400   }
TotalByteSize()401   int64_t TotalByteSize() { return total_byte_size_; }
402 
403  private:
404   int64_t memory_limit_;
405   int64_t total_byte_size_;
406   OpKernelContext* context_;
407   std::vector<Tensor> allocated_tensors_;
408 };
409 
410 }  // end namespace
411 
GetCufftWorkspaceLimit(const string & envvar_in_mb,int64_t default_value_in_bytes)412 int64_t GetCufftWorkspaceLimit(const string& envvar_in_mb,
413                                int64_t default_value_in_bytes) {
414   const char* workspace_limit_in_mb_str = getenv(envvar_in_mb.c_str());
415   if (workspace_limit_in_mb_str != nullptr &&
416       strcmp(workspace_limit_in_mb_str, "") != 0) {
417     int64_t scratch_limit_in_mb = -1;
418     Status status = ReadInt64FromEnvVar(envvar_in_mb, default_value_in_bytes,
419                                         &scratch_limit_in_mb);
420     if (!status.ok()) {
421       LOG(WARNING) << "Invalid value for env-var " << envvar_in_mb << ": "
422                    << workspace_limit_in_mb_str;
423     } else {
424       return scratch_limit_in_mb * (1 << 20);
425     }
426   }
427   return default_value_in_bytes;
428 }
429 
430 class FFTGPUBase : public FFTBase {
431  public:
432   using FFTBase::FFTBase;
433 
434  protected:
435   static int64_t CufftScratchSize;
DoFFT(OpKernelContext * ctx,const Tensor & in,uint64 * fft_shape,Tensor * out)436   void DoFFT(OpKernelContext* ctx, const Tensor& in, uint64* fft_shape,
437              Tensor* out) override {
438     auto* stream = ctx->op_device_context()->stream();
439     OP_REQUIRES(ctx, stream, errors::Internal("No GPU stream available."));
440 
441     const TensorShape& input_shape = in.shape();
442     const TensorShape& output_shape = out->shape();
443 
444     const int fft_rank = Rank();
445     int batch_size = 1;
446     for (int i = 0; i < input_shape.dims() - fft_rank; ++i) {
447       batch_size *= input_shape.dim_size(i);
448     }
449     uint64 input_embed[3];
450     const uint64 input_stride = 1;
451     uint64 input_distance = 1;
452     uint64 output_embed[3];
453     const uint64 output_stride = 1;
454     uint64 output_distance = 1;
455 
456     for (int i = 0; i < fft_rank; ++i) {
457       auto dim_offset = input_shape.dims() - fft_rank + i;
458       input_embed[i] = input_shape.dim_size(dim_offset);
459       input_distance *= input_shape.dim_size(dim_offset);
460       output_embed[i] = output_shape.dim_size(dim_offset);
461       output_distance *= output_shape.dim_size(dim_offset);
462     }
463 
464     constexpr bool kInPlaceFft = false;
465     const bool is_complex128 =
466         in.dtype() == DT_COMPLEX128 || out->dtype() == DT_COMPLEX128;
467 
468     const auto kFftType =
469         IsReal()
470             ? (IsForward()
471                    ? (is_complex128 ? se::fft::Type::kD2Z : se::fft::Type::kR2C)
472                    : (is_complex128 ? se::fft::Type::kZ2D
473                                     : se::fft::Type::kC2R))
474             : (IsForward() ? (is_complex128 ? se::fft::Type::kZ2ZForward
475                                             : se::fft::Type::kC2CForward)
476                            : (is_complex128 ? se::fft::Type::kZ2ZInverse
477                                             : se::fft::Type::kC2CInverse));
478 
479     CufftScratchAllocator scratch_allocator(CufftScratchSize, ctx);
480     auto plan =
481         stream->parent()->AsFft()->CreateBatchedPlanWithScratchAllocator(
482             stream, fft_rank, fft_shape, input_embed, input_stride,
483             input_distance, output_embed, output_stride, output_distance,
484             kFftType, kInPlaceFft, batch_size, &scratch_allocator);
485     OP_REQUIRES(
486         ctx, plan != nullptr,
487         errors::Internal(
488             "Failed to create cuFFT batched plan with scratch allocator"));
489 
490     if (IsReal()) {
491       if (IsForward()) {
492         if (is_complex128) {
493           DCHECK_EQ(in.dtype(), DT_DOUBLE);
494           DCHECK_EQ(out->dtype(), DT_COMPLEX128);
495           DoFFTInternal<double, complex128>(ctx, stream, plan.get(), kFftType,
496                                             output_distance, in, out);
497         } else {
498           DCHECK_EQ(in.dtype(), DT_FLOAT);
499           DCHECK_EQ(out->dtype(), DT_COMPLEX64);
500           DoFFTInternal<float, complex64>(ctx, stream, plan.get(), kFftType,
501                                           output_distance, in, out);
502         }
503       } else {
504         if (is_complex128) {
505           DCHECK_EQ(in.dtype(), DT_COMPLEX128);
506           DCHECK_EQ(out->dtype(), DT_DOUBLE);
507           DoFFTInternal<complex128, double>(ctx, stream, plan.get(), kFftType,
508                                             output_distance, in, out);
509         } else {
510           DCHECK_EQ(in.dtype(), DT_COMPLEX64);
511           DCHECK_EQ(out->dtype(), DT_FLOAT);
512           DoFFTInternal<complex64, float>(ctx, stream, plan.get(), kFftType,
513                                           output_distance, in, out);
514         }
515       }
516     } else {
517       if (is_complex128) {
518         DCHECK_EQ(in.dtype(), DT_COMPLEX128);
519         DCHECK_EQ(out->dtype(), DT_COMPLEX128);
520         DoFFTInternal<complex128, complex128>(ctx, stream, plan.get(), kFftType,
521                                               output_distance, in, out);
522       } else {
523         DCHECK_EQ(in.dtype(), DT_COMPLEX64);
524         DCHECK_EQ(out->dtype(), DT_COMPLEX64);
525         DoFFTInternal<complex64, complex64>(ctx, stream, plan.get(), kFftType,
526                                             output_distance, in, out);
527       }
528     }
529   }
530 
531  private:
532   template <typename T>
533   struct RealTypeFromComplexType {
534     typedef T RealT;
535   };
536 
537   template <typename T>
538   struct RealTypeFromComplexType<std::complex<T>> {
539     typedef T RealT;
540   };
541 
542   template <typename InT, typename OutT>
DoFFTInternal(OpKernelContext * ctx,se::Stream * stream,se::fft::Plan * plan,const se::fft::Type fft_type,const uint64 output_distance,const Tensor & in,Tensor * out)543   void DoFFTInternal(OpKernelContext* ctx, se::Stream* stream,
544                      se::fft::Plan* plan, const se::fft::Type fft_type,
545                      const uint64 output_distance, const Tensor& in,
546                      Tensor* out) {
547     const TensorShape& input_shape = in.shape();
548     const TensorShape& output_shape = out->shape();
549     auto src =
550         AsDeviceMemory<InT>(in.flat<InT>().data(), input_shape.num_elements());
551     auto dst = AsDeviceMemory<OutT>(out->flat<OutT>().data(),
552                                     output_shape.num_elements());
553     OP_REQUIRES(
554         ctx, stream->ThenFft(plan, src, &dst).ok(),
555         errors::Internal("fft failed : type=", static_cast<int>(fft_type),
556                          " in.shape=", input_shape.DebugString()));
557     if (!IsForward()) {
558       typedef typename RealTypeFromComplexType<OutT>::RealT RealT;
559       RealT alpha = 1.0 / output_distance;
560       OP_REQUIRES(
561           ctx,
562           stream->ThenBlasScal(output_shape.num_elements(), alpha, &dst, 1)
563               .ok(),
564           errors::Internal("BlasScal failed : in.shape=",
565                            input_shape.DebugString()));
566     }
567   }
568 };
569 
570 int64_t FFTGPUBase::CufftScratchSize = GetCufftWorkspaceLimit(
571     // default value is in bytes despite the name of the environment variable
572     "TF_CUFFT_WORKSPACE_LIMIT_IN_MB", 1LL << 32  // 4GB
573 );
574 
575 template <bool Forward, bool _Real, int FFTRank>
576 class FFTGPU : public FFTGPUBase {
577  public:
578   static_assert(FFTRank >= 1 && FFTRank <= 3,
579                 "Only 1D, 2D and 3D FFTs supported.");
FFTGPU(OpKernelConstruction * ctx)580   explicit FFTGPU(OpKernelConstruction* ctx) : FFTGPUBase(ctx) {}
581 
582  protected:
Rank() const583   int Rank() const override { return FFTRank; }
IsForward() const584   bool IsForward() const override { return Forward; }
IsReal() const585   bool IsReal() const override { return _Real; }
586 };
587 
588 // Register GPU kernels with priority 1 so that if a custom FFT CPU kernel is
589 // registered with priority 1 (to override the default Eigen CPU kernel), the
590 // CPU kernel does not outrank the GPU kernel.
591 REGISTER_KERNEL_BUILDER(Name("FFT").Device(DEVICE_GPU).Priority(1),
592                         FFTGPU<true, false, 1>);
593 REGISTER_KERNEL_BUILDER(Name("IFFT").Device(DEVICE_GPU).Priority(1),
594                         FFTGPU<false, false, 1>);
595 REGISTER_KERNEL_BUILDER(Name("FFT2D").Device(DEVICE_GPU).Priority(1),
596                         FFTGPU<true, false, 2>);
597 REGISTER_KERNEL_BUILDER(Name("IFFT2D").Device(DEVICE_GPU).Priority(1),
598                         FFTGPU<false, false, 2>);
599 REGISTER_KERNEL_BUILDER(Name("FFT3D").Device(DEVICE_GPU).Priority(1),
600                         FFTGPU<true, false, 3>);
601 REGISTER_KERNEL_BUILDER(Name("IFFT3D").Device(DEVICE_GPU).Priority(1),
602                         FFTGPU<false, false, 3>);
603 
604 REGISTER_KERNEL_BUILDER(
605     Name("RFFT").Device(DEVICE_GPU).HostMemory("fft_length").Priority(1),
606     FFTGPU<true, true, 1>);
607 REGISTER_KERNEL_BUILDER(
608     Name("IRFFT").Device(DEVICE_GPU).HostMemory("fft_length").Priority(1),
609     FFTGPU<false, true, 1>);
610 REGISTER_KERNEL_BUILDER(
611     Name("RFFT2D").Device(DEVICE_GPU).HostMemory("fft_length").Priority(1),
612     FFTGPU<true, true, 2>);
613 REGISTER_KERNEL_BUILDER(
614     Name("IRFFT2D").Device(DEVICE_GPU).HostMemory("fft_length").Priority(1),
615     FFTGPU<false, true, 2>);
616 REGISTER_KERNEL_BUILDER(
617     Name("RFFT3D").Device(DEVICE_GPU).HostMemory("fft_length").Priority(1),
618     FFTGPU<true, true, 3>);
619 REGISTER_KERNEL_BUILDER(
620     Name("IRFFT3D").Device(DEVICE_GPU).HostMemory("fft_length").Priority(1),
621     FFTGPU<false, true, 3>);
622 
623 // Deprecated kernels.
624 REGISTER_KERNEL_BUILDER(Name("BatchFFT").Device(DEVICE_GPU).Priority(1),
625                         FFTGPU<true, false, 1>);
626 REGISTER_KERNEL_BUILDER(Name("BatchIFFT").Device(DEVICE_GPU).Priority(1),
627                         FFTGPU<false, false, 1>);
628 REGISTER_KERNEL_BUILDER(Name("BatchFFT2D").Device(DEVICE_GPU).Priority(1),
629                         FFTGPU<true, false, 2>);
630 REGISTER_KERNEL_BUILDER(Name("BatchIFFT2D").Device(DEVICE_GPU).Priority(1),
631                         FFTGPU<false, false, 2>);
632 REGISTER_KERNEL_BUILDER(Name("BatchFFT3D").Device(DEVICE_GPU).Priority(1),
633                         FFTGPU<true, false, 3>);
634 REGISTER_KERNEL_BUILDER(Name("BatchIFFT3D").Device(DEVICE_GPU).Priority(1),
635                         FFTGPU<false, false, 3>);
636 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
637 
638 }  // end namespace tensorflow
639