1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.h"
17
18 #include "absl/algorithm/container.h"
19 #include "absl/strings/str_cat.h"
20 #include "absl/strings/str_format.h"
21 #include "absl/time/time.h"
22 #include "absl/types/optional.h"
23 #include "tensorflow/compiler/xla/literal_util.h"
24 #include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h"
25 #include "tensorflow/compiler/xla/service/gpu/convolution_thunk.h"
26 #include "tensorflow/compiler/xla/service/gpu/gpu_autotuning.pb.h"
27 #include "tensorflow/compiler/xla/service/gpu/hlo_algorithm_denylist.h"
28 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
29 #include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h"
30 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
31 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
32 #include "tensorflow/compiler/xla/status_macros.h"
33 #include "tensorflow/compiler/xla/util.h"
34 #include "tensorflow/core/lib/strings/numbers.h"
35 #include "tensorflow/core/platform/logger.h"
36 #include "tensorflow/core/platform/mutex.h"
37 #include "tensorflow/core/util/env_var.h"
38 #include "tensorflow/core/util/proto/proto_utils.h"
39
40 #if (defined(GOOGLE_CUDA) && GOOGLE_CUDA)
41 #include "tensorflow/compiler/xla/service/gpu/buffer_comparator.h"
42 #include "tensorflow/stream_executor/gpu/redzone_allocator.h"
43 #endif
44
45 namespace xla {
46 namespace gpu {
47 namespace {
48
49 using absl::optional;
50 using se::DeviceMemoryBase;
51 using se::dnn::AlgorithmDesc;
52 using tensorflow::AutotuneResult;
53
54 class ScratchAllocator : public se::ScratchAllocator {
55 public:
ScratchAllocator(int device_ordinal,se::DeviceMemoryAllocator * memory_allocator)56 ScratchAllocator(int device_ordinal,
57 se::DeviceMemoryAllocator* memory_allocator)
58 : device_ordinal_(device_ordinal), memory_allocator_(memory_allocator) {}
59
GetMemoryLimitInBytes()60 int64 GetMemoryLimitInBytes() override {
61 return 1LL << 32; // 4GB. TODO(jlebar): Tune this?
62 }
TotalAllocatedBytes()63 int64 TotalAllocatedBytes() { return total_allocated_bytes_; }
64
65 StatusOr<se::DeviceMemory<uint8>> AllocateBytes(int64 byte_size) override;
66
67 template <typename T>
Allocate(int64 num_elements)68 StatusOr<se::DeviceMemory<T>> Allocate(int64 num_elements) {
69 TF_ASSIGN_OR_RETURN(se::DeviceMemory<uint8> bytes,
70 AllocateBytes(num_elements * sizeof(T)));
71 return se::DeviceMemory<T>(bytes);
72 }
73
74 private:
75 const int device_ordinal_;
76 se::DeviceMemoryAllocator* memory_allocator_;
77 std::vector<se::OwningDeviceMemory> allocated_buffers_;
78 int64 total_allocated_bytes_ = 0;
79 };
80
AllocateBytes(int64 byte_size)81 StatusOr<se::DeviceMemory<uint8>> ScratchAllocator::AllocateBytes(
82 int64 byte_size) {
83 CHECK_GE(byte_size, 0) << "byte_size must be positive.";
84 if (byte_size > GetMemoryLimitInBytes()) {
85 return se::port::Status(
86 se::port::error::RESOURCE_EXHAUSTED,
87 absl::StrFormat(
88 "Allocating %d bytes exceeds the memory limit of %d bytes.",
89 byte_size, GetMemoryLimitInBytes()));
90 }
91
92 TF_ASSIGN_OR_RETURN(se::OwningDeviceMemory allocated_buffer,
93 memory_allocator_->Allocate(device_ordinal_, byte_size,
94 /*retry_on_failure=*/false));
95 total_allocated_bytes_ += byte_size;
96
97 se::DeviceMemoryBase buffer_addr = *allocated_buffer;
98 allocated_buffers_.push_back(std::move(allocated_buffer));
99 return se::DeviceMemory<uint8>(buffer_addr);
100 }
101
GetAlgorithms(CudnnConvKind kind,se::StreamExecutor * stream_exec)102 std::vector<AlgorithmDesc> GetAlgorithms(CudnnConvKind kind,
103 se::StreamExecutor* stream_exec) {
104 std::vector<AlgorithmDesc> algorithms;
105 bool succ = false;
106 switch (kind) {
107 case CudnnConvKind::kBackwardFilter:
108 succ =
109 stream_exec->GetConvolveBackwardFilterAlgorithms(true, &algorithms);
110 break;
111 case CudnnConvKind::kBackwardInput:
112 succ = stream_exec->GetConvolveBackwardDataAlgorithms(true, &algorithms);
113 break;
114 case CudnnConvKind::kForward:
115 case CudnnConvKind::kForwardActivation:
116 succ = stream_exec->GetConvolveAlgorithms(true, &algorithms);
117 break;
118 }
119 DCHECK(succ);
120
121 return algorithms;
122 }
123
GetMIOpenAlgorithms(const HloCustomCallInstruction * instr,absl::Span<se::DeviceMemoryBase> operand_buffers,se::DeviceMemoryBase result_buffer,se::StreamExecutor * stream_exec,ScratchAllocator * scratch_allocator,se::Stream * stream)124 StatusOr<std::vector<se::dnn::ProfileResult>> GetMIOpenAlgorithms(
125 const HloCustomCallInstruction* instr,
126 absl::Span<se::DeviceMemoryBase> operand_buffers,
127 se::DeviceMemoryBase result_buffer, se::StreamExecutor* stream_exec,
128 ScratchAllocator* scratch_allocator, se::Stream* stream) {
129 std::vector<se::dnn::ProfileResult> algorithms;
130
131 TF_ASSIGN_OR_RETURN(GpuConvConfig config, GetGpuConvConfig(instr));
132
133 TF_ASSIGN_OR_RETURN(se::dnn::ConvolutionKind kind,
134 GetDNNConvKindFromCudnnConvKind(config.kind));
135
136 TF_ASSIGN_OR_RETURN(se::dnn::DataType dtype,
137 GetDNNDataTypeFromPrimitiveType(config.output_type));
138
139 TF_ASSIGN_OR_RETURN(GpuConvParams params,
140 GetGpuConvParams(config, operand_buffers, result_buffer));
141
142 bool succ = stream_exec->GetMIOpenConvolveAlgorithms(
143 kind, dtype, stream, params.config.input_descriptor, params.input_buf,
144 params.config.filter_descriptor, params.filter_buf,
145 params.config.output_descriptor, params.output_buf,
146 params.config.conv_desc, scratch_allocator, &algorithms);
147 DCHECK(succ);
148
149 return algorithms;
150 }
151
AlgorithmToString(const AlgorithmDesc & algo)152 string AlgorithmToString(const AlgorithmDesc& algo) {
153 if (algo.tensor_ops_enabled()) {
154 return absl::StrCat(algo.algo_id(), "+TC");
155 }
156 return absl::StrCat(algo.algo_id());
157 }
158
NumBytesToString(int64 bytes)159 string NumBytesToString(int64 bytes) {
160 return absl::StrCat(tensorflow::strings::HumanReadableNumBytes(bytes), " (",
161 bytes, "B)");
162 }
163
GetCudnnVersion(se::StreamExecutor * stream_executor)164 tensorflow::CudnnVersion GetCudnnVersion(se::StreamExecutor* stream_executor) {
165 tensorflow::CudnnVersion cudnn_version;
166 if (auto* dnn = stream_executor->AsDnn()) {
167 StatusOr<se::dnn::VersionInfo> version_or = dnn->GetVersion();
168 if (version_or.ok()) {
169 const auto& version = version_or.ValueOrDie();
170 cudnn_version.set_major(version.major_version());
171 cudnn_version.set_minor(version.minor_version());
172 cudnn_version.set_patch(version.patch());
173 }
174 }
175 return cudnn_version;
176 }
177
GetComputeCapability(se::StreamExecutor * stream_executor)178 tensorflow::ComputeCapability GetComputeCapability(
179 se::StreamExecutor* stream_executor) {
180 tensorflow::ComputeCapability cc;
181 int cc_major, cc_minor;
182 stream_executor->GetDeviceDescription().cuda_compute_capability(&cc_major,
183 &cc_minor);
184 cc.set_major(cc_major);
185 cc.set_minor(cc_minor);
186 return cc;
187 }
188
PrintPlatformInfo(const se::Stream * stream)189 void PrintPlatformInfo(const se::Stream* stream) {
190 auto* se = stream->parent();
191 const auto& desc = se->GetDeviceDescription();
192 LOG(ERROR) << "Device: " << desc.name();
193 LOG(ERROR) << "Platform: " << desc.platform_version();
194 LOG(ERROR) << "Driver: " << desc.driver_version();
195 LOG(ERROR) << "Runtime: " << desc.runtime_version();
196
197 auto* dnn = se->AsDnn();
198 if (dnn) {
199 auto dnn_version = dnn->GetVersion();
200 if (dnn_version.ok()) {
201 auto v = dnn_version.ValueOrDie();
202 LOG(ERROR) << "cudnn version: " << v.major_version() << "."
203 << v.minor_version() << "." << v.patch();
204 }
205 }
206 }
207
208 #if (defined(GOOGLE_CUDA) && GOOGLE_CUDA)
209 // Returns true if the redzones in `allocator`'s allocations are unmodified.
210 //
211 // If the redzones are modified, logs an error, sets the appropriate failure
212 // bits on `result`, and returns false.
213 //
214 // Returns a status if an unexpected error has occurred, and the stream
215 // has been poisoned.
216 //
217 // `name` is a user-friendly name for the set of redzones being checked, e.g.
218 // "input/output" or "scratch".
CheckRedzones(const se::RedzoneAllocator & allocator,se::Stream * stream,absl::string_view name,const HloInstruction * instr,AutotuneResult * result)219 StatusOr<bool> CheckRedzones(const se::RedzoneAllocator& allocator,
220 se::Stream* stream, absl::string_view name,
221 const HloInstruction* instr,
222 AutotuneResult* result) {
223 XLA_SCOPED_LOGGING_TIMER_LEVEL("CudnnConvAlgorithmPicker checking redzones",
224 2);
225 using RedzoneCheckStatus = se::RedzoneAllocator::RedzoneCheckStatus;
226 TF_ASSIGN_OR_RETURN(RedzoneCheckStatus redzone_check,
227 allocator.CheckRedzones());
228 if (redzone_check.ok()) {
229 return true;
230 }
231
232 auto* fail = result->mutable_failure();
233 fail->set_kind(AutotuneResult::REDZONE_MODIFIED);
234 *fail->mutable_msg() = redzone_check.RedzoneFailureMsg();
235 fail->set_buffer_address(
236 reinterpret_cast<uint64>(redzone_check.user_buffer_address));
237
238 LOG(ERROR) << absl::StreamFormat(
239 "Detected cudnn out-of-bounds write in conv %s buffer! This is likely a "
240 "cudnn bug. We will skip this algorithm in the future, but your GPU "
241 "state may already be corrupted, leading to incorrect results. Within "
242 "Google, no action is needed on your part. Outside of Google, please "
243 "ensure you're running the latest version of cudnn. If that doesn't fix "
244 "the problem, please file a bug with this full error message and we'll "
245 "contact nvidia.",
246 name);
247 LOG(ERROR) << redzone_check.RedzoneFailureMsg();
248 LOG(ERROR) << "HloInstruction " << instr->ToString();
249 PrintPlatformInfo(stream);
250 return false;
251 }
252 #endif
253
254 using ConvCacheKey =
255 std::tuple<se::StreamExecutor*,
256 /* conv->ToString(HloPrintOptions::Canonical()) */ std::string>;
257
258 struct ConvCacheStats {
259 int64 cache_hits = 0;
260 int64 cache_misses = 0;
261
LogStatsxla::gpu::__anon6437d7ed0111::ConvCacheStats262 void LogStats() {
263 VLOG(2) << "Cache hits: " << cache_hits;
264 VLOG(2) << "Cache misses: " << cache_misses;
265 }
266 };
267
AutotuneCacheKeyfromInstruction(const HloCustomCallInstruction * conv,se::StreamExecutor * se)268 ConvCacheKey AutotuneCacheKeyfromInstruction(
269 const HloCustomCallInstruction* conv, se::StreamExecutor* se) {
270 auto options = HloPrintOptions::Canonical();
271 options.set_print_backend_config(true);
272 return std::make_tuple(se, conv->ToString(options));
273 }
274
275 tensorflow::mutex autotune_cache_lock(tensorflow::LINKER_INITIALIZED);
276 auto& autotune_cache TF_GUARDED_BY(autotune_cache_lock) =
277 *new absl::flat_hash_map<ConvCacheKey, AutotuneResult>();
278 auto& autotune_cache_stats TF_GUARDED_BY(autotune_cache_lock) =
279 *new ConvCacheStats();
280 } // anonymous namespace
281
PickBestAlgorithm(const HloCustomCallInstruction * instr)282 StatusOr<AutotuneResult> GpuConvAlgorithmPicker::PickBestAlgorithm(
283 const HloCustomCallInstruction* instr) {
284 // Don't run this function concurrently on the same GPU.
285 //
286 // This is a bit of a hack and doesn't protect us against arbitrary concurrent
287 // use of a GPU, but it's sufficient to let us compile two HLO modules
288 // concurrently and then run them sequentially.
289 //
290 // Putting the lock in here rather than in PickBestAlgorithmNoCache lets us
291 // avoid ever doing duplicate work. If we have a cache miss, only one thread
292 // will run PickBestAlgorithmImpl for a particular device.
293 tensorflow::mutex_lock lock = LockGpu(stream_exec_);
294
295 // We cache the autotuning results to avoid doing the duplicate work,
296 // which can greatly improve both stability (deterministic numeric results
297 // within a process for a given input) and performance (2x speedup on some
298 // models).
299 ConvCacheKey key = AutotuneCacheKeyfromInstruction(instr, stream_exec_);
300 {
301 tensorflow::mutex_lock lock(autotune_cache_lock);
302 auto it = autotune_cache.find(key);
303 if (it != autotune_cache.end()) {
304 autotune_cache_stats.cache_hits++;
305 return it->second;
306 }
307 autotune_cache_stats.cache_misses++;
308 }
309
310 // Make sure any previous activity on this executor is done. We don't want to
311 // interfere with programs that are still running on the GPU.
312 if (!stream_exec_->SynchronizeAllActivity()) {
313 return InternalError("Failed to synchronize GPU for autotuning.");
314 }
315
316 // allocator either points to this->allocator_ or, if that's null, to a
317 // se::StreamExecutorMemoryAllocator for stream_exec_.
318 se::DeviceMemoryAllocator* allocator;
319 optional<se::StreamExecutorMemoryAllocator> se_allocator;
320 if (allocator_ != nullptr) {
321 allocator = allocator_;
322 } else {
323 se_allocator.emplace(stream_exec_);
324 allocator = &*se_allocator;
325 }
326
327 TF_ASSIGN_OR_RETURN(se::Stream* const stream,
328 allocator->GetStream(stream_exec_->device_ordinal()));
329 StatusOr<AutotuneResult> result_or(InternalError("Unknown platform."));
330 // Check StreamExecutor on which platform it is. ROCm and Cuda implementation
331 // have diverged. Specifically, we need to make sure redzone allocator related
332 // utilities are not used in ROCm routine
333 if (stream_exec_->platform_kind() == se::PlatformKind::kROCm) {
334 result_or = PickBestAlgorithmNoCacheRocm(instr, allocator, stream);
335 } else if (stream_exec_->platform_kind() == se::PlatformKind::kCuda) {
336 #if (defined(GOOGLE_CUDA) && GOOGLE_CUDA)
337 result_or = PickBestAlgorithmNoCacheCuda(instr, allocator, stream);
338 #endif
339 }
340
341 if (result_or.ok()) {
342 tensorflow::mutex_lock lock(autotune_cache_lock);
343 CHECK(autotune_cache.insert({key, result_or.ValueOrDie()}).second);
344 }
345 return result_or;
346 }
347
348 // The following function allows deterministic ops to be implemented relatively
349 // quickly using environment variables. It is intended to be temporary. The
350 // longer-term intention is to enable deterministic ops via tf.config and
351 // appropriate plumbing. See the discussion on PR 34951 for more information:
352 // https://github.com/tensorflow/tensorflow/pull/34951#discussion_r355682316
353 // This function and associated comment are replicated in the following three
354 // places:
355 // 1. tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.cc
356 // 2. tensorflow/core/kernels/gpu_utils.cc
357 // 3. tensorflow/stream_executor/cuda/cuda_dnn.cc
358 // When implementing the plumbing, you should also search for the use of
359 // TF_DETERMINISTIC_OPS on its own.
360 // TODO(duncanriach): move to an API that uses tf.config and implement the first
361 // phase of plumbing.
RequireCudnnDeterminism()362 static bool RequireCudnnDeterminism() {
363 static bool require_cudnn_determinism = [] {
364 bool deterministic_ops = false;
365 TF_CHECK_OK(tensorflow::ReadBoolFromEnvVar("TF_DETERMINISTIC_OPS",
366 /*default_val=*/false,
367 &deterministic_ops));
368 bool cudnn_deterministic = false;
369 TF_CHECK_OK(tensorflow::ReadBoolFromEnvVar("TF_CUDNN_DETERMINISTIC",
370 /*default_val=*/false,
371 &cudnn_deterministic));
372 return deterministic_ops || cudnn_deterministic;
373 }();
374 return require_cudnn_determinism;
375 }
376
377 #if (defined(GOOGLE_CUDA) && GOOGLE_CUDA)
378 StatusOr<tensorflow::AutotuneResult>
PickBestAlgorithmNoCacheCuda(const HloCustomCallInstruction * instr,se::DeviceMemoryAllocator * allocator,se::Stream * stream)379 GpuConvAlgorithmPicker::PickBestAlgorithmNoCacheCuda(
380 const HloCustomCallInstruction* instr, se::DeviceMemoryAllocator* allocator,
381 se::Stream* stream) {
382 // Right now Redzone allocator is available in Cuda target only
383 XLA_SCOPED_LOGGING_TIMER(absl::StrCat(
384 "GpuConvAlgorithmPicker::PickBestAlgorithmImpl for ", instr->ToString()));
385
386 const Shape& result_shape = instr->shape().tuple_shapes(0);
387 int64 rng_state = 0;
388
389 const HloModuleConfig& hlo_module_config = instr->GetModule()->config();
390 const int32 conv_autotune_level =
391 hlo_module_config.debug_options().xla_gpu_autotune_level();
392 const bool init_conv_data = conv_autotune_level > 1;
393 const bool check_conv = conv_autotune_level > 3;
394 const auto initialize_buffer = [init_conv_data, &stream, &rng_state](
395 DeviceMemoryBase buffer,
396 const Shape& buffer_shape) {
397 if (init_conv_data) {
398 InitializeBuffer(stream, buffer_shape.element_type(), &rng_state, buffer);
399 }
400 };
401
402 // Allocate space for the input, filter, and output of the convolution.
403 se::RedzoneAllocator input_output_allocator(
404 stream, allocator, PtxOptsFromConfig(hlo_module_config));
405 std::vector<se::DeviceMemoryBase> operand_buffers;
406 for (const auto* operand : instr->operands()) {
407 TF_ASSIGN_OR_RETURN(auto buffer,
408 input_output_allocator.AllocateBytes(
409 ShapeUtil::ByteSizeOf(operand->shape())));
410 initialize_buffer(buffer, operand->shape());
411 operand_buffers.push_back(buffer);
412 }
413 TF_ASSIGN_OR_RETURN(auto result_buffer,
414 input_output_allocator.AllocateBytes(
415 ShapeUtil::ByteSizeOf(result_shape)));
416 initialize_buffer(result_buffer, result_shape);
417
418 TF_ASSIGN_OR_RETURN(auto backend_config,
419 instr->backend_config<CudnnConvBackendConfig>());
420
421 optional<BufferComparator> comparator;
422 // Use the first algorithm that's supported as reference. There isn't a
423 // particular reason to use it, as any algorithm suffices. It doesn't make
424 // this algorithm considered correct, though.
425 se::DeviceMemoryBase reference_result_buffer;
426 AlgorithmDesc first_algorithm;
427
428 TF_ASSIGN_OR_RETURN(CudnnConvKind kind, GetCudnnConvKind(instr));
429 std::vector<AutotuneResult> profile_results;
430
431 const DebugOptions& debug_options =
432 instr->GetModule()->config().debug_options();
433
434 const bool crash_on_checking_failure =
435 debug_options.xla_gpu_crash_on_verification_failures();
436
437 const auto canonical_hlo =
438 std::get<1>(AutotuneCacheKeyfromInstruction(instr, stream_exec_));
439
440 string blas_version;
441 if (auto* blas = stream_exec_->AsBlas()) {
442 (void)blas->GetVersion(&blas_version);
443 }
444
445 absl::Span<const AlgorithmDesc> disabled_algos = GetDisabledConvAlgorithms(
446 GetComputeCapability(stream_exec_), GetCudnnVersion(stream_exec_),
447 blas_version, canonical_hlo);
448
449 TF_ASSIGN_OR_RETURN(GpuConvConfig config, GetGpuConvConfig(instr));
450
451 for (const AlgorithmDesc& alg : GetAlgorithms(kind, stream_exec_)) {
452 XLA_SCOPED_LOGGING_TIMER_LEVEL(
453 absl::StrCat("CudnnConvAlgorithmPicker::PickBestAlgorithm algo ",
454 AlgorithmToString(alg)),
455 2);
456
457 if (absl::c_linear_search(disabled_algos, alg)) {
458 LOG(INFO) << "Omitted potentially buggy algorithm "
459 << AlgorithmToString(alg) << " for conv " << instr->ToString();
460 continue;
461 }
462
463 se::RedzoneAllocator scratch_allocator(
464 stream, allocator, PtxOptsFromConfig(hlo_module_config));
465 se::dnn::ProfileResult profile_result;
466 VLOG(3) << "Trying algorithm " << AlgorithmToString(alg) << " for "
467 << instr->ToString();
468
469 // Use assignment instead of brace-list to make GCC 4.9 happy.
470 RunConvOptions options;
471 options.profile_result = &profile_result;
472 options.algo_override = alg;
473 Status launch_status =
474 RunGpuConv(config, absl::MakeSpan(operand_buffers), result_buffer,
475 &scratch_allocator, stream, options);
476
477 if (!launch_status.ok()) {
478 continue;
479 }
480
481 if (!profile_result.is_valid()) {
482 continue;
483 }
484
485 profile_results.emplace_back();
486 AutotuneResult& result = profile_results.back();
487 result.mutable_conv()->set_algorithm(alg.algo_id());
488 result.mutable_conv()->set_tensor_ops_enabled(alg.tensor_ops_enabled());
489
490 int64 scratch_bytes_used =
491 scratch_allocator.TotalAllocatedBytesExcludingRedzones();
492 result.set_scratch_bytes(scratch_bytes_used);
493 *result.mutable_run_time() = tensorflow::proto_utils::ToDurationProto(
494 absl::Milliseconds(profile_result.elapsed_time_in_ms()));
495
496 if (!check_conv) {
497 continue;
498 }
499
500 // Check for writes to redzones.
501 TF_ASSIGN_OR_RETURN(bool input_output_allocator_redzone_clear,
502 CheckRedzones(input_output_allocator, stream,
503 "input/output", instr, &result));
504
505 TF_ASSIGN_OR_RETURN(
506 bool scratch_allocator_redzone_clear,
507 CheckRedzones(scratch_allocator, stream, "scratch", instr, &result));
508
509 if (!input_output_allocator_redzone_clear ||
510 !scratch_allocator_redzone_clear) {
511 AlgorithmDenylist proto;
512 auto entry = proto.add_entries();
513 entry->set_hlo(canonical_hlo);
514 *entry->mutable_cc() = GetComputeCapability(stream_exec_);
515 *entry->mutable_cudnn_version() = GetCudnnVersion(stream_exec_);
516 entry->set_blas_version(blas_version);
517 auto algo = entry->add_algos();
518 algo->set_id(alg.algo_id());
519 algo->set_tensor_ops(alg.tensor_ops_enabled());
520
521 LOG(ERROR) << "To denylist this algorithm for this convolution, "
522 "copy-paste the following "
523 "proto to the denylist file pointed by XLA_FLAGS "
524 "--xla_gpu_algorithm_denylist_path="
525 << GetDebugOptionsFromFlags().xla_gpu_algorithm_denylist_path()
526 << " : " << proto.ShortDebugString();
527 continue;
528 }
529
530 if (comparator.has_value()) {
531 XLA_SCOPED_LOGGING_TIMER_LEVEL("BufferComparator::CompareEqual", 2);
532 StatusOr<bool> compare_result = comparator->CompareEqual(
533 stream, reference_result_buffer, result_buffer);
534 if (!compare_result.ok()) {
535 LOG(ERROR) << "Unable to compare " << AlgorithmToString(first_algorithm)
536 << " against " << AlgorithmToString(alg) << " for "
537 << instr->ToString() << ": " << compare_result.status();
538 if (compare_result.status().code() ==
539 tensorflow::error::RESOURCE_EXHAUSTED) {
540 // Possibly OOM. Propagate the error.
541 return compare_result.status();
542 }
543 CHECK(!crash_on_checking_failure);
544 } else if (!compare_result.ValueOrDie()) {
545 LOG(ERROR)
546 << "Results mismatch between different convolution algorithms. "
547 "This is likely a bug/unexpected loss of precision in cudnn.\n"
548 << instr->ToString() << " for "
549 << AlgorithmToString(first_algorithm) << " vs "
550 << AlgorithmToString(alg);
551 PrintPlatformInfo(stream);
552 VLOG(1) << "Full module on failure: \n"
553 << instr->GetModule()->ToString();
554 auto* fail = result.mutable_failure();
555 fail->set_kind(AutotuneResult::WRONG_RESULT);
556 fail->set_buffer_address(
557 reinterpret_cast<uint64>(result_buffer.opaque()));
558 auto* reference_conv = fail->mutable_reference_conv();
559 reference_conv->set_algorithm(first_algorithm.algo_id());
560 reference_conv->set_tensor_ops_enabled(
561 first_algorithm.tensor_ops_enabled());
562 }
563 } else {
564 XLA_SCOPED_LOGGING_TIMER_LEVEL("BufferComparator::Create", 2);
565 comparator.emplace(result_shape, hlo_module_config);
566 TF_ASSIGN_OR_RETURN(
567 reference_result_buffer,
568 input_output_allocator.AllocateBytes(result_buffer.size()));
569 stream->ThenMemcpy(&reference_result_buffer, result_buffer,
570 result_buffer.size());
571 first_algorithm = alg;
572 }
573 }
574
575 // Log the autotuning result.
576 {
577 tensorflow::AutotuningLog log;
578 {
579 ConvInstructionLog instr_log;
580 *instr_log.mutable_instruction() = instr->ToProto();
581 for (int i = 0; i < instr->operand_count(); i++) {
582 *instr_log.add_operand_shapes() = instr->operand(i)->shape().ToProto();
583 instr_log.add_operand_addresses(
584 reinterpret_cast<uint64>(operand_buffers[i].opaque()));
585 }
586 instr_log.set_result_address(
587 reinterpret_cast<uint64>(result_buffer.opaque()));
588 log.mutable_instr()->PackFrom(instr_log);
589 }
590 for (const auto& profile : profile_results) {
591 *log.add_results() = profile;
592 }
593 *log.mutable_compute_capability() = GetComputeCapability(stream_exec_);
594 *log.mutable_cudnn_version() = GetCudnnVersion(stream_exec_);
595 log.set_device_pci_bus_id(
596 stream_exec_->GetDeviceDescription().pci_bus_id());
597 log.set_blas_version(blas_version);
598 VLOG(1) << "Autotuning result: " << log.ShortDebugString();
599 // If we crash on checking failure, we are in a testing/benchmark mode, thus
600 // omitting logging through the logger.
601 if (!crash_on_checking_failure) {
602 tensorflow::Logger::GetSingleton()->LogProto(log);
603 }
604 }
605
606 // Crash on miscompares and redzone violations if desired. Do this after
607 // logging the autotuning results, otherwise we won't get any data!
608 for (const auto& result : profile_results) {
609 if (result.has_failure()) {
610 CHECK(!crash_on_checking_failure);
611 }
612 }
613
614 // Choose the fastest convolution that doesn't produce a REDZONE_MODIFIED
615 // error.
616 //
617 // TODO(jlebar): We ought to be able to detect redzone reads by noticing NaNs
618 // in the output of the conv and skip those.
619 //
620 // For now, we ignore WRONG_RESULT failures because false-positives are
621 // possible (e.g. perhaps the reference algorithm is the one that's
622 // incorrect!). But we don't ignore REDZONE_MODIFIED failures because they're
623 // quite severe and can be detected with high accuracy.
624 std::vector<AutotuneResult> filtered_results;
625 absl::c_copy_if(
626 profile_results, std::back_inserter(filtered_results),
627 [](const AutotuneResult& r) {
628 return !(r.has_failure() &&
629 r.failure().kind() != AutotuneResult::WRONG_RESULT);
630 });
631 if (filtered_results.empty()) {
632 return InternalError(
633 "All algorithms tried for convolution %s failed. Falling back to "
634 "default algorithm. ",
635 instr->ToString());
636 }
637
638 auto selected_result = filtered_results.begin();
639 if (!RequireCudnnDeterminism() &&
640 !hlo_module_config.debug_options().xla_gpu_deterministic_ops()) {
641 selected_result = absl::c_min_element(
642 filtered_results,
643 [](const AutotuneResult& lhs, const AutotuneResult& rhs) {
644 return tensorflow::proto_utils::FromDurationProto(lhs.run_time()) <
645 tensorflow::proto_utils::FromDurationProto(rhs.run_time());
646 });
647 }
648
649 return *selected_result;
650 }
651 #endif
652
653 StatusOr<tensorflow::AutotuneResult>
PickBestAlgorithmNoCacheRocm(const HloCustomCallInstruction * instr,se::DeviceMemoryAllocator * allocator,se::Stream * stream)654 GpuConvAlgorithmPicker::PickBestAlgorithmNoCacheRocm(
655 const HloCustomCallInstruction* instr, se::DeviceMemoryAllocator* allocator,
656 se::Stream* stream) {
657 XLA_SCOPED_LOGGING_TIMER(absl::StrCat(
658 "GpuConvAlgorithmPicker::PickBestAlgorithmImpl for ", instr->ToString()));
659
660 const auto device_ordinal = stream_exec_->device_ordinal();
661 std::vector<se::DeviceMemoryBase> operand_buffers;
662
663 ScratchAllocator input_output_allocator(device_ordinal, allocator);
664 const auto initialize_buffer = [stream](DeviceMemoryBase buffer) {
665 // Although we don't have evidence this matters, zero out the buffers
666 // before autotuning. It's conceivable that using uninitialized memory as
667 // the inputs might affect performance if e.g. the inputs contain
668 // denormals, and this is easy enough.
669 stream->ThenMemZero(&buffer, buffer.size());
670 };
671
672 // Allocate space for the input, filter, and output of the convolution. We
673 // use a ScratchAllocator for this instead of calling allocator_ directly so
674 // that our allocations don't leak.
675 for (const auto* operand : instr->operands()) {
676 TF_ASSIGN_OR_RETURN(auto buffer,
677 input_output_allocator.AllocateBytes(
678 ShapeUtil::ByteSizeOf(operand->shape())));
679 initialize_buffer(buffer);
680 operand_buffers.push_back(buffer);
681 }
682
683 TF_ASSIGN_OR_RETURN(
684 auto result_buffer,
685 input_output_allocator.AllocateBytes(
686 ShapeUtil::ByteSizeOf(instr->shape().tuple_shapes(0))));
687 initialize_buffer(result_buffer);
688
689 ScratchAllocator scratch_allocator(device_ordinal, allocator);
690
691 TF_ASSIGN_OR_RETURN(
692 std::vector<se::dnn::ProfileResult> algorithms,
693 GetMIOpenAlgorithms(instr, absl::MakeSpan(operand_buffers), result_buffer,
694 stream_exec_, &scratch_allocator, stream));
695
696 std::vector<AutotuneResult> profile_results;
697
698 if (algorithms.size() == 1) {
699 auto profile_result = algorithms[0];
700 profile_results.emplace_back();
701 auto& result = profile_results.back();
702 result.mutable_conv()->set_algorithm(profile_result.algorithm().algo_id());
703 result.mutable_conv()->set_tensor_ops_enabled(
704 profile_result.algorithm().tensor_ops_enabled());
705
706 result.set_scratch_bytes(profile_result.scratch_size());
707 *result.mutable_run_time() = tensorflow::proto_utils::ToDurationProto(
708 absl::Milliseconds(profile_result.elapsed_time_in_ms()));
709 } else {
710 TF_ASSIGN_OR_RETURN(GpuConvConfig config, GetGpuConvConfig(instr));
711 for (const auto& miopen_alg : algorithms) {
712 const auto& alg = miopen_alg.algorithm();
713 XLA_SCOPED_LOGGING_TIMER_LEVEL(
714 absl::StrCat("CudnnConvAlgorithmPicker::PickBestAlgorithm algo ",
715 AlgorithmToString(alg)),
716 2);
717
718 se::dnn::ProfileResult profile_result;
719 VLOG(3) << "Trying algorithm " << AlgorithmToString(alg) << " for "
720 << instr->ToString();
721
722 // Use assignment instead of brace-list to make GCC 4.9 happy.
723 RunConvOptions options;
724 options.profile_result = &profile_result;
725 options.algo_override = alg;
726 options.scratch_size_override = miopen_alg.scratch_size();
727 Status launch_status =
728 RunGpuConv(config, absl::MakeSpan(operand_buffers), result_buffer,
729 &scratch_allocator, stream, options);
730
731 if (!launch_status.ok()) {
732 continue;
733 }
734
735 if (!profile_result.is_valid()) {
736 continue;
737 }
738
739 profile_results.emplace_back();
740 AutotuneResult& result = profile_results.back();
741 result.mutable_conv()->set_algorithm(alg.algo_id());
742 result.mutable_conv()->set_tensor_ops_enabled(alg.tensor_ops_enabled());
743
744 int64 scratch_bytes_used = scratch_allocator.TotalAllocatedBytes();
745 result.set_scratch_bytes(scratch_bytes_used);
746 *result.mutable_run_time() = tensorflow::proto_utils::ToDurationProto(
747 absl::Milliseconds(profile_result.elapsed_time_in_ms()));
748 }
749 }
750 const auto& best_result = absl::c_min_element(
751 profile_results,
752 [&](const AutotuneResult& lhs, const AutotuneResult& rhs) {
753 return tensorflow::proto_utils::FromDurationProto(lhs.run_time()) <
754 tensorflow::proto_utils::FromDurationProto(rhs.run_time());
755 });
756
757 if (best_result != profile_results.end()) {
758 return *best_result;
759 }
760
761 return InternalError(
762 "All algorithms tried for convolution %s failed. Falling back to "
763 "default algorithm.",
764 instr->ToString());
765 }
766
RunOnInstruction(HloInstruction * instr)767 StatusOr<bool> GpuConvAlgorithmPicker::RunOnInstruction(HloInstruction* instr) {
768 CHECK(IsCustomCallToDnnConvolution(*instr));
769
770 StatusOr<AutotuneResult> best_algo_or =
771 PickBestAlgorithm(Cast<HloCustomCallInstruction>(instr));
772 if (!best_algo_or.ok()) {
773 LOG(WARNING) << "Failed to determine best cudnn convolution algorithm: "
774 << best_algo_or.status()
775 << "\n\nConvolution performance may be suboptimal.";
776 return false;
777 }
778
779 auto best_algo = std::move(best_algo_or).ValueOrDie();
780 VLOG(2) << "Setting cudnn conv to use algorithm "
781 << best_algo.conv().algorithm() << " and "
782 << NumBytesToString(best_algo.scratch_bytes())
783 << " of scratch memory: " << instr->ToString()
784 << " tensor_ops_enabled: " << best_algo.conv().tensor_ops_enabled();
785
786 // Replace instr with a new CustomCall which has the correct algorithm, and
787 // whose output shape has the appropriate amount of scratch memory.
788 HloComputation* computation = instr->parent();
789 Shape new_call_shape = ShapeUtil::MakeTupleShape(
790 {instr->shape().tuple_shapes(0),
791 ShapeUtil::MakeShape(U8, {best_algo.scratch_bytes()})});
792
793 TF_ASSIGN_OR_RETURN(CudnnConvBackendConfig backend_config,
794 instr->backend_config<CudnnConvBackendConfig>());
795 backend_config.set_algorithm(best_algo.conv().algorithm());
796 backend_config.set_tensor_ops_enabled(best_algo.conv().tensor_ops_enabled());
797
798 HloInstruction* new_call = computation->AddInstruction(
799 instr->CloneWithNewOperands(new_call_shape, instr->operands()));
800
801 VLOG(2) << "Replacing convolution " << instr->ToString() << " with "
802 << new_call->ToString();
803
804 TF_RETURN_IF_ERROR(new_call->set_backend_config(backend_config));
805
806 // Repackage new_call so it has the same shape as the original call, namely
807 // (conv_result, u8[0]).
808 HloInstruction* new_tuple =
809 computation->AddInstruction(HloInstruction::CreateTuple(
810 {computation->AddInstruction(HloInstruction::CreateGetTupleElement(
811 new_call_shape.tuple_shapes(0), new_call, 0)),
812 computation->AddInstruction(HloInstruction::CreateConstant(
813 LiteralUtil::CreateR1<uint8>({})))}));
814
815 TF_RETURN_IF_ERROR(instr->parent()->ReplaceInstruction(instr, new_tuple));
816 return true;
817 }
818
RunOnComputation(HloComputation * computation)819 StatusOr<bool> GpuConvAlgorithmPicker::RunOnComputation(
820 HloComputation* computation) {
821 std::vector<HloInstruction*> convs;
822 for (auto* instr : computation->instructions()) {
823 if (IsCustomCallToDnnConvolution(*instr)) {
824 convs.push_back(instr);
825 }
826 }
827
828 bool changed = false;
829 for (auto* instr : convs) {
830 TF_ASSIGN_OR_RETURN(bool result, RunOnInstruction(instr));
831 changed |= result;
832 }
833 return changed;
834 }
835
Run(HloModule * module)836 StatusOr<bool> GpuConvAlgorithmPicker::Run(HloModule* module) {
837 XLA_SCOPED_LOGGING_TIMER("GpuConvAlgorithmPicker");
838
839 if (module->config().debug_options().xla_gpu_autotune_level() == 0) {
840 VLOG(2) << "Convolution auto-tuning disabled, GpuConvAlgorithmPicker "
841 "returning early.";
842 return false;
843 }
844
845 bool changed = false;
846 for (HloComputation* computation : module->MakeNonfusionComputations()) {
847 TF_ASSIGN_OR_RETURN(bool result, RunOnComputation(computation));
848 changed |= result;
849 }
850
851 {
852 tensorflow::mutex_lock lock(autotune_cache_lock);
853 autotune_cache_stats.LogStats();
854 }
855
856 return changed;
857 }
858
859 } // namespace gpu
860 } // namespace xla
861