• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // ROCM-specific support for FFT functionality -- this wraps the rocFFT library
17 // capabilities, and is only included into ROCM implementation code -- it will
18 // not introduce rocm headers into other code.
19 
20 #ifndef TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_FFT_H_
21 #define TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_FFT_H_
22 
23 #include "rocm/include/rocfft/hipfft.h"
24 #include "tensorflow/stream_executor/fft.h"
25 #include "tensorflow/stream_executor/platform/port.h"
26 #include "tensorflow/stream_executor/plugin_registry.h"
27 #include "tensorflow/stream_executor/scratch_allocator.h"
28 
29 namespace stream_executor {
30 
31 class Stream;
32 
33 namespace gpu {
34 
35 class GpuExecutor;
36 
37 // Opaque and unique indentifier for the rocFFT plugin.
38 extern const PluginId kRocFftPlugin;
39 
40 // ROCMFftPlan uses deferred initialization. Only a single call of
41 // Initialize() is allowed to properly create hipfft plan and set member
42 // variable is_initialized_ to true. Newly added interface that uses member
43 // variables should first check is_initialized_ to make sure that the values of
44 // member variables are valid.
45 class ROCMFftPlan : public fft::Plan {
46  public:
ROCMFftPlan()47   ROCMFftPlan()
48       : parent_(nullptr),
49         plan_(),
50         fft_type_(fft::Type::kInvalid),
51         scratch_(nullptr),
52         scratch_size_bytes_(0),
53         is_initialized_(false) {}
54   ~ROCMFftPlan() override;
55 
56   // Get FFT direction in hipFFT based on FFT type.
57   int GetFftDirection() const;
GetPlan()58   hipfftHandle GetPlan() const {
59     if (IsInitialized()) {
60       return plan_;
61     } else {
62       LOG(FATAL) << "Try to get hipfftHandle value before initialization.";
63     }
64   }
65 
66   // Initialize function for batched plan
67   port::Status Initialize(GpuExecutor *parent, Stream *stream, int rank,
68                           uint64 *elem_count, uint64 *input_embed,
69                           uint64 input_stride, uint64 input_distance,
70                           uint64 *output_embed, uint64 output_stride,
71                           uint64 output_distance, fft::Type type,
72                           int batch_count, ScratchAllocator *scratch_allocator);
73 
74   // Initialize function for 1d,2d, and 3d plan
75   port::Status Initialize(GpuExecutor *parent, Stream *stream, int rank,
76                           uint64 *elem_count, fft::Type type,
77                           ScratchAllocator *scratch_allocator);
78 
79   port::Status UpdateScratchAllocator(Stream *stream,
80                                       ScratchAllocator *scratch_allocator);
81 
82  protected:
IsInitialized()83   bool IsInitialized() const { return is_initialized_; }
84 
85  private:
86   GpuExecutor *parent_;
87   hipfftHandle plan_;
88   fft::Type fft_type_;
89   DeviceMemory<uint8> scratch_;
90   size_t scratch_size_bytes_;
91   bool is_initialized_;
92 };
93 
94 // FFT support for ROCM platform via rocFFT library.
95 //
96 // This satisfies the platform-agnostic FftSupport interface.
97 //
98 // Note that the hipFFT handle that this encapsulates is implicitly tied to the
99 // context (and, as a result, the device) that the parent GpuExecutor is tied
100 // to. This simply happens as an artifact of creating the hipFFT handle when a
101 // ROCM context is active.
102 //
103 // Thread-safe. The ROCM context associated with all operations is the ROCM
104 // context of parent_, so all context is explicit.
105 class ROCMFft : public fft::FftSupport {
106  public:
ROCMFft(GpuExecutor * parent)107   explicit ROCMFft(GpuExecutor *parent) : parent_(parent) {}
~ROCMFft()108   ~ROCMFft() override {}
109 
110   TENSORFLOW_STREAM_EXECUTOR_GPU_FFT_SUPPORT_OVERRIDES
111 
112  private:
113   GpuExecutor *parent_;
114 
115   // Two helper functions that execute dynload::hipfftExec?2?.
116 
117   // This is for complex to complex FFT, when the direction is required.
118   template <typename FuncT, typename InputT, typename OutputT>
119   bool DoFftWithDirectionInternal(Stream *stream, fft::Plan *plan,
120                                   FuncT hipfft_exec,
121                                   const DeviceMemory<InputT> &input,
122                                   DeviceMemory<OutputT> *output);
123 
124   // This is for complex to real or real to complex FFT, when the direction
125   // is implied.
126   template <typename FuncT, typename InputT, typename OutputT>
127   bool DoFftInternal(Stream *stream, fft::Plan *plan, FuncT hipfft_exec,
128                      const DeviceMemory<InputT> &input,
129                      DeviceMemory<OutputT> *output);
130 
131   SE_DISALLOW_COPY_AND_ASSIGN(ROCMFft);
132 };
133 
134 }  // namespace gpu
135 }  // namespace stream_executor
136 
137 #endif  // TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_FFT_H_
138