1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/gpu/fft_thunk.h"
17
18 #include <string>
19
20 #include "absl/strings/str_cat.h"
21 #include "absl/strings/str_format.h"
22 #include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
23 #include "tensorflow/compiler/xla/types.h"
24 #include "tensorflow/compiler/xla/util.h"
25 #include "tensorflow/core/platform/logging.h"
26 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
27
28 namespace xla {
29 namespace gpu {
30
FftScratchAllocator(int device_ordinal,se::DeviceMemoryAllocator * memory_allocator)31 FftScratchAllocator::FftScratchAllocator(
32 int device_ordinal, se::DeviceMemoryAllocator* memory_allocator)
33 : device_ordinal_(device_ordinal), memory_allocator_(memory_allocator) {}
34
GetMemoryLimitInBytes()35 int64 FftScratchAllocator::GetMemoryLimitInBytes() {
36 constexpr int64 kFftScratchSize = 1LL << 32; // 4GB by default.
37 return kFftScratchSize;
38 }
39
AllocateBytes(int64 byte_size)40 StatusOr<se::DeviceMemory<uint8>> FftScratchAllocator::AllocateBytes(
41 int64 byte_size) {
42 CHECK_GE(byte_size, 0) << "byte_size must be positive.";
43 if (byte_size > GetMemoryLimitInBytes()) {
44 return se::port::Status(
45 se::port::error::RESOURCE_EXHAUSTED,
46 absl::StrFormat(
47 "Allocating %d bytes exceeds the memory limit of %d bytes.",
48 byte_size, GetMemoryLimitInBytes()));
49 }
50
51 TF_ASSIGN_OR_RETURN(se::OwningDeviceMemory allocated_buffer,
52 memory_allocator_->Allocate(device_ordinal_, byte_size,
53 /*retry_on_failure=*/false));
54 total_allocated_bytes_ += byte_size;
55
56 se::DeviceMemoryBase buffer_addr = *allocated_buffer;
57 allocated_buffers_.push_back(std::move(allocated_buffer));
58 return se::DeviceMemory<uint8>(buffer_addr);
59 }
60
61 namespace {
62
FftTypeToSeType(FftType type,bool double_precision)63 se::fft::Type FftTypeToSeType(FftType type, bool double_precision) {
64 switch (type) {
65 case FftType::FFT:
66 return double_precision ? se::fft::Type::kZ2ZForward
67 : se::fft::Type::kC2CForward;
68 case FftType::IFFT:
69 return double_precision ? se::fft::Type::kZ2ZInverse
70 : se::fft::Type::kC2CInverse;
71 case FftType::IRFFT:
72 return double_precision ? se::fft::Type::kZ2D : se::fft::Type::kC2R;
73 case FftType::RFFT:
74 return double_precision ? se::fft::Type::kD2Z : se::fft::Type::kR2C;
75 default:
76 LOG(FATAL) << "unsupported fft type";
77 }
78 }
79
FftTypeToString(se::fft::Type type)80 string FftTypeToString(se::fft::Type type) {
81 switch (type) {
82 case se::fft::Type::kC2CForward:
83 case se::fft::Type::kZ2ZForward:
84 return "FFT";
85 case se::fft::Type::kC2CInverse:
86 case se::fft::Type::kZ2ZInverse:
87 return "IFFT";
88 case se::fft::Type::kC2R:
89 case se::fft::Type::kZ2D:
90 return "IRFFT";
91 case se::fft::Type::kR2C:
92 case se::fft::Type::kD2Z:
93 return "RFFT";
94 default:
95 LOG(FATAL) << "unknown fft type";
96 }
97 }
98
99 } // namespace
100
FftThunk(ThunkInfo thunk_info,FftType fft_type,absl::Span<const int64> fft_length,const BufferAllocation::Slice & input_buffer,const BufferAllocation::Slice & output_buffer,const Shape & input_shape,const Shape & output_shape)101 FftThunk::FftThunk(ThunkInfo thunk_info, FftType fft_type,
102 absl::Span<const int64> fft_length,
103 const BufferAllocation::Slice& input_buffer,
104 const BufferAllocation::Slice& output_buffer,
105 const Shape& input_shape, const Shape& output_shape)
106 : Thunk(Kind::kFft, thunk_info),
107 fft_type_(
108 FftTypeToSeType(fft_type, input_shape.element_type() == F64 ||
109 input_shape.element_type() == C128)),
110 fft_length_(fft_length.begin(), fft_length.end()),
111 scale_factor_(1.0f),
112 input_buffer_(input_buffer),
113 output_buffer_(output_buffer),
114 input_shape_(input_shape),
115 output_shape_(output_shape) {}
116
ExecuteOnStream(const ExecuteParams & params)117 Status FftThunk::ExecuteOnStream(const ExecuteParams& params) {
118 auto& stream = *params.stream;
119 auto& buffer_allocations = *params.buffer_allocations;
120
121 VLOG(3) << "FFT type: " << FftTypeToString(fft_type_);
122 VLOG(3) << "Input shape: " << ShapeUtil::HumanStringWithLayout(input_shape_);
123 VLOG(3) << "Output shape: "
124 << ShapeUtil::HumanStringWithLayout(output_shape_);
125
126 FftScratchAllocator scratch_allocator(buffer_allocations.device_ordinal(),
127 buffer_allocations.memory_allocator());
128
129 auto op_profiler =
130 params.profiler->MakeScopedInstructionProfiler(profile_index());
131 FftPlan* fft_plan_ptr;
132 {
133 absl::MutexLock lock(&mu_);
134 std::unique_ptr<FftPlan>& plan =
135 fft_plans_[buffer_allocations.device_ordinal()];
136 if (!plan) {
137 plan = std::make_unique<FftPlan>();
138 }
139 fft_plan_ptr = plan.get();
140 }
141 // CuFFT thread-safety requires that separate host threads not share plans;
142 // protect each plan with a mutex.
143 absl::MutexLock lock(&fft_plan_ptr->mu);
144 std::unique_ptr<se::fft::Plan>& fft_plan = fft_plan_ptr->plan;
145 if (fft_plan == nullptr) {
146 const int64 fft_rank = fft_length_.size();
147 CHECK_LE(fft_rank, 3);
148 int batch_size = 1;
149 for (int i = 0; i < input_shape_.dimensions_size() - fft_rank; ++i) {
150 batch_size *= input_shape_.dimensions(i);
151 }
152 uint64 fft_length[3];
153 uint64 input_embed[3];
154 const uint64 input_stride = 1;
155 uint64 input_distance = 1;
156 uint64 output_embed[3];
157 const uint64 output_stride = 1;
158 uint64 output_distance = 1;
159
160 for (int i = 0; i < fft_rank; ++i) {
161 auto dim_offset = input_shape_.dimensions_size() - fft_rank + i;
162 fft_length[i] = static_cast<uint64>(fft_length_[i]);
163 input_embed[i] = input_shape_.dimensions(dim_offset);
164 input_distance *= input_shape_.dimensions(dim_offset);
165 output_embed[i] = output_shape_.dimensions(dim_offset);
166 output_distance *= output_shape_.dimensions(dim_offset);
167 }
168
169 constexpr bool kInPlaceFft = false;
170 fft_plan = stream.parent()->AsFft()->CreateBatchedPlanWithScratchAllocator(
171 &stream, fft_rank, fft_length, input_embed, input_stride,
172 input_distance, output_embed, output_stride, output_distance, fft_type_,
173 kInPlaceFft, batch_size, &scratch_allocator);
174 scale_factor_ = 1.0f / output_distance;
175 } else {
176 stream.parent()->AsFft()->UpdatePlanWithScratchAllocator(
177 &stream, fft_plan.get(), &scratch_allocator);
178 }
179
180 bool launch_ok;
181 switch (fft_type_) {
182 case se::fft::Type::kC2CForward: {
183 se::DeviceMemory<complex64> input_data(
184 buffer_allocations.GetDeviceAddress(input_buffer_));
185 se::DeviceMemory<complex64> output_data(
186 buffer_allocations.GetDeviceAddress(output_buffer_));
187 launch_ok = stream.ThenFft(fft_plan.get(), input_data, &output_data).ok();
188 break;
189 }
190 case se::fft::Type::kZ2ZForward: {
191 se::DeviceMemory<complex128> input_data(
192 buffer_allocations.GetDeviceAddress(input_buffer_));
193 se::DeviceMemory<complex128> output_data(
194 buffer_allocations.GetDeviceAddress(output_buffer_));
195 launch_ok = stream.ThenFft(fft_plan.get(), input_data, &output_data).ok();
196 break;
197 }
198 case se::fft::Type::kC2CInverse: {
199 se::DeviceMemory<complex64> input_data(
200 buffer_allocations.GetDeviceAddress(input_buffer_));
201 se::DeviceMemory<complex64> output_data(
202 buffer_allocations.GetDeviceAddress(output_buffer_));
203 launch_ok = stream.ThenFft(fft_plan.get(), input_data, &output_data).ok();
204 if (launch_ok) {
205 launch_ok = stream
206 .ThenBlasScal(ShapeUtil::ElementsIn(output_shape_),
207 complex64(scale_factor_), &output_data, 1)
208 .ok();
209 }
210 break;
211 }
212 case se::fft::Type::kZ2ZInverse: {
213 se::DeviceMemory<complex128> input_data(
214 buffer_allocations.GetDeviceAddress(input_buffer_));
215 se::DeviceMemory<complex128> output_data(
216 buffer_allocations.GetDeviceAddress(output_buffer_));
217 launch_ok = stream.ThenFft(fft_plan.get(), input_data, &output_data).ok();
218 if (launch_ok) {
219 launch_ok =
220 stream
221 .ThenBlasScal(ShapeUtil::ElementsIn(output_shape_),
222 complex128(scale_factor_), &output_data, 1)
223 .ok();
224 }
225 break;
226 }
227 case se::fft::Type::kR2C: {
228 se::DeviceMemory<float> input_data(
229 buffer_allocations.GetDeviceAddress(input_buffer_));
230 se::DeviceMemory<complex64> output_data(
231 buffer_allocations.GetDeviceAddress(output_buffer_));
232 launch_ok = stream.ThenFft(fft_plan.get(), input_data, &output_data).ok();
233 break;
234 }
235 case se::fft::Type::kD2Z: {
236 se::DeviceMemory<double> input_data(
237 buffer_allocations.GetDeviceAddress(input_buffer_));
238 se::DeviceMemory<complex128> output_data(
239 buffer_allocations.GetDeviceAddress(output_buffer_));
240 launch_ok = stream.ThenFft(fft_plan.get(), input_data, &output_data).ok();
241 break;
242 }
243 case se::fft::Type::kC2R: {
244 se::DeviceMemory<complex64> input_data(
245 buffer_allocations.GetDeviceAddress(input_buffer_));
246 se::DeviceMemory<float> output_data(
247 buffer_allocations.GetDeviceAddress(output_buffer_));
248 launch_ok = stream.ThenFft(fft_plan.get(), input_data, &output_data).ok();
249 if (launch_ok) {
250 launch_ok = stream
251 .ThenBlasScal(ShapeUtil::ElementsIn(output_shape_),
252 scale_factor_, &output_data, 1)
253 .ok();
254 }
255 break;
256 }
257 case se::fft::Type::kZ2D: {
258 se::DeviceMemory<complex128> input_data(
259 buffer_allocations.GetDeviceAddress(input_buffer_));
260 se::DeviceMemory<double> output_data(
261 buffer_allocations.GetDeviceAddress(output_buffer_));
262 launch_ok = stream.ThenFft(fft_plan.get(), input_data, &output_data).ok();
263 if (launch_ok) {
264 launch_ok = stream
265 .ThenBlasScal(ShapeUtil::ElementsIn(output_shape_),
266 scale_factor_, &output_data, 1)
267 .ok();
268 }
269 break;
270 }
271 default:
272 LOG(FATAL) << "unsupported fft type";
273 }
274 if (launch_ok) {
275 return Status::OK();
276 }
277 return InternalError("Unable to launch fft for thunk %p with type %s", this,
278 FftTypeToString(fft_type_));
279 }
280
281 } // namespace gpu
282 } // namespace xla
283