• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h"
17 
18 #include "absl/memory/memory.h"
19 #include "tensorflow/compiler/xla/layout_util.h"
20 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
21 #include "tensorflow/compiler/xla/util.h"
22 #include "tensorflow/core/lib/core/errors.h"
23 #include "tensorflow/core/lib/gtl/cleanup.h"
24 #include "tensorflow/core/lib/io/path.h"
25 #include "tensorflow/core/platform/cuda_libdevice_path.h"
26 #include "tensorflow/core/platform/regexp.h"
27 #include "tensorflow/core/platform/subprocess.h"
28 #include "tensorflow/core/platform/tracing.h"
29 #include "tensorflow/core/profiler/lib/traceme.h"
30 #include "tensorflow/core/util/determinism.h"
31 #include "tensorflow/core/util/env_var.h"
32 #include "tensorflow/core/util/proto/proto_utils.h"
33 #include "tensorflow/stream_executor/gpu/gpu_asm_opts.h"
34 #include "tensorflow/stream_executor/kernel_spec.h"
35 
36 namespace xla {
37 namespace gpu {
38 
39 namespace {
40 
41 using se::dnn::DataLayout;
42 using se::dnn::DataLayoutString;
43 using se::dnn::FilterLayout;
44 using se::dnn::FilterLayoutString;
45 using tensorflow::AutotuneResult;
46 
47 // Returns the smallest integer >= 0 that's not in the given set of numbers.
48 //
49 // For example, FindMissingDnum({1, 0, 3, 4}) returns 2.
50 //
51 // This is useful for handling DataLayout::kBatchDepthYX4, which repesents a
52 // layout [N, C/k, H, W, k] for some constant k, usually 4 or 32.
53 // ConvolutionDimensionNumbers doesn't explicitly say which dimension is `k`,
54 // but we can infer it by finding the first dnum that isn't otherwise mentioned
55 // in the dnums.
FindMissingDnum(absl::Span<const int64> vals)56 int64 FindMissingDnum(absl::Span<const int64> vals) {
57   for (int i = 0; i < vals.size(); i++) {
58     if (!absl::c_linear_search(vals, i)) {
59       return i;
60     }
61   }
62   return vals.size();
63 }
64 
65 }  // anonymous namespace
66 
67 StatusOr<std::tuple<Layout, Layout, Layout>>
StreamExecutorConvLayoutsToXlaLayouts(const ConvolutionDimensionNumbers & dnums,DataLayout input,FilterLayout filter,DataLayout output)68 StreamExecutorConvLayoutsToXlaLayouts(const ConvolutionDimensionNumbers& dnums,
69                                       DataLayout input, FilterLayout filter,
70                                       DataLayout output) {
71   std::vector<int64> input_layout;
72   switch (input) {
73     case DataLayout::kBatchDepthYX:  // NCHW
74       input_layout.push_back(dnums.input_batch_dimension());
75       input_layout.push_back(dnums.input_feature_dimension());
76       input_layout.insert(input_layout.end(),
77                           dnums.input_spatial_dimensions().begin(),
78                           dnums.input_spatial_dimensions().end());
79       break;
80     case DataLayout::kBatchDepthYX4:   // NCHW_VECT_C
81     case DataLayout::kBatchDepthYX32:  // NCHW_VECT_C
82       input_layout.push_back(dnums.input_batch_dimension());
83       input_layout.push_back(dnums.input_feature_dimension());
84       input_layout.insert(input_layout.end(),
85                           dnums.input_spatial_dimensions().begin(),
86                           dnums.input_spatial_dimensions().end());
87       input_layout.push_back(FindMissingDnum(input_layout));
88       break;
89     case DataLayout::kBatchYXDepth:  // NHWC
90       input_layout.push_back(dnums.input_batch_dimension());
91       input_layout.insert(input_layout.end(),
92                           dnums.input_spatial_dimensions().begin(),
93                           dnums.input_spatial_dimensions().end());
94       input_layout.push_back(dnums.input_feature_dimension());
95       break;
96     default:
97       return InternalError("Invalid input layout %s for conv with dnums %s",
98                            DataLayoutString(input),
99                            ConvolutionDimensionNumbersToString(dnums));
100   }
101 
102   std::vector<int64> filter_layout;
103   switch (filter) {
104     case FilterLayout::kOutputInputYX:  // OIHW
105       filter_layout.push_back(dnums.kernel_output_feature_dimension());
106       filter_layout.push_back(dnums.kernel_input_feature_dimension());
107       filter_layout.insert(filter_layout.end(),
108                            dnums.kernel_spatial_dimensions().begin(),
109                            dnums.kernel_spatial_dimensions().end());
110       break;
111     case FilterLayout::kOutputInputYX4:   // OIHW_VECT_C
112     case FilterLayout::kOutputInputYX32:  // OIHW_VECT_C
113       filter_layout.push_back(dnums.kernel_output_feature_dimension());
114       filter_layout.push_back(dnums.kernel_input_feature_dimension());
115       filter_layout.insert(filter_layout.end(),
116                            dnums.kernel_spatial_dimensions().begin(),
117                            dnums.kernel_spatial_dimensions().end());
118       filter_layout.push_back(FindMissingDnum(filter_layout));
119       break;
120     case FilterLayout::kOutputYXInput:  // OHWI
121       filter_layout.push_back(dnums.kernel_output_feature_dimension());
122       filter_layout.insert(filter_layout.end(),
123                            dnums.kernel_spatial_dimensions().begin(),
124                            dnums.kernel_spatial_dimensions().end());
125       filter_layout.push_back(dnums.kernel_input_feature_dimension());
126       break;
127     default:
128       return InternalError("Invalid filter layout %s for conv with dnums %s",
129                            FilterLayoutString(filter),
130                            ConvolutionDimensionNumbersToString(dnums));
131   }
132 
133   std::vector<int64> output_layout;
134   switch (output) {
135     case DataLayout::kBatchDepthYX:  // NCHW
136       output_layout.push_back(dnums.output_batch_dimension());
137       output_layout.push_back(dnums.output_feature_dimension());
138       output_layout.insert(output_layout.end(),
139                            dnums.output_spatial_dimensions().begin(),
140                            dnums.output_spatial_dimensions().end());
141       break;
142     case DataLayout::kBatchDepthYX4:   // NCHW_VECT_C
143     case DataLayout::kBatchDepthYX32:  // NCHW_VECT_C
144       output_layout.push_back(dnums.output_batch_dimension());
145       output_layout.push_back(dnums.output_feature_dimension());
146       output_layout.insert(output_layout.end(),
147                            dnums.output_spatial_dimensions().begin(),
148                            dnums.output_spatial_dimensions().end());
149       output_layout.push_back(FindMissingDnum(output_layout));
150       break;
151     case DataLayout::kBatchYXDepth:  // NHWC
152       output_layout.push_back(dnums.output_batch_dimension());
153       output_layout.insert(output_layout.end(),
154                            dnums.output_spatial_dimensions().begin(),
155                            dnums.output_spatial_dimensions().end());
156       output_layout.push_back(dnums.output_feature_dimension());
157       break;
158     default:
159       return InternalError("Invalid output layout %s for conv with dnums %s",
160                            DataLayoutString(output),
161                            ConvolutionDimensionNumbersToString(dnums));
162   }
163 
164   return std::make_tuple(LayoutUtil::MakeLayoutFromMajorToMinor(input_layout),
165                          LayoutUtil::MakeLayoutFromMajorToMinor(filter_layout),
166                          LayoutUtil::MakeLayoutFromMajorToMinor(output_layout));
167 }
168 
169 StatusOr<std::tuple<DataLayout, FilterLayout, DataLayout>>
XlaConvShapesToStreamExecutorLayouts(const ConvolutionDimensionNumbers & dnums,const Shape & input,const Shape & filter,const Shape & output)170 XlaConvShapesToStreamExecutorLayouts(const ConvolutionDimensionNumbers& dnums,
171                                      const Shape& input, const Shape& filter,
172                                      const Shape& output) {
173   CHECK(input.has_layout());
174   CHECK(filter.has_layout());
175   CHECK(output.has_layout());
176 
177   Layout nchw_input, nchw_filter, nchw_output;
178   std::tie(nchw_input, nchw_filter, nchw_output) =
179       StreamExecutorConvLayoutsToXlaLayouts(dnums, DataLayout::kBatchDepthYX,
180                                             FilterLayout::kOutputInputYX,
181                                             DataLayout::kBatchDepthYX)
182           .ConsumeValueOrDie();
183 
184   // NCHW4 and NCHW32 have the same Layout; we disambiguate them below.
185   Layout nchw_vect_input, nchw_vect_filter, nchw_vect_output;
186   std::tie(nchw_vect_input, nchw_vect_filter, nchw_vect_output) =
187       StreamExecutorConvLayoutsToXlaLayouts(dnums, DataLayout::kBatchDepthYX4,
188                                             FilterLayout::kOutputInputYX4,
189                                             DataLayout::kBatchDepthYX4)
190           .ConsumeValueOrDie();
191 
192   Layout nhwc_input, nhwc_filter, nhwc_output;
193   std::tie(nhwc_input, nhwc_filter, nhwc_output) =
194       StreamExecutorConvLayoutsToXlaLayouts(dnums, DataLayout::kBatchYXDepth,
195                                             FilterLayout::kOutputYXInput,
196                                             DataLayout::kBatchYXDepth)
197           .ConsumeValueOrDie();
198 
199   DataLayout input_layout;
200   if (LayoutUtil::Equal(input.layout(), nchw_input)) {
201     input_layout = DataLayout::kBatchDepthYX;
202   } else if (LayoutUtil::Equal(input.layout(), nchw_vect_input)) {
203     // Differentiate between VECT_4 and VECT_32 by looking at the input shape.
204     int64_t vect_size = input.dimensions(input.layout().minor_to_major(0));
205     if (vect_size == 4) {
206       input_layout = DataLayout::kBatchDepthYX4;
207     } else if (vect_size == 32) {
208       input_layout = DataLayout::kBatchDepthYX32;
209     } else {
210       return InternalError(
211           "Invalid input shape %s for conv with dnums %s.  Most-minor dim "
212           "should be 4 or 32, but was %d.",
213           ShapeUtil::HumanStringWithLayout(input),
214           ConvolutionDimensionNumbersToString(dnums), vect_size);
215     }
216   } else if (LayoutUtil::Equal(input.layout(), nhwc_input)) {
217     input_layout = DataLayout::kBatchYXDepth;
218   } else {
219     return InternalError("Invalid input layout %s for conv with dnums %s",
220                          LayoutUtil::HumanString(input.layout()),
221                          ConvolutionDimensionNumbersToString(dnums));
222   }
223 
224   FilterLayout filter_layout;
225   if (LayoutUtil::Equal(filter.layout(), nchw_filter)) {
226     filter_layout = FilterLayout::kOutputInputYX;
227   } else if (LayoutUtil::Equal(filter.layout(), nchw_vect_filter)) {
228     int64_t vect_size = filter.dimensions(filter.layout().minor_to_major(0));
229     if (vect_size == 4) {
230       filter_layout = FilterLayout::kOutputInputYX4;
231     } else if (vect_size == 32) {
232       filter_layout = FilterLayout::kOutputInputYX32;
233     } else {
234       return InternalError(
235           "Invalid filter shape %s for conv with dnums %s.  Most-minor dim "
236           "should be 4 or 32, but was %d.",
237           ShapeUtil::HumanStringWithLayout(filter),
238           ConvolutionDimensionNumbersToString(dnums), vect_size);
239     }
240   } else if (LayoutUtil::Equal(filter.layout(), nhwc_filter)) {
241     filter_layout = FilterLayout::kOutputYXInput;
242   } else {
243     return InternalError("Invalid filter layout %s for conv with dnums %s",
244                          LayoutUtil::HumanString(filter.layout()),
245                          ConvolutionDimensionNumbersToString(dnums));
246   }
247 
248   DataLayout output_layout;
249   if (LayoutUtil::Equal(output.layout(), nchw_output)) {
250     output_layout = DataLayout::kBatchDepthYX;
251   } else if (LayoutUtil::Equal(output.layout(), nchw_vect_output)) {
252     int64_t vect_size = output.dimensions(output.layout().minor_to_major(0));
253     if (vect_size == 4) {
254       output_layout = DataLayout::kBatchDepthYX4;
255     } else if (vect_size == 32) {
256       output_layout = DataLayout::kBatchDepthYX32;
257     } else {
258       return InternalError(
259           "Invalid output shape %s for conv with dnums %s.  Most-minor dim "
260           "should be 4 or 32, but was %d.",
261           ShapeUtil::HumanStringWithLayout(output),
262           ConvolutionDimensionNumbersToString(dnums), vect_size);
263     }
264   } else if (LayoutUtil::Equal(output.layout(), nhwc_output)) {
265     output_layout = DataLayout::kBatchYXDepth;
266   } else {
267     return InternalError("Invalid output layout %s for conv with dnums %s",
268                          LayoutUtil::HumanString(output.layout()),
269                          ConvolutionDimensionNumbersToString(dnums));
270   }
271 
272   return std::make_tuple(input_layout, filter_layout, output_layout);
273 }
274 
275 // Given unique integers D = {d0, d1, ds...}, finds the first integer less than
276 // `rank` which is not in D.  If there is no such number (because all the values
277 // in [0, rank) appear), returns nullopt.
278 //
279 // When D is the set of dimensions in a ConvolutionDimensionNumbers, this finds
280 // the dimension number that corresponds to the vectorized-features dimension in
281 // the convolution.
FindVectorizedDim(int64_t rank,int64_t d0,int64_t d1,absl::Span<const int64> ds)282 static absl::optional<int64> FindVectorizedDim(int64_t rank, int64_t d0,
283                                                int64_t d1,
284                                                absl::Span<const int64> ds) {
285   for (int64_t i = 0; i < rank; i++) {
286     if (i == d0 || i == d1 || absl::c_linear_search(ds, i)) {
287       continue;
288     }
289     return i;
290   }
291   return absl::nullopt;
292 }
293 
294 std::tuple<absl::optional<int64>, absl::optional<int64>, absl::optional<int64>>
FindVectorizedFeatureDims(const ConvolutionDimensionNumbers & dnums,const Shape & input,const Shape & filter,const Shape & output)295 FindVectorizedFeatureDims(const ConvolutionDimensionNumbers& dnums,
296                           const Shape& input, const Shape& filter,
297                           const Shape& output) {
298   return {
299       FindVectorizedDim(input.dimensions_size(), dnums.input_batch_dimension(),
300                         dnums.input_feature_dimension(),
301                         dnums.input_spatial_dimensions()),
302       FindVectorizedDim(filter.dimensions_size(),
303                         dnums.kernel_input_feature_dimension(),
304                         dnums.kernel_output_feature_dimension(),
305                         dnums.kernel_spatial_dimensions()),
306       FindVectorizedDim(
307           output.dimensions_size(), dnums.output_batch_dimension(),
308           dnums.output_feature_dimension(), dnums.output_spatial_dimensions()),
309   };
310 }
311 
LockGpu(const se::StreamExecutor * stream_exec)312 tensorflow::mutex_lock LockGpu(const se::StreamExecutor* stream_exec) {
313   static tensorflow::mutex mu(tensorflow::LINKER_INITIALIZED);
314   // se::Platform*s are global singletons guaranteed to live forever.
315   static auto* mutexes =
316       new std::map<std::pair<const se::Platform*, /*device_ordinal*/ int64>,
317                    tensorflow::mutex>();
318 
319   tensorflow::mutex_lock global_lock(mu);
320   auto it = mutexes
321                 ->emplace(std::piecewise_construct,
322                           std::make_tuple(stream_exec->platform(),
323                                           stream_exec->device_ordinal()),
324                           std::make_tuple())
325                 .first;
326   return tensorflow::mutex_lock{it->second};
327 }
328 
CreateKernel(absl::string_view kernel_name,uint64 num_args,absl::string_view ptx,absl::Span<const uint8> cubin_data,se::StreamExecutor * stream_exec)329 StatusOr<std::unique_ptr<se::KernelBase>> CreateKernel(
330     absl::string_view kernel_name, uint64 num_args, absl::string_view ptx,
331     absl::Span<const uint8> cubin_data, se::StreamExecutor* stream_exec) {
332   se::MultiKernelLoaderSpec loader_spec(num_args);
333   loader_spec.AddCudaPtxInMemory(ptx, kernel_name);
334 
335   if (!cubin_data.empty()) {
336     loader_spec.AddCudaCubinInMemory(
337         reinterpret_cast<const char*>(cubin_data.data()), kernel_name);
338   }
339 
340   auto kernel_base = absl::make_unique<se::KernelBase>(stream_exec);
341   TF_RETURN_IF_ERROR(stream_exec->GetKernel(loader_spec, kernel_base.get()));
342   return std::move(kernel_base);
343 }
344 
ExecuteKernelOnStream(const se::KernelBase & kernel,absl::Span<const se::DeviceMemoryBase> args,const LaunchDimensions & dims,se::Stream * stream)345 Status ExecuteKernelOnStream(const se::KernelBase& kernel,
346                              absl::Span<const se::DeviceMemoryBase> args,
347                              const LaunchDimensions& dims, se::Stream* stream) {
348   static constexpr int kKernelArgsLimit = 1024;
349   auto kernel_args = absl::make_unique<se::KernelArgsArray<kKernelArgsLimit>>();
350   for (const se::DeviceMemoryBase& buf : args) {
351     kernel_args->add_device_memory_argument(buf);
352   }
353   LaunchDimensions::Dim3D thread_counts = dims.thread_counts_per_block();
354   LaunchDimensions::Dim3D block_counts = dims.block_counts();
355   return stream->parent()->Launch(
356       stream, se::ThreadDim(thread_counts.x, thread_counts.y, thread_counts.z),
357       se::BlockDim(block_counts.x, block_counts.y, block_counts.z), kernel,
358       *kernel_args);
359 }
360 
PtxOptsFromConfig(const HloModuleConfig & hlo_module_config)361 se::GpuAsmOpts PtxOptsFromConfig(const HloModuleConfig& hlo_module_config) {
362   string extra_string =
363       hlo_module_config.debug_options().xla_gpu_asm_extra_flags();
364   std::vector<std::string> extra_flags;
365   extra_flags = absl::StrSplit(extra_string, ',', absl::SkipEmpty());
366   return se::GpuAsmOpts(
367       hlo_module_config.debug_options().xla_gpu_disable_gpuasm_optimizations(),
368       hlo_module_config.debug_options().xla_gpu_cuda_data_dir(), extra_flags);
369 }
370 
371 // Unimplemented for integers yet.
372 template <typename T, typename Generator>
373 typename std::enable_if<std::is_integral<T>::value,
374                         T>::type static UniformDistribution(T lhs, T rhs,
375                                                             Generator* gen) =
376     delete;
377 
378 template <typename T, typename Generator>
379 typename std::enable_if<std::is_floating_point<T>::value,
UniformDistribution(T lhs,T rhs,Generator * gen)380                         T>::type static UniformDistribution(T lhs, T rhs,
381                                                             Generator* gen) {
382   return std::uniform_real_distribution<T>(lhs, rhs)(*gen);
383 }
384 
385 template <typename T>
InitializeTypedBuffer(se::Stream * stream,se::DeviceMemoryBase buffer,int64 * rng_state)386 static void InitializeTypedBuffer(se::Stream* stream,
387                                   se::DeviceMemoryBase buffer,
388                                   int64* rng_state) {
389   // Accesses to static variables are not locked, since the caller is already
390   // in a critical section.
391   static std::vector<T>* host_buffer = [] {
392     // Use a large prime number to fragment the accesses.
393     auto* ret = new std::vector<T>(10069);
394     // Default-seeded random numbers.
395     std::mt19937 gen;
396     for (auto& element : *ret) {
397       // Only double gets random values in double.  Other data types get random
398       // values in float then cast them to the target data types.
399       using RandomFloatingPointType =
400           typename std::conditional<std::is_same<T, Eigen::half>::value, float,
401                                     T>::type;
402       using RandomType =
403           typename std::conditional<std::is_integral<T>::value, float,
404                                     RandomFloatingPointType>::type;
405       // Scale down the values for fp16 to have less overflows.
406       auto upper_bound =
407           RandomType(std::is_same<T, Eigen::half>::value ? 0.1 : 1.0);
408       auto rand_val = UniformDistribution(RandomType(0), upper_bound, &gen);
409       // For float or double, it is between [0,1].
410       // For fp16, it ranges between [0, 0.1].
411       // For integer types, element is either 0 or 1 for less overflows
412       // especially for int8.
413       element = T(std::is_integral<T>::value ? rand_val + 0.5 : rand_val);
414     }
415     return ret;
416   }();
417 
418   int64& host_index = *rng_state;
419 
420   char* current_addr = static_cast<char*>(buffer.opaque());
421   CHECK_EQ(0, buffer.size() % sizeof(T));
422   int64_t elements_left = buffer.size() / sizeof(T);
423   while (elements_left > 0) {
424     CHECK_LE(host_index, host_buffer->size());
425     if (host_buffer->size() == host_index) {
426       host_index = 0;
427     }
428     int64_t elements_copied =
429         std::min<int64>(host_buffer->size() - host_index, elements_left);
430     se::DeviceMemoryBase mem(current_addr, elements_copied * sizeof(T));
431     stream->ThenMemcpy(&mem, host_buffer->data() + host_index,
432                        elements_copied * sizeof(T));
433     current_addr += elements_copied * sizeof(T);
434     elements_left -= elements_copied;
435     host_index += elements_copied;
436   }
437 }
438 
InitializeBuffer(se::Stream * stream,PrimitiveType buffer_type,int64 * rng_state,se::DeviceMemoryBase buffer)439 void InitializeBuffer(se::Stream* stream, PrimitiveType buffer_type,
440                       int64* rng_state, se::DeviceMemoryBase buffer) {
441   switch (buffer_type) {
442     case xla::F16:
443     case xla::BF16:
444       // Using F16 for BF16 initialization: it's fine since we only need some
445       // random number there, and random generator is not working for BF16 (not
446       // all required overloads are there).
447       return InitializeTypedBuffer<Eigen::half>(stream, buffer, rng_state);
448     case xla::F32:
449     case xla::C64:
450       return InitializeTypedBuffer<float>(stream, buffer, rng_state);
451     case xla::F64:
452     case xla::C128:
453       return InitializeTypedBuffer<double>(stream, buffer, rng_state);
454     case xla::S8:
455       return InitializeTypedBuffer<int8>(stream, buffer, rng_state);
456     case xla::S32:
457       return InitializeTypedBuffer<int32>(stream, buffer, rng_state);
458     default:
459       LOG(FATAL) << "Unexpected type: "
460                  << primitive_util::LowercasePrimitiveTypeName(buffer_type);
461   }
462 }
463 
GetDNNConvKindFromCudnnConvKind(CudnnConvKind kind)464 StatusOr<se::dnn::ConvolutionKind> GetDNNConvKindFromCudnnConvKind(
465     CudnnConvKind kind) {
466   switch (kind) {
467     case CudnnConvKind::kBackwardFilter:
468       return se::dnn::BACKWARD_FILTER;
469     case CudnnConvKind::kBackwardInput:
470       return se::dnn::BACKWARD_DATA;
471     case CudnnConvKind::kForward:
472       return se::dnn::FORWARD;
473     default:
474       break;
475   }
476   return InternalError("Unexpected convolution kind");
477 }
478 
GetDNNDataTypeFromPrimitiveType(PrimitiveType type)479 StatusOr<se::dnn::DataType> GetDNNDataTypeFromPrimitiveType(
480     PrimitiveType type) {
481   switch (type) {
482     case F16:
483       return se::dnn::ToDataType<Eigen::half>::value;
484     case F32:
485       return se::dnn::ToDataType<float>::value;
486     case F64:
487       return se::dnn::ToDataType<double>::value;
488     default:
489       break;
490   }
491   return InternalError("Unsupported convolution datatype");
492 }
493 
RequireDeterminism(const HloModuleConfig & config)494 bool RequireDeterminism(const HloModuleConfig& config) {
495   static bool require_cudnn_determinism = [] {
496     // TODO(reedwm): Remove the TF_CUDNN_DETERMINISTIC env var.
497     bool cudnn_deterministic = false;
498     TF_CHECK_OK(tensorflow::ReadBoolFromEnvVar("TF_CUDNN_DETERMINISTIC",
499                                                /*default_val=*/false,
500                                                &cudnn_deterministic));
501     return cudnn_deterministic;
502   }();
503   return tensorflow::OpDeterminismRequired() || require_cudnn_determinism ||
504          config.debug_options().xla_gpu_deterministic_ops();
505 }
506 
PickBestResult(absl::Span<AutotuneResult const> profile_results,const HloInstruction & instr)507 StatusOr<AutotuneResult> PickBestResult(
508     absl::Span<AutotuneResult const> profile_results,
509     const HloInstruction& instr) {
510   std::vector<AutotuneResult> filtered_results;
511 
512   // For now, we ignore WRONG_RESULT failures because false-positives are
513   // possible (e.g. perhaps the reference algorithm is the one that's
514   // incorrect!).  But we don't ignore REDZONE_MODIFIED failures because they're
515   // quite severe and can be detected with high accuracy.
516   absl::c_copy_if(
517       profile_results, std::back_inserter(filtered_results),
518       [](const AutotuneResult& r) {
519         return !(r.has_failure() &&
520                  r.failure().kind() != AutotuneResult::WRONG_RESULT);
521       });
522 
523   if (filtered_results.empty()) {
524     return InternalError(
525         "All algorithms tried for %s failed. Falling back to "
526         "default algorithm. ",
527         instr.ToString());
528   }
529 
530   auto selected_result = filtered_results.begin();
531   if (!RequireDeterminism(instr.parent()->parent()->config())) {
532     selected_result = absl::c_min_element(
533         filtered_results,
534         [](const AutotuneResult& lhs, const AutotuneResult& rhs) {
535           return tensorflow::proto_utils::FromDurationProto(lhs.run_time()) <
536                  tensorflow::proto_utils::FromDurationProto(rhs.run_time());
537         });
538   }
539   return *selected_result;
540 }
541 
542 }  // namespace gpu
543 }  // namespace xla
544