1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/stream_executor/rocm/rocm_fft.h"
17
18 #include <complex>
19
20 #include "tensorflow/stream_executor/device_memory.h"
21 #include "tensorflow/stream_executor/gpu/gpu_activation.h"
22 #include "tensorflow/stream_executor/gpu/gpu_executor.h"
23 #include "tensorflow/stream_executor/gpu/gpu_helpers.h"
24 #include "tensorflow/stream_executor/gpu/gpu_stream.h"
25 #include "tensorflow/stream_executor/lib/env.h"
26 #include "tensorflow/stream_executor/lib/initialize.h"
27 #include "tensorflow/stream_executor/lib/status.h"
28 #include "tensorflow/stream_executor/platform/dso_loader.h"
29 #include "tensorflow/stream_executor/platform/logging.h"
30 #include "tensorflow/stream_executor/platform/port.h"
31 #include "tensorflow/stream_executor/plugin_registry.h"
32 #include "tensorflow/stream_executor/rocm/rocm_platform_id.h"
33 #include "tensorflow/stream_executor/stream_executor_internal.h"
34
35 namespace stream_executor {
36 namespace gpu {
37
38 PLUGIN_REGISTRY_DEFINE_PLUGIN_ID(kRocFftPlugin);
39
40 namespace wrap {
41
42 #ifdef PLATFORM_GOOGLE
43 // This macro wraps a global identifier, given by __name, in a callable
44 // structure that loads the DLL symbol out of the DSO handle in a thread-safe
45 // manner on first use. This dynamic loading technique is used to avoid DSO
46 // dependencies on vendor libraries which may or may not be available in the
47 // deployed binary environment.
48 #define STREAM_EXECUTOR_ROCFFT_WRAP(__name) \
49 struct WrapperShim__##__name { \
50 template <typename... Args> \
51 hipfftResult operator()(GpuExecutor *parent, Args... args) { \
52 gpu::ScopedActivateExecutorContext sac{parent}; \
53 return ::__name(args...); \
54 } \
55 } __name;
56
57 #else
58
59 #define STREAM_EXECUTOR_ROCFFT_WRAP(__name) \
60 struct DynLoadShim__##__name { \
61 static const char *kName; \
62 using FuncPtrT = std::add_pointer<decltype(::__name)>::type; \
63 static void *GetDsoHandle() { \
64 auto s = internal::CachedDsoLoader::GetRocfftDsoHandle(); \
65 return s.ValueOrDie(); \
66 } \
67 static FuncPtrT LoadOrDie() { \
68 void *f; \
69 auto s = port::Env::Default()->GetSymbolFromLibrary(GetDsoHandle(), \
70 kName, &f); \
71 CHECK(s.ok()) << "could not find " << kName \
72 << " in rocfft DSO; dlerror: " << s.error_message(); \
73 return reinterpret_cast<FuncPtrT>(f); \
74 } \
75 static FuncPtrT DynLoad() { \
76 static FuncPtrT f = LoadOrDie(); \
77 return f; \
78 } \
79 template <typename... Args> \
80 hipfftResult operator()(GpuExecutor *parent, Args... args) { \
81 gpu::ScopedActivateExecutorContext sac{parent}; \
82 return DynLoad()(args...); \
83 } \
84 } __name; \
85 const char *DynLoadShim__##__name::kName = #__name;
86
87 #endif
88
89 // clang-format off
90 #define ROCFFT_ROUTINE_EACH(__macro) \
91 __macro(hipfftDestroy) \
92 __macro(hipfftSetStream) \
93 __macro(hipfftPlan1d) \
94 __macro(hipfftPlan2d) \
95 __macro(hipfftPlan3d) \
96 __macro(hipfftPlanMany) \
97 __macro(hipfftCreate) \
98 __macro(hipfftSetAutoAllocation) \
99 __macro(hipfftSetWorkArea) \
100 __macro(hipfftGetSize1d) \
101 __macro(hipfftMakePlan1d) \
102 __macro(hipfftGetSize2d) \
103 __macro(hipfftMakePlan2d) \
104 __macro(hipfftGetSize3d) \
105 __macro(hipfftMakePlan3d) \
106 __macro(hipfftGetSizeMany) \
107 __macro(hipfftMakePlanMany) \
108 __macro(hipfftExecD2Z) \
109 __macro(hipfftExecZ2D) \
110 __macro(hipfftExecC2C) \
111 __macro(hipfftExecC2R) \
112 __macro(hipfftExecZ2Z) \
113 __macro(hipfftExecR2C)
114
115 // clang-format on
116
117 ROCFFT_ROUTINE_EACH(STREAM_EXECUTOR_ROCFFT_WRAP)
118
119 } // namespace wrap
120
121 namespace {
122
123 // A helper function transforming gpu_fft arguments into rocFFT arguments.
ROCMFftType(fft::Type type)124 hipfftType ROCMFftType(fft::Type type) {
125 switch (type) {
126 case fft::Type::kC2CForward:
127 case fft::Type::kC2CInverse:
128 return HIPFFT_C2C;
129 case fft::Type::kC2R:
130 return HIPFFT_C2R;
131 case fft::Type::kR2C:
132 return HIPFFT_R2C;
133 case fft::Type::kZ2ZForward:
134 case fft::Type::kZ2ZInverse:
135 return HIPFFT_Z2Z;
136 case fft::Type::kZ2D:
137 return HIPFFT_Z2D;
138 case fft::Type::kD2Z:
139 return HIPFFT_D2Z;
140 default:
141 LOG(FATAL) << "Invalid value of fft::Type.";
142 }
143 }
144
145 // Associates the given stream with the given rocFFT plan.
SetStream(GpuExecutor * parent,hipfftHandle plan,Stream * stream)146 bool SetStream(GpuExecutor *parent, hipfftHandle plan, Stream *stream) {
147 auto ret = wrap::hipfftSetStream(parent, plan, AsGpuStreamValue(stream));
148 if (ret != HIPFFT_SUCCESS) {
149 LOG(ERROR) << "failed to run rocFFT routine hipfftSetStream: " << ret;
150 return false;
151 }
152 return true;
153 }
154
155 } // namespace
156
Initialize(GpuExecutor * parent,Stream * stream,int rank,uint64 * elem_count,uint64 * input_embed,uint64 input_stride,uint64 input_distance,uint64 * output_embed,uint64 output_stride,uint64 output_distance,fft::Type type,int batch_count,ScratchAllocator * scratch_allocator)157 port::Status ROCMFftPlan::Initialize(
158 GpuExecutor *parent, Stream *stream, int rank, uint64 *elem_count,
159 uint64 *input_embed, uint64 input_stride, uint64 input_distance,
160 uint64 *output_embed, uint64 output_stride, uint64 output_distance,
161 fft::Type type, int batch_count, ScratchAllocator *scratch_allocator) {
162 if (IsInitialized()) {
163 LOG(FATAL) << "Try to repeatedly initialize.";
164 }
165 is_initialized_ = true;
166 int elem_count_[3], input_embed_[3], output_embed_[3];
167 for (int i = 0; i < rank; ++i) {
168 elem_count_[i] = elem_count[i];
169 if (input_embed) {
170 input_embed_[i] = input_embed[i];
171 }
172 if (output_embed) {
173 output_embed_[i] = output_embed[i];
174 }
175 }
176 parent_ = parent;
177 fft_type_ = type;
178 if (batch_count == 1 && input_embed == nullptr && output_embed == nullptr) {
179 hipfftResult_t ret;
180 if (scratch_allocator == nullptr) {
181 switch (rank) {
182 case 1:
183 // hipfftPlan1d
184 ret = wrap::hipfftPlan1d(parent, &plan_, elem_count_[0],
185 ROCMFftType(type), 1 /* = batch */);
186 if (ret != HIPFFT_SUCCESS) {
187 LOG(ERROR) << "failed to create rocFFT 1d plan:" << ret;
188 return port::Status{port::error::INTERNAL,
189 "Failed to create rocFFT 1d plan."};
190 }
191 return port::Status::OK();
192 case 2:
193 // hipfftPlan2d
194 ret = wrap::hipfftPlan2d(parent, &plan_, elem_count_[0],
195 elem_count_[1], ROCMFftType(type));
196 if (ret != HIPFFT_SUCCESS) {
197 LOG(ERROR) << "failed to create rocFFT 2d plan:" << ret;
198 return port::Status{port::error::INTERNAL,
199 "Failed to create rocFFT 2d plan."};
200 }
201 return port::Status::OK();
202 case 3:
203 // hipfftPlan3d
204 ret =
205 wrap::hipfftPlan3d(parent, &plan_, elem_count_[0], elem_count_[1],
206 elem_count_[2], ROCMFftType(type));
207 if (ret != HIPFFT_SUCCESS) {
208 LOG(ERROR) << "failed to create rocFFT 3d plan:" << ret;
209 return port::Status{port::error::INTERNAL,
210 "Failed to create rocFFT 3d plan."};
211 }
212 return port::Status::OK();
213 default:
214 LOG(ERROR) << "Invalid rank value for hipfftPlan. "
215 "Requested 1, 2, or 3, given: "
216 << rank;
217 return port::Status{port::error::INVALID_ARGUMENT,
218 "hipfftPlan only takes rank 1, 2, or 3."};
219 }
220 } else {
221 ret = wrap::hipfftCreate(parent, &plan_);
222 if (ret != HIPFFT_SUCCESS) {
223 LOG(ERROR) << "failed to create rocFFT plan:" << ret;
224 return port::Status{port::error::INTERNAL,
225 "Failed to create rocFFT plan."};
226 }
227 ret = wrap::hipfftSetAutoAllocation(parent, plan_, 0);
228 if (ret != HIPFFT_SUCCESS) {
229 LOG(ERROR) << "failed to set auto allocation for rocFFT plan:" << ret;
230 return port::Status{port::error::INTERNAL,
231 "Failed to set auto allocation for rocFFT plan."};
232 }
233 switch (rank) {
234 case 1:
235 ret = wrap::hipfftMakePlan1d(parent, plan_, elem_count_[0],
236 ROCMFftType(type), /*batch=*/1,
237 &scratch_size_bytes_);
238 if (ret != HIPFFT_SUCCESS) {
239 LOG(ERROR) << "failed to make rocFFT 1d plan:" << ret;
240 return port::Status{port::error::INTERNAL,
241 "Failed to make rocFFT 1d plan."};
242 }
243 break;
244 case 2:
245 ret = wrap::hipfftMakePlan2d(parent, plan_, elem_count_[0],
246 elem_count_[1], ROCMFftType(type),
247 &scratch_size_bytes_);
248 if (ret != HIPFFT_SUCCESS) {
249 LOG(ERROR) << "failed to make rocFFT 2d plan:" << ret;
250 return port::Status{port::error::INTERNAL,
251 "Failed to make rocFFT 2d plan."};
252 }
253 break;
254 case 3:
255 ret = wrap::hipfftMakePlan3d(parent, plan_, elem_count_[0],
256 elem_count_[1], elem_count_[2],
257 ROCMFftType(type), &scratch_size_bytes_);
258 if (ret != HIPFFT_SUCCESS) {
259 LOG(ERROR) << "failed to make rocFFT 3d plan:" << ret;
260 return port::Status{port::error::INTERNAL,
261 "Failed to make rocFFT 3d plan."};
262 }
263 break;
264 default:
265 LOG(ERROR) << "Invalid rank value for hipfftPlan. "
266 "Requested 1, 2, or 3, given: "
267 << rank;
268 return port::Status{port::error::INVALID_ARGUMENT,
269 "hipfftPlan only takes rank 1, 2, or 3."};
270 }
271 return UpdateScratchAllocator(stream, scratch_allocator);
272 }
273 } else {
274 // For either multiple batches or rank higher than 3, use hipfftPlanMany().
275 if (scratch_allocator == nullptr) {
276 auto ret = wrap::hipfftPlanMany(
277 parent, &plan_, rank, elem_count_,
278 input_embed ? input_embed_ : nullptr, input_stride, input_distance,
279 output_embed ? output_embed_ : nullptr, output_stride,
280 output_distance, ROCMFftType(type), batch_count);
281 if (ret != HIPFFT_SUCCESS) {
282 LOG(ERROR) << "failed to create rocFFT batched plan:" << ret;
283 return port::Status{port::error::INTERNAL,
284 "Failed to create rocFFT batched plan."};
285 }
286 } else {
287 auto ret = wrap::hipfftCreate(parent, &plan_);
288 if (ret != HIPFFT_SUCCESS) {
289 LOG(ERROR) << "failed to create rocFFT batched plan:" << ret;
290 return port::Status{port::error::INTERNAL,
291 "Failed to create rocFFT batched plan."};
292 }
293 ret = wrap::hipfftSetAutoAllocation(parent, plan_, 0);
294 if (ret != HIPFFT_SUCCESS) {
295 LOG(ERROR) << "failed to set auto allocation for rocFFT batched plan:"
296 << ret;
297 return port::Status{
298 port::error::INTERNAL,
299 "Failed to set auto allocation for rocFFT batched plan."};
300 }
301 ret = wrap::hipfftMakePlanMany(
302 parent, plan_, rank, elem_count_,
303 input_embed ? input_embed_ : nullptr, input_stride, input_distance,
304 output_embed ? output_embed_ : nullptr, output_stride,
305 output_distance, ROCMFftType(type), batch_count,
306 &scratch_size_bytes_);
307 if (ret != HIPFFT_SUCCESS) {
308 LOG(ERROR) << "failed to make rocFFT batched plan:" << ret;
309 return port::Status{port::error::INTERNAL,
310 "Failed to make rocFFT batched plan."};
311 }
312 return UpdateScratchAllocator(stream, scratch_allocator);
313 }
314 }
315 return port::Status::OK();
316 }
317
Initialize(GpuExecutor * parent,Stream * stream,int rank,uint64 * elem_count,fft::Type type,ScratchAllocator * scratch_allocator)318 port::Status ROCMFftPlan::Initialize(GpuExecutor *parent, Stream *stream,
319 int rank, uint64 *elem_count,
320 fft::Type type,
321 ScratchAllocator *scratch_allocator) {
322 return Initialize(parent_, stream, rank, elem_count,
323 /*input_embed=*/nullptr, /*input_stride=*/0,
324 /*input_distance=*/0,
325 /*output_embed=*/nullptr, /*output_stride=*/0,
326 /*output_distance=*/0, type, 1, scratch_allocator);
327 }
328
UpdateScratchAllocator(Stream * stream,ScratchAllocator * scratch_allocator)329 port::Status ROCMFftPlan::UpdateScratchAllocator(
330 Stream *stream, ScratchAllocator *scratch_allocator) {
331 if (scratch_size_bytes_ != 0) {
332 auto allocated = scratch_allocator->AllocateBytes(scratch_size_bytes_);
333 if (!allocated.ok() || (scratch_ = allocated.ValueOrDie()) == nullptr) {
334 LOG(ERROR) << "failed to allocate work area.";
335 return allocated.status();
336 }
337 }
338 // Connect work area with allocated space.
339 auto ret = wrap::hipfftSetWorkArea(parent_, plan_, scratch_.opaque());
340 if (ret != HIPFFT_SUCCESS) {
341 LOG(ERROR) << "failed to set work area for rocFFT plan:" << ret;
342 return port::Status(port::error::INTERNAL,
343 "Failed to set work area for rocFFT plan.");
344 }
345 return port::Status::OK();
346 }
347
~ROCMFftPlan()348 ROCMFftPlan::~ROCMFftPlan() { wrap::hipfftDestroy(parent_, plan_); }
349
GetFftDirection() const350 int ROCMFftPlan::GetFftDirection() const {
351 if (!IsInitialized()) {
352 LOG(FATAL) << "Try to get fft direction before initialization.";
353 } else {
354 switch (fft_type_) {
355 case fft::Type::kC2CForward:
356 case fft::Type::kZ2ZForward:
357 case fft::Type::kR2C:
358 case fft::Type::kD2Z:
359 return HIPFFT_FORWARD;
360 case fft::Type::kC2CInverse:
361 case fft::Type::kZ2ZInverse:
362 case fft::Type::kC2R:
363 case fft::Type::kZ2D:
364 return HIPFFT_BACKWARD;
365 default:
366 LOG(FATAL) << "Invalid value of fft::Type.";
367 }
368 }
369 }
370
Create1dPlan(Stream * stream,uint64 num_x,fft::Type type,bool in_place_fft)371 std::unique_ptr<fft::Plan> ROCMFft::Create1dPlan(Stream *stream, uint64 num_x,
372 fft::Type type,
373 bool in_place_fft) {
374 std::unique_ptr<ROCMFftPlan> fft_plan_ptr{new ROCMFftPlan()};
375 uint64 elem_count[1] = {num_x};
376 port::Status status = fft_plan_ptr->Initialize(
377 parent_, stream, 1, elem_count, type, /*scratch_allocator=*/nullptr);
378 // TODO(yangzihao): In the future, send error msg back to TensorFlow
379 // so it can fail gracefully,
380 if (!status.ok()) {
381 LOG(FATAL) << "failed to initialize hipfft 1d plan: "
382 << status.error_message();
383 }
384 return std::move(fft_plan_ptr);
385 }
386
Create1dPlanWithScratchAllocator(Stream * stream,uint64 num_x,fft::Type type,bool in_place_fft,ScratchAllocator * scratch_allocator)387 std::unique_ptr<fft::Plan> ROCMFft::Create1dPlanWithScratchAllocator(
388 Stream *stream, uint64 num_x, fft::Type type, bool in_place_fft,
389 ScratchAllocator *scratch_allocator) {
390 std::unique_ptr<ROCMFftPlan> fft_plan_ptr{new ROCMFftPlan()};
391 uint64 elem_count[1] = {num_x};
392 port::Status status = fft_plan_ptr->Initialize(parent_, stream, 1, elem_count,
393 type, scratch_allocator);
394 if (!status.ok()) {
395 LOG(FATAL)
396 << "failed to initialize hipfft 1d plan with customized allocator: "
397 << status.error_message();
398 }
399 return std::move(fft_plan_ptr);
400 }
401
Create2dPlan(Stream * stream,uint64 num_x,uint64 num_y,fft::Type type,bool in_place_fft)402 std::unique_ptr<fft::Plan> ROCMFft::Create2dPlan(Stream *stream, uint64 num_x,
403 uint64 num_y, fft::Type type,
404 bool in_place_fft) {
405 std::unique_ptr<ROCMFftPlan> fft_plan_ptr{new ROCMFftPlan()};
406 uint64 elem_count[2] = {num_x, num_y};
407 port::Status status = fft_plan_ptr->Initialize(
408 parent_, stream, 1, elem_count, type, /*scratch_allocator=*/nullptr);
409 if (!status.ok()) {
410 LOG(FATAL) << "failed to initialize hipfft 2d plan: "
411 << status.error_message();
412 }
413 return std::move(fft_plan_ptr);
414 }
415
Create2dPlanWithScratchAllocator(Stream * stream,uint64 num_x,uint64 num_y,fft::Type type,bool in_place_fft,ScratchAllocator * scratch_allocator)416 std::unique_ptr<fft::Plan> ROCMFft::Create2dPlanWithScratchAllocator(
417 Stream *stream, uint64 num_x, uint64 num_y, fft::Type type,
418 bool in_place_fft, ScratchAllocator *scratch_allocator) {
419 std::unique_ptr<ROCMFftPlan> fft_plan_ptr{new ROCMFftPlan()};
420 uint64 elem_count[2] = {num_x, num_y};
421 port::Status status = fft_plan_ptr->Initialize(parent_, stream, 2, elem_count,
422 type, scratch_allocator);
423 if (!status.ok()) {
424 LOG(FATAL)
425 << "failed to initialize hipfft 2d plan with customized allocator: "
426 << status.error_message();
427 }
428 return std::move(fft_plan_ptr);
429 }
430
Create3dPlan(Stream * stream,uint64 num_x,uint64 num_y,uint64 num_z,fft::Type type,bool in_place_fft)431 std::unique_ptr<fft::Plan> ROCMFft::Create3dPlan(Stream *stream, uint64 num_x,
432 uint64 num_y, uint64 num_z,
433 fft::Type type,
434 bool in_place_fft) {
435 std::unique_ptr<ROCMFftPlan> fft_plan_ptr{new ROCMFftPlan()};
436 uint64 elem_count[3] = {num_x, num_y, num_z};
437 port::Status status = fft_plan_ptr->Initialize(
438 parent_, stream, 3, elem_count, type, /*scratch_allocator=*/nullptr);
439 if (!status.ok()) {
440 LOG(FATAL) << "failed to initialize hipfft 3d plan: "
441 << status.error_message();
442 }
443 return std::move(fft_plan_ptr);
444 }
445
Create3dPlanWithScratchAllocator(Stream * stream,uint64 num_x,uint64 num_y,uint64 num_z,fft::Type type,bool in_place_fft,ScratchAllocator * scratch_allocator)446 std::unique_ptr<fft::Plan> ROCMFft::Create3dPlanWithScratchAllocator(
447 Stream *stream, uint64 num_x, uint64 num_y, uint64 num_z, fft::Type type,
448 bool in_place_fft, ScratchAllocator *scratch_allocator) {
449 std::unique_ptr<ROCMFftPlan> fft_plan_ptr{new ROCMFftPlan()};
450 uint64 elem_count[3] = {num_x, num_y, num_z};
451 port::Status status = fft_plan_ptr->Initialize(parent_, stream, 3, elem_count,
452 type, scratch_allocator);
453 if (!status.ok()) {
454 LOG(FATAL)
455 << "failed to initialize hipfft 3d plan with customized allocator: "
456 << status.error_message();
457 }
458 return std::move(fft_plan_ptr);
459 }
460
CreateBatchedPlan(Stream * stream,int rank,uint64 * elem_count,uint64 * input_embed,uint64 input_stride,uint64 input_distance,uint64 * output_embed,uint64 output_stride,uint64 output_distance,fft::Type type,bool in_place_fft,int batch_count)461 std::unique_ptr<fft::Plan> ROCMFft::CreateBatchedPlan(
462 Stream *stream, int rank, uint64 *elem_count, uint64 *input_embed,
463 uint64 input_stride, uint64 input_distance, uint64 *output_embed,
464 uint64 output_stride, uint64 output_distance, fft::Type type,
465 bool in_place_fft, int batch_count) {
466 std::unique_ptr<ROCMFftPlan> fft_plan_ptr{new ROCMFftPlan()};
467 port::Status status = fft_plan_ptr->Initialize(
468 parent_, stream, rank, elem_count, input_embed, input_stride,
469 input_distance, output_embed, output_stride, output_distance, type,
470 batch_count, /*scratch_allocator=*/nullptr);
471 if (!status.ok()) {
472 LOG(FATAL) << "failed to initialize batched hipfft plan: "
473 << status.error_message();
474 }
475
476 return std::move(fft_plan_ptr);
477 }
478
CreateBatchedPlanWithScratchAllocator(Stream * stream,int rank,uint64 * elem_count,uint64 * input_embed,uint64 input_stride,uint64 input_distance,uint64 * output_embed,uint64 output_stride,uint64 output_distance,fft::Type type,bool in_place_fft,int batch_count,ScratchAllocator * scratch_allocator)479 std::unique_ptr<fft::Plan> ROCMFft::CreateBatchedPlanWithScratchAllocator(
480 Stream *stream, int rank, uint64 *elem_count, uint64 *input_embed,
481 uint64 input_stride, uint64 input_distance, uint64 *output_embed,
482 uint64 output_stride, uint64 output_distance, fft::Type type,
483 bool in_place_fft, int batch_count, ScratchAllocator *scratch_allocator) {
484 std::unique_ptr<ROCMFftPlan> fft_plan_ptr{new ROCMFftPlan()};
485 port::Status status = fft_plan_ptr->Initialize(
486 parent_, stream, rank, elem_count, input_embed, input_stride,
487 input_distance, output_embed, output_stride, output_distance, type,
488 batch_count, scratch_allocator);
489 if (!status.ok()) {
490 LOG(FATAL) << "failed to initialize batched hipfft plan with customized "
491 "allocator: "
492 << status.error_message();
493 }
494 return std::move(fft_plan_ptr);
495 }
496
UpdatePlanWithScratchAllocator(Stream * stream,fft::Plan * plan,ScratchAllocator * scratch_allocator)497 void ROCMFft::UpdatePlanWithScratchAllocator(
498 Stream *stream, fft::Plan *plan, ScratchAllocator *scratch_allocator) {
499 ROCMFftPlan *rocm_fft_plan = dynamic_cast<ROCMFftPlan *>(plan);
500 port::Status status =
501 rocm_fft_plan->UpdateScratchAllocator(stream, scratch_allocator);
502 if (!status.ok()) {
503 LOG(FATAL) << "failed to update custom allocator for hipfft plan: "
504 << status.error_message();
505 }
506 }
507
508 template <typename FuncT, typename InputT, typename OutputT>
DoFftInternal(Stream * stream,fft::Plan * plan,FuncT hipfftExec,const DeviceMemory<InputT> & input,DeviceMemory<OutputT> * output)509 bool ROCMFft::DoFftInternal(Stream *stream, fft::Plan *plan, FuncT hipfftExec,
510 const DeviceMemory<InputT> &input,
511 DeviceMemory<OutputT> *output) {
512 ROCMFftPlan *rocm_fft_plan = dynamic_cast<ROCMFftPlan *>(plan);
513 if (rocm_fft_plan == nullptr) {
514 LOG(ERROR) << "the passed-in plan is not a ROCMFftPlan object.";
515 return false;
516 }
517
518 if (!SetStream(parent_, rocm_fft_plan->GetPlan(), stream)) {
519 return false;
520 }
521
522 auto ret = hipfftExec(parent_, rocm_fft_plan->GetPlan(),
523 GpuComplex(const_cast<InputT *>(GpuMemory(input))),
524 GpuComplex(GpuMemoryMutable(output)));
525
526 if (ret != HIPFFT_SUCCESS) {
527 LOG(ERROR) << "failed to run rocFFT routine: " << ret;
528 return false;
529 }
530
531 return true;
532 }
533
534 template <typename FuncT, typename InputT, typename OutputT>
DoFftWithDirectionInternal(Stream * stream,fft::Plan * plan,FuncT hipfftExec,const DeviceMemory<InputT> & input,DeviceMemory<OutputT> * output)535 bool ROCMFft::DoFftWithDirectionInternal(Stream *stream, fft::Plan *plan,
536 FuncT hipfftExec,
537 const DeviceMemory<InputT> &input,
538 DeviceMemory<OutputT> *output) {
539 ROCMFftPlan *rocm_fft_plan = dynamic_cast<ROCMFftPlan *>(plan);
540 if (rocm_fft_plan == nullptr) {
541 LOG(ERROR) << "the passed-in plan is not a ROCMFftPlan object.";
542 return false;
543 }
544
545 if (!SetStream(parent_, rocm_fft_plan->GetPlan(), stream)) {
546 return false;
547 }
548
549 auto ret = hipfftExec(parent_, rocm_fft_plan->GetPlan(),
550 GpuComplex(const_cast<InputT *>(GpuMemory(input))),
551 GpuComplex(GpuMemoryMutable(output)),
552 rocm_fft_plan->GetFftDirection());
553
554 if (ret != HIPFFT_SUCCESS) {
555 LOG(ERROR) << "failed to run rocFFT routine: " << ret;
556 return false;
557 }
558
559 return true;
560 }
561
562 #define STREAM_EXECUTOR_ROCM_DEFINE_FFT(__type, __fft_type1, __fft_type2, \
563 __fft_type3) \
564 bool ROCMFft::DoFft(Stream *stream, fft::Plan *plan, \
565 const DeviceMemory<std::complex<__type>> &input, \
566 DeviceMemory<std::complex<__type>> *output) { \
567 return DoFftWithDirectionInternal( \
568 stream, plan, wrap::hipfftExec##__fft_type1, input, output); \
569 } \
570 bool ROCMFft::DoFft(Stream *stream, fft::Plan *plan, \
571 const DeviceMemory<__type> &input, \
572 DeviceMemory<std::complex<__type>> *output) { \
573 return DoFftInternal(stream, plan, wrap::hipfftExec##__fft_type2, input, \
574 output); \
575 } \
576 bool ROCMFft::DoFft(Stream *stream, fft::Plan *plan, \
577 const DeviceMemory<std::complex<__type>> &input, \
578 DeviceMemory<__type> *output) { \
579 return DoFftInternal(stream, plan, wrap::hipfftExec##__fft_type3, input, \
580 output); \
581 }
582
583 STREAM_EXECUTOR_ROCM_DEFINE_FFT(float, C2C, R2C, C2R)
584 STREAM_EXECUTOR_ROCM_DEFINE_FFT(double, Z2Z, D2Z, Z2D)
585
586 #undef STREAM_EXECUTOR_ROCM_DEFINE_FFT
587
588 } // namespace gpu
589
initialize_rocfft()590 void initialize_rocfft() {
591 auto rocFftAlreadyRegistered = PluginRegistry::Instance()->HasFactory(
592 rocm::kROCmPlatformId, PluginKind::kFft, gpu::kRocFftPlugin);
593
594 if (!rocFftAlreadyRegistered) {
595 port::Status status =
596 PluginRegistry::Instance()->RegisterFactory<PluginRegistry::FftFactory>(
597 rocm::kROCmPlatformId, gpu::kRocFftPlugin, "rocFFT",
598 [](internal::StreamExecutorInterface *parent) -> fft::FftSupport * {
599 gpu::GpuExecutor *rocm_executor =
600 dynamic_cast<gpu::GpuExecutor *>(parent);
601 if (rocm_executor == nullptr) {
602 LOG(ERROR)
603 << "Attempting to initialize an instance of the rocFFT "
604 << "support library with a non-ROCM StreamExecutor";
605 return nullptr;
606 }
607
608 return new gpu::ROCMFft(rocm_executor);
609 });
610 if (!status.ok()) {
611 LOG(ERROR) << "Unable to register rocFFT factory: "
612 << status.error_message();
613 }
614
615 PluginRegistry::Instance()->SetDefaultFactory(
616 rocm::kROCmPlatformId, PluginKind::kFft, gpu::kRocFftPlugin);
617 }
618 }
619
620 } // namespace stream_executor
621
622 REGISTER_MODULE_INITIALIZER(register_rocfft,
623 { stream_executor::initialize_rocfft(); });
624