• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.h"
17 
18 #include "absl/algorithm/container.h"
19 #include "absl/strings/str_cat.h"
20 #include "absl/strings/str_format.h"
21 #include "absl/time/time.h"
22 #include "absl/types/optional.h"
23 #include "tensorflow/compiler/xla/literal_util.h"
24 #include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h"
25 #include "tensorflow/compiler/xla/service/gpu/convolution_thunk.h"
26 #include "tensorflow/compiler/xla/service/gpu/gpu_autotuning.pb.h"
27 #include "tensorflow/compiler/xla/service/gpu/hlo_algorithm_denylist.h"
28 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
29 #include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h"
30 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
31 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
32 #include "tensorflow/compiler/xla/status_macros.h"
33 #include "tensorflow/compiler/xla/util.h"
34 #include "tensorflow/core/lib/strings/numbers.h"
35 #include "tensorflow/core/platform/logger.h"
36 #include "tensorflow/core/platform/mutex.h"
37 #include "tensorflow/core/util/env_var.h"
38 #include "tensorflow/core/util/proto/proto_utils.h"
39 
40 #if (defined(GOOGLE_CUDA) && GOOGLE_CUDA)
41 #include "tensorflow/compiler/xla/service/gpu/buffer_comparator.h"
42 #include "tensorflow/stream_executor/gpu/redzone_allocator.h"
43 #endif
44 
45 namespace xla {
46 namespace gpu {
47 namespace {
48 
49 using absl::optional;
50 using se::DeviceMemoryBase;
51 using se::dnn::AlgorithmDesc;
52 using tensorflow::AutotuneResult;
53 
54 class ScratchAllocator : public se::ScratchAllocator {
55  public:
ScratchAllocator(int device_ordinal,se::DeviceMemoryAllocator * memory_allocator)56   ScratchAllocator(int device_ordinal,
57                    se::DeviceMemoryAllocator* memory_allocator)
58       : device_ordinal_(device_ordinal), memory_allocator_(memory_allocator) {}
59 
GetMemoryLimitInBytes()60   int64 GetMemoryLimitInBytes() override {
61     return 1LL << 32;  // 4GB.  TODO(jlebar): Tune this?
62   }
TotalAllocatedBytes()63   int64 TotalAllocatedBytes() { return total_allocated_bytes_; }
64 
65   StatusOr<se::DeviceMemory<uint8>> AllocateBytes(int64 byte_size) override;
66 
67   template <typename T>
Allocate(int64 num_elements)68   StatusOr<se::DeviceMemory<T>> Allocate(int64 num_elements) {
69     TF_ASSIGN_OR_RETURN(se::DeviceMemory<uint8> bytes,
70                         AllocateBytes(num_elements * sizeof(T)));
71     return se::DeviceMemory<T>(bytes);
72   }
73 
74  private:
75   const int device_ordinal_;
76   se::DeviceMemoryAllocator* memory_allocator_;
77   std::vector<se::OwningDeviceMemory> allocated_buffers_;
78   int64 total_allocated_bytes_ = 0;
79 };
80 
AllocateBytes(int64 byte_size)81 StatusOr<se::DeviceMemory<uint8>> ScratchAllocator::AllocateBytes(
82     int64 byte_size) {
83   CHECK_GE(byte_size, 0) << "byte_size must be positive.";
84   if (byte_size > GetMemoryLimitInBytes()) {
85     return se::port::Status(
86         se::port::error::RESOURCE_EXHAUSTED,
87         absl::StrFormat(
88             "Allocating %d bytes exceeds the memory limit of %d bytes.",
89             byte_size, GetMemoryLimitInBytes()));
90   }
91 
92   TF_ASSIGN_OR_RETURN(se::OwningDeviceMemory allocated_buffer,
93                       memory_allocator_->Allocate(device_ordinal_, byte_size,
94                                                   /*retry_on_failure=*/false));
95   total_allocated_bytes_ += byte_size;
96 
97   se::DeviceMemoryBase buffer_addr = *allocated_buffer;
98   allocated_buffers_.push_back(std::move(allocated_buffer));
99   return se::DeviceMemory<uint8>(buffer_addr);
100 }
101 
GetAlgorithms(CudnnConvKind kind,se::StreamExecutor * stream_exec)102 std::vector<AlgorithmDesc> GetAlgorithms(CudnnConvKind kind,
103                                          se::StreamExecutor* stream_exec) {
104   std::vector<AlgorithmDesc> algorithms;
105   bool succ = false;
106   switch (kind) {
107     case CudnnConvKind::kBackwardFilter:
108       succ =
109           stream_exec->GetConvolveBackwardFilterAlgorithms(true, &algorithms);
110       break;
111     case CudnnConvKind::kBackwardInput:
112       succ = stream_exec->GetConvolveBackwardDataAlgorithms(true, &algorithms);
113       break;
114     case CudnnConvKind::kForward:
115     case CudnnConvKind::kForwardActivation:
116       succ = stream_exec->GetConvolveAlgorithms(true, &algorithms);
117       break;
118   }
119   DCHECK(succ);
120 
121   return algorithms;
122 }
123 
GetMIOpenAlgorithms(const HloCustomCallInstruction * instr,absl::Span<se::DeviceMemoryBase> operand_buffers,se::DeviceMemoryBase result_buffer,se::StreamExecutor * stream_exec,ScratchAllocator * scratch_allocator,se::Stream * stream)124 StatusOr<std::vector<se::dnn::ProfileResult>> GetMIOpenAlgorithms(
125     const HloCustomCallInstruction* instr,
126     absl::Span<se::DeviceMemoryBase> operand_buffers,
127     se::DeviceMemoryBase result_buffer, se::StreamExecutor* stream_exec,
128     ScratchAllocator* scratch_allocator, se::Stream* stream) {
129   std::vector<se::dnn::ProfileResult> algorithms;
130 
131   TF_ASSIGN_OR_RETURN(GpuConvConfig config, GetGpuConvConfig(instr));
132 
133   TF_ASSIGN_OR_RETURN(se::dnn::ConvolutionKind kind,
134                       GetDNNConvKindFromCudnnConvKind(config.kind));
135 
136   TF_ASSIGN_OR_RETURN(se::dnn::DataType dtype,
137                       GetDNNDataTypeFromPrimitiveType(config.output_type));
138 
139   TF_ASSIGN_OR_RETURN(GpuConvParams params,
140                       GetGpuConvParams(config, operand_buffers, result_buffer));
141 
142   bool succ = stream_exec->GetMIOpenConvolveAlgorithms(
143       kind, dtype, stream, params.config.input_descriptor, params.input_buf,
144       params.config.filter_descriptor, params.filter_buf,
145       params.config.output_descriptor, params.output_buf,
146       params.config.conv_desc, scratch_allocator, &algorithms);
147   DCHECK(succ);
148 
149   return algorithms;
150 }
151 
AlgorithmToString(const AlgorithmDesc & algo)152 string AlgorithmToString(const AlgorithmDesc& algo) {
153   if (algo.tensor_ops_enabled()) {
154     return absl::StrCat(algo.algo_id(), "+TC");
155   }
156   return absl::StrCat(algo.algo_id());
157 }
158 
NumBytesToString(int64 bytes)159 string NumBytesToString(int64 bytes) {
160   return absl::StrCat(tensorflow::strings::HumanReadableNumBytes(bytes), " (",
161                       bytes, "B)");
162 }
163 
GetCudnnVersion(se::StreamExecutor * stream_executor)164 tensorflow::CudnnVersion GetCudnnVersion(se::StreamExecutor* stream_executor) {
165   tensorflow::CudnnVersion cudnn_version;
166   if (auto* dnn = stream_executor->AsDnn()) {
167     StatusOr<se::dnn::VersionInfo> version_or = dnn->GetVersion();
168     if (version_or.ok()) {
169       const auto& version = version_or.ValueOrDie();
170       cudnn_version.set_major(version.major_version());
171       cudnn_version.set_minor(version.minor_version());
172       cudnn_version.set_patch(version.patch());
173     }
174   }
175   return cudnn_version;
176 }
177 
GetComputeCapability(se::StreamExecutor * stream_executor)178 tensorflow::ComputeCapability GetComputeCapability(
179     se::StreamExecutor* stream_executor) {
180   tensorflow::ComputeCapability cc;
181   int cc_major, cc_minor;
182   stream_executor->GetDeviceDescription().cuda_compute_capability(&cc_major,
183                                                                   &cc_minor);
184   cc.set_major(cc_major);
185   cc.set_minor(cc_minor);
186   return cc;
187 }
188 
PrintPlatformInfo(const se::Stream * stream)189 void PrintPlatformInfo(const se::Stream* stream) {
190   auto* se = stream->parent();
191   const auto& desc = se->GetDeviceDescription();
192   LOG(ERROR) << "Device: " << desc.name();
193   LOG(ERROR) << "Platform: " << desc.platform_version();
194   LOG(ERROR) << "Driver: " << desc.driver_version();
195   LOG(ERROR) << "Runtime: " << desc.runtime_version();
196 
197   auto* dnn = se->AsDnn();
198   if (dnn) {
199     auto dnn_version = dnn->GetVersion();
200     if (dnn_version.ok()) {
201       auto v = dnn_version.ValueOrDie();
202       LOG(ERROR) << "cudnn version: " << v.major_version() << "."
203                  << v.minor_version() << "." << v.patch();
204     }
205   }
206 }
207 
208 #if (defined(GOOGLE_CUDA) && GOOGLE_CUDA)
209 // Returns true if the redzones in `allocator`'s allocations are unmodified.
210 //
211 // If the redzones are modified, logs an error, sets the appropriate failure
212 // bits on `result`, and returns false.
213 //
214 // Returns a status if an unexpected error has occurred, and the stream
215 // has been poisoned.
216 //
217 // `name` is a user-friendly name for the set of redzones being checked, e.g.
218 // "input/output" or "scratch".
CheckRedzones(const se::RedzoneAllocator & allocator,se::Stream * stream,absl::string_view name,const HloInstruction * instr,AutotuneResult * result)219 StatusOr<bool> CheckRedzones(const se::RedzoneAllocator& allocator,
220                              se::Stream* stream, absl::string_view name,
221                              const HloInstruction* instr,
222                              AutotuneResult* result) {
223   XLA_SCOPED_LOGGING_TIMER_LEVEL("CudnnConvAlgorithmPicker checking redzones",
224                                  2);
225   using RedzoneCheckStatus = se::RedzoneAllocator::RedzoneCheckStatus;
226   TF_ASSIGN_OR_RETURN(RedzoneCheckStatus redzone_check,
227                       allocator.CheckRedzones());
228   if (redzone_check.ok()) {
229     return true;
230   }
231 
232   auto* fail = result->mutable_failure();
233   fail->set_kind(AutotuneResult::REDZONE_MODIFIED);
234   *fail->mutable_msg() = redzone_check.RedzoneFailureMsg();
235   fail->set_buffer_address(
236       reinterpret_cast<uint64>(redzone_check.user_buffer_address));
237 
238   LOG(ERROR) << absl::StreamFormat(
239       "Detected cudnn out-of-bounds write in conv %s buffer! This is likely a "
240       "cudnn bug. We will skip this algorithm in the future, but your GPU "
241       "state may already be corrupted, leading to incorrect results. Within "
242       "Google, no action is needed on your part. Outside of Google, please "
243       "ensure you're running the latest version of cudnn. If that doesn't fix "
244       "the problem, please file a bug with this full error message and we'll "
245       "contact nvidia.",
246       name);
247   LOG(ERROR) << redzone_check.RedzoneFailureMsg();
248   LOG(ERROR) << "HloInstruction " << instr->ToString();
249   PrintPlatformInfo(stream);
250   return false;
251 }
252 #endif
253 
254 using ConvCacheKey =
255     std::tuple<se::StreamExecutor*,
256                /* conv->ToString(HloPrintOptions::Canonical()) */ std::string>;
257 
258 struct ConvCacheStats {
259   int64 cache_hits = 0;
260   int64 cache_misses = 0;
261 
LogStatsxla::gpu::__anon6437d7ed0111::ConvCacheStats262   void LogStats() {
263     VLOG(2) << "Cache hits: " << cache_hits;
264     VLOG(2) << "Cache misses: " << cache_misses;
265   }
266 };
267 
AutotuneCacheKeyfromInstruction(const HloCustomCallInstruction * conv,se::StreamExecutor * se)268 ConvCacheKey AutotuneCacheKeyfromInstruction(
269     const HloCustomCallInstruction* conv, se::StreamExecutor* se) {
270   auto options = HloPrintOptions::Canonical();
271   options.set_print_backend_config(true);
272   return std::make_tuple(se, conv->ToString(options));
273 }
274 
275 tensorflow::mutex autotune_cache_lock(tensorflow::LINKER_INITIALIZED);
276 auto& autotune_cache TF_GUARDED_BY(autotune_cache_lock) =
277     *new absl::flat_hash_map<ConvCacheKey, AutotuneResult>();
278 auto& autotune_cache_stats TF_GUARDED_BY(autotune_cache_lock) =
279     *new ConvCacheStats();
280 }  // anonymous namespace
281 
PickBestAlgorithm(const HloCustomCallInstruction * instr)282 StatusOr<AutotuneResult> GpuConvAlgorithmPicker::PickBestAlgorithm(
283     const HloCustomCallInstruction* instr) {
284   // Don't run this function concurrently on the same GPU.
285   //
286   // This is a bit of a hack and doesn't protect us against arbitrary concurrent
287   // use of a GPU, but it's sufficient to let us compile two HLO modules
288   // concurrently and then run them sequentially.
289   //
290   // Putting the lock in here rather than in PickBestAlgorithmNoCache lets us
291   // avoid ever doing duplicate work.  If we have a cache miss, only one thread
292   // will run PickBestAlgorithmImpl for a particular device.
293   tensorflow::mutex_lock lock = LockGpu(stream_exec_);
294 
295   // We cache the autotuning results to avoid doing the duplicate work,
296   // which can greatly improve both stability (deterministic numeric results
297   // within a process for a given input) and performance (2x speedup on some
298   // models).
299   ConvCacheKey key = AutotuneCacheKeyfromInstruction(instr, stream_exec_);
300   {
301     tensorflow::mutex_lock lock(autotune_cache_lock);
302     auto it = autotune_cache.find(key);
303     if (it != autotune_cache.end()) {
304       autotune_cache_stats.cache_hits++;
305       return it->second;
306     }
307     autotune_cache_stats.cache_misses++;
308   }
309 
310   // Make sure any previous activity on this executor is done. We don't want to
311   // interfere with programs that are still running on the GPU.
312   if (!stream_exec_->SynchronizeAllActivity()) {
313     return InternalError("Failed to synchronize GPU for autotuning.");
314   }
315 
316   // allocator either points to this->allocator_ or, if that's null, to a
317   // se::StreamExecutorMemoryAllocator for stream_exec_.
318   se::DeviceMemoryAllocator* allocator;
319   optional<se::StreamExecutorMemoryAllocator> se_allocator;
320   if (allocator_ != nullptr) {
321     allocator = allocator_;
322   } else {
323     se_allocator.emplace(stream_exec_);
324     allocator = &*se_allocator;
325   }
326 
327   TF_ASSIGN_OR_RETURN(se::Stream* const stream,
328                       allocator->GetStream(stream_exec_->device_ordinal()));
329   StatusOr<AutotuneResult> result_or(InternalError("Unknown platform."));
330   // Check StreamExecutor on which platform it is. ROCm and Cuda implementation
331   // have diverged. Specifically, we need to make sure redzone allocator related
332   // utilities are not used in ROCm routine
333   if (stream_exec_->platform_kind() == se::PlatformKind::kROCm) {
334     result_or = PickBestAlgorithmNoCacheRocm(instr, allocator, stream);
335   } else if (stream_exec_->platform_kind() == se::PlatformKind::kCuda) {
336 #if (defined(GOOGLE_CUDA) && GOOGLE_CUDA)
337     result_or = PickBestAlgorithmNoCacheCuda(instr, allocator, stream);
338 #endif
339   }
340 
341   if (result_or.ok()) {
342     tensorflow::mutex_lock lock(autotune_cache_lock);
343     CHECK(autotune_cache.insert({key, result_or.ValueOrDie()}).second);
344   }
345   return result_or;
346 }
347 
348 // The following function allows deterministic ops to be implemented relatively
349 // quickly using environment variables. It is intended to be temporary. The
350 // longer-term intention is to enable deterministic ops via tf.config and
351 // appropriate plumbing. See the discussion on PR 34951 for more information:
352 // https://github.com/tensorflow/tensorflow/pull/34951#discussion_r355682316
353 // This function and associated comment are replicated in the following three
354 // places:
355 //   1. tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.cc
356 //   2. tensorflow/core/kernels/gpu_utils.cc
357 //   3. tensorflow/stream_executor/cuda/cuda_dnn.cc
358 // When implementing the plumbing, you should also search for the use of
359 // TF_DETERMINISTIC_OPS on its own.
360 // TODO(duncanriach): move to an API that uses tf.config and implement the first
361 //                    phase of plumbing.
RequireCudnnDeterminism()362 static bool RequireCudnnDeterminism() {
363   static bool require_cudnn_determinism = [] {
364     bool deterministic_ops = false;
365     TF_CHECK_OK(tensorflow::ReadBoolFromEnvVar("TF_DETERMINISTIC_OPS",
366                                                /*default_val=*/false,
367                                                &deterministic_ops));
368     bool cudnn_deterministic = false;
369     TF_CHECK_OK(tensorflow::ReadBoolFromEnvVar("TF_CUDNN_DETERMINISTIC",
370                                                /*default_val=*/false,
371                                                &cudnn_deterministic));
372     return deterministic_ops || cudnn_deterministic;
373   }();
374   return require_cudnn_determinism;
375 }
376 
377 #if (defined(GOOGLE_CUDA) && GOOGLE_CUDA)
378 StatusOr<tensorflow::AutotuneResult>
PickBestAlgorithmNoCacheCuda(const HloCustomCallInstruction * instr,se::DeviceMemoryAllocator * allocator,se::Stream * stream)379 GpuConvAlgorithmPicker::PickBestAlgorithmNoCacheCuda(
380     const HloCustomCallInstruction* instr, se::DeviceMemoryAllocator* allocator,
381     se::Stream* stream) {
382   // Right now Redzone allocator is available in Cuda target only
383   XLA_SCOPED_LOGGING_TIMER(absl::StrCat(
384       "GpuConvAlgorithmPicker::PickBestAlgorithmImpl for ", instr->ToString()));
385 
386   const Shape& result_shape = instr->shape().tuple_shapes(0);
387   int64 rng_state = 0;
388 
389   const HloModuleConfig& hlo_module_config = instr->GetModule()->config();
390   const int32 conv_autotune_level =
391       hlo_module_config.debug_options().xla_gpu_autotune_level();
392   const bool init_conv_data = conv_autotune_level > 1;
393   const bool check_conv = conv_autotune_level > 3;
394   const auto initialize_buffer = [init_conv_data, &stream, &rng_state](
395                                      DeviceMemoryBase buffer,
396                                      const Shape& buffer_shape) {
397     if (init_conv_data) {
398       InitializeBuffer(stream, buffer_shape.element_type(), &rng_state, buffer);
399     }
400   };
401 
402   // Allocate space for the input, filter, and output of the convolution.
403   se::RedzoneAllocator input_output_allocator(
404       stream, allocator, PtxOptsFromConfig(hlo_module_config));
405   std::vector<se::DeviceMemoryBase> operand_buffers;
406   for (const auto* operand : instr->operands()) {
407     TF_ASSIGN_OR_RETURN(auto buffer,
408                         input_output_allocator.AllocateBytes(
409                             ShapeUtil::ByteSizeOf(operand->shape())));
410     initialize_buffer(buffer, operand->shape());
411     operand_buffers.push_back(buffer);
412   }
413   TF_ASSIGN_OR_RETURN(auto result_buffer,
414                       input_output_allocator.AllocateBytes(
415                           ShapeUtil::ByteSizeOf(result_shape)));
416   initialize_buffer(result_buffer, result_shape);
417 
418   TF_ASSIGN_OR_RETURN(auto backend_config,
419                       instr->backend_config<CudnnConvBackendConfig>());
420 
421   optional<BufferComparator> comparator;
422   // Use the first algorithm that's supported as reference. There isn't a
423   // particular reason to use it, as any algorithm suffices. It doesn't make
424   // this algorithm considered correct, though.
425   se::DeviceMemoryBase reference_result_buffer;
426   AlgorithmDesc first_algorithm;
427 
428   TF_ASSIGN_OR_RETURN(CudnnConvKind kind, GetCudnnConvKind(instr));
429   std::vector<AutotuneResult> profile_results;
430 
431   const DebugOptions& debug_options =
432       instr->GetModule()->config().debug_options();
433 
434   const bool crash_on_checking_failure =
435       debug_options.xla_gpu_crash_on_verification_failures();
436 
437   const auto canonical_hlo =
438       std::get<1>(AutotuneCacheKeyfromInstruction(instr, stream_exec_));
439 
440   string blas_version;
441   if (auto* blas = stream_exec_->AsBlas()) {
442     (void)blas->GetVersion(&blas_version);
443   }
444 
445   absl::Span<const AlgorithmDesc> disabled_algos = GetDisabledConvAlgorithms(
446       GetComputeCapability(stream_exec_), GetCudnnVersion(stream_exec_),
447       blas_version, canonical_hlo);
448 
449   TF_ASSIGN_OR_RETURN(GpuConvConfig config, GetGpuConvConfig(instr));
450 
451   for (const AlgorithmDesc& alg : GetAlgorithms(kind, stream_exec_)) {
452     XLA_SCOPED_LOGGING_TIMER_LEVEL(
453         absl::StrCat("CudnnConvAlgorithmPicker::PickBestAlgorithm algo ",
454                      AlgorithmToString(alg)),
455         2);
456 
457     if (absl::c_linear_search(disabled_algos, alg)) {
458       LOG(INFO) << "Omitted potentially buggy algorithm "
459                 << AlgorithmToString(alg) << " for conv " << instr->ToString();
460       continue;
461     }
462 
463     se::RedzoneAllocator scratch_allocator(
464         stream, allocator, PtxOptsFromConfig(hlo_module_config));
465     se::dnn::ProfileResult profile_result;
466     VLOG(3) << "Trying algorithm " << AlgorithmToString(alg) << " for "
467             << instr->ToString();
468 
469     // Use assignment instead of brace-list to make GCC 4.9 happy.
470     RunConvOptions options;
471     options.profile_result = &profile_result;
472     options.algo_override = alg;
473     Status launch_status =
474         RunGpuConv(config, absl::MakeSpan(operand_buffers), result_buffer,
475                    &scratch_allocator, stream, options);
476 
477     if (!launch_status.ok()) {
478       continue;
479     }
480 
481     if (!profile_result.is_valid()) {
482       continue;
483     }
484 
485     profile_results.emplace_back();
486     AutotuneResult& result = profile_results.back();
487     result.mutable_conv()->set_algorithm(alg.algo_id());
488     result.mutable_conv()->set_tensor_ops_enabled(alg.tensor_ops_enabled());
489 
490     int64 scratch_bytes_used =
491         scratch_allocator.TotalAllocatedBytesExcludingRedzones();
492     result.set_scratch_bytes(scratch_bytes_used);
493     *result.mutable_run_time() = tensorflow::proto_utils::ToDurationProto(
494         absl::Milliseconds(profile_result.elapsed_time_in_ms()));
495 
496     if (!check_conv) {
497       continue;
498     }
499 
500     // Check for writes to redzones.
501     TF_ASSIGN_OR_RETURN(bool input_output_allocator_redzone_clear,
502                         CheckRedzones(input_output_allocator, stream,
503                                       "input/output", instr, &result));
504 
505     TF_ASSIGN_OR_RETURN(
506         bool scratch_allocator_redzone_clear,
507         CheckRedzones(scratch_allocator, stream, "scratch", instr, &result));
508 
509     if (!input_output_allocator_redzone_clear ||
510         !scratch_allocator_redzone_clear) {
511       AlgorithmDenylist proto;
512       auto entry = proto.add_entries();
513       entry->set_hlo(canonical_hlo);
514       *entry->mutable_cc() = GetComputeCapability(stream_exec_);
515       *entry->mutable_cudnn_version() = GetCudnnVersion(stream_exec_);
516       entry->set_blas_version(blas_version);
517       auto algo = entry->add_algos();
518       algo->set_id(alg.algo_id());
519       algo->set_tensor_ops(alg.tensor_ops_enabled());
520 
521       LOG(ERROR) << "To denylist this algorithm for this convolution, "
522                     "copy-paste the following "
523                     "proto to the denylist file pointed by XLA_FLAGS "
524                     "--xla_gpu_algorithm_denylist_path="
525                  << GetDebugOptionsFromFlags().xla_gpu_algorithm_denylist_path()
526                  << " : " << proto.ShortDebugString();
527       continue;
528     }
529 
530     if (comparator.has_value()) {
531       XLA_SCOPED_LOGGING_TIMER_LEVEL("BufferComparator::CompareEqual", 2);
532       StatusOr<bool> compare_result = comparator->CompareEqual(
533           stream, reference_result_buffer, result_buffer);
534       if (!compare_result.ok()) {
535         LOG(ERROR) << "Unable to compare " << AlgorithmToString(first_algorithm)
536                    << " against " << AlgorithmToString(alg) << " for "
537                    << instr->ToString() << ": " << compare_result.status();
538         if (compare_result.status().code() ==
539             tensorflow::error::RESOURCE_EXHAUSTED) {
540           // Possibly OOM. Propagate the error.
541           return compare_result.status();
542         }
543         CHECK(!crash_on_checking_failure);
544       } else if (!compare_result.ValueOrDie()) {
545         LOG(ERROR)
546             << "Results mismatch between different convolution algorithms. "
547                "This is likely a bug/unexpected loss of precision in cudnn.\n"
548             << instr->ToString() << " for "
549             << AlgorithmToString(first_algorithm) << " vs "
550             << AlgorithmToString(alg);
551         PrintPlatformInfo(stream);
552         VLOG(1) << "Full module on failure: \n"
553                 << instr->GetModule()->ToString();
554         auto* fail = result.mutable_failure();
555         fail->set_kind(AutotuneResult::WRONG_RESULT);
556         fail->set_buffer_address(
557             reinterpret_cast<uint64>(result_buffer.opaque()));
558         auto* reference_conv = fail->mutable_reference_conv();
559         reference_conv->set_algorithm(first_algorithm.algo_id());
560         reference_conv->set_tensor_ops_enabled(
561             first_algorithm.tensor_ops_enabled());
562       }
563     } else {
564       XLA_SCOPED_LOGGING_TIMER_LEVEL("BufferComparator::Create", 2);
565       comparator.emplace(result_shape, hlo_module_config);
566       TF_ASSIGN_OR_RETURN(
567           reference_result_buffer,
568           input_output_allocator.AllocateBytes(result_buffer.size()));
569       stream->ThenMemcpy(&reference_result_buffer, result_buffer,
570                          result_buffer.size());
571       first_algorithm = alg;
572     }
573   }
574 
575   // Log the autotuning result.
576   {
577     tensorflow::AutotuningLog log;
578     {
579       ConvInstructionLog instr_log;
580       *instr_log.mutable_instruction() = instr->ToProto();
581       for (int i = 0; i < instr->operand_count(); i++) {
582         *instr_log.add_operand_shapes() = instr->operand(i)->shape().ToProto();
583         instr_log.add_operand_addresses(
584             reinterpret_cast<uint64>(operand_buffers[i].opaque()));
585       }
586       instr_log.set_result_address(
587           reinterpret_cast<uint64>(result_buffer.opaque()));
588       log.mutable_instr()->PackFrom(instr_log);
589     }
590     for (const auto& profile : profile_results) {
591       *log.add_results() = profile;
592     }
593     *log.mutable_compute_capability() = GetComputeCapability(stream_exec_);
594     *log.mutable_cudnn_version() = GetCudnnVersion(stream_exec_);
595     log.set_device_pci_bus_id(
596         stream_exec_->GetDeviceDescription().pci_bus_id());
597     log.set_blas_version(blas_version);
598     VLOG(1) << "Autotuning result: " << log.ShortDebugString();
599     // If we crash on checking failure, we are in a testing/benchmark mode, thus
600     // omitting logging through the logger.
601     if (!crash_on_checking_failure) {
602       tensorflow::Logger::GetSingleton()->LogProto(log);
603     }
604   }
605 
606   // Crash on miscompares and redzone violations if desired.  Do this after
607   // logging the autotuning results, otherwise we won't get any data!
608   for (const auto& result : profile_results) {
609     if (result.has_failure()) {
610       CHECK(!crash_on_checking_failure);
611     }
612   }
613 
614   // Choose the fastest convolution that doesn't produce a REDZONE_MODIFIED
615   // error.
616   //
617   // TODO(jlebar): We ought to be able to detect redzone reads by noticing NaNs
618   // in the output of the conv and skip those.
619   //
620   // For now, we ignore WRONG_RESULT failures because false-positives are
621   // possible (e.g. perhaps the reference algorithm is the one that's
622   // incorrect!).  But we don't ignore REDZONE_MODIFIED failures because they're
623   // quite severe and can be detected with high accuracy.
624   std::vector<AutotuneResult> filtered_results;
625   absl::c_copy_if(
626       profile_results, std::back_inserter(filtered_results),
627       [](const AutotuneResult& r) {
628         return !(r.has_failure() &&
629                  r.failure().kind() != AutotuneResult::WRONG_RESULT);
630       });
631   if (filtered_results.empty()) {
632     return InternalError(
633         "All algorithms tried for convolution %s failed. Falling back to "
634         "default algorithm. ",
635         instr->ToString());
636   }
637 
638   auto selected_result = filtered_results.begin();
639   if (!RequireCudnnDeterminism() &&
640       !hlo_module_config.debug_options().xla_gpu_deterministic_ops()) {
641     selected_result = absl::c_min_element(
642         filtered_results,
643         [](const AutotuneResult& lhs, const AutotuneResult& rhs) {
644           return tensorflow::proto_utils::FromDurationProto(lhs.run_time()) <
645                  tensorflow::proto_utils::FromDurationProto(rhs.run_time());
646         });
647   }
648 
649   return *selected_result;
650 }
651 #endif
652 
653 StatusOr<tensorflow::AutotuneResult>
PickBestAlgorithmNoCacheRocm(const HloCustomCallInstruction * instr,se::DeviceMemoryAllocator * allocator,se::Stream * stream)654 GpuConvAlgorithmPicker::PickBestAlgorithmNoCacheRocm(
655     const HloCustomCallInstruction* instr, se::DeviceMemoryAllocator* allocator,
656     se::Stream* stream) {
657   XLA_SCOPED_LOGGING_TIMER(absl::StrCat(
658       "GpuConvAlgorithmPicker::PickBestAlgorithmImpl for ", instr->ToString()));
659 
660   const auto device_ordinal = stream_exec_->device_ordinal();
661   std::vector<se::DeviceMemoryBase> operand_buffers;
662 
663   ScratchAllocator input_output_allocator(device_ordinal, allocator);
664   const auto initialize_buffer = [stream](DeviceMemoryBase buffer) {
665     // Although we don't have evidence this matters, zero out the buffers
666     // before autotuning.  It's conceivable that using uninitialized memory as
667     // the inputs might affect performance if e.g. the inputs contain
668     // denormals, and this is easy enough.
669     stream->ThenMemZero(&buffer, buffer.size());
670   };
671 
672   // Allocate space for the input, filter, and output of the convolution.  We
673   // use a ScratchAllocator for this instead of calling allocator_ directly so
674   // that our allocations don't leak.
675   for (const auto* operand : instr->operands()) {
676     TF_ASSIGN_OR_RETURN(auto buffer,
677                         input_output_allocator.AllocateBytes(
678                             ShapeUtil::ByteSizeOf(operand->shape())));
679     initialize_buffer(buffer);
680     operand_buffers.push_back(buffer);
681   }
682 
683   TF_ASSIGN_OR_RETURN(
684       auto result_buffer,
685       input_output_allocator.AllocateBytes(
686           ShapeUtil::ByteSizeOf(instr->shape().tuple_shapes(0))));
687   initialize_buffer(result_buffer);
688 
689   ScratchAllocator scratch_allocator(device_ordinal, allocator);
690 
691   TF_ASSIGN_OR_RETURN(
692       std::vector<se::dnn::ProfileResult> algorithms,
693       GetMIOpenAlgorithms(instr, absl::MakeSpan(operand_buffers), result_buffer,
694                           stream_exec_, &scratch_allocator, stream));
695 
696   std::vector<AutotuneResult> profile_results;
697 
698   if (algorithms.size() == 1) {
699     auto profile_result = algorithms[0];
700     profile_results.emplace_back();
701     auto& result = profile_results.back();
702     result.mutable_conv()->set_algorithm(profile_result.algorithm().algo_id());
703     result.mutable_conv()->set_tensor_ops_enabled(
704         profile_result.algorithm().tensor_ops_enabled());
705 
706     result.set_scratch_bytes(profile_result.scratch_size());
707     *result.mutable_run_time() = tensorflow::proto_utils::ToDurationProto(
708         absl::Milliseconds(profile_result.elapsed_time_in_ms()));
709   } else {
710     TF_ASSIGN_OR_RETURN(GpuConvConfig config, GetGpuConvConfig(instr));
711     for (const auto& miopen_alg : algorithms) {
712       const auto& alg = miopen_alg.algorithm();
713       XLA_SCOPED_LOGGING_TIMER_LEVEL(
714           absl::StrCat("CudnnConvAlgorithmPicker::PickBestAlgorithm algo ",
715                        AlgorithmToString(alg)),
716           2);
717 
718       se::dnn::ProfileResult profile_result;
719       VLOG(3) << "Trying algorithm " << AlgorithmToString(alg) << " for "
720               << instr->ToString();
721 
722       // Use assignment instead of brace-list to make GCC 4.9 happy.
723       RunConvOptions options;
724       options.profile_result = &profile_result;
725       options.algo_override = alg;
726       options.scratch_size_override = miopen_alg.scratch_size();
727       Status launch_status =
728           RunGpuConv(config, absl::MakeSpan(operand_buffers), result_buffer,
729                      &scratch_allocator, stream, options);
730 
731       if (!launch_status.ok()) {
732         continue;
733       }
734 
735       if (!profile_result.is_valid()) {
736         continue;
737       }
738 
739       profile_results.emplace_back();
740       AutotuneResult& result = profile_results.back();
741       result.mutable_conv()->set_algorithm(alg.algo_id());
742       result.mutable_conv()->set_tensor_ops_enabled(alg.tensor_ops_enabled());
743 
744       int64 scratch_bytes_used = scratch_allocator.TotalAllocatedBytes();
745       result.set_scratch_bytes(scratch_bytes_used);
746       *result.mutable_run_time() = tensorflow::proto_utils::ToDurationProto(
747           absl::Milliseconds(profile_result.elapsed_time_in_ms()));
748     }
749   }
750   const auto& best_result = absl::c_min_element(
751       profile_results,
752       [&](const AutotuneResult& lhs, const AutotuneResult& rhs) {
753         return tensorflow::proto_utils::FromDurationProto(lhs.run_time()) <
754                tensorflow::proto_utils::FromDurationProto(rhs.run_time());
755       });
756 
757   if (best_result != profile_results.end()) {
758     return *best_result;
759   }
760 
761   return InternalError(
762       "All algorithms tried for convolution %s failed.  Falling back to "
763       "default algorithm.",
764       instr->ToString());
765 }
766 
RunOnInstruction(HloInstruction * instr)767 StatusOr<bool> GpuConvAlgorithmPicker::RunOnInstruction(HloInstruction* instr) {
768   CHECK(IsCustomCallToDnnConvolution(*instr));
769 
770   StatusOr<AutotuneResult> best_algo_or =
771       PickBestAlgorithm(Cast<HloCustomCallInstruction>(instr));
772   if (!best_algo_or.ok()) {
773     LOG(WARNING) << "Failed to determine best cudnn convolution algorithm: "
774                  << best_algo_or.status()
775                  << "\n\nConvolution performance may be suboptimal.";
776     return false;
777   }
778 
779   auto best_algo = std::move(best_algo_or).ValueOrDie();
780   VLOG(2) << "Setting cudnn conv to use algorithm "
781           << best_algo.conv().algorithm() << " and "
782           << NumBytesToString(best_algo.scratch_bytes())
783           << " of scratch memory: " << instr->ToString()
784           << " tensor_ops_enabled: " << best_algo.conv().tensor_ops_enabled();
785 
786   // Replace instr with a new CustomCall which has the correct algorithm, and
787   // whose output shape has the appropriate amount of scratch memory.
788   HloComputation* computation = instr->parent();
789   Shape new_call_shape = ShapeUtil::MakeTupleShape(
790       {instr->shape().tuple_shapes(0),
791        ShapeUtil::MakeShape(U8, {best_algo.scratch_bytes()})});
792 
793   TF_ASSIGN_OR_RETURN(CudnnConvBackendConfig backend_config,
794                       instr->backend_config<CudnnConvBackendConfig>());
795   backend_config.set_algorithm(best_algo.conv().algorithm());
796   backend_config.set_tensor_ops_enabled(best_algo.conv().tensor_ops_enabled());
797 
798   HloInstruction* new_call = computation->AddInstruction(
799       instr->CloneWithNewOperands(new_call_shape, instr->operands()));
800 
801   VLOG(2) << "Replacing convolution " << instr->ToString() << " with "
802           << new_call->ToString();
803 
804   TF_RETURN_IF_ERROR(new_call->set_backend_config(backend_config));
805 
806   // Repackage new_call so it has the same shape as the original call, namely
807   // (conv_result, u8[0]).
808   HloInstruction* new_tuple =
809       computation->AddInstruction(HloInstruction::CreateTuple(
810           {computation->AddInstruction(HloInstruction::CreateGetTupleElement(
811                new_call_shape.tuple_shapes(0), new_call, 0)),
812            computation->AddInstruction(HloInstruction::CreateConstant(
813                LiteralUtil::CreateR1<uint8>({})))}));
814 
815   TF_RETURN_IF_ERROR(instr->parent()->ReplaceInstruction(instr, new_tuple));
816   return true;
817 }
818 
RunOnComputation(HloComputation * computation)819 StatusOr<bool> GpuConvAlgorithmPicker::RunOnComputation(
820     HloComputation* computation) {
821   std::vector<HloInstruction*> convs;
822   for (auto* instr : computation->instructions()) {
823     if (IsCustomCallToDnnConvolution(*instr)) {
824       convs.push_back(instr);
825     }
826   }
827 
828   bool changed = false;
829   for (auto* instr : convs) {
830     TF_ASSIGN_OR_RETURN(bool result, RunOnInstruction(instr));
831     changed |= result;
832   }
833   return changed;
834 }
835 
Run(HloModule * module)836 StatusOr<bool> GpuConvAlgorithmPicker::Run(HloModule* module) {
837   XLA_SCOPED_LOGGING_TIMER("GpuConvAlgorithmPicker");
838 
839   if (module->config().debug_options().xla_gpu_autotune_level() == 0) {
840     VLOG(2) << "Convolution auto-tuning disabled, GpuConvAlgorithmPicker "
841                "returning early.";
842     return false;
843   }
844 
845   bool changed = false;
846   for (HloComputation* computation : module->MakeNonfusionComputations()) {
847     TF_ASSIGN_OR_RETURN(bool result, RunOnComputation(computation));
848     changed |= result;
849   }
850 
851   {
852     tensorflow::mutex_lock lock(autotune_cache_lock);
853     autotune_cache_stats.LogStats();
854   }
855 
856   return changed;
857 }
858 
859 }  // namespace gpu
860 }  // namespace xla
861