1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/gpu/gemm_algorithm_picker.h"
17
18 #include <limits>
19
20 #include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h"
21 #include "tensorflow/compiler/xla/service/gpu/buffer_comparator.h"
22 #include "tensorflow/compiler/xla/service/gpu/gemm_thunk.h"
23 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
24 #include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h"
25 #include "tensorflow/compiler/xla/service/hlo_computation.h"
26 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
27 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
28 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
29 #include "tensorflow/compiler/xla/util.h"
30 #include "tensorflow/core/lib/core/errors.h"
31 #include "tensorflow/core/platform/logger.h"
32 #include "tensorflow/core/protobuf/autotuning.pb.h"
33 #include "tensorflow/core/util/proto/proto_utils.h"
34 #include "tensorflow/stream_executor/blas.h"
35 #include "tensorflow/stream_executor/device_memory.h"
36 #include "tensorflow/stream_executor/device_memory_allocator.h"
37 #include "tensorflow/stream_executor/gpu/redzone_allocator.h"
38
39 namespace xla {
40 namespace gpu {
41
42 using tensorflow::AutotuneResult;
43
44 using GemmCacheKey =
45 std::tuple<se::StreamExecutor*, Shape, Shape, Shape, std::string>;
46
47 static tensorflow::mutex autotune_cache_mu(tensorflow::LINKER_INITIALIZED);
48 static auto& autotune_cache TF_GUARDED_BY(autotune_cache_mu) =
49 *new absl::flat_hash_map<GemmCacheKey,
50 absl::optional<se::blas::AlgorithmType>>();
51 static int64 cache_hits TF_GUARDED_BY(autotune_cache_mu) = 0;
52 static int64 cache_misses TF_GUARDED_BY(autotune_cache_mu) = 0;
53
54 // Experimentally tries to pick the best algorithm for the given gemm.
55 //
56 // This may fail under perfectly normal circumstances. In particular, it will
57 // fail if the program was built with < CUDA 8 or if we're using a gpu older
58 // than sm_50 -- in both cases, cublas doesn't support gemm-with-algorithm at
59 // all.
DoUncachedGemmAutotune(const HloInstruction * gemm,se::Stream * stream,se::DeviceMemoryAllocator * allocator)60 static StatusOr<absl::optional<se::blas::AlgorithmType>> DoUncachedGemmAutotune(
61 const HloInstruction* gemm, se::Stream* stream,
62 se::DeviceMemoryAllocator* allocator) {
63 if (!stream->parent()->SynchronizeAllActivity()) {
64 return InternalError("Failed to synchronize GPU for autotuning.");
65 }
66
67 const HloModuleConfig& hlo_module_config = gemm->GetModule()->config();
68 const bool init_cublas_data =
69 hlo_module_config.debug_options().xla_gpu_autotune_level() > 1;
70 se::RedzoneAllocator input_output_allocator(
71 stream, allocator, PtxOptsFromConfig(hlo_module_config),
72 /*memory_limit=*/std::numeric_limits<int64>::max());
73
74 BufferComparator comparator(gemm->shape(), hlo_module_config);
75
76 int64 rng_state = 0;
77 auto get_initialized_buffer =
78 [&](const HloInstruction* op) -> StatusOr<se::DeviceMemoryBase> {
79 TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase buffer,
80 input_output_allocator.AllocateBytes(
81 ShapeUtil::ByteSizeOf(op->shape())));
82 if (init_cublas_data) {
83 InitializeBuffer(stream, op->shape().element_type(), &rng_state, buffer);
84 }
85 return buffer;
86 };
87
88 TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase lhs_buffer,
89 get_initialized_buffer(gemm->operand(0)));
90 TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase rhs_buffer,
91 get_initialized_buffer(gemm->operand(1)));
92 TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase output_buffer,
93 get_initialized_buffer(gemm));
94 TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase reference_result_buffer,
95 get_initialized_buffer(gemm));
96
97 const DebugOptions& debug_options =
98 gemm->GetModule()->config().debug_options();
99
100 const bool crash_on_checking_failure =
101 debug_options.xla_gpu_crash_on_verification_failures();
102
103 GemmBackendConfig backend_config =
104 gemm->backend_config<GemmBackendConfig>().ValueOrDie();
105 const int32 cublas_autotune_level =
106 gemm->GetModule()->config().debug_options().xla_gpu_autotune_level();
107 const bool reinit_cublas_data = cublas_autotune_level > 2;
108 const bool check_cublas = cublas_autotune_level > 3;
109
110 VLOG(3) << "Starting autotune of GemmThunk " << gemm->ToString();
111
112 std::vector<se::blas::AlgorithmType> algorithms;
113 CHECK(stream->parent()->GetBlasGemmAlgorithms(&algorithms));
114
115 absl::optional<se::blas::AlgorithmType> first_algorithm;
116 std::vector<AutotuneResult> profile_results;
117
118 GpuGemmConfig config = GetGpuGemmConfig(gemm);
119
120 for (se::blas::AlgorithmType algorithm : algorithms) {
121 // Make sure the output buffer always has the same value if we use
122 // the bias parameter.
123 if (reinit_cublas_data && backend_config.beta() != 0) {
124 int64 rng_state = 0;
125 InitializeBuffer(stream, gemm->shape().element_type(), &rng_state,
126 output_buffer);
127 }
128 se::blas::ProfileResult profile_result;
129
130 // We expect GemmWithAlgorithm to fail sometimes -- in fact, it will fail
131 // for all algorithms if we're targeting < sm_50. But because we pass a
132 // non-null ProfileResult, DoGemmWithAlgorithm should always return true,
133 // and the actual success-ness is returned in ProfileResult::is_valid.
134 CHECK(RunGemm(config, lhs_buffer, rhs_buffer, output_buffer, stream,
135 /*implements_whole_instruction=*/true,
136 /*profile_index=*/-1,
137 /*profiler=*/nullptr,
138 /*profile_result=*/&profile_result, algorithm)
139 .ok());
140
141 if (!profile_result.is_valid()) {
142 // Unsupported algorithm.
143 continue;
144 }
145
146 profile_results.emplace_back();
147 AutotuneResult& result = profile_results.back();
148 result.mutable_gemm()->set_algorithm(algorithm);
149
150 VLOG(2) << "cublas gemm algorithm " << algorithm << " took "
151 << profile_result.elapsed_time_in_ms() << "ms" << std::endl;
152
153 *result.mutable_run_time() = tensorflow::proto_utils::ToDurationProto(
154 absl::Milliseconds(profile_result.elapsed_time_in_ms()));
155
156 if (!check_cublas) {
157 continue;
158 }
159
160 TF_ASSIGN_OR_RETURN(
161 se::RedzoneAllocator::RedzoneCheckStatus rz_check_status,
162 input_output_allocator.CheckRedzones());
163 if (!rz_check_status.ok()) {
164 result.mutable_failure()->set_kind(AutotuneResult::REDZONE_MODIFIED);
165 *result.mutable_failure()->mutable_msg() =
166 rz_check_status.RedzoneFailureMsg();
167 LOG(ERROR) << "Detected cuBLAS out-of-bounds write in gemm buffer";
168 CHECK(!crash_on_checking_failure);
169 continue;
170 }
171
172 if (!first_algorithm) {
173 // First run: set the reference result buffer.
174 CHECK(reference_result_buffer.size() == output_buffer.size());
175 stream->ThenMemcpy(&reference_result_buffer, output_buffer,
176 output_buffer.size());
177 first_algorithm.emplace(algorithm);
178 } else {
179 // Perform the comparison.
180 TF_ASSIGN_OR_RETURN(bool compare_result,
181 comparator.CompareEqual(stream, output_buffer,
182 reference_result_buffer));
183 if (!compare_result) {
184 LOG(ERROR) << "Results mismatch between different GEMM algorithms. "
185 << "This is likely a bug/unexpected loss of precision "
186 << "in cuBLAS.";
187 CHECK(!crash_on_checking_failure);
188
189 result.mutable_failure()->set_kind(AutotuneResult::WRONG_RESULT);
190 result.mutable_failure()->mutable_reference_gemm()->set_algorithm(
191 *first_algorithm);
192 }
193 }
194 }
195
196 tensorflow::AutotuningLog log;
197 for (const AutotuneResult& profile : profile_results) {
198 *log.add_results() = profile;
199 }
200 if (!crash_on_checking_failure) {
201 tensorflow::Logger::GetSingleton()->LogProto(log);
202 }
203
204 // Choose fastest correct GEMM, but allow for incorrect results (since the
205 // reference result is chosen arbitrary).
206 auto has_failure = [](const AutotuneResult& r) {
207 return r.has_failure() &&
208 r.failure().kind() != AutotuneResult::WRONG_RESULT;
209 };
210
211 auto result_comparison_key = [&has_failure](const AutotuneResult& r) {
212 return std::make_tuple(
213 has_failure(r),
214 tensorflow::proto_utils::FromDurationProto(r.run_time()));
215 };
216 const auto& best_result = absl::c_min_element(
217 profile_results,
218 [&](const AutotuneResult& lhs, const AutotuneResult& rhs) {
219 return result_comparison_key(lhs) < result_comparison_key(rhs);
220 });
221
222 if (best_result != profile_results.end() && !has_failure(*best_result)) {
223 return {best_result->gemm().algorithm()};
224 }
225
226 VLOG(1) << "Unable to autotune cuBLAS gemm on stream " << stream
227 << " none of the " << algorithms.size() << " ran successfully";
228 return {absl::nullopt};
229 }
230
DoGemmAutotune(const HloInstruction * instr,const GemmBackendConfig & gemm_config,se::DeviceMemoryAllocator * allocator,se::Stream * stream)231 static StatusOr<absl::optional<se::blas::AlgorithmType>> DoGemmAutotune(
232 const HloInstruction* instr, const GemmBackendConfig& gemm_config,
233 se::DeviceMemoryAllocator* allocator, se::Stream* stream) {
234 const HloInstruction* lhs = instr->operand(0);
235 const HloInstruction* rhs = instr->operand(1);
236
237 // Don't run autotuning concurrently on the same GPU.
238 tensorflow::mutex_lock gpu_lock = LockGpu(stream->parent());
239
240 GemmCacheKey key =
241 std::make_tuple(stream->parent(), lhs->shape(), rhs->shape(),
242 instr->shape(), gemm_config.SerializeAsString());
243
244 tensorflow::mutex_lock cache_lock(autotune_cache_mu);
245 auto it = autotune_cache.find(key);
246 int64 autotuning_requests = cache_hits + cache_misses;
247 if (autotuning_requests && autotuning_requests % 10 == 0) {
248 VLOG(2) << "Autotuning cache hits/(hits + misses): " << cache_hits << "/"
249 << autotuning_requests;
250 }
251
252 if (it != autotune_cache.end()) {
253 cache_hits++;
254 VLOG(4) << "Autotuning cache hit, using algorithm: "
255 << (it->second.has_value() ? absl::StrCat(*(it->second))
256 : "<generic>");
257 return it->second;
258 }
259 cache_misses++;
260 VLOG(4) << "Autotuning cache miss";
261
262 int64 batch_size = gemm_config.batch_size();
263 absl::optional<se::blas::AlgorithmType> result;
264 if (batch_size != 1) {
265 // TODO(b/112111608): Implement auto tune for batched gemm.
266 VLOG(2) << "Batch size is non-singular, using generic algorithm";
267 result = absl::nullopt;
268 } else {
269 TF_ASSIGN_OR_RETURN(result,
270 DoUncachedGemmAutotune(instr, stream, allocator));
271 }
272
273 CHECK(autotune_cache.emplace(key, result).second);
274 return result;
275 }
276
RunOnInstruction(HloInstruction * instr,se::StreamExecutor * executor,se::DeviceMemoryAllocator * allocator)277 static StatusOr<bool> RunOnInstruction(HloInstruction* instr,
278 se::StreamExecutor* executor,
279 se::DeviceMemoryAllocator* allocator) {
280 if (allocator == nullptr) {
281 allocator = executor->GetAllocator();
282 }
283 TF_ASSIGN_OR_RETURN(se::Stream* const stream,
284 allocator->GetStream(executor->device_ordinal()));
285
286 GemmBackendConfig gemm_config =
287 instr->backend_config<GemmBackendConfig>().ValueOrDie();
288
289 TF_ASSIGN_OR_RETURN(absl::optional<se::blas::AlgorithmType> gemm_algorithm,
290 DoGemmAutotune(instr, gemm_config, allocator, stream));
291
292 // We update instruction->backend_config(); if no algorithms are supported,
293 // a different API is used, which does not require specifying an algorithm.
294 GemmBackendConfig updated_config = gemm_config;
295 if (gemm_algorithm) {
296 updated_config.set_selected_algorithm(*gemm_algorithm);
297 }
298 TF_RETURN_IF_ERROR(instr->set_backend_config(updated_config));
299 return updated_config.SerializeAsString() != gemm_config.SerializeAsString();
300 }
301
RunOnComputation(HloComputation * computation,se::StreamExecutor * se,se::DeviceMemoryAllocator * allocator)302 static StatusOr<bool> RunOnComputation(HloComputation* computation,
303 se::StreamExecutor* se,
304 se::DeviceMemoryAllocator* allocator) {
305 bool changed = false;
306 for (HloInstruction* instr : computation->instructions()) {
307 if (IsCublasGemm(*instr)) {
308 TF_ASSIGN_OR_RETURN(bool result, RunOnInstruction(instr, se, allocator));
309 changed |= result;
310 }
311 }
312 return changed;
313 }
314
Run(HloModule * module)315 StatusOr<bool> GemmAlgorithmPicker::Run(HloModule* module) {
316 XLA_SCOPED_LOGGING_TIMER("GemmAlgorithmPicker");
317
318 if (module->config().debug_options().xla_gpu_autotune_level() == 0) {
319 VLOG(2) << "GEMM auto-tuning disabled, GemmAlgorithmPicker returning early";
320 return false;
321 }
322
323 bool changed = false;
324 for (HloComputation* computation : module->MakeNonfusionComputations()) {
325 TF_ASSIGN_OR_RETURN(
326 bool result, RunOnComputation(computation, stream_exec_, allocator_));
327 changed |= result;
328 }
329 return changed;
330 }
331
332 } // namespace gpu
333 } // namespace xla
334