• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/stream_executor/rocm/rocm_fft.h"
17 
18 #include <complex>
19 
20 #include "tensorflow/stream_executor/device_memory.h"
21 #include "tensorflow/stream_executor/gpu/gpu_activation.h"
22 #include "tensorflow/stream_executor/gpu/gpu_executor.h"
23 #include "tensorflow/stream_executor/gpu/gpu_helpers.h"
24 #include "tensorflow/stream_executor/gpu/gpu_stream.h"
25 #include "tensorflow/stream_executor/lib/env.h"
26 #include "tensorflow/stream_executor/lib/initialize.h"
27 #include "tensorflow/stream_executor/lib/status.h"
28 #include "tensorflow/stream_executor/platform/dso_loader.h"
29 #include "tensorflow/stream_executor/platform/logging.h"
30 #include "tensorflow/stream_executor/platform/port.h"
31 #include "tensorflow/stream_executor/plugin_registry.h"
32 #include "tensorflow/stream_executor/rocm/rocm_platform_id.h"
33 #include "tensorflow/stream_executor/stream_executor_internal.h"
34 
35 namespace stream_executor {
36 namespace gpu {
37 
38 PLUGIN_REGISTRY_DEFINE_PLUGIN_ID(kRocFftPlugin);
39 
40 namespace wrap {
41 
42 #ifdef PLATFORM_GOOGLE
43 // This macro wraps a global identifier, given by __name, in a callable
44 // structure that loads the DLL symbol out of the DSO handle in a thread-safe
45 // manner on first use. This dynamic loading technique is used to avoid DSO
46 // dependencies on vendor libraries which may or may not be available in the
47 // deployed binary environment.
48 #define STREAM_EXECUTOR_ROCFFT_WRAP(__name)                      \
49   struct WrapperShim__##__name {                                 \
50     template <typename... Args>                                  \
51     hipfftResult operator()(GpuExecutor *parent, Args... args) { \
52       gpu::ScopedActivateExecutorContext sac{parent};            \
53       return ::__name(args...);                                  \
54     }                                                            \
55   } __name;
56 
57 #else
58 
59 #define STREAM_EXECUTOR_ROCFFT_WRAP(__name)                               \
60   struct DynLoadShim__##__name {                                          \
61     static const char *kName;                                             \
62     using FuncPtrT = std::add_pointer<decltype(::__name)>::type;          \
63     static void *GetDsoHandle() {                                         \
64       auto s = internal::CachedDsoLoader::GetRocfftDsoHandle();           \
65       return s.ValueOrDie();                                              \
66     }                                                                     \
67     static FuncPtrT LoadOrDie() {                                         \
68       void *f;                                                            \
69       auto s = port::Env::Default()->GetSymbolFromLibrary(GetDsoHandle(), \
70                                                           kName, &f);     \
71       CHECK(s.ok()) << "could not find " << kName                         \
72                     << " in rocfft DSO; dlerror: " << s.error_message();  \
73       return reinterpret_cast<FuncPtrT>(f);                               \
74     }                                                                     \
75     static FuncPtrT DynLoad() {                                           \
76       static FuncPtrT f = LoadOrDie();                                    \
77       return f;                                                           \
78     }                                                                     \
79     template <typename... Args>                                           \
80     hipfftResult operator()(GpuExecutor *parent, Args... args) {          \
81       gpu::ScopedActivateExecutorContext sac{parent};                     \
82       return DynLoad()(args...);                                          \
83     }                                                                     \
84   } __name;                                                               \
85   const char *DynLoadShim__##__name::kName = #__name;
86 
87 #endif
88 
89 // clang-format off
90 #define ROCFFT_ROUTINE_EACH(__macro) \
91   __macro(hipfftDestroy)             \
92   __macro(hipfftSetStream)           \
93   __macro(hipfftPlan1d)              \
94   __macro(hipfftPlan2d)              \
95   __macro(hipfftPlan3d)              \
96   __macro(hipfftPlanMany)            \
97   __macro(hipfftCreate)              \
98   __macro(hipfftSetAutoAllocation)   \
99   __macro(hipfftSetWorkArea)         \
100   __macro(hipfftGetSize1d)           \
101   __macro(hipfftMakePlan1d)          \
102   __macro(hipfftGetSize2d)           \
103   __macro(hipfftMakePlan2d)          \
104   __macro(hipfftGetSize3d)           \
105   __macro(hipfftMakePlan3d)          \
106   __macro(hipfftGetSizeMany)         \
107   __macro(hipfftMakePlanMany)        \
108   __macro(hipfftExecD2Z)             \
109   __macro(hipfftExecZ2D)             \
110   __macro(hipfftExecC2C)             \
111   __macro(hipfftExecC2R)             \
112   __macro(hipfftExecZ2Z)             \
113   __macro(hipfftExecR2C)
114 
115 // clang-format on
116 
117 ROCFFT_ROUTINE_EACH(STREAM_EXECUTOR_ROCFFT_WRAP)
118 
119 }  // namespace wrap
120 
121 namespace {
122 
123 // A helper function transforming gpu_fft arguments into rocFFT arguments.
ROCMFftType(fft::Type type)124 hipfftType ROCMFftType(fft::Type type) {
125   switch (type) {
126     case fft::Type::kC2CForward:
127     case fft::Type::kC2CInverse:
128       return HIPFFT_C2C;
129     case fft::Type::kC2R:
130       return HIPFFT_C2R;
131     case fft::Type::kR2C:
132       return HIPFFT_R2C;
133     case fft::Type::kZ2ZForward:
134     case fft::Type::kZ2ZInverse:
135       return HIPFFT_Z2Z;
136     case fft::Type::kZ2D:
137       return HIPFFT_Z2D;
138     case fft::Type::kD2Z:
139       return HIPFFT_D2Z;
140     default:
141       LOG(FATAL) << "Invalid value of fft::Type.";
142   }
143 }
144 
145 // Associates the given stream with the given rocFFT plan.
SetStream(GpuExecutor * parent,hipfftHandle plan,Stream * stream)146 bool SetStream(GpuExecutor *parent, hipfftHandle plan, Stream *stream) {
147   auto ret = wrap::hipfftSetStream(parent, plan, AsGpuStreamValue(stream));
148   if (ret != HIPFFT_SUCCESS) {
149     LOG(ERROR) << "failed to run rocFFT routine hipfftSetStream: " << ret;
150     return false;
151   }
152   return true;
153 }
154 
155 }  // namespace
156 
Initialize(GpuExecutor * parent,Stream * stream,int rank,uint64 * elem_count,uint64 * input_embed,uint64 input_stride,uint64 input_distance,uint64 * output_embed,uint64 output_stride,uint64 output_distance,fft::Type type,int batch_count,ScratchAllocator * scratch_allocator)157 port::Status ROCMFftPlan::Initialize(
158     GpuExecutor *parent, Stream *stream, int rank, uint64 *elem_count,
159     uint64 *input_embed, uint64 input_stride, uint64 input_distance,
160     uint64 *output_embed, uint64 output_stride, uint64 output_distance,
161     fft::Type type, int batch_count, ScratchAllocator *scratch_allocator) {
162   if (IsInitialized()) {
163     LOG(FATAL) << "Try to repeatedly initialize.";
164   }
165   is_initialized_ = true;
166   int elem_count_[3], input_embed_[3], output_embed_[3];
167   for (int i = 0; i < rank; ++i) {
168     elem_count_[i] = elem_count[i];
169     if (input_embed) {
170       input_embed_[i] = input_embed[i];
171     }
172     if (output_embed) {
173       output_embed_[i] = output_embed[i];
174     }
175   }
176   parent_ = parent;
177   fft_type_ = type;
178   if (batch_count == 1 && input_embed == nullptr && output_embed == nullptr) {
179     hipfftResult_t ret;
180     if (scratch_allocator == nullptr) {
181       switch (rank) {
182         case 1:
183           // hipfftPlan1d
184           ret = wrap::hipfftPlan1d(parent, &plan_, elem_count_[0],
185                                    ROCMFftType(type), 1 /* = batch */);
186           if (ret != HIPFFT_SUCCESS) {
187             LOG(ERROR) << "failed to create rocFFT 1d plan:" << ret;
188             return port::Status{port::error::INTERNAL,
189                                 "Failed to create rocFFT 1d plan."};
190           }
191           return port::Status::OK();
192         case 2:
193           // hipfftPlan2d
194           ret = wrap::hipfftPlan2d(parent, &plan_, elem_count_[0],
195                                    elem_count_[1], ROCMFftType(type));
196           if (ret != HIPFFT_SUCCESS) {
197             LOG(ERROR) << "failed to create rocFFT 2d plan:" << ret;
198             return port::Status{port::error::INTERNAL,
199                                 "Failed to create rocFFT 2d plan."};
200           }
201           return port::Status::OK();
202         case 3:
203           // hipfftPlan3d
204           ret =
205               wrap::hipfftPlan3d(parent, &plan_, elem_count_[0], elem_count_[1],
206                                  elem_count_[2], ROCMFftType(type));
207           if (ret != HIPFFT_SUCCESS) {
208             LOG(ERROR) << "failed to create rocFFT 3d plan:" << ret;
209             return port::Status{port::error::INTERNAL,
210                                 "Failed to create rocFFT 3d plan."};
211           }
212           return port::Status::OK();
213         default:
214           LOG(ERROR) << "Invalid rank value for hipfftPlan. "
215                         "Requested 1, 2, or 3, given: "
216                      << rank;
217           return port::Status{port::error::INVALID_ARGUMENT,
218                               "hipfftPlan only takes rank 1, 2, or 3."};
219       }
220     } else {
221       ret = wrap::hipfftCreate(parent, &plan_);
222       if (ret != HIPFFT_SUCCESS) {
223         LOG(ERROR) << "failed to create rocFFT plan:" << ret;
224         return port::Status{port::error::INTERNAL,
225                             "Failed to create rocFFT plan."};
226       }
227       ret = wrap::hipfftSetAutoAllocation(parent, plan_, 0);
228       if (ret != HIPFFT_SUCCESS) {
229         LOG(ERROR) << "failed to set auto allocation for rocFFT plan:" << ret;
230         return port::Status{port::error::INTERNAL,
231                             "Failed to set auto allocation for rocFFT plan."};
232       }
233       switch (rank) {
234         case 1:
235           ret = wrap::hipfftMakePlan1d(parent, plan_, elem_count_[0],
236                                        ROCMFftType(type), /*batch=*/1,
237                                        &scratch_size_bytes_);
238           if (ret != HIPFFT_SUCCESS) {
239             LOG(ERROR) << "failed to make rocFFT 1d plan:" << ret;
240             return port::Status{port::error::INTERNAL,
241                                 "Failed to make rocFFT 1d plan."};
242           }
243           break;
244         case 2:
245           ret = wrap::hipfftMakePlan2d(parent, plan_, elem_count_[0],
246                                        elem_count_[1], ROCMFftType(type),
247                                        &scratch_size_bytes_);
248           if (ret != HIPFFT_SUCCESS) {
249             LOG(ERROR) << "failed to make rocFFT 2d plan:" << ret;
250             return port::Status{port::error::INTERNAL,
251                                 "Failed to make rocFFT 2d plan."};
252           }
253           break;
254         case 3:
255           ret = wrap::hipfftMakePlan3d(parent, plan_, elem_count_[0],
256                                        elem_count_[1], elem_count_[2],
257                                        ROCMFftType(type), &scratch_size_bytes_);
258           if (ret != HIPFFT_SUCCESS) {
259             LOG(ERROR) << "failed to make rocFFT 3d plan:" << ret;
260             return port::Status{port::error::INTERNAL,
261                                 "Failed to make rocFFT 3d plan."};
262           }
263           break;
264         default:
265           LOG(ERROR) << "Invalid rank value for hipfftPlan. "
266                         "Requested 1, 2, or 3, given: "
267                      << rank;
268           return port::Status{port::error::INVALID_ARGUMENT,
269                               "hipfftPlan only takes rank 1, 2, or 3."};
270       }
271       return UpdateScratchAllocator(stream, scratch_allocator);
272     }
273   } else {
274     // For either multiple batches or rank higher than 3, use hipfftPlanMany().
275     if (scratch_allocator == nullptr) {
276       auto ret = wrap::hipfftPlanMany(
277           parent, &plan_, rank, elem_count_,
278           input_embed ? input_embed_ : nullptr, input_stride, input_distance,
279           output_embed ? output_embed_ : nullptr, output_stride,
280           output_distance, ROCMFftType(type), batch_count);
281       if (ret != HIPFFT_SUCCESS) {
282         LOG(ERROR) << "failed to create rocFFT batched plan:" << ret;
283         return port::Status{port::error::INTERNAL,
284                             "Failed to create rocFFT batched plan."};
285       }
286     } else {
287       auto ret = wrap::hipfftCreate(parent, &plan_);
288       if (ret != HIPFFT_SUCCESS) {
289         LOG(ERROR) << "failed to create rocFFT batched plan:" << ret;
290         return port::Status{port::error::INTERNAL,
291                             "Failed to create rocFFT batched plan."};
292       }
293       ret = wrap::hipfftSetAutoAllocation(parent, plan_, 0);
294       if (ret != HIPFFT_SUCCESS) {
295         LOG(ERROR) << "failed to set auto allocation for rocFFT batched plan:"
296                    << ret;
297         return port::Status{
298             port::error::INTERNAL,
299             "Failed to set auto allocation for rocFFT batched plan."};
300       }
301       ret = wrap::hipfftMakePlanMany(
302           parent, plan_, rank, elem_count_,
303           input_embed ? input_embed_ : nullptr, input_stride, input_distance,
304           output_embed ? output_embed_ : nullptr, output_stride,
305           output_distance, ROCMFftType(type), batch_count,
306           &scratch_size_bytes_);
307       if (ret != HIPFFT_SUCCESS) {
308         LOG(ERROR) << "failed to make rocFFT batched plan:" << ret;
309         return port::Status{port::error::INTERNAL,
310                             "Failed to make rocFFT batched plan."};
311       }
312       return UpdateScratchAllocator(stream, scratch_allocator);
313     }
314   }
315   return port::Status::OK();
316 }
317 
Initialize(GpuExecutor * parent,Stream * stream,int rank,uint64 * elem_count,fft::Type type,ScratchAllocator * scratch_allocator)318 port::Status ROCMFftPlan::Initialize(GpuExecutor *parent, Stream *stream,
319                                      int rank, uint64 *elem_count,
320                                      fft::Type type,
321                                      ScratchAllocator *scratch_allocator) {
322   return Initialize(parent_, stream, rank, elem_count,
323                     /*input_embed=*/nullptr, /*input_stride=*/0,
324                     /*input_distance=*/0,
325                     /*output_embed=*/nullptr, /*output_stride=*/0,
326                     /*output_distance=*/0, type, 1, scratch_allocator);
327 }
328 
UpdateScratchAllocator(Stream * stream,ScratchAllocator * scratch_allocator)329 port::Status ROCMFftPlan::UpdateScratchAllocator(
330     Stream *stream, ScratchAllocator *scratch_allocator) {
331   if (scratch_size_bytes_ != 0) {
332     auto allocated = scratch_allocator->AllocateBytes(scratch_size_bytes_);
333     if (!allocated.ok() || (scratch_ = allocated.ValueOrDie()) == nullptr) {
334       LOG(ERROR) << "failed to allocate work area.";
335       return allocated.status();
336     }
337   }
338   // Connect work area with allocated space.
339   auto ret = wrap::hipfftSetWorkArea(parent_, plan_, scratch_.opaque());
340   if (ret != HIPFFT_SUCCESS) {
341     LOG(ERROR) << "failed to set work area for rocFFT plan:" << ret;
342     return port::Status(port::error::INTERNAL,
343                         "Failed to set work area for rocFFT plan.");
344   }
345   return port::Status::OK();
346 }
347 
~ROCMFftPlan()348 ROCMFftPlan::~ROCMFftPlan() { wrap::hipfftDestroy(parent_, plan_); }
349 
GetFftDirection() const350 int ROCMFftPlan::GetFftDirection() const {
351   if (!IsInitialized()) {
352     LOG(FATAL) << "Try to get fft direction before initialization.";
353   } else {
354     switch (fft_type_) {
355       case fft::Type::kC2CForward:
356       case fft::Type::kZ2ZForward:
357       case fft::Type::kR2C:
358       case fft::Type::kD2Z:
359         return HIPFFT_FORWARD;
360       case fft::Type::kC2CInverse:
361       case fft::Type::kZ2ZInverse:
362       case fft::Type::kC2R:
363       case fft::Type::kZ2D:
364         return HIPFFT_BACKWARD;
365       default:
366         LOG(FATAL) << "Invalid value of fft::Type.";
367     }
368   }
369 }
370 
Create1dPlan(Stream * stream,uint64 num_x,fft::Type type,bool in_place_fft)371 std::unique_ptr<fft::Plan> ROCMFft::Create1dPlan(Stream *stream, uint64 num_x,
372                                                  fft::Type type,
373                                                  bool in_place_fft) {
374   std::unique_ptr<ROCMFftPlan> fft_plan_ptr{new ROCMFftPlan()};
375   uint64 elem_count[1] = {num_x};
376   port::Status status = fft_plan_ptr->Initialize(
377       parent_, stream, 1, elem_count, type, /*scratch_allocator=*/nullptr);
378   // TODO(yangzihao): In the future, send error msg back to TensorFlow
379   // so it can fail gracefully,
380   if (!status.ok()) {
381     LOG(FATAL) << "failed to initialize hipfft 1d plan: "
382                << status.error_message();
383   }
384   return std::move(fft_plan_ptr);
385 }
386 
Create1dPlanWithScratchAllocator(Stream * stream,uint64 num_x,fft::Type type,bool in_place_fft,ScratchAllocator * scratch_allocator)387 std::unique_ptr<fft::Plan> ROCMFft::Create1dPlanWithScratchAllocator(
388     Stream *stream, uint64 num_x, fft::Type type, bool in_place_fft,
389     ScratchAllocator *scratch_allocator) {
390   std::unique_ptr<ROCMFftPlan> fft_plan_ptr{new ROCMFftPlan()};
391   uint64 elem_count[1] = {num_x};
392   port::Status status = fft_plan_ptr->Initialize(parent_, stream, 1, elem_count,
393                                                  type, scratch_allocator);
394   if (!status.ok()) {
395     LOG(FATAL)
396         << "failed to initialize hipfft 1d plan with customized allocator: "
397         << status.error_message();
398   }
399   return std::move(fft_plan_ptr);
400 }
401 
Create2dPlan(Stream * stream,uint64 num_x,uint64 num_y,fft::Type type,bool in_place_fft)402 std::unique_ptr<fft::Plan> ROCMFft::Create2dPlan(Stream *stream, uint64 num_x,
403                                                  uint64 num_y, fft::Type type,
404                                                  bool in_place_fft) {
405   std::unique_ptr<ROCMFftPlan> fft_plan_ptr{new ROCMFftPlan()};
406   uint64 elem_count[2] = {num_x, num_y};
407   port::Status status = fft_plan_ptr->Initialize(
408       parent_, stream, 1, elem_count, type, /*scratch_allocator=*/nullptr);
409   if (!status.ok()) {
410     LOG(FATAL) << "failed to initialize hipfft 2d plan: "
411                << status.error_message();
412   }
413   return std::move(fft_plan_ptr);
414 }
415 
Create2dPlanWithScratchAllocator(Stream * stream,uint64 num_x,uint64 num_y,fft::Type type,bool in_place_fft,ScratchAllocator * scratch_allocator)416 std::unique_ptr<fft::Plan> ROCMFft::Create2dPlanWithScratchAllocator(
417     Stream *stream, uint64 num_x, uint64 num_y, fft::Type type,
418     bool in_place_fft, ScratchAllocator *scratch_allocator) {
419   std::unique_ptr<ROCMFftPlan> fft_plan_ptr{new ROCMFftPlan()};
420   uint64 elem_count[2] = {num_x, num_y};
421   port::Status status = fft_plan_ptr->Initialize(parent_, stream, 2, elem_count,
422                                                  type, scratch_allocator);
423   if (!status.ok()) {
424     LOG(FATAL)
425         << "failed to initialize hipfft 2d plan with customized allocator: "
426         << status.error_message();
427   }
428   return std::move(fft_plan_ptr);
429 }
430 
Create3dPlan(Stream * stream,uint64 num_x,uint64 num_y,uint64 num_z,fft::Type type,bool in_place_fft)431 std::unique_ptr<fft::Plan> ROCMFft::Create3dPlan(Stream *stream, uint64 num_x,
432                                                  uint64 num_y, uint64 num_z,
433                                                  fft::Type type,
434                                                  bool in_place_fft) {
435   std::unique_ptr<ROCMFftPlan> fft_plan_ptr{new ROCMFftPlan()};
436   uint64 elem_count[3] = {num_x, num_y, num_z};
437   port::Status status = fft_plan_ptr->Initialize(
438       parent_, stream, 3, elem_count, type, /*scratch_allocator=*/nullptr);
439   if (!status.ok()) {
440     LOG(FATAL) << "failed to initialize hipfft 3d plan: "
441                << status.error_message();
442   }
443   return std::move(fft_plan_ptr);
444 }
445 
Create3dPlanWithScratchAllocator(Stream * stream,uint64 num_x,uint64 num_y,uint64 num_z,fft::Type type,bool in_place_fft,ScratchAllocator * scratch_allocator)446 std::unique_ptr<fft::Plan> ROCMFft::Create3dPlanWithScratchAllocator(
447     Stream *stream, uint64 num_x, uint64 num_y, uint64 num_z, fft::Type type,
448     bool in_place_fft, ScratchAllocator *scratch_allocator) {
449   std::unique_ptr<ROCMFftPlan> fft_plan_ptr{new ROCMFftPlan()};
450   uint64 elem_count[3] = {num_x, num_y, num_z};
451   port::Status status = fft_plan_ptr->Initialize(parent_, stream, 3, elem_count,
452                                                  type, scratch_allocator);
453   if (!status.ok()) {
454     LOG(FATAL)
455         << "failed to initialize hipfft 3d plan with customized allocator: "
456         << status.error_message();
457   }
458   return std::move(fft_plan_ptr);
459 }
460 
CreateBatchedPlan(Stream * stream,int rank,uint64 * elem_count,uint64 * input_embed,uint64 input_stride,uint64 input_distance,uint64 * output_embed,uint64 output_stride,uint64 output_distance,fft::Type type,bool in_place_fft,int batch_count)461 std::unique_ptr<fft::Plan> ROCMFft::CreateBatchedPlan(
462     Stream *stream, int rank, uint64 *elem_count, uint64 *input_embed,
463     uint64 input_stride, uint64 input_distance, uint64 *output_embed,
464     uint64 output_stride, uint64 output_distance, fft::Type type,
465     bool in_place_fft, int batch_count) {
466   std::unique_ptr<ROCMFftPlan> fft_plan_ptr{new ROCMFftPlan()};
467   port::Status status = fft_plan_ptr->Initialize(
468       parent_, stream, rank, elem_count, input_embed, input_stride,
469       input_distance, output_embed, output_stride, output_distance, type,
470       batch_count, /*scratch_allocator=*/nullptr);
471   if (!status.ok()) {
472     LOG(FATAL) << "failed to initialize batched hipfft plan: "
473                << status.error_message();
474   }
475 
476   return std::move(fft_plan_ptr);
477 }
478 
CreateBatchedPlanWithScratchAllocator(Stream * stream,int rank,uint64 * elem_count,uint64 * input_embed,uint64 input_stride,uint64 input_distance,uint64 * output_embed,uint64 output_stride,uint64 output_distance,fft::Type type,bool in_place_fft,int batch_count,ScratchAllocator * scratch_allocator)479 std::unique_ptr<fft::Plan> ROCMFft::CreateBatchedPlanWithScratchAllocator(
480     Stream *stream, int rank, uint64 *elem_count, uint64 *input_embed,
481     uint64 input_stride, uint64 input_distance, uint64 *output_embed,
482     uint64 output_stride, uint64 output_distance, fft::Type type,
483     bool in_place_fft, int batch_count, ScratchAllocator *scratch_allocator) {
484   std::unique_ptr<ROCMFftPlan> fft_plan_ptr{new ROCMFftPlan()};
485   port::Status status = fft_plan_ptr->Initialize(
486       parent_, stream, rank, elem_count, input_embed, input_stride,
487       input_distance, output_embed, output_stride, output_distance, type,
488       batch_count, scratch_allocator);
489   if (!status.ok()) {
490     LOG(FATAL) << "failed to initialize batched hipfft plan with customized "
491                   "allocator: "
492                << status.error_message();
493   }
494   return std::move(fft_plan_ptr);
495 }
496 
UpdatePlanWithScratchAllocator(Stream * stream,fft::Plan * plan,ScratchAllocator * scratch_allocator)497 void ROCMFft::UpdatePlanWithScratchAllocator(
498     Stream *stream, fft::Plan *plan, ScratchAllocator *scratch_allocator) {
499   ROCMFftPlan *rocm_fft_plan = dynamic_cast<ROCMFftPlan *>(plan);
500   port::Status status =
501       rocm_fft_plan->UpdateScratchAllocator(stream, scratch_allocator);
502   if (!status.ok()) {
503     LOG(FATAL) << "failed to update custom allocator for hipfft plan: "
504                << status.error_message();
505   }
506 }
507 
508 template <typename FuncT, typename InputT, typename OutputT>
DoFftInternal(Stream * stream,fft::Plan * plan,FuncT hipfftExec,const DeviceMemory<InputT> & input,DeviceMemory<OutputT> * output)509 bool ROCMFft::DoFftInternal(Stream *stream, fft::Plan *plan, FuncT hipfftExec,
510                             const DeviceMemory<InputT> &input,
511                             DeviceMemory<OutputT> *output) {
512   ROCMFftPlan *rocm_fft_plan = dynamic_cast<ROCMFftPlan *>(plan);
513   if (rocm_fft_plan == nullptr) {
514     LOG(ERROR) << "the passed-in plan is not a ROCMFftPlan object.";
515     return false;
516   }
517 
518   if (!SetStream(parent_, rocm_fft_plan->GetPlan(), stream)) {
519     return false;
520   }
521 
522   auto ret = hipfftExec(parent_, rocm_fft_plan->GetPlan(),
523                         GpuComplex(const_cast<InputT *>(GpuMemory(input))),
524                         GpuComplex(GpuMemoryMutable(output)));
525 
526   if (ret != HIPFFT_SUCCESS) {
527     LOG(ERROR) << "failed to run rocFFT routine: " << ret;
528     return false;
529   }
530 
531   return true;
532 }
533 
534 template <typename FuncT, typename InputT, typename OutputT>
DoFftWithDirectionInternal(Stream * stream,fft::Plan * plan,FuncT hipfftExec,const DeviceMemory<InputT> & input,DeviceMemory<OutputT> * output)535 bool ROCMFft::DoFftWithDirectionInternal(Stream *stream, fft::Plan *plan,
536                                          FuncT hipfftExec,
537                                          const DeviceMemory<InputT> &input,
538                                          DeviceMemory<OutputT> *output) {
539   ROCMFftPlan *rocm_fft_plan = dynamic_cast<ROCMFftPlan *>(plan);
540   if (rocm_fft_plan == nullptr) {
541     LOG(ERROR) << "the passed-in plan is not a ROCMFftPlan object.";
542     return false;
543   }
544 
545   if (!SetStream(parent_, rocm_fft_plan->GetPlan(), stream)) {
546     return false;
547   }
548 
549   auto ret = hipfftExec(parent_, rocm_fft_plan->GetPlan(),
550                         GpuComplex(const_cast<InputT *>(GpuMemory(input))),
551                         GpuComplex(GpuMemoryMutable(output)),
552                         rocm_fft_plan->GetFftDirection());
553 
554   if (ret != HIPFFT_SUCCESS) {
555     LOG(ERROR) << "failed to run rocFFT routine: " << ret;
556     return false;
557   }
558 
559   return true;
560 }
561 
562 #define STREAM_EXECUTOR_ROCM_DEFINE_FFT(__type, __fft_type1, __fft_type2,    \
563                                         __fft_type3)                         \
564   bool ROCMFft::DoFft(Stream *stream, fft::Plan *plan,                       \
565                       const DeviceMemory<std::complex<__type>> &input,       \
566                       DeviceMemory<std::complex<__type>> *output) {          \
567     return DoFftWithDirectionInternal(                                       \
568         stream, plan, wrap::hipfftExec##__fft_type1, input, output);         \
569   }                                                                          \
570   bool ROCMFft::DoFft(Stream *stream, fft::Plan *plan,                       \
571                       const DeviceMemory<__type> &input,                     \
572                       DeviceMemory<std::complex<__type>> *output) {          \
573     return DoFftInternal(stream, plan, wrap::hipfftExec##__fft_type2, input, \
574                          output);                                            \
575   }                                                                          \
576   bool ROCMFft::DoFft(Stream *stream, fft::Plan *plan,                       \
577                       const DeviceMemory<std::complex<__type>> &input,       \
578                       DeviceMemory<__type> *output) {                        \
579     return DoFftInternal(stream, plan, wrap::hipfftExec##__fft_type3, input, \
580                          output);                                            \
581   }
582 
583 STREAM_EXECUTOR_ROCM_DEFINE_FFT(float, C2C, R2C, C2R)
584 STREAM_EXECUTOR_ROCM_DEFINE_FFT(double, Z2Z, D2Z, Z2D)
585 
586 #undef STREAM_EXECUTOR_ROCM_DEFINE_FFT
587 
588 }  // namespace gpu
589 
initialize_rocfft()590 void initialize_rocfft() {
591   auto rocFftAlreadyRegistered = PluginRegistry::Instance()->HasFactory(
592       rocm::kROCmPlatformId, PluginKind::kFft, gpu::kRocFftPlugin);
593 
594   if (!rocFftAlreadyRegistered) {
595     port::Status status =
596         PluginRegistry::Instance()->RegisterFactory<PluginRegistry::FftFactory>(
597             rocm::kROCmPlatformId, gpu::kRocFftPlugin, "rocFFT",
598             [](internal::StreamExecutorInterface *parent) -> fft::FftSupport * {
599               gpu::GpuExecutor *rocm_executor =
600                   dynamic_cast<gpu::GpuExecutor *>(parent);
601               if (rocm_executor == nullptr) {
602                 LOG(ERROR)
603                     << "Attempting to initialize an instance of the rocFFT "
604                     << "support library with a non-ROCM StreamExecutor";
605                 return nullptr;
606               }
607 
608               return new gpu::ROCMFft(rocm_executor);
609             });
610     if (!status.ok()) {
611       LOG(ERROR) << "Unable to register rocFFT factory: "
612                  << status.error_message();
613     }
614 
615     PluginRegistry::Instance()->SetDefaultFactory(
616         rocm::kROCmPlatformId, PluginKind::kFft, gpu::kRocFftPlugin);
617   }
618 }
619 
620 }  // namespace stream_executor
621 
622 REGISTER_MODULE_INITIALIZER(register_rocfft,
623                             { stream_executor::initialize_rocfft(); });
624