• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #define EIGEN_USE_THREADS
17 
18 // See docs in ../ops/spectral_ops.cc.
19 
20 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
21 #include "tensorflow/core/framework/op.h"
22 #include "tensorflow/core/framework/op_kernel.h"
23 #include "tensorflow/core/framework/tensor.h"
24 #include "tensorflow/core/framework/tensor_shape.h"
25 #include "tensorflow/core/framework/types.h"
26 #include "tensorflow/core/platform/logging.h"
27 #include "tensorflow/core/platform/types.h"
28 #include "tensorflow/core/util/env_var.h"
29 #include "tensorflow/core/util/work_sharder.h"
30 
31 #if GOOGLE_CUDA
32 #include "tensorflow/core/platform/stream_executor.h"
33 #endif
34 
35 namespace tensorflow {
36 
37 class FFTBase : public OpKernel {
38  public:
FFTBase(OpKernelConstruction * ctx)39   explicit FFTBase(OpKernelConstruction* ctx) : OpKernel(ctx) {}
40 
Compute(OpKernelContext * ctx)41   void Compute(OpKernelContext* ctx) override {
42     const Tensor& in = ctx->input(0);
43     const TensorShape& input_shape = in.shape();
44     const int fft_rank = Rank();
45     OP_REQUIRES(
46         ctx, input_shape.dims() >= fft_rank,
47         errors::InvalidArgument("Input must have rank of at least ", fft_rank,
48                                 " but got: ", input_shape.DebugString()));
49 
50     Tensor* out;
51     TensorShape output_shape = input_shape;
52     uint64 fft_shape[3] = {0, 0, 0};
53 
54     // In R2C or C2R mode, we use a second input to specify the FFT length
55     // instead of inferring it from the input shape.
56     if (IsReal()) {
57       const Tensor& fft_length = ctx->input(1);
58       OP_REQUIRES(ctx,
59                   fft_length.shape().dims() == 1 &&
60                       fft_length.shape().dim_size(0) == fft_rank,
61                   errors::InvalidArgument("fft_length must have shape [",
62                                           fft_rank, "]"));
63 
64       auto fft_length_as_vec = fft_length.vec<int32>();
65       for (int i = 0; i < fft_rank; ++i) {
66         fft_shape[i] = fft_length_as_vec(i);
67         // Each input dimension must have length of at least fft_shape[i]. For
68         // IRFFTs, the inner-most input dimension must have length of at least
69         // fft_shape[i] / 2 + 1.
70         bool inner_most = (i == fft_rank - 1);
71         uint64 min_input_dim_length =
72             !IsForward() && inner_most ? fft_shape[i] / 2 + 1 : fft_shape[i];
73         auto input_index = input_shape.dims() - fft_rank + i;
74         OP_REQUIRES(
75             ctx,
76             // We pass through empty tensors, so special case them here.
77             input_shape.dim_size(input_index) == 0 ||
78                 input_shape.dim_size(input_index) >= min_input_dim_length,
79             errors::InvalidArgument(
80                 "Input dimension ", input_index,
81                 " must have length of at least ", min_input_dim_length,
82                 " but got: ", input_shape.dim_size(input_index)));
83         uint64 dim = IsForward() && inner_most && fft_shape[i] != 0
84                          ? fft_shape[i] / 2 + 1
85                          : fft_shape[i];
86         output_shape.set_dim(output_shape.dims() - fft_rank + i, dim);
87       }
88     } else {
89       for (int i = 0; i < fft_rank; ++i) {
90         fft_shape[i] =
91             output_shape.dim_size(output_shape.dims() - fft_rank + i);
92       }
93     }
94 
95     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, output_shape, &out));
96     if (input_shape.num_elements() == 0) {
97       return;
98     }
99 
100     DoFFT(ctx, in, fft_shape, out);
101   }
102 
103  protected:
104   virtual int Rank() const = 0;
105   virtual bool IsForward() const = 0;
106   virtual bool IsReal() const = 0;
107 
108   // The function that actually computes the FFT.
109   virtual void DoFFT(OpKernelContext* ctx, const Tensor& in, uint64* fft_shape,
110                      Tensor* out) = 0;
111 };
112 
113 typedef Eigen::ThreadPoolDevice CPUDevice;
114 
115 template <bool Forward, bool _Real, int FFTRank>
116 class FFTCPU : public FFTBase {
117  public:
118   using FFTBase::FFTBase;
119 
120  protected:
Rank() const121   int Rank() const override { return FFTRank; }
IsForward() const122   bool IsForward() const override { return Forward; }
IsReal() const123   bool IsReal() const override { return _Real; }
124 
DoFFT(OpKernelContext * ctx,const Tensor & in,uint64 * fft_shape,Tensor * out)125   void DoFFT(OpKernelContext* ctx, const Tensor& in, uint64* fft_shape,
126              Tensor* out) override {
127     // Create the axes (which are always trailing).
128     const auto axes = Eigen::ArrayXi::LinSpaced(FFTRank, 1, FFTRank);
129     auto device = ctx->eigen_device<CPUDevice>();
130 
131     if (!IsReal()) {
132       // Compute the FFT using Eigen.
133       constexpr auto direction =
134           Forward ? Eigen::FFT_FORWARD : Eigen::FFT_REVERSE;
135       if (in.dtype() == DT_COMPLEX64) {
136         DCHECK_EQ(out->dtype(), DT_COMPLEX64);
137         auto input = Tensor(in).flat_inner_dims<complex64, FFTRank + 1>();
138         auto output = out->flat_inner_dims<complex64, FFTRank + 1>();
139         output.device(device) =
140             input.template fft<Eigen::BothParts, direction>(axes);
141       } else {
142         DCHECK_EQ(DT_COMPLEX128, in.dtype());
143         DCHECK_EQ(DT_COMPLEX128, out->dtype());
144         auto input = Tensor(in).flat_inner_dims<complex128, FFTRank + 1>();
145         auto output = out->flat_inner_dims<complex128, FFTRank + 1>();
146         output.device(device) =
147             input.template fft<Eigen::BothParts, direction>(axes);
148       }
149     } else {
150       if (IsForward()) {
151         auto input = Tensor(in).flat_inner_dims<float, FFTRank + 1>();
152         const auto input_dims = input.dimensions();
153 
154         // Slice input to fft_shape on its inner-most dimensions.
155         Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> input_slice_sizes;
156         input_slice_sizes[0] = input_dims[0];
157         TensorShape temp_shape{input_dims[0]};
158         for (int i = 1; i <= FFTRank; ++i) {
159           input_slice_sizes[i] = fft_shape[i - 1];
160           temp_shape.AddDim(fft_shape[i - 1]);
161         }
162 
163         auto output = out->flat_inner_dims<complex64, FFTRank + 1>();
164         const Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> zero_start_indices;
165 
166         // Compute the full FFT using a temporary tensor.
167         Tensor temp;
168         OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<complex64>::v(),
169                                                temp_shape, &temp));
170         auto full_fft = temp.flat_inner_dims<complex64, FFTRank + 1>();
171         full_fft.device(device) =
172             input.slice(zero_start_indices, input_slice_sizes)
173                 .template fft<Eigen::BothParts, Eigen::FFT_FORWARD>(axes);
174 
175         // Slice away the negative frequency components.
176         output.device(device) =
177             full_fft.slice(zero_start_indices, output.dimensions());
178       } else {
179         // Reconstruct the full FFT and take the inverse.
180         auto input = Tensor(in).flat_inner_dims<complex64, FFTRank + 1>();
181         auto output = out->flat_inner_dims<float, FFTRank + 1>();
182         const auto input_dims = input.dimensions();
183 
184         // Calculate the shape of the temporary tensor for the full FFT and the
185         // region we will slice from input given fft_shape. We slice input to
186         // fft_shape on its inner-most dimensions, except the last (which we
187         // slice to fft_shape[-1] / 2 + 1).
188         Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> input_slice_sizes;
189         input_slice_sizes[0] = input_dims[0];
190         TensorShape full_fft_shape;
191         full_fft_shape.AddDim(input_dims[0]);
192         for (auto i = 1; i <= FFTRank; i++) {
193           input_slice_sizes[i] =
194               i == FFTRank ? fft_shape[i - 1] / 2 + 1 : fft_shape[i - 1];
195           full_fft_shape.AddDim(fft_shape[i - 1]);
196         }
197 
198         Tensor temp;
199         OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<complex64>::v(),
200                                                full_fft_shape, &temp));
201         auto full_fft = temp.flat_inner_dims<complex64, FFTRank + 1>();
202 
203         // Calculate the starting point and range of the source of
204         // negative frequency part.
205         auto neg_sizes = input_slice_sizes;
206         neg_sizes[FFTRank] =
207             fft_shape[FFTRank - 1] - input_slice_sizes[FFTRank];
208         Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> neg_target_indices;
209         neg_target_indices[FFTRank] = input_slice_sizes[FFTRank];
210 
211         const Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> start_indices;
212         Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> neg_start_indices;
213         neg_start_indices[FFTRank] = 1;
214 
215         full_fft.slice(start_indices, input_slice_sizes).device(device) =
216             input.slice(start_indices, input_slice_sizes);
217 
218         // First, conduct IFFTs on outer dimensions. We save computation (and
219         // avoid touching uninitialized memory) by slicing full_fft to the
220         // subregion we wrote input to.
221         if (FFTRank > 1) {
222           const auto outer_axes =
223               Eigen::ArrayXi::LinSpaced(FFTRank - 1, 1, FFTRank - 1);
224           full_fft.slice(start_indices, input_slice_sizes).device(device) =
225               full_fft.slice(start_indices, input_slice_sizes)
226                   .template fft<Eigen::BothParts, Eigen::FFT_REVERSE>(
227                       outer_axes);
228         }
229 
230         // Reconstruct the full FFT by appending reversed and conjugated
231         // spectrum as the negative frequency part.
232         Eigen::array<bool, FFTRank + 1> reverse_last_axis;
233         for (auto i = 0; i <= FFTRank; i++) {
234           reverse_last_axis[i] = i == FFTRank;
235         }
236 
237         if (neg_sizes[FFTRank] != 0) {
238           full_fft.slice(neg_target_indices, neg_sizes).device(device) =
239               full_fft.slice(neg_start_indices, neg_sizes)
240                   .reverse(reverse_last_axis)
241                   .conjugate();
242         }
243 
244         auto inner_axis = Eigen::array<int, 1>{FFTRank};
245         output.device(device) =
246             full_fft.template fft<Eigen::RealPart, Eigen::FFT_REVERSE>(
247                 inner_axis);
248       }
249     }
250   }
251 };
252 
253 // Use labels to distinguish between internal and open source versions
254 // of these kernels.
255 #ifdef PLATFORM_GOOGLE
256 #define FFT_LABEL "eigen"
257 #else
258 #define FFT_LABEL ""
259 #endif
260 
261 REGISTER_KERNEL_BUILDER(Name("FFT").Device(DEVICE_CPU).Label(FFT_LABEL),
262                         FFTCPU<true, false, 1>);
263 REGISTER_KERNEL_BUILDER(Name("IFFT").Device(DEVICE_CPU).Label(FFT_LABEL),
264                         FFTCPU<false, false, 1>);
265 REGISTER_KERNEL_BUILDER(Name("FFT2D").Device(DEVICE_CPU).Label(FFT_LABEL),
266                         FFTCPU<true, false, 2>);
267 REGISTER_KERNEL_BUILDER(Name("IFFT2D").Device(DEVICE_CPU).Label(FFT_LABEL),
268                         FFTCPU<false, false, 2>);
269 REGISTER_KERNEL_BUILDER(Name("FFT3D").Device(DEVICE_CPU).Label(FFT_LABEL),
270                         FFTCPU<true, false, 3>);
271 REGISTER_KERNEL_BUILDER(Name("IFFT3D").Device(DEVICE_CPU).Label(FFT_LABEL),
272                         FFTCPU<false, false, 3>);
273 
274 REGISTER_KERNEL_BUILDER(Name("RFFT").Device(DEVICE_CPU).Label(FFT_LABEL),
275                         FFTCPU<true, true, 1>);
276 REGISTER_KERNEL_BUILDER(Name("IRFFT").Device(DEVICE_CPU).Label(FFT_LABEL),
277                         FFTCPU<false, true, 1>);
278 REGISTER_KERNEL_BUILDER(Name("RFFT2D").Device(DEVICE_CPU).Label(FFT_LABEL),
279                         FFTCPU<true, true, 2>);
280 REGISTER_KERNEL_BUILDER(Name("IRFFT2D").Device(DEVICE_CPU).Label(FFT_LABEL),
281                         FFTCPU<false, true, 2>);
282 REGISTER_KERNEL_BUILDER(Name("RFFT3D").Device(DEVICE_CPU).Label(FFT_LABEL),
283                         FFTCPU<true, true, 3>);
284 REGISTER_KERNEL_BUILDER(Name("IRFFT3D").Device(DEVICE_CPU).Label(FFT_LABEL),
285                         FFTCPU<false, true, 3>);
286 
287 #undef FFT_LABEL
288 
289 #if GOOGLE_CUDA
290 
291 namespace {
292 template <typename T>
AsDeviceMemory(const T * cuda_memory)293 se::DeviceMemory<T> AsDeviceMemory(const T* cuda_memory) {
294   se::DeviceMemoryBase wrapped(const_cast<T*>(cuda_memory));
295   se::DeviceMemory<T> typed(wrapped);
296   return typed;
297 }
298 
299 template <typename T>
AsDeviceMemory(const T * cuda_memory,uint64 size)300 se::DeviceMemory<T> AsDeviceMemory(const T* cuda_memory, uint64 size) {
301   se::DeviceMemoryBase wrapped(const_cast<T*>(cuda_memory), size * sizeof(T));
302   se::DeviceMemory<T> typed(wrapped);
303   return typed;
304 }
305 
306 // A class to provide scratch-space allocator for Stream-Executor Cufft
307 // callback. Tensorflow is responsible for releasing the temporary buffers after
308 // the kernel finishes.
309 // TODO(yangzihao): Refactor redundant code in subclasses of ScratchAllocator
310 // into base class.
311 class CufftScratchAllocator : public se::ScratchAllocator {
312  public:
~CufftScratchAllocator()313   ~CufftScratchAllocator() override {}
CufftScratchAllocator(int64 memory_limit,OpKernelContext * context)314   CufftScratchAllocator(int64 memory_limit, OpKernelContext* context)
315       : memory_limit_(memory_limit), total_byte_size_(0), context_(context) {}
GetMemoryLimitInBytes(se::Stream * stream)316   int64 GetMemoryLimitInBytes(se::Stream* stream) override {
317     return memory_limit_;
318   }
AllocateBytes(se::Stream * stream,int64 byte_size)319   se::port::StatusOr<se::DeviceMemory<uint8>> AllocateBytes(
320       se::Stream* stream, int64 byte_size) override {
321     Tensor temporary_memory;
322     if (byte_size > memory_limit_) {
323       return se::port::StatusOr<se::DeviceMemory<uint8>>();
324     }
325     AllocationAttributes allocation_attr;
326     allocation_attr.no_retry_on_failure = true;
327     Status allocation_status(context_->allocate_temp(
328         DT_UINT8, TensorShape({byte_size}), &temporary_memory,
329         AllocatorAttributes(), allocation_attr));
330     if (!allocation_status.ok()) {
331       return se::port::StatusOr<se::DeviceMemory<uint8>>();
332     }
333     // Hold the reference of the allocated tensors until the end of the
334     // allocator.
335     allocated_tensors_.push_back(temporary_memory);
336     total_byte_size_ += byte_size;
337     return se::port::StatusOr<se::DeviceMemory<uint8>>(
338         AsDeviceMemory(temporary_memory.flat<uint8>().data(),
339                        temporary_memory.flat<uint8>().size()));
340   }
TotalByteSize()341   int64 TotalByteSize() { return total_byte_size_; }
342 
343  private:
344   int64 memory_limit_;
345   int64 total_byte_size_;
346   OpKernelContext* context_;
347   std::vector<Tensor> allocated_tensors_;
348 };
349 
350 }  // end namespace
351 
GetCufftWorkspaceLimit(const string & envvar_in_mb,int64 default_value_in_bytes)352 int64 GetCufftWorkspaceLimit(const string& envvar_in_mb,
353                              int64 default_value_in_bytes) {
354   const char* workspace_limit_in_mb_str = getenv(envvar_in_mb.c_str());
355   if (workspace_limit_in_mb_str != nullptr &&
356       strcmp(workspace_limit_in_mb_str, "") != 0) {
357     int64 scratch_limit_in_mb = -1;
358     Status status = ReadInt64FromEnvVar(envvar_in_mb, default_value_in_bytes,
359                                         &scratch_limit_in_mb);
360     if (!status.ok()) {
361       LOG(WARNING) << "Invalid value for env-var " << envvar_in_mb << ": "
362                    << workspace_limit_in_mb_str;
363     } else {
364       return scratch_limit_in_mb * (1 << 20);
365     }
366   }
367   return default_value_in_bytes;
368 }
369 
370 class FFTGPUBase : public FFTBase {
371  public:
372   using FFTBase::FFTBase;
373 
374  protected:
375   static int64 CufftScratchSize;
DoFFT(OpKernelContext * ctx,const Tensor & in,uint64 * fft_shape,Tensor * out)376   void DoFFT(OpKernelContext* ctx, const Tensor& in, uint64* fft_shape,
377              Tensor* out) override {
378     auto* stream = ctx->op_device_context()->stream();
379     OP_REQUIRES(ctx, stream, errors::Internal("No GPU stream available."));
380 
381     const TensorShape& input_shape = in.shape();
382     const TensorShape& output_shape = out->shape();
383 
384     const int fft_rank = Rank();
385     int batch_size = 1;
386     for (int i = 0; i < input_shape.dims() - fft_rank; ++i) {
387       batch_size *= input_shape.dim_size(i);
388     }
389     uint64 input_embed[3];
390     const uint64 input_stride = 1;
391     uint64 input_distance = 1;
392     uint64 output_embed[3];
393     const uint64 output_stride = 1;
394     uint64 output_distance = 1;
395 
396     for (int i = 0; i < fft_rank; ++i) {
397       auto dim_offset = input_shape.dims() - fft_rank + i;
398       input_embed[i] = input_shape.dim_size(dim_offset);
399       input_distance *= input_shape.dim_size(dim_offset);
400       output_embed[i] = output_shape.dim_size(dim_offset);
401       output_distance *= output_shape.dim_size(dim_offset);
402     }
403 
404     constexpr bool kInPlaceFft = false;
405     const bool is_complex128 = in.dtype() == DT_COMPLEX128;
406     // complex128 real FFT is not supported yet.
407     DCHECK(!IsReal() || !is_complex128);
408 
409     const auto kFftType =
410         IsReal() ? (IsForward() ? se::fft::Type::kR2C : se::fft::Type::kC2R)
411                  : (IsForward() ? (is_complex128 ? se::fft::Type::kZ2ZForward
412                                                  : se::fft::Type::kC2CForward)
413                                 : (is_complex128 ? se::fft::Type::kZ2ZInverse
414                                                  : se::fft::Type::kC2CInverse));
415 
416     CufftScratchAllocator scratch_allocator(CufftScratchSize, ctx);
417     auto plan =
418         stream->parent()->AsFft()->CreateBatchedPlanWithScratchAllocator(
419             stream, fft_rank, fft_shape, input_embed, input_stride,
420             input_distance, output_embed, output_stride, output_distance,
421             kFftType, kInPlaceFft, batch_size, &scratch_allocator);
422 
423     if (IsReal()) {
424       if (IsForward()) {
425         auto src = AsDeviceMemory<float>(in.flat<float>().data());
426         auto dst = AsDeviceMemory<complex64>(out->flat<complex64>().data());
427         OP_REQUIRES(
428             ctx, stream->ThenFft(plan.get(), src, &dst).ok(),
429             errors::Internal("fft failed : type=", static_cast<int>(kFftType),
430                              " in.shape=", input_shape.DebugString()));
431       } else {
432         auto src = AsDeviceMemory<complex64>(in.flat<complex64>().data());
433         auto dst = AsDeviceMemory<float>(out->flat<float>().data());
434         OP_REQUIRES(
435             ctx, stream->ThenFft(plan.get(), src, &dst).ok(),
436             errors::Internal("fft failed : type=", static_cast<int>(kFftType),
437                              " in.shape=", input_shape.DebugString()));
438         auto alpha = 1.f / output_distance;
439         OP_REQUIRES(
440             ctx,
441             stream->ThenBlasScal(output_shape.num_elements(), alpha, &dst, 1)
442                 .ok(),
443             errors::Internal("BlasScal failed : in.shape=",
444                              input_shape.DebugString()));
445       }
446     } else {
447       if (!is_complex128) {
448         DCHECK_EQ(in.dtype(), DT_COMPLEX64);
449         DCHECK_EQ(out->dtype(), DT_COMPLEX64);
450         auto src = AsDeviceMemory<complex64>(in.flat<complex64>().data());
451         auto dst = AsDeviceMemory<complex64>(out->flat<complex64>().data());
452         OP_REQUIRES(
453             ctx, stream->ThenFft(plan.get(), src, &dst).ok(),
454             errors::Internal("fft failed : type=", static_cast<int>(kFftType),
455                              " in.shape=", input_shape.DebugString()));
456         if (!IsForward()) {
457           float alpha = 1.f / output_distance;
458           OP_REQUIRES(
459               ctx,
460               stream->ThenBlasScal(output_shape.num_elements(), alpha, &dst, 1)
461                   .ok(),
462               errors::Internal("BlasScal failed : in.shape=",
463                                input_shape.DebugString()));
464         }
465       } else {
466         DCHECK_EQ(in.dtype(), DT_COMPLEX128);
467         DCHECK_EQ(out->dtype(), DT_COMPLEX128);
468         auto src = AsDeviceMemory<complex128>(in.flat<complex128>().data());
469         auto dst = AsDeviceMemory<complex128>(out->flat<complex128>().data());
470         OP_REQUIRES(
471             ctx, stream->ThenFft(plan.get(), src, &dst).ok(),
472             errors::Internal("fft failed : type=", static_cast<int>(kFftType),
473                              " in.shape=", input_shape.DebugString()));
474         if (!IsForward()) {
475           double alpha = 1.0 / output_distance;
476           OP_REQUIRES(
477               ctx,
478               stream->ThenBlasScal(output_shape.num_elements(), alpha, &dst, 1)
479                   .ok(),
480               errors::Internal("BlasScal failed : in.shape=",
481                                input_shape.DebugString()));
482         }
483       }
484     }
485   }
486 };
487 
488 int64 FFTGPUBase::CufftScratchSize = GetCufftWorkspaceLimit(
489     // default value is in bytes despite the name of the environment variable
490     "TF_CUFFT_WORKSPACE_LIMIT_IN_MB", 1LL << 32  // 4GB
491 );
492 
493 template <bool Forward, bool _Real, int FFTRank>
494 class FFTGPU : public FFTGPUBase {
495  public:
496   static_assert(FFTRank >= 1 && FFTRank <= 3,
497                 "Only 1D, 2D and 3D FFTs supported.");
FFTGPU(OpKernelConstruction * ctx)498   explicit FFTGPU(OpKernelConstruction* ctx) : FFTGPUBase(ctx) {}
499 
500  protected:
Rank() const501   int Rank() const override { return FFTRank; }
IsForward() const502   bool IsForward() const override { return Forward; }
IsReal() const503   bool IsReal() const override { return _Real; }
504 };
505 
506 REGISTER_KERNEL_BUILDER(Name("FFT").Device(DEVICE_GPU), FFTGPU<true, false, 1>);
507 REGISTER_KERNEL_BUILDER(Name("IFFT").Device(DEVICE_GPU),
508                         FFTGPU<false, false, 1>);
509 REGISTER_KERNEL_BUILDER(Name("FFT2D").Device(DEVICE_GPU),
510                         FFTGPU<true, false, 2>);
511 REGISTER_KERNEL_BUILDER(Name("IFFT2D").Device(DEVICE_GPU),
512                         FFTGPU<false, false, 2>);
513 REGISTER_KERNEL_BUILDER(Name("FFT3D").Device(DEVICE_GPU),
514                         FFTGPU<true, false, 3>);
515 REGISTER_KERNEL_BUILDER(Name("IFFT3D").Device(DEVICE_GPU),
516                         FFTGPU<false, false, 3>);
517 
518 REGISTER_KERNEL_BUILDER(
519     Name("RFFT").Device(DEVICE_GPU).HostMemory("fft_length"),
520     FFTGPU<true, true, 1>);
521 REGISTER_KERNEL_BUILDER(
522     Name("IRFFT").Device(DEVICE_GPU).HostMemory("fft_length"),
523     FFTGPU<false, true, 1>);
524 REGISTER_KERNEL_BUILDER(
525     Name("RFFT2D").Device(DEVICE_GPU).HostMemory("fft_length"),
526     FFTGPU<true, true, 2>);
527 REGISTER_KERNEL_BUILDER(
528     Name("IRFFT2D").Device(DEVICE_GPU).HostMemory("fft_length"),
529     FFTGPU<false, true, 2>);
530 REGISTER_KERNEL_BUILDER(
531     Name("RFFT3D").Device(DEVICE_GPU).HostMemory("fft_length"),
532     FFTGPU<true, true, 3>);
533 REGISTER_KERNEL_BUILDER(
534     Name("IRFFT3D").Device(DEVICE_GPU).HostMemory("fft_length"),
535     FFTGPU<false, true, 3>);
536 
537 // Deprecated kernels.
538 REGISTER_KERNEL_BUILDER(Name("BatchFFT").Device(DEVICE_GPU),
539                         FFTGPU<true, false, 1>);
540 REGISTER_KERNEL_BUILDER(Name("BatchIFFT").Device(DEVICE_GPU),
541                         FFTGPU<false, false, 1>);
542 REGISTER_KERNEL_BUILDER(Name("BatchFFT2D").Device(DEVICE_GPU),
543                         FFTGPU<true, false, 2>);
544 REGISTER_KERNEL_BUILDER(Name("BatchIFFT2D").Device(DEVICE_GPU),
545                         FFTGPU<false, false, 2>);
546 REGISTER_KERNEL_BUILDER(Name("BatchFFT3D").Device(DEVICE_GPU),
547                         FFTGPU<true, false, 3>);
548 REGISTER_KERNEL_BUILDER(Name("BatchIFFT3D").Device(DEVICE_GPU),
549                         FFTGPU<false, false, 3>);
550 #endif  // GOOGLE_CUDA
551 
552 }  // end namespace tensorflow
553