1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #define EIGEN_USE_THREADS
17
18 // See docs in ../ops/spectral_ops.cc.
19
20 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
21 #include "tensorflow/core/framework/op.h"
22 #include "tensorflow/core/framework/op_kernel.h"
23 #include "tensorflow/core/framework/tensor.h"
24 #include "tensorflow/core/framework/tensor_shape.h"
25 #include "tensorflow/core/framework/types.h"
26 #include "tensorflow/core/platform/logging.h"
27 #include "tensorflow/core/platform/types.h"
28 #include "tensorflow/core/util/env_var.h"
29 #include "tensorflow/core/util/work_sharder.h"
30
31 #if GOOGLE_CUDA
32 #include "tensorflow/core/platform/stream_executor.h"
33 #endif
34
35 namespace tensorflow {
36
37 class FFTBase : public OpKernel {
38 public:
FFTBase(OpKernelConstruction * ctx)39 explicit FFTBase(OpKernelConstruction* ctx) : OpKernel(ctx) {}
40
Compute(OpKernelContext * ctx)41 void Compute(OpKernelContext* ctx) override {
42 const Tensor& in = ctx->input(0);
43 const TensorShape& input_shape = in.shape();
44 const int fft_rank = Rank();
45 OP_REQUIRES(
46 ctx, input_shape.dims() >= fft_rank,
47 errors::InvalidArgument("Input must have rank of at least ", fft_rank,
48 " but got: ", input_shape.DebugString()));
49
50 Tensor* out;
51 TensorShape output_shape = input_shape;
52 uint64 fft_shape[3] = {0, 0, 0};
53
54 // In R2C or C2R mode, we use a second input to specify the FFT length
55 // instead of inferring it from the input shape.
56 if (IsReal()) {
57 const Tensor& fft_length = ctx->input(1);
58 OP_REQUIRES(ctx,
59 fft_length.shape().dims() == 1 &&
60 fft_length.shape().dim_size(0) == fft_rank,
61 errors::InvalidArgument("fft_length must have shape [",
62 fft_rank, "]"));
63
64 auto fft_length_as_vec = fft_length.vec<int32>();
65 for (int i = 0; i < fft_rank; ++i) {
66 fft_shape[i] = fft_length_as_vec(i);
67 // Each input dimension must have length of at least fft_shape[i]. For
68 // IRFFTs, the inner-most input dimension must have length of at least
69 // fft_shape[i] / 2 + 1.
70 bool inner_most = (i == fft_rank - 1);
71 uint64 min_input_dim_length =
72 !IsForward() && inner_most ? fft_shape[i] / 2 + 1 : fft_shape[i];
73 auto input_index = input_shape.dims() - fft_rank + i;
74 OP_REQUIRES(
75 ctx,
76 // We pass through empty tensors, so special case them here.
77 input_shape.dim_size(input_index) == 0 ||
78 input_shape.dim_size(input_index) >= min_input_dim_length,
79 errors::InvalidArgument(
80 "Input dimension ", input_index,
81 " must have length of at least ", min_input_dim_length,
82 " but got: ", input_shape.dim_size(input_index)));
83 uint64 dim = IsForward() && inner_most && fft_shape[i] != 0
84 ? fft_shape[i] / 2 + 1
85 : fft_shape[i];
86 output_shape.set_dim(output_shape.dims() - fft_rank + i, dim);
87 }
88 } else {
89 for (int i = 0; i < fft_rank; ++i) {
90 fft_shape[i] =
91 output_shape.dim_size(output_shape.dims() - fft_rank + i);
92 }
93 }
94
95 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, output_shape, &out));
96 if (input_shape.num_elements() == 0) {
97 return;
98 }
99
100 DoFFT(ctx, in, fft_shape, out);
101 }
102
103 protected:
104 virtual int Rank() const = 0;
105 virtual bool IsForward() const = 0;
106 virtual bool IsReal() const = 0;
107
108 // The function that actually computes the FFT.
109 virtual void DoFFT(OpKernelContext* ctx, const Tensor& in, uint64* fft_shape,
110 Tensor* out) = 0;
111 };
112
113 typedef Eigen::ThreadPoolDevice CPUDevice;
114
115 template <bool Forward, bool _Real, int FFTRank>
116 class FFTCPU : public FFTBase {
117 public:
118 using FFTBase::FFTBase;
119
120 protected:
Rank() const121 int Rank() const override { return FFTRank; }
IsForward() const122 bool IsForward() const override { return Forward; }
IsReal() const123 bool IsReal() const override { return _Real; }
124
DoFFT(OpKernelContext * ctx,const Tensor & in,uint64 * fft_shape,Tensor * out)125 void DoFFT(OpKernelContext* ctx, const Tensor& in, uint64* fft_shape,
126 Tensor* out) override {
127 // Create the axes (which are always trailing).
128 const auto axes = Eigen::ArrayXi::LinSpaced(FFTRank, 1, FFTRank);
129 auto device = ctx->eigen_device<CPUDevice>();
130
131 if (!IsReal()) {
132 // Compute the FFT using Eigen.
133 constexpr auto direction =
134 Forward ? Eigen::FFT_FORWARD : Eigen::FFT_REVERSE;
135 if (in.dtype() == DT_COMPLEX64) {
136 DCHECK_EQ(out->dtype(), DT_COMPLEX64);
137 auto input = Tensor(in).flat_inner_dims<complex64, FFTRank + 1>();
138 auto output = out->flat_inner_dims<complex64, FFTRank + 1>();
139 output.device(device) =
140 input.template fft<Eigen::BothParts, direction>(axes);
141 } else {
142 DCHECK_EQ(DT_COMPLEX128, in.dtype());
143 DCHECK_EQ(DT_COMPLEX128, out->dtype());
144 auto input = Tensor(in).flat_inner_dims<complex128, FFTRank + 1>();
145 auto output = out->flat_inner_dims<complex128, FFTRank + 1>();
146 output.device(device) =
147 input.template fft<Eigen::BothParts, direction>(axes);
148 }
149 } else {
150 if (IsForward()) {
151 auto input = Tensor(in).flat_inner_dims<float, FFTRank + 1>();
152 const auto input_dims = input.dimensions();
153
154 // Slice input to fft_shape on its inner-most dimensions.
155 Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> input_slice_sizes;
156 input_slice_sizes[0] = input_dims[0];
157 TensorShape temp_shape{input_dims[0]};
158 for (int i = 1; i <= FFTRank; ++i) {
159 input_slice_sizes[i] = fft_shape[i - 1];
160 temp_shape.AddDim(fft_shape[i - 1]);
161 }
162
163 auto output = out->flat_inner_dims<complex64, FFTRank + 1>();
164 const Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> zero_start_indices;
165
166 // Compute the full FFT using a temporary tensor.
167 Tensor temp;
168 OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<complex64>::v(),
169 temp_shape, &temp));
170 auto full_fft = temp.flat_inner_dims<complex64, FFTRank + 1>();
171 full_fft.device(device) =
172 input.slice(zero_start_indices, input_slice_sizes)
173 .template fft<Eigen::BothParts, Eigen::FFT_FORWARD>(axes);
174
175 // Slice away the negative frequency components.
176 output.device(device) =
177 full_fft.slice(zero_start_indices, output.dimensions());
178 } else {
179 // Reconstruct the full FFT and take the inverse.
180 auto input = Tensor(in).flat_inner_dims<complex64, FFTRank + 1>();
181 auto output = out->flat_inner_dims<float, FFTRank + 1>();
182 const auto input_dims = input.dimensions();
183
184 // Calculate the shape of the temporary tensor for the full FFT and the
185 // region we will slice from input given fft_shape. We slice input to
186 // fft_shape on its inner-most dimensions, except the last (which we
187 // slice to fft_shape[-1] / 2 + 1).
188 Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> input_slice_sizes;
189 input_slice_sizes[0] = input_dims[0];
190 TensorShape full_fft_shape;
191 full_fft_shape.AddDim(input_dims[0]);
192 for (auto i = 1; i <= FFTRank; i++) {
193 input_slice_sizes[i] =
194 i == FFTRank ? fft_shape[i - 1] / 2 + 1 : fft_shape[i - 1];
195 full_fft_shape.AddDim(fft_shape[i - 1]);
196 }
197
198 Tensor temp;
199 OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<complex64>::v(),
200 full_fft_shape, &temp));
201 auto full_fft = temp.flat_inner_dims<complex64, FFTRank + 1>();
202
203 // Calculate the starting point and range of the source of
204 // negative frequency part.
205 auto neg_sizes = input_slice_sizes;
206 neg_sizes[FFTRank] =
207 fft_shape[FFTRank - 1] - input_slice_sizes[FFTRank];
208 Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> neg_target_indices;
209 neg_target_indices[FFTRank] = input_slice_sizes[FFTRank];
210
211 const Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> start_indices;
212 Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> neg_start_indices;
213 neg_start_indices[FFTRank] = 1;
214
215 full_fft.slice(start_indices, input_slice_sizes).device(device) =
216 input.slice(start_indices, input_slice_sizes);
217
218 // First, conduct IFFTs on outer dimensions. We save computation (and
219 // avoid touching uninitialized memory) by slicing full_fft to the
220 // subregion we wrote input to.
221 if (FFTRank > 1) {
222 const auto outer_axes =
223 Eigen::ArrayXi::LinSpaced(FFTRank - 1, 1, FFTRank - 1);
224 full_fft.slice(start_indices, input_slice_sizes).device(device) =
225 full_fft.slice(start_indices, input_slice_sizes)
226 .template fft<Eigen::BothParts, Eigen::FFT_REVERSE>(
227 outer_axes);
228 }
229
230 // Reconstruct the full FFT by appending reversed and conjugated
231 // spectrum as the negative frequency part.
232 Eigen::array<bool, FFTRank + 1> reverse_last_axis;
233 for (auto i = 0; i <= FFTRank; i++) {
234 reverse_last_axis[i] = i == FFTRank;
235 }
236
237 if (neg_sizes[FFTRank] != 0) {
238 full_fft.slice(neg_target_indices, neg_sizes).device(device) =
239 full_fft.slice(neg_start_indices, neg_sizes)
240 .reverse(reverse_last_axis)
241 .conjugate();
242 }
243
244 auto inner_axis = Eigen::array<int, 1>{FFTRank};
245 output.device(device) =
246 full_fft.template fft<Eigen::RealPart, Eigen::FFT_REVERSE>(
247 inner_axis);
248 }
249 }
250 }
251 };
252
253 // Use labels to distinguish between internal and open source versions
254 // of these kernels.
255 #ifdef PLATFORM_GOOGLE
256 #define FFT_LABEL "eigen"
257 #else
258 #define FFT_LABEL ""
259 #endif
260
261 REGISTER_KERNEL_BUILDER(Name("FFT").Device(DEVICE_CPU).Label(FFT_LABEL),
262 FFTCPU<true, false, 1>);
263 REGISTER_KERNEL_BUILDER(Name("IFFT").Device(DEVICE_CPU).Label(FFT_LABEL),
264 FFTCPU<false, false, 1>);
265 REGISTER_KERNEL_BUILDER(Name("FFT2D").Device(DEVICE_CPU).Label(FFT_LABEL),
266 FFTCPU<true, false, 2>);
267 REGISTER_KERNEL_BUILDER(Name("IFFT2D").Device(DEVICE_CPU).Label(FFT_LABEL),
268 FFTCPU<false, false, 2>);
269 REGISTER_KERNEL_BUILDER(Name("FFT3D").Device(DEVICE_CPU).Label(FFT_LABEL),
270 FFTCPU<true, false, 3>);
271 REGISTER_KERNEL_BUILDER(Name("IFFT3D").Device(DEVICE_CPU).Label(FFT_LABEL),
272 FFTCPU<false, false, 3>);
273
274 REGISTER_KERNEL_BUILDER(Name("RFFT").Device(DEVICE_CPU).Label(FFT_LABEL),
275 FFTCPU<true, true, 1>);
276 REGISTER_KERNEL_BUILDER(Name("IRFFT").Device(DEVICE_CPU).Label(FFT_LABEL),
277 FFTCPU<false, true, 1>);
278 REGISTER_KERNEL_BUILDER(Name("RFFT2D").Device(DEVICE_CPU).Label(FFT_LABEL),
279 FFTCPU<true, true, 2>);
280 REGISTER_KERNEL_BUILDER(Name("IRFFT2D").Device(DEVICE_CPU).Label(FFT_LABEL),
281 FFTCPU<false, true, 2>);
282 REGISTER_KERNEL_BUILDER(Name("RFFT3D").Device(DEVICE_CPU).Label(FFT_LABEL),
283 FFTCPU<true, true, 3>);
284 REGISTER_KERNEL_BUILDER(Name("IRFFT3D").Device(DEVICE_CPU).Label(FFT_LABEL),
285 FFTCPU<false, true, 3>);
286
287 #undef FFT_LABEL
288
289 #if GOOGLE_CUDA
290
291 namespace {
292 template <typename T>
AsDeviceMemory(const T * cuda_memory)293 se::DeviceMemory<T> AsDeviceMemory(const T* cuda_memory) {
294 se::DeviceMemoryBase wrapped(const_cast<T*>(cuda_memory));
295 se::DeviceMemory<T> typed(wrapped);
296 return typed;
297 }
298
299 template <typename T>
AsDeviceMemory(const T * cuda_memory,uint64 size)300 se::DeviceMemory<T> AsDeviceMemory(const T* cuda_memory, uint64 size) {
301 se::DeviceMemoryBase wrapped(const_cast<T*>(cuda_memory), size * sizeof(T));
302 se::DeviceMemory<T> typed(wrapped);
303 return typed;
304 }
305
306 // A class to provide scratch-space allocator for Stream-Executor Cufft
307 // callback. Tensorflow is responsible for releasing the temporary buffers after
308 // the kernel finishes.
309 // TODO(yangzihao): Refactor redundant code in subclasses of ScratchAllocator
310 // into base class.
311 class CufftScratchAllocator : public se::ScratchAllocator {
312 public:
~CufftScratchAllocator()313 ~CufftScratchAllocator() override {}
CufftScratchAllocator(int64 memory_limit,OpKernelContext * context)314 CufftScratchAllocator(int64 memory_limit, OpKernelContext* context)
315 : memory_limit_(memory_limit), total_byte_size_(0), context_(context) {}
GetMemoryLimitInBytes(se::Stream * stream)316 int64 GetMemoryLimitInBytes(se::Stream* stream) override {
317 return memory_limit_;
318 }
AllocateBytes(se::Stream * stream,int64 byte_size)319 se::port::StatusOr<se::DeviceMemory<uint8>> AllocateBytes(
320 se::Stream* stream, int64 byte_size) override {
321 Tensor temporary_memory;
322 if (byte_size > memory_limit_) {
323 return se::port::StatusOr<se::DeviceMemory<uint8>>();
324 }
325 AllocationAttributes allocation_attr;
326 allocation_attr.no_retry_on_failure = true;
327 Status allocation_status(context_->allocate_temp(
328 DT_UINT8, TensorShape({byte_size}), &temporary_memory,
329 AllocatorAttributes(), allocation_attr));
330 if (!allocation_status.ok()) {
331 return se::port::StatusOr<se::DeviceMemory<uint8>>();
332 }
333 // Hold the reference of the allocated tensors until the end of the
334 // allocator.
335 allocated_tensors_.push_back(temporary_memory);
336 total_byte_size_ += byte_size;
337 return se::port::StatusOr<se::DeviceMemory<uint8>>(
338 AsDeviceMemory(temporary_memory.flat<uint8>().data(),
339 temporary_memory.flat<uint8>().size()));
340 }
TotalByteSize()341 int64 TotalByteSize() { return total_byte_size_; }
342
343 private:
344 int64 memory_limit_;
345 int64 total_byte_size_;
346 OpKernelContext* context_;
347 std::vector<Tensor> allocated_tensors_;
348 };
349
350 } // end namespace
351
GetCufftWorkspaceLimit(const string & envvar_in_mb,int64 default_value_in_bytes)352 int64 GetCufftWorkspaceLimit(const string& envvar_in_mb,
353 int64 default_value_in_bytes) {
354 const char* workspace_limit_in_mb_str = getenv(envvar_in_mb.c_str());
355 if (workspace_limit_in_mb_str != nullptr &&
356 strcmp(workspace_limit_in_mb_str, "") != 0) {
357 int64 scratch_limit_in_mb = -1;
358 Status status = ReadInt64FromEnvVar(envvar_in_mb, default_value_in_bytes,
359 &scratch_limit_in_mb);
360 if (!status.ok()) {
361 LOG(WARNING) << "Invalid value for env-var " << envvar_in_mb << ": "
362 << workspace_limit_in_mb_str;
363 } else {
364 return scratch_limit_in_mb * (1 << 20);
365 }
366 }
367 return default_value_in_bytes;
368 }
369
370 class FFTGPUBase : public FFTBase {
371 public:
372 using FFTBase::FFTBase;
373
374 protected:
375 static int64 CufftScratchSize;
DoFFT(OpKernelContext * ctx,const Tensor & in,uint64 * fft_shape,Tensor * out)376 void DoFFT(OpKernelContext* ctx, const Tensor& in, uint64* fft_shape,
377 Tensor* out) override {
378 auto* stream = ctx->op_device_context()->stream();
379 OP_REQUIRES(ctx, stream, errors::Internal("No GPU stream available."));
380
381 const TensorShape& input_shape = in.shape();
382 const TensorShape& output_shape = out->shape();
383
384 const int fft_rank = Rank();
385 int batch_size = 1;
386 for (int i = 0; i < input_shape.dims() - fft_rank; ++i) {
387 batch_size *= input_shape.dim_size(i);
388 }
389 uint64 input_embed[3];
390 const uint64 input_stride = 1;
391 uint64 input_distance = 1;
392 uint64 output_embed[3];
393 const uint64 output_stride = 1;
394 uint64 output_distance = 1;
395
396 for (int i = 0; i < fft_rank; ++i) {
397 auto dim_offset = input_shape.dims() - fft_rank + i;
398 input_embed[i] = input_shape.dim_size(dim_offset);
399 input_distance *= input_shape.dim_size(dim_offset);
400 output_embed[i] = output_shape.dim_size(dim_offset);
401 output_distance *= output_shape.dim_size(dim_offset);
402 }
403
404 constexpr bool kInPlaceFft = false;
405 const bool is_complex128 = in.dtype() == DT_COMPLEX128;
406 // complex128 real FFT is not supported yet.
407 DCHECK(!IsReal() || !is_complex128);
408
409 const auto kFftType =
410 IsReal() ? (IsForward() ? se::fft::Type::kR2C : se::fft::Type::kC2R)
411 : (IsForward() ? (is_complex128 ? se::fft::Type::kZ2ZForward
412 : se::fft::Type::kC2CForward)
413 : (is_complex128 ? se::fft::Type::kZ2ZInverse
414 : se::fft::Type::kC2CInverse));
415
416 CufftScratchAllocator scratch_allocator(CufftScratchSize, ctx);
417 auto plan =
418 stream->parent()->AsFft()->CreateBatchedPlanWithScratchAllocator(
419 stream, fft_rank, fft_shape, input_embed, input_stride,
420 input_distance, output_embed, output_stride, output_distance,
421 kFftType, kInPlaceFft, batch_size, &scratch_allocator);
422
423 if (IsReal()) {
424 if (IsForward()) {
425 auto src = AsDeviceMemory<float>(in.flat<float>().data());
426 auto dst = AsDeviceMemory<complex64>(out->flat<complex64>().data());
427 OP_REQUIRES(
428 ctx, stream->ThenFft(plan.get(), src, &dst).ok(),
429 errors::Internal("fft failed : type=", static_cast<int>(kFftType),
430 " in.shape=", input_shape.DebugString()));
431 } else {
432 auto src = AsDeviceMemory<complex64>(in.flat<complex64>().data());
433 auto dst = AsDeviceMemory<float>(out->flat<float>().data());
434 OP_REQUIRES(
435 ctx, stream->ThenFft(plan.get(), src, &dst).ok(),
436 errors::Internal("fft failed : type=", static_cast<int>(kFftType),
437 " in.shape=", input_shape.DebugString()));
438 auto alpha = 1.f / output_distance;
439 OP_REQUIRES(
440 ctx,
441 stream->ThenBlasScal(output_shape.num_elements(), alpha, &dst, 1)
442 .ok(),
443 errors::Internal("BlasScal failed : in.shape=",
444 input_shape.DebugString()));
445 }
446 } else {
447 if (!is_complex128) {
448 DCHECK_EQ(in.dtype(), DT_COMPLEX64);
449 DCHECK_EQ(out->dtype(), DT_COMPLEX64);
450 auto src = AsDeviceMemory<complex64>(in.flat<complex64>().data());
451 auto dst = AsDeviceMemory<complex64>(out->flat<complex64>().data());
452 OP_REQUIRES(
453 ctx, stream->ThenFft(plan.get(), src, &dst).ok(),
454 errors::Internal("fft failed : type=", static_cast<int>(kFftType),
455 " in.shape=", input_shape.DebugString()));
456 if (!IsForward()) {
457 float alpha = 1.f / output_distance;
458 OP_REQUIRES(
459 ctx,
460 stream->ThenBlasScal(output_shape.num_elements(), alpha, &dst, 1)
461 .ok(),
462 errors::Internal("BlasScal failed : in.shape=",
463 input_shape.DebugString()));
464 }
465 } else {
466 DCHECK_EQ(in.dtype(), DT_COMPLEX128);
467 DCHECK_EQ(out->dtype(), DT_COMPLEX128);
468 auto src = AsDeviceMemory<complex128>(in.flat<complex128>().data());
469 auto dst = AsDeviceMemory<complex128>(out->flat<complex128>().data());
470 OP_REQUIRES(
471 ctx, stream->ThenFft(plan.get(), src, &dst).ok(),
472 errors::Internal("fft failed : type=", static_cast<int>(kFftType),
473 " in.shape=", input_shape.DebugString()));
474 if (!IsForward()) {
475 double alpha = 1.0 / output_distance;
476 OP_REQUIRES(
477 ctx,
478 stream->ThenBlasScal(output_shape.num_elements(), alpha, &dst, 1)
479 .ok(),
480 errors::Internal("BlasScal failed : in.shape=",
481 input_shape.DebugString()));
482 }
483 }
484 }
485 }
486 };
487
488 int64 FFTGPUBase::CufftScratchSize = GetCufftWorkspaceLimit(
489 // default value is in bytes despite the name of the environment variable
490 "TF_CUFFT_WORKSPACE_LIMIT_IN_MB", 1LL << 32 // 4GB
491 );
492
493 template <bool Forward, bool _Real, int FFTRank>
494 class FFTGPU : public FFTGPUBase {
495 public:
496 static_assert(FFTRank >= 1 && FFTRank <= 3,
497 "Only 1D, 2D and 3D FFTs supported.");
FFTGPU(OpKernelConstruction * ctx)498 explicit FFTGPU(OpKernelConstruction* ctx) : FFTGPUBase(ctx) {}
499
500 protected:
Rank() const501 int Rank() const override { return FFTRank; }
IsForward() const502 bool IsForward() const override { return Forward; }
IsReal() const503 bool IsReal() const override { return _Real; }
504 };
505
506 REGISTER_KERNEL_BUILDER(Name("FFT").Device(DEVICE_GPU), FFTGPU<true, false, 1>);
507 REGISTER_KERNEL_BUILDER(Name("IFFT").Device(DEVICE_GPU),
508 FFTGPU<false, false, 1>);
509 REGISTER_KERNEL_BUILDER(Name("FFT2D").Device(DEVICE_GPU),
510 FFTGPU<true, false, 2>);
511 REGISTER_KERNEL_BUILDER(Name("IFFT2D").Device(DEVICE_GPU),
512 FFTGPU<false, false, 2>);
513 REGISTER_KERNEL_BUILDER(Name("FFT3D").Device(DEVICE_GPU),
514 FFTGPU<true, false, 3>);
515 REGISTER_KERNEL_BUILDER(Name("IFFT3D").Device(DEVICE_GPU),
516 FFTGPU<false, false, 3>);
517
518 REGISTER_KERNEL_BUILDER(
519 Name("RFFT").Device(DEVICE_GPU).HostMemory("fft_length"),
520 FFTGPU<true, true, 1>);
521 REGISTER_KERNEL_BUILDER(
522 Name("IRFFT").Device(DEVICE_GPU).HostMemory("fft_length"),
523 FFTGPU<false, true, 1>);
524 REGISTER_KERNEL_BUILDER(
525 Name("RFFT2D").Device(DEVICE_GPU).HostMemory("fft_length"),
526 FFTGPU<true, true, 2>);
527 REGISTER_KERNEL_BUILDER(
528 Name("IRFFT2D").Device(DEVICE_GPU).HostMemory("fft_length"),
529 FFTGPU<false, true, 2>);
530 REGISTER_KERNEL_BUILDER(
531 Name("RFFT3D").Device(DEVICE_GPU).HostMemory("fft_length"),
532 FFTGPU<true, true, 3>);
533 REGISTER_KERNEL_BUILDER(
534 Name("IRFFT3D").Device(DEVICE_GPU).HostMemory("fft_length"),
535 FFTGPU<false, true, 3>);
536
537 // Deprecated kernels.
538 REGISTER_KERNEL_BUILDER(Name("BatchFFT").Device(DEVICE_GPU),
539 FFTGPU<true, false, 1>);
540 REGISTER_KERNEL_BUILDER(Name("BatchIFFT").Device(DEVICE_GPU),
541 FFTGPU<false, false, 1>);
542 REGISTER_KERNEL_BUILDER(Name("BatchFFT2D").Device(DEVICE_GPU),
543 FFTGPU<true, false, 2>);
544 REGISTER_KERNEL_BUILDER(Name("BatchIFFT2D").Device(DEVICE_GPU),
545 FFTGPU<false, false, 2>);
546 REGISTER_KERNEL_BUILDER(Name("BatchFFT3D").Device(DEVICE_GPU),
547 FFTGPU<true, false, 3>);
548 REGISTER_KERNEL_BUILDER(Name("BatchIFFT3D").Device(DEVICE_GPU),
549 FFTGPU<false, false, 3>);
550 #endif // GOOGLE_CUDA
551
552 } // end namespace tensorflow
553