• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/gpu/fft_thunk.h"
17 
18 #include <string>
19 
20 #include "absl/strings/str_cat.h"
21 #include "absl/strings/str_format.h"
22 #include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
23 #include "tensorflow/compiler/xla/types.h"
24 #include "tensorflow/compiler/xla/util.h"
25 #include "tensorflow/core/platform/logging.h"
26 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
27 
28 namespace xla {
29 namespace gpu {
30 
FftScratchAllocator(int device_ordinal,se::DeviceMemoryAllocator * memory_allocator)31 FftScratchAllocator::FftScratchAllocator(
32     int device_ordinal, se::DeviceMemoryAllocator* memory_allocator)
33     : device_ordinal_(device_ordinal), memory_allocator_(memory_allocator) {}
34 
GetMemoryLimitInBytes()35 int64 FftScratchAllocator::GetMemoryLimitInBytes() {
36   constexpr int64 kFftScratchSize = 1LL << 32;  // 4GB by default.
37   return kFftScratchSize;
38 }
39 
AllocateBytes(int64 byte_size)40 StatusOr<se::DeviceMemory<uint8>> FftScratchAllocator::AllocateBytes(
41     int64 byte_size) {
42   CHECK_GE(byte_size, 0) << "byte_size must be positive.";
43   if (byte_size > GetMemoryLimitInBytes()) {
44     return se::port::Status(
45         se::port::error::RESOURCE_EXHAUSTED,
46         absl::StrFormat(
47             "Allocating %d bytes exceeds the memory limit of %d bytes.",
48             byte_size, GetMemoryLimitInBytes()));
49   }
50 
51   TF_ASSIGN_OR_RETURN(se::OwningDeviceMemory allocated_buffer,
52                       memory_allocator_->Allocate(device_ordinal_, byte_size,
53                                                   /*retry_on_failure=*/false));
54   total_allocated_bytes_ += byte_size;
55 
56   se::DeviceMemoryBase buffer_addr = *allocated_buffer;
57   allocated_buffers_.push_back(std::move(allocated_buffer));
58   return se::DeviceMemory<uint8>(buffer_addr);
59 }
60 
61 namespace {
62 
FftTypeToSeType(FftType type,bool double_precision)63 se::fft::Type FftTypeToSeType(FftType type, bool double_precision) {
64   switch (type) {
65     case FftType::FFT:
66       return double_precision ? se::fft::Type::kZ2ZForward
67                               : se::fft::Type::kC2CForward;
68     case FftType::IFFT:
69       return double_precision ? se::fft::Type::kZ2ZInverse
70                               : se::fft::Type::kC2CInverse;
71     case FftType::IRFFT:
72       return double_precision ? se::fft::Type::kZ2D : se::fft::Type::kC2R;
73     case FftType::RFFT:
74       return double_precision ? se::fft::Type::kD2Z : se::fft::Type::kR2C;
75     default:
76       LOG(FATAL) << "unsupported fft type";
77   }
78 }
79 
FftTypeToString(se::fft::Type type)80 string FftTypeToString(se::fft::Type type) {
81   switch (type) {
82     case se::fft::Type::kC2CForward:
83     case se::fft::Type::kZ2ZForward:
84       return "FFT";
85     case se::fft::Type::kC2CInverse:
86     case se::fft::Type::kZ2ZInverse:
87       return "IFFT";
88     case se::fft::Type::kC2R:
89     case se::fft::Type::kZ2D:
90       return "IRFFT";
91     case se::fft::Type::kR2C:
92     case se::fft::Type::kD2Z:
93       return "RFFT";
94     default:
95       LOG(FATAL) << "unknown fft type";
96   }
97 }
98 
99 }  // namespace
100 
FftThunk(ThunkInfo thunk_info,FftType fft_type,absl::Span<const int64> fft_length,const BufferAllocation::Slice & input_buffer,const BufferAllocation::Slice & output_buffer,const Shape & input_shape,const Shape & output_shape)101 FftThunk::FftThunk(ThunkInfo thunk_info, FftType fft_type,
102                    absl::Span<const int64> fft_length,
103                    const BufferAllocation::Slice& input_buffer,
104                    const BufferAllocation::Slice& output_buffer,
105                    const Shape& input_shape, const Shape& output_shape)
106     : Thunk(Kind::kFft, thunk_info),
107       fft_type_(
108           FftTypeToSeType(fft_type, input_shape.element_type() == F64 ||
109                                         input_shape.element_type() == C128)),
110       fft_length_(fft_length.begin(), fft_length.end()),
111       scale_factor_(1.0f),
112       input_buffer_(input_buffer),
113       output_buffer_(output_buffer),
114       input_shape_(input_shape),
115       output_shape_(output_shape) {}
116 
ExecuteOnStream(const ExecuteParams & params)117 Status FftThunk::ExecuteOnStream(const ExecuteParams& params) {
118   auto& stream = *params.stream;
119   auto& buffer_allocations = *params.buffer_allocations;
120 
121   VLOG(3) << "FFT type: " << FftTypeToString(fft_type_);
122   VLOG(3) << "Input shape: " << ShapeUtil::HumanStringWithLayout(input_shape_);
123   VLOG(3) << "Output shape: "
124           << ShapeUtil::HumanStringWithLayout(output_shape_);
125 
126   FftScratchAllocator scratch_allocator(buffer_allocations.device_ordinal(),
127                                         buffer_allocations.memory_allocator());
128 
129   auto op_profiler =
130       params.profiler->MakeScopedInstructionProfiler(profile_index());
131   FftPlan* fft_plan_ptr;
132   {
133     absl::MutexLock lock(&mu_);
134     std::unique_ptr<FftPlan>& plan =
135         fft_plans_[buffer_allocations.device_ordinal()];
136     if (!plan) {
137       plan = std::make_unique<FftPlan>();
138     }
139     fft_plan_ptr = plan.get();
140   }
141   // CuFFT thread-safety requires that separate host threads not share plans;
142   // protect each plan with a mutex.
143   absl::MutexLock lock(&fft_plan_ptr->mu);
144   std::unique_ptr<se::fft::Plan>& fft_plan = fft_plan_ptr->plan;
145   if (fft_plan == nullptr) {
146     const int64 fft_rank = fft_length_.size();
147     CHECK_LE(fft_rank, 3);
148     int batch_size = 1;
149     for (int i = 0; i < input_shape_.dimensions_size() - fft_rank; ++i) {
150       batch_size *= input_shape_.dimensions(i);
151     }
152     uint64 fft_length[3];
153     uint64 input_embed[3];
154     const uint64 input_stride = 1;
155     uint64 input_distance = 1;
156     uint64 output_embed[3];
157     const uint64 output_stride = 1;
158     uint64 output_distance = 1;
159 
160     for (int i = 0; i < fft_rank; ++i) {
161       auto dim_offset = input_shape_.dimensions_size() - fft_rank + i;
162       fft_length[i] = static_cast<uint64>(fft_length_[i]);
163       input_embed[i] = input_shape_.dimensions(dim_offset);
164       input_distance *= input_shape_.dimensions(dim_offset);
165       output_embed[i] = output_shape_.dimensions(dim_offset);
166       output_distance *= output_shape_.dimensions(dim_offset);
167     }
168 
169     constexpr bool kInPlaceFft = false;
170     fft_plan = stream.parent()->AsFft()->CreateBatchedPlanWithScratchAllocator(
171         &stream, fft_rank, fft_length, input_embed, input_stride,
172         input_distance, output_embed, output_stride, output_distance, fft_type_,
173         kInPlaceFft, batch_size, &scratch_allocator);
174     scale_factor_ = 1.0f / output_distance;
175   } else {
176     stream.parent()->AsFft()->UpdatePlanWithScratchAllocator(
177         &stream, fft_plan.get(), &scratch_allocator);
178   }
179 
180   bool launch_ok;
181   switch (fft_type_) {
182     case se::fft::Type::kC2CForward: {
183       se::DeviceMemory<complex64> input_data(
184           buffer_allocations.GetDeviceAddress(input_buffer_));
185       se::DeviceMemory<complex64> output_data(
186           buffer_allocations.GetDeviceAddress(output_buffer_));
187       launch_ok = stream.ThenFft(fft_plan.get(), input_data, &output_data).ok();
188       break;
189     }
190     case se::fft::Type::kZ2ZForward: {
191       se::DeviceMemory<complex128> input_data(
192           buffer_allocations.GetDeviceAddress(input_buffer_));
193       se::DeviceMemory<complex128> output_data(
194           buffer_allocations.GetDeviceAddress(output_buffer_));
195       launch_ok = stream.ThenFft(fft_plan.get(), input_data, &output_data).ok();
196       break;
197     }
198     case se::fft::Type::kC2CInverse: {
199       se::DeviceMemory<complex64> input_data(
200           buffer_allocations.GetDeviceAddress(input_buffer_));
201       se::DeviceMemory<complex64> output_data(
202           buffer_allocations.GetDeviceAddress(output_buffer_));
203       launch_ok = stream.ThenFft(fft_plan.get(), input_data, &output_data).ok();
204       if (launch_ok) {
205         launch_ok = stream
206                         .ThenBlasScal(ShapeUtil::ElementsIn(output_shape_),
207                                       complex64(scale_factor_), &output_data, 1)
208                         .ok();
209       }
210       break;
211     }
212     case se::fft::Type::kZ2ZInverse: {
213       se::DeviceMemory<complex128> input_data(
214           buffer_allocations.GetDeviceAddress(input_buffer_));
215       se::DeviceMemory<complex128> output_data(
216           buffer_allocations.GetDeviceAddress(output_buffer_));
217       launch_ok = stream.ThenFft(fft_plan.get(), input_data, &output_data).ok();
218       if (launch_ok) {
219         launch_ok =
220             stream
221                 .ThenBlasScal(ShapeUtil::ElementsIn(output_shape_),
222                               complex128(scale_factor_), &output_data, 1)
223                 .ok();
224       }
225       break;
226     }
227     case se::fft::Type::kR2C: {
228       se::DeviceMemory<float> input_data(
229           buffer_allocations.GetDeviceAddress(input_buffer_));
230       se::DeviceMemory<complex64> output_data(
231           buffer_allocations.GetDeviceAddress(output_buffer_));
232       launch_ok = stream.ThenFft(fft_plan.get(), input_data, &output_data).ok();
233       break;
234     }
235     case se::fft::Type::kD2Z: {
236       se::DeviceMemory<double> input_data(
237           buffer_allocations.GetDeviceAddress(input_buffer_));
238       se::DeviceMemory<complex128> output_data(
239           buffer_allocations.GetDeviceAddress(output_buffer_));
240       launch_ok = stream.ThenFft(fft_plan.get(), input_data, &output_data).ok();
241       break;
242     }
243     case se::fft::Type::kC2R: {
244       se::DeviceMemory<complex64> input_data(
245           buffer_allocations.GetDeviceAddress(input_buffer_));
246       se::DeviceMemory<float> output_data(
247           buffer_allocations.GetDeviceAddress(output_buffer_));
248       launch_ok = stream.ThenFft(fft_plan.get(), input_data, &output_data).ok();
249       if (launch_ok) {
250         launch_ok = stream
251                         .ThenBlasScal(ShapeUtil::ElementsIn(output_shape_),
252                                       scale_factor_, &output_data, 1)
253                         .ok();
254       }
255       break;
256     }
257     case se::fft::Type::kZ2D: {
258       se::DeviceMemory<complex128> input_data(
259           buffer_allocations.GetDeviceAddress(input_buffer_));
260       se::DeviceMemory<double> output_data(
261           buffer_allocations.GetDeviceAddress(output_buffer_));
262       launch_ok = stream.ThenFft(fft_plan.get(), input_data, &output_data).ok();
263       if (launch_ok) {
264         launch_ok = stream
265                         .ThenBlasScal(ShapeUtil::ElementsIn(output_shape_),
266                                       scale_factor_, &output_data, 1)
267                         .ok();
268       }
269       break;
270     }
271     default:
272       LOG(FATAL) << "unsupported fft type";
273   }
274   if (launch_ok) {
275     return Status::OK();
276   }
277   return InternalError("Unable to launch fft for thunk %p with type %s", this,
278                        FftTypeToString(fft_type_));
279 }
280 
281 }  // namespace gpu
282 }  // namespace xla
283