• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Exposes the family of FFT routines as pre-canned high performance calls for
17 // use in conjunction with the StreamExecutor abstraction.
18 //
19 // Note that this interface is optionally supported by platforms; see
20 // StreamExecutor::SupportsFft() for details.
21 //
22 // This abstraction makes it simple to entrain FFT operations on GPU data into
23 // a Stream -- users typically will not use this API directly, but will use the
24 // Stream builder methods to entrain these operations "under the hood". For
25 // example:
26 //
27 //  DeviceMemory<std::complex<float>> x =
28 //    stream_exec->AllocateArray<std::complex<float>>(1024);
29 //  DeviceMemory<std::complex<float>> y =
30 //    stream_exec->AllocateArray<std::complex<float>>(1024);
31 //  // ... populate x and y ...
32 //  Stream stream{stream_exec};
33 //  std::unique_ptr<Plan> plan =
34 //     stream_exec.AsFft()->Create1dPlan(&stream, 1024, Type::kC2CForward);
35 //  stream
36 //    .Init()
37 //    .ThenFft(plan.get(), x, &y);
38 //  SE_CHECK_OK(stream.BlockHostUntilDone());
39 //
40 // By using stream operations in this manner the user can easily intermix custom
41 // kernel launches (via StreamExecutor::ThenLaunch()) with these pre-canned FFT
42 // routines.
43 
44 #ifndef TENSORFLOW_STREAM_EXECUTOR_FFT_H_
45 #define TENSORFLOW_STREAM_EXECUTOR_FFT_H_
46 
47 #include <complex>
48 #include <memory>
49 #include "tensorflow/stream_executor/platform/port.h"
50 
51 namespace stream_executor {
52 
53 class Stream;
54 template <typename ElemT>
55 class DeviceMemory;
56 class ScratchAllocator;
57 
58 namespace fft {
59 
60 // Specifies FFT input and output types, and the direction.
61 // R, D, C, and Z stand for SP real, DP real, SP complex, and DP complex.
62 enum class Type {
63   kInvalid,
64   kC2CForward,
65   kC2CInverse,
66   kC2R,
67   kR2C,
68   kZ2ZForward,
69   kZ2ZInverse,
70   kZ2D,
71   kD2Z
72 };
73 
74 // FFT plan class. Each FFT implementation should define a plan class that is
75 // derived from this class. It does not provide any interface but serves
76 // as a common type that is used to execute the plan.
77 class Plan {
78  public:
~Plan()79   virtual ~Plan() {}
80 };
81 
82 // FFT support interface -- this can be derived from a GPU executor when the
83 // underlying platform has an FFT library implementation available. See
84 // StreamExecutor::AsFft().
85 //
86 // This support interface is not generally thread-safe; it is only thread-safe
87 // for the CUDA platform (cuFFT) usage; host side FFT support is known
88 // thread-compatible, but not thread-safe.
89 class FftSupport {
90  public:
~FftSupport()91   virtual ~FftSupport() {}
92 
93   // Creates a 1d FFT plan.
94   virtual std::unique_ptr<Plan> Create1dPlan(Stream *stream, uint64 num_x,
95                                              Type type, bool in_place_fft) = 0;
96 
97   // Creates a 2d FFT plan.
98   virtual std::unique_ptr<Plan> Create2dPlan(Stream *stream, uint64 num_x,
99                                              uint64 num_y, Type type,
100                                              bool in_place_fft) = 0;
101 
102   // Creates a 3d FFT plan.
103   virtual std::unique_ptr<Plan> Create3dPlan(Stream *stream, uint64 num_x,
104                                              uint64 num_y, uint64 num_z,
105                                              Type type, bool in_place_fft) = 0;
106 
107   // Creates a 1d FFT plan with scratch allocator.
108   virtual std::unique_ptr<Plan> Create1dPlanWithScratchAllocator(
109       Stream *stream, uint64 num_x, Type type, bool in_place_fft,
110       ScratchAllocator *scratch_allocator) = 0;
111 
112   // Creates a 2d FFT plan with scratch allocator.
113   virtual std::unique_ptr<Plan> Create2dPlanWithScratchAllocator(
114       Stream *stream, uint64 num_x, uint64 num_y, Type type, bool in_place_fft,
115       ScratchAllocator *scratch_allocator) = 0;
116 
117   // Creates a 3d FFT plan with scratch allocator.
118   virtual std::unique_ptr<Plan> Create3dPlanWithScratchAllocator(
119       Stream *stream, uint64 num_x, uint64 num_y, uint64 num_z, Type type,
120       bool in_place_fft, ScratchAllocator *scratch_allocator) = 0;
121 
122   // Creates a batched FFT plan.
123   //
124   // stream:          The GPU stream in which the FFT runs.
125   // rank:            Dimensionality of the transform (1, 2, or 3).
126   // elem_count:      Array of size rank, describing the size of each dimension.
127   // input_embed, output_embed:
128   //                  Pointer of size rank that indicates the storage dimensions
129   //                  of the input/output data in memory. If set to null_ptr all
130   //                  other advanced data layout parameters are ignored.
131   // input_stride:    Indicates the distance (number of elements; same below)
132   //                  between two successive input elements.
133   // input_distance:  Indicates the distance between the first element of two
134   //                  consecutive signals in a batch of the input data.
135   // output_stride:   Indicates the distance between two successive output
136   //                  elements.
137   // output_distance: Indicates the distance between the first element of two
138   //                  consecutive signals in a batch of the output data.
139   virtual std::unique_ptr<Plan> CreateBatchedPlan(
140       Stream *stream, int rank, uint64 *elem_count, uint64 *input_embed,
141       uint64 input_stride, uint64 input_distance, uint64 *output_embed,
142       uint64 output_stride, uint64 output_distance, Type type,
143       bool in_place_fft, int batch_count) = 0;
144 
145   // Creates a batched FFT plan with scratch allocator.
146   //
147   // stream:          The GPU stream in which the FFT runs.
148   // rank:            Dimensionality of the transform (1, 2, or 3).
149   // elem_count:      Array of size rank, describing the size of each dimension.
150   // input_embed, output_embed:
151   //                  Pointer of size rank that indicates the storage dimensions
152   //                  of the input/output data in memory. If set to null_ptr all
153   //                  other advanced data layout parameters are ignored.
154   // input_stride:    Indicates the distance (number of elements; same below)
155   //                  between two successive input elements.
156   // input_distance:  Indicates the distance between the first element of two
157   //                  consecutive signals in a batch of the input data.
158   // output_stride:   Indicates the distance between two successive output
159   //                  elements.
160   // output_distance: Indicates the distance between the first element of two
161   //                  consecutive signals in a batch of the output data.
162   virtual std::unique_ptr<Plan> CreateBatchedPlanWithScratchAllocator(
163       Stream *stream, int rank, uint64 *elem_count, uint64 *input_embed,
164       uint64 input_stride, uint64 input_distance, uint64 *output_embed,
165       uint64 output_stride, uint64 output_distance, Type type,
166       bool in_place_fft, int batch_count,
167       ScratchAllocator *scratch_allocator) = 0;
168 
169   // Updates the plan's work area with space allocated by a new scratch
170   // allocator. This facilitates plan reuse with scratch allocators.
171   //
172   // This requires that the plan was originally created using a scratch
173   // allocator, as otherwise scratch space will have been allocated internally
174   // by cuFFT.
175   virtual void UpdatePlanWithScratchAllocator(
176       Stream *stream, Plan *plan, ScratchAllocator *scratch_allocator) = 0;
177 
178   // Computes complex-to-complex FFT in the transform direction as specified
179   // by direction parameter.
180   virtual bool DoFft(Stream *stream, Plan *plan,
181                      const DeviceMemory<std::complex<float>> &input,
182                      DeviceMemory<std::complex<float>> *output) = 0;
183   virtual bool DoFft(Stream *stream, Plan *plan,
184                      const DeviceMemory<std::complex<double>> &input,
185                      DeviceMemory<std::complex<double>> *output) = 0;
186 
187   // Computes real-to-complex FFT in forward direction.
188   virtual bool DoFft(Stream *stream, Plan *plan,
189                      const DeviceMemory<float> &input,
190                      DeviceMemory<std::complex<float>> *output) = 0;
191   virtual bool DoFft(Stream *stream, Plan *plan,
192                      const DeviceMemory<double> &input,
193                      DeviceMemory<std::complex<double>> *output) = 0;
194 
195   // Computes complex-to-real FFT in inverse direction.
196   virtual bool DoFft(Stream *stream, Plan *plan,
197                      const DeviceMemory<std::complex<float>> &input,
198                      DeviceMemory<float> *output) = 0;
199   virtual bool DoFft(Stream *stream, Plan *plan,
200                      const DeviceMemory<std::complex<double>> &input,
201                      DeviceMemory<double> *output) = 0;
202 
203  protected:
FftSupport()204   FftSupport() {}
205 
206  private:
207   SE_DISALLOW_COPY_AND_ASSIGN(FftSupport);
208 };
209 
210 // Macro used to quickly declare overrides for abstract virtuals in the
211 // fft::FftSupport base class. Assumes that it's emitted somewhere inside the
212 // ::stream_executor namespace.
213 #define TENSORFLOW_STREAM_EXECUTOR_GPU_FFT_SUPPORT_OVERRIDES                   \
214   std::unique_ptr<fft::Plan> Create1dPlan(Stream *stream, uint64 num_x,        \
215                                           fft::Type type, bool in_place_fft)   \
216       override;                                                                \
217   std::unique_ptr<fft::Plan> Create2dPlan(Stream *stream, uint64 num_x,        \
218                                           uint64 num_y, fft::Type type,        \
219                                           bool in_place_fft) override;         \
220   std::unique_ptr<fft::Plan> Create3dPlan(                                     \
221       Stream *stream, uint64 num_x, uint64 num_y, uint64 num_z,                \
222       fft::Type type, bool in_place_fft) override;                             \
223   std::unique_ptr<fft::Plan> Create1dPlanWithScratchAllocator(                 \
224       Stream *stream, uint64 num_x, fft::Type type, bool in_place_fft,         \
225       ScratchAllocator *scratch_allocator) override;                           \
226   std::unique_ptr<fft::Plan> Create2dPlanWithScratchAllocator(                 \
227       Stream *stream, uint64 num_x, uint64 num_y, fft::Type type,              \
228       bool in_place_fft, ScratchAllocator *scratch_allocator) override;        \
229   std::unique_ptr<fft::Plan> Create3dPlanWithScratchAllocator(                 \
230       Stream *stream, uint64 num_x, uint64 num_y, uint64 num_z,                \
231       fft::Type type, bool in_place_fft, ScratchAllocator *scratch_allocator)  \
232       override;                                                                \
233   std::unique_ptr<fft::Plan> CreateBatchedPlan(                                \
234       Stream *stream, int rank, uint64 *elem_count, uint64 *input_embed,       \
235       uint64 input_stride, uint64 input_distance, uint64 *output_embed,        \
236       uint64 output_stride, uint64 output_distance, fft::Type type,            \
237       bool in_place_fft, int batch_count) override;                            \
238   std::unique_ptr<fft::Plan> CreateBatchedPlanWithScratchAllocator(            \
239       Stream *stream, int rank, uint64 *elem_count, uint64 *input_embed,       \
240       uint64 input_stride, uint64 input_distance, uint64 *output_embed,        \
241       uint64 output_stride, uint64 output_distance, fft::Type type,            \
242       bool in_place_fft, int batch_count, ScratchAllocator *scratch_allocator) \
243       override;                                                                \
244   void UpdatePlanWithScratchAllocator(Stream *stream, fft::Plan *plan,         \
245                                       ScratchAllocator *scratch_allocator)     \
246       override;                                                                \
247   bool DoFft(Stream *stream, fft::Plan *plan,                                  \
248              const DeviceMemory<std::complex<float>> &input,                   \
249              DeviceMemory<std::complex<float>> *output) override;              \
250   bool DoFft(Stream *stream, fft::Plan *plan,                                  \
251              const DeviceMemory<std::complex<double>> &input,                  \
252              DeviceMemory<std::complex<double>> *output) override;             \
253   bool DoFft(Stream *stream, fft::Plan *plan,                                  \
254              const DeviceMemory<float> &input,                                 \
255              DeviceMemory<std::complex<float>> *output) override;              \
256   bool DoFft(Stream *stream, fft::Plan *plan,                                  \
257              const DeviceMemory<double> &input,                                \
258              DeviceMemory<std::complex<double>> *output) override;             \
259   bool DoFft(Stream *stream, fft::Plan *plan,                                  \
260              const DeviceMemory<std::complex<float>> &input,                   \
261              DeviceMemory<float> *output) override;                            \
262   bool DoFft(Stream *stream, fft::Plan *plan,                                  \
263              const DeviceMemory<std::complex<double>> &input,                  \
264              DeviceMemory<double> *output) override;
265 
266 }  // namespace fft
267 }  // namespace stream_executor
268 
269 #endif  // TENSORFLOW_STREAM_EXECUTOR_FFT_H_
270