• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/platform/errors.h"
17 #define EIGEN_USE_THREADS
18 
19 // See docs in ../ops/fft_ops.cc.
20 
21 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
22 #include "tensorflow/core/framework/op.h"
23 #include "tensorflow/core/framework/op_kernel.h"
24 #include "tensorflow/core/framework/tensor.h"
25 #include "tensorflow/core/framework/tensor_shape.h"
26 #include "tensorflow/core/framework/types.h"
27 #include "tensorflow/core/lib/core/errors.h"
28 #include "tensorflow/core/platform/logging.h"
29 #include "tensorflow/core/platform/types.h"
30 #include "tensorflow/core/util/env_var.h"
31 #include "tensorflow/core/util/work_sharder.h"
32 
33 #if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \
34     (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM)
35 #include "tensorflow/core/platform/stream_executor.h"
36 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
37 
38 namespace tensorflow {
39 
40 class FFTBase : public OpKernel {
41  public:
FFTBase(OpKernelConstruction * ctx)42   explicit FFTBase(OpKernelConstruction* ctx) : OpKernel(ctx) {}
43 
Compute(OpKernelContext * ctx)44   void Compute(OpKernelContext* ctx) override {
45     const Tensor& in = ctx->input(0);
46     const TensorShape& input_shape = in.shape();
47     const int fft_rank = Rank();
48     OP_REQUIRES(
49         ctx, input_shape.dims() >= fft_rank,
50         errors::InvalidArgument("Input must have rank of at least ", fft_rank,
51                                 " but got: ", input_shape.DebugString()));
52 
53     Tensor* out;
54     TensorShape output_shape = input_shape;
55     uint64 fft_shape[3] = {0, 0, 0};
56 
57     // In R2C or C2R mode, we use a second input to specify the FFT length
58     // instead of inferring it from the input shape.
59     if (IsReal()) {
60       const Tensor& fft_length = ctx->input(1);
61       OP_REQUIRES(ctx,
62                   fft_length.shape().dims() == 1 &&
63                       fft_length.shape().dim_size(0) == fft_rank,
64                   errors::InvalidArgument("fft_length must have shape [",
65                                           fft_rank, "]"));
66 
67       auto fft_length_as_vec = fft_length.vec<int32>();
68       for (int i = 0; i < fft_rank; ++i) {
69         fft_shape[i] = fft_length_as_vec(i);
70         // Each input dimension must have length of at least fft_shape[i]. For
71         // IRFFTs, the inner-most input dimension must have length of at least
72         // fft_shape[i] / 2 + 1.
73         bool inner_most = (i == fft_rank - 1);
74         uint64 min_input_dim_length =
75             !IsForward() && inner_most ? fft_shape[i] / 2 + 1 : fft_shape[i];
76         auto input_index = input_shape.dims() - fft_rank + i;
77         OP_REQUIRES(
78             ctx,
79             // We pass through empty tensors, so special case them here.
80             input_shape.dim_size(input_index) == 0 ||
81                 input_shape.dim_size(input_index) >= min_input_dim_length,
82             errors::InvalidArgument(
83                 "Input dimension ", input_index,
84                 " must have length of at least ", min_input_dim_length,
85                 " but got: ", input_shape.dim_size(input_index)));
86         uint64 dim = IsForward() && inner_most && fft_shape[i] != 0
87                          ? fft_shape[i] / 2 + 1
88                          : fft_shape[i];
89         output_shape.set_dim(output_shape.dims() - fft_rank + i, dim);
90       }
91     } else {
92       for (int i = 0; i < fft_rank; ++i) {
93         fft_shape[i] =
94             output_shape.dim_size(output_shape.dims() - fft_rank + i);
95       }
96     }
97 
98     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, output_shape, &out));
99 
100     if (IsReal()) {
101       if (IsForward()) {
102         OP_REQUIRES(
103             ctx,
104             (in.dtype() == DT_FLOAT && out->dtype() == DT_COMPLEX64) ||
105                 (in.dtype() == DT_DOUBLE && out->dtype() == DT_COMPLEX128),
106             errors::InvalidArgument("Wrong types for forward real FFT: in=",
107                                     in.dtype(), " out=", out->dtype()));
108       } else {
109         OP_REQUIRES(
110             ctx,
111             (in.dtype() == DT_COMPLEX64 && out->dtype() == DT_FLOAT) ||
112                 (in.dtype() == DT_COMPLEX128 && out->dtype() == DT_DOUBLE),
113             errors::InvalidArgument("Wrong types for backward real FFT: in=",
114                                     in.dtype(), " out=", out->dtype()));
115       }
116     } else {
117       OP_REQUIRES(
118           ctx,
119           (in.dtype() == DT_COMPLEX64 && out->dtype() == DT_COMPLEX64) ||
120               (in.dtype() == DT_COMPLEX128 && out->dtype() == DT_COMPLEX128),
121           errors::InvalidArgument("Wrong types for FFT: in=", in.dtype(),
122                                   " out=", out->dtype()));
123     }
124 
125     if (input_shape.num_elements() == 0) {
126       DCHECK_EQ(0, output_shape.num_elements());
127       return;
128     }
129 
130     DoFFT(ctx, in, fft_shape, out);
131   }
132 
133  protected:
134   virtual int Rank() const = 0;
135   virtual bool IsForward() const = 0;
136   virtual bool IsReal() const = 0;
137 
138   // The function that actually computes the FFT.
139   virtual void DoFFT(OpKernelContext* ctx, const Tensor& in, uint64* fft_shape,
140                      Tensor* out) = 0;
141 };
142 
143 typedef Eigen::ThreadPoolDevice CPUDevice;
144 
145 template <bool Forward, bool _Real, int FFTRank>
146 class FFTCPU : public FFTBase {
147  public:
148   using FFTBase::FFTBase;
149 
150  protected:
Rank() const151   int Rank() const override { return FFTRank; }
IsForward() const152   bool IsForward() const override { return Forward; }
IsReal() const153   bool IsReal() const override { return _Real; }
154 
DoFFT(OpKernelContext * ctx,const Tensor & in,uint64 * fft_shape,Tensor * out)155   void DoFFT(OpKernelContext* ctx, const Tensor& in, uint64* fft_shape,
156              Tensor* out) override {
157     // Create the axes (which are always trailing).
158     const auto axes = Eigen::ArrayXi::LinSpaced(FFTRank, 1, FFTRank);
159     auto device = ctx->eigen_device<CPUDevice>();
160 
161     const bool is_complex128 =
162         in.dtype() == DT_COMPLEX128 || out->dtype() == DT_COMPLEX128;
163 
164     if (!IsReal()) {
165       // Compute the FFT using Eigen.
166       constexpr auto direction =
167           Forward ? Eigen::FFT_FORWARD : Eigen::FFT_REVERSE;
168       if (is_complex128) {
169         DCHECK_EQ(in.dtype(), DT_COMPLEX128);
170         DCHECK_EQ(out->dtype(), DT_COMPLEX128);
171         auto input = Tensor(in).flat_inner_dims<complex128, FFTRank + 1>();
172         auto output = out->flat_inner_dims<complex128, FFTRank + 1>();
173         output.device(device) =
174             input.template fft<Eigen::BothParts, direction>(axes);
175       } else {
176         DCHECK_EQ(in.dtype(), DT_COMPLEX64);
177         DCHECK_EQ(out->dtype(), DT_COMPLEX64);
178         auto input = Tensor(in).flat_inner_dims<complex64, FFTRank + 1>();
179         auto output = out->flat_inner_dims<complex64, FFTRank + 1>();
180         output.device(device) =
181             input.template fft<Eigen::BothParts, direction>(axes);
182       }
183     } else {
184       if (IsForward()) {
185         if (is_complex128) {
186           DCHECK_EQ(in.dtype(), DT_DOUBLE);
187           DCHECK_EQ(out->dtype(), DT_COMPLEX128);
188           DoRealForwardFFT<double, complex128>(ctx, fft_shape, in, out);
189         } else {
190           DCHECK_EQ(in.dtype(), DT_FLOAT);
191           DCHECK_EQ(out->dtype(), DT_COMPLEX64);
192           DoRealForwardFFT<float, complex64>(ctx, fft_shape, in, out);
193         }
194       } else {
195         if (is_complex128) {
196           DCHECK_EQ(in.dtype(), DT_COMPLEX128);
197           DCHECK_EQ(out->dtype(), DT_DOUBLE);
198           DoRealBackwardFFT<complex128, double>(ctx, fft_shape, in, out);
199         } else {
200           DCHECK_EQ(in.dtype(), DT_COMPLEX64);
201           DCHECK_EQ(out->dtype(), DT_FLOAT);
202           DoRealBackwardFFT<complex64, float>(ctx, fft_shape, in, out);
203         }
204       }
205     }
206   }
207 
208   template <typename RealT, typename ComplexT>
DoRealForwardFFT(OpKernelContext * ctx,uint64 * fft_shape,const Tensor & in,Tensor * out)209   void DoRealForwardFFT(OpKernelContext* ctx, uint64* fft_shape,
210                         const Tensor& in, Tensor* out) {
211     // Create the axes (which are always trailing).
212     const auto axes = Eigen::ArrayXi::LinSpaced(FFTRank, 1, FFTRank);
213     auto device = ctx->eigen_device<CPUDevice>();
214     auto input = Tensor(in).flat_inner_dims<RealT, FFTRank + 1>();
215     const auto input_dims = input.dimensions();
216 
217     // Slice input to fft_shape on its inner-most dimensions.
218     Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> input_slice_sizes;
219     input_slice_sizes[0] = input_dims[0];
220     TensorShape temp_shape{input_dims[0]};
221     for (int i = 1; i <= FFTRank; ++i) {
222       input_slice_sizes[i] = fft_shape[i - 1];
223       temp_shape.AddDim(fft_shape[i - 1]);
224     }
225     OP_REQUIRES(ctx, temp_shape.num_elements() > 0,
226                 errors::InvalidArgument("Obtained a FFT shape of 0 elements: ",
227                                         temp_shape.DebugString()));
228 
229     auto output = out->flat_inner_dims<ComplexT, FFTRank + 1>();
230     const Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> zero_start_indices;
231 
232     // Compute the full FFT using a temporary tensor.
233     Tensor temp;
234     OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<ComplexT>::v(),
235                                            temp_shape, &temp));
236     auto full_fft = temp.flat_inner_dims<ComplexT, FFTRank + 1>();
237     full_fft.device(device) =
238         input.slice(zero_start_indices, input_slice_sizes)
239             .template fft<Eigen::BothParts, Eigen::FFT_FORWARD>(axes);
240 
241     // Slice away the negative frequency components.
242     output.device(device) =
243         full_fft.slice(zero_start_indices, output.dimensions());
244   }
245 
246   template <typename ComplexT, typename RealT>
DoRealBackwardFFT(OpKernelContext * ctx,uint64 * fft_shape,const Tensor & in,Tensor * out)247   void DoRealBackwardFFT(OpKernelContext* ctx, uint64* fft_shape,
248                          const Tensor& in, Tensor* out) {
249     auto device = ctx->eigen_device<CPUDevice>();
250     // Reconstruct the full FFT and take the inverse.
251     auto input = Tensor(in).flat_inner_dims<ComplexT, FFTRank + 1>();
252     auto output = out->flat_inner_dims<RealT, FFTRank + 1>();
253     const auto input_dims = input.dimensions();
254 
255     // Calculate the shape of the temporary tensor for the full FFT and the
256     // region we will slice from input given fft_shape. We slice input to
257     // fft_shape on its inner-most dimensions, except the last (which we
258     // slice to fft_shape[-1] / 2 + 1).
259     Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> input_slice_sizes;
260     input_slice_sizes[0] = input_dims[0];
261     TensorShape full_fft_shape;
262     full_fft_shape.AddDim(input_dims[0]);
263     for (auto i = 1; i <= FFTRank; i++) {
264       input_slice_sizes[i] =
265           i == FFTRank ? fft_shape[i - 1] / 2 + 1 : fft_shape[i - 1];
266       full_fft_shape.AddDim(fft_shape[i - 1]);
267     }
268     OP_REQUIRES(ctx, full_fft_shape.num_elements() > 0,
269                 errors::InvalidArgument("Obtained a FFT shape of 0 elements: ",
270                                         full_fft_shape.DebugString()));
271 
272     Tensor temp;
273     OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<ComplexT>::v(),
274                                            full_fft_shape, &temp));
275     auto full_fft = temp.flat_inner_dims<ComplexT, FFTRank + 1>();
276 
277     // Calculate the starting point and range of the source of
278     // negative frequency part.
279     auto neg_sizes = input_slice_sizes;
280     neg_sizes[FFTRank] = fft_shape[FFTRank - 1] - input_slice_sizes[FFTRank];
281     Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> neg_target_indices;
282     neg_target_indices[FFTRank] = input_slice_sizes[FFTRank];
283 
284     const Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> start_indices;
285     Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> neg_start_indices;
286     neg_start_indices[FFTRank] = 1;
287 
288     full_fft.slice(start_indices, input_slice_sizes).device(device) =
289         input.slice(start_indices, input_slice_sizes);
290 
291     // First, conduct IFFTs on outer dimensions. We save computation (and
292     // avoid touching uninitialized memory) by slicing full_fft to the
293     // subregion we wrote input to.
294     if (FFTRank > 1) {
295       const auto outer_axes =
296           Eigen::ArrayXi::LinSpaced(FFTRank - 1, 1, FFTRank - 1);
297       full_fft.slice(start_indices, input_slice_sizes).device(device) =
298           full_fft.slice(start_indices, input_slice_sizes)
299               .template fft<Eigen::BothParts, Eigen::FFT_REVERSE>(outer_axes);
300     }
301 
302     // Reconstruct the full FFT by appending reversed and conjugated
303     // spectrum as the negative frequency part.
304     Eigen::array<bool, FFTRank + 1> reverse_last_axis;
305     for (auto i = 0; i <= FFTRank; i++) {
306       reverse_last_axis[i] = i == FFTRank;
307     }
308 
309     if (neg_sizes[FFTRank] != 0) {
310       full_fft.slice(neg_target_indices, neg_sizes).device(device) =
311           full_fft.slice(neg_start_indices, neg_sizes)
312               .reverse(reverse_last_axis)
313               .conjugate();
314     }
315 
316     auto inner_axis = Eigen::array<int, 1>{FFTRank};
317     output.device(device) =
318         full_fft.template fft<Eigen::RealPart, Eigen::FFT_REVERSE>(inner_axis);
319   }
320 };
321 
322 REGISTER_KERNEL_BUILDER(Name("FFT").Device(DEVICE_CPU), FFTCPU<true, false, 1>);
323 REGISTER_KERNEL_BUILDER(Name("IFFT").Device(DEVICE_CPU),
324                         FFTCPU<false, false, 1>);
325 REGISTER_KERNEL_BUILDER(Name("FFT2D").Device(DEVICE_CPU),
326                         FFTCPU<true, false, 2>);
327 REGISTER_KERNEL_BUILDER(Name("IFFT2D").Device(DEVICE_CPU),
328                         FFTCPU<false, false, 2>);
329 REGISTER_KERNEL_BUILDER(Name("FFT3D").Device(DEVICE_CPU),
330                         FFTCPU<true, false, 3>);
331 REGISTER_KERNEL_BUILDER(Name("IFFT3D").Device(DEVICE_CPU),
332                         FFTCPU<false, false, 3>);
333 
334 REGISTER_KERNEL_BUILDER(Name("RFFT").Device(DEVICE_CPU), FFTCPU<true, true, 1>);
335 REGISTER_KERNEL_BUILDER(Name("IRFFT").Device(DEVICE_CPU),
336                         FFTCPU<false, true, 1>);
337 REGISTER_KERNEL_BUILDER(Name("RFFT2D").Device(DEVICE_CPU),
338                         FFTCPU<true, true, 2>);
339 REGISTER_KERNEL_BUILDER(Name("IRFFT2D").Device(DEVICE_CPU),
340                         FFTCPU<false, true, 2>);
341 REGISTER_KERNEL_BUILDER(Name("RFFT3D").Device(DEVICE_CPU),
342                         FFTCPU<true, true, 3>);
343 REGISTER_KERNEL_BUILDER(Name("IRFFT3D").Device(DEVICE_CPU),
344                         FFTCPU<false, true, 3>);
345 
346 #if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \
347     (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM)
348 
349 namespace {
350 template <typename T>
AsDeviceMemory(const T * cuda_memory)351 se::DeviceMemory<T> AsDeviceMemory(const T* cuda_memory) {
352   se::DeviceMemoryBase wrapped(const_cast<T*>(cuda_memory));
353   se::DeviceMemory<T> typed(wrapped);
354   return typed;
355 }
356 
357 template <typename T>
AsDeviceMemory(const T * cuda_memory,uint64 size)358 se::DeviceMemory<T> AsDeviceMemory(const T* cuda_memory, uint64 size) {
359   se::DeviceMemoryBase wrapped(const_cast<T*>(cuda_memory), size * sizeof(T));
360   se::DeviceMemory<T> typed(wrapped);
361   return typed;
362 }
363 
364 // A class to provide scratch-space allocator for Stream-Executor Cufft
365 // callback. Tensorflow is responsible for releasing the temporary buffers after
366 // the kernel finishes.
367 // TODO(yangzihao): Refactor redundant code in subclasses of ScratchAllocator
368 // into base class.
369 class CufftScratchAllocator : public se::ScratchAllocator {
370  public:
~CufftScratchAllocator()371   ~CufftScratchAllocator() override {}
CufftScratchAllocator(int64_t memory_limit,OpKernelContext * context)372   CufftScratchAllocator(int64_t memory_limit, OpKernelContext* context)
373       : memory_limit_(memory_limit), total_byte_size_(0), context_(context) {}
GetMemoryLimitInBytes()374   int64 GetMemoryLimitInBytes() override { return memory_limit_; }
AllocateBytes(int64_t byte_size)375   se::port::StatusOr<se::DeviceMemory<uint8>> AllocateBytes(
376       int64_t byte_size) override {
377     Tensor temporary_memory;
378     if (byte_size > memory_limit_) {
379       return se::port::StatusOr<se::DeviceMemory<uint8>>();
380     }
381     AllocationAttributes allocation_attr;
382     allocation_attr.retry_on_failure = false;
383     Status allocation_status(context_->allocate_temp(
384         DT_UINT8, TensorShape({byte_size}), &temporary_memory,
385         AllocatorAttributes(), allocation_attr));
386     if (!allocation_status.ok()) {
387       return se::port::StatusOr<se::DeviceMemory<uint8>>();
388     }
389     // Hold the reference of the allocated tensors until the end of the
390     // allocator.
391     allocated_tensors_.push_back(temporary_memory);
392     total_byte_size_ += byte_size;
393     return se::port::StatusOr<se::DeviceMemory<uint8>>(
394         AsDeviceMemory(temporary_memory.flat<uint8>().data(),
395                        temporary_memory.flat<uint8>().size()));
396   }
TotalByteSize()397   int64 TotalByteSize() { return total_byte_size_; }
398 
399  private:
400   int64 memory_limit_;
401   int64 total_byte_size_;
402   OpKernelContext* context_;
403   std::vector<Tensor> allocated_tensors_;
404 };
405 
406 }  // end namespace
407 
GetCufftWorkspaceLimit(const string & envvar_in_mb,int64_t default_value_in_bytes)408 int64 GetCufftWorkspaceLimit(const string& envvar_in_mb,
409                              int64_t default_value_in_bytes) {
410   const char* workspace_limit_in_mb_str = getenv(envvar_in_mb.c_str());
411   if (workspace_limit_in_mb_str != nullptr &&
412       strcmp(workspace_limit_in_mb_str, "") != 0) {
413     int64_t scratch_limit_in_mb = -1;
414     Status status = ReadInt64FromEnvVar(envvar_in_mb, default_value_in_bytes,
415                                         &scratch_limit_in_mb);
416     if (!status.ok()) {
417       LOG(WARNING) << "Invalid value for env-var " << envvar_in_mb << ": "
418                    << workspace_limit_in_mb_str;
419     } else {
420       return scratch_limit_in_mb * (1 << 20);
421     }
422   }
423   return default_value_in_bytes;
424 }
425 
426 class FFTGPUBase : public FFTBase {
427  public:
428   using FFTBase::FFTBase;
429 
430  protected:
431   static int64_t CufftScratchSize;
DoFFT(OpKernelContext * ctx,const Tensor & in,uint64 * fft_shape,Tensor * out)432   void DoFFT(OpKernelContext* ctx, const Tensor& in, uint64* fft_shape,
433              Tensor* out) override {
434     auto* stream = ctx->op_device_context()->stream();
435     OP_REQUIRES(ctx, stream, errors::Internal("No GPU stream available."));
436 
437     const TensorShape& input_shape = in.shape();
438     const TensorShape& output_shape = out->shape();
439 
440     const int fft_rank = Rank();
441     int batch_size = 1;
442     for (int i = 0; i < input_shape.dims() - fft_rank; ++i) {
443       batch_size *= input_shape.dim_size(i);
444     }
445     uint64 input_embed[3];
446     const uint64 input_stride = 1;
447     uint64 input_distance = 1;
448     uint64 output_embed[3];
449     const uint64 output_stride = 1;
450     uint64 output_distance = 1;
451 
452     for (int i = 0; i < fft_rank; ++i) {
453       auto dim_offset = input_shape.dims() - fft_rank + i;
454       input_embed[i] = input_shape.dim_size(dim_offset);
455       input_distance *= input_shape.dim_size(dim_offset);
456       output_embed[i] = output_shape.dim_size(dim_offset);
457       output_distance *= output_shape.dim_size(dim_offset);
458     }
459 
460     constexpr bool kInPlaceFft = false;
461     const bool is_complex128 =
462         in.dtype() == DT_COMPLEX128 || out->dtype() == DT_COMPLEX128;
463 
464     const auto kFftType =
465         IsReal()
466             ? (IsForward()
467                    ? (is_complex128 ? se::fft::Type::kD2Z : se::fft::Type::kR2C)
468                    : (is_complex128 ? se::fft::Type::kZ2D
469                                     : se::fft::Type::kC2R))
470             : (IsForward() ? (is_complex128 ? se::fft::Type::kZ2ZForward
471                                             : se::fft::Type::kC2CForward)
472                            : (is_complex128 ? se::fft::Type::kZ2ZInverse
473                                             : se::fft::Type::kC2CInverse));
474 
475     CufftScratchAllocator scratch_allocator(CufftScratchSize, ctx);
476     auto plan =
477         stream->parent()->AsFft()->CreateBatchedPlanWithScratchAllocator(
478             stream, fft_rank, fft_shape, input_embed, input_stride,
479             input_distance, output_embed, output_stride, output_distance,
480             kFftType, kInPlaceFft, batch_size, &scratch_allocator);
481 
482     if (IsReal()) {
483       if (IsForward()) {
484         if (is_complex128) {
485           DCHECK_EQ(in.dtype(), DT_DOUBLE);
486           DCHECK_EQ(out->dtype(), DT_COMPLEX128);
487           DoFFTInternal<double, complex128>(ctx, stream, plan.get(), kFftType,
488                                             output_distance, in, out);
489         } else {
490           DCHECK_EQ(in.dtype(), DT_FLOAT);
491           DCHECK_EQ(out->dtype(), DT_COMPLEX64);
492           DoFFTInternal<float, complex64>(ctx, stream, plan.get(), kFftType,
493                                           output_distance, in, out);
494         }
495       } else {
496         if (is_complex128) {
497           DCHECK_EQ(in.dtype(), DT_COMPLEX128);
498           DCHECK_EQ(out->dtype(), DT_DOUBLE);
499           DoFFTInternal<complex128, double>(ctx, stream, plan.get(), kFftType,
500                                             output_distance, in, out);
501         } else {
502           DCHECK_EQ(in.dtype(), DT_COMPLEX64);
503           DCHECK_EQ(out->dtype(), DT_FLOAT);
504           DoFFTInternal<complex64, float>(ctx, stream, plan.get(), kFftType,
505                                           output_distance, in, out);
506         }
507       }
508     } else {
509       if (is_complex128) {
510         DCHECK_EQ(in.dtype(), DT_COMPLEX128);
511         DCHECK_EQ(out->dtype(), DT_COMPLEX128);
512         DoFFTInternal<complex128, complex128>(ctx, stream, plan.get(), kFftType,
513                                               output_distance, in, out);
514       } else {
515         DCHECK_EQ(in.dtype(), DT_COMPLEX64);
516         DCHECK_EQ(out->dtype(), DT_COMPLEX64);
517         DoFFTInternal<complex64, complex64>(ctx, stream, plan.get(), kFftType,
518                                             output_distance, in, out);
519       }
520     }
521   }
522 
523  private:
524   template <typename T>
525   struct RealTypeFromComplexType {
526     typedef T RealT;
527   };
528 
529   template <typename T>
530   struct RealTypeFromComplexType<std::complex<T>> {
531     typedef T RealT;
532   };
533 
534   template <typename InT, typename OutT>
DoFFTInternal(OpKernelContext * ctx,se::Stream * stream,se::fft::Plan * plan,const se::fft::Type fft_type,const uint64 output_distance,const Tensor & in,Tensor * out)535   void DoFFTInternal(OpKernelContext* ctx, se::Stream* stream,
536                      se::fft::Plan* plan, const se::fft::Type fft_type,
537                      const uint64 output_distance, const Tensor& in,
538                      Tensor* out) {
539     const TensorShape& input_shape = in.shape();
540     const TensorShape& output_shape = out->shape();
541     auto src =
542         AsDeviceMemory<InT>(in.flat<InT>().data(), input_shape.num_elements());
543     auto dst = AsDeviceMemory<OutT>(out->flat<OutT>().data(),
544                                     output_shape.num_elements());
545     OP_REQUIRES(
546         ctx, stream->ThenFft(plan, src, &dst).ok(),
547         errors::Internal("fft failed : type=", static_cast<int>(fft_type),
548                          " in.shape=", input_shape.DebugString()));
549     if (!IsForward()) {
550       typedef typename RealTypeFromComplexType<OutT>::RealT RealT;
551       RealT alpha = 1.0 / output_distance;
552       OP_REQUIRES(
553           ctx,
554           stream->ThenBlasScal(output_shape.num_elements(), alpha, &dst, 1)
555               .ok(),
556           errors::Internal("BlasScal failed : in.shape=",
557                            input_shape.DebugString()));
558     }
559   }
560 };
561 
562 int64_t FFTGPUBase::CufftScratchSize = GetCufftWorkspaceLimit(
563     // default value is in bytes despite the name of the environment variable
564     "TF_CUFFT_WORKSPACE_LIMIT_IN_MB", 1LL << 32  // 4GB
565 );
566 
567 template <bool Forward, bool _Real, int FFTRank>
568 class FFTGPU : public FFTGPUBase {
569  public:
570   static_assert(FFTRank >= 1 && FFTRank <= 3,
571                 "Only 1D, 2D and 3D FFTs supported.");
FFTGPU(OpKernelConstruction * ctx)572   explicit FFTGPU(OpKernelConstruction* ctx) : FFTGPUBase(ctx) {}
573 
574  protected:
Rank() const575   int Rank() const override { return FFTRank; }
IsForward() const576   bool IsForward() const override { return Forward; }
IsReal() const577   bool IsReal() const override { return _Real; }
578 };
579 
580 // Register GPU kernels with priority 1 so that if a custom FFT CPU kernel is
581 // registered with priority 1 (to override the default Eigen CPU kernel), the
582 // CPU kernel does not outrank the GPU kernel.
583 REGISTER_KERNEL_BUILDER(Name("FFT").Device(DEVICE_GPU).Priority(1),
584                         FFTGPU<true, false, 1>);
585 REGISTER_KERNEL_BUILDER(Name("IFFT").Device(DEVICE_GPU).Priority(1),
586                         FFTGPU<false, false, 1>);
587 REGISTER_KERNEL_BUILDER(Name("FFT2D").Device(DEVICE_GPU).Priority(1),
588                         FFTGPU<true, false, 2>);
589 REGISTER_KERNEL_BUILDER(Name("IFFT2D").Device(DEVICE_GPU).Priority(1),
590                         FFTGPU<false, false, 2>);
591 REGISTER_KERNEL_BUILDER(Name("FFT3D").Device(DEVICE_GPU).Priority(1),
592                         FFTGPU<true, false, 3>);
593 REGISTER_KERNEL_BUILDER(Name("IFFT3D").Device(DEVICE_GPU).Priority(1),
594                         FFTGPU<false, false, 3>);
595 
596 REGISTER_KERNEL_BUILDER(
597     Name("RFFT").Device(DEVICE_GPU).HostMemory("fft_length").Priority(1),
598     FFTGPU<true, true, 1>);
599 REGISTER_KERNEL_BUILDER(
600     Name("IRFFT").Device(DEVICE_GPU).HostMemory("fft_length").Priority(1),
601     FFTGPU<false, true, 1>);
602 REGISTER_KERNEL_BUILDER(
603     Name("RFFT2D").Device(DEVICE_GPU).HostMemory("fft_length").Priority(1),
604     FFTGPU<true, true, 2>);
605 REGISTER_KERNEL_BUILDER(
606     Name("IRFFT2D").Device(DEVICE_GPU).HostMemory("fft_length").Priority(1),
607     FFTGPU<false, true, 2>);
608 REGISTER_KERNEL_BUILDER(
609     Name("RFFT3D").Device(DEVICE_GPU).HostMemory("fft_length").Priority(1),
610     FFTGPU<true, true, 3>);
611 REGISTER_KERNEL_BUILDER(
612     Name("IRFFT3D").Device(DEVICE_GPU).HostMemory("fft_length").Priority(1),
613     FFTGPU<false, true, 3>);
614 
615 // Deprecated kernels.
616 REGISTER_KERNEL_BUILDER(Name("BatchFFT").Device(DEVICE_GPU).Priority(1),
617                         FFTGPU<true, false, 1>);
618 REGISTER_KERNEL_BUILDER(Name("BatchIFFT").Device(DEVICE_GPU).Priority(1),
619                         FFTGPU<false, false, 1>);
620 REGISTER_KERNEL_BUILDER(Name("BatchFFT2D").Device(DEVICE_GPU).Priority(1),
621                         FFTGPU<true, false, 2>);
622 REGISTER_KERNEL_BUILDER(Name("BatchIFFT2D").Device(DEVICE_GPU).Priority(1),
623                         FFTGPU<false, false, 2>);
624 REGISTER_KERNEL_BUILDER(Name("BatchFFT3D").Device(DEVICE_GPU).Priority(1),
625                         FFTGPU<true, false, 3>);
626 REGISTER_KERNEL_BUILDER(Name("BatchIFFT3D").Device(DEVICE_GPU).Priority(1),
627                         FFTGPU<false, false, 3>);
628 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
629 
630 }  // end namespace tensorflow
631