1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h"
17
18 #include "absl/memory/memory.h"
19 #include "tensorflow/compiler/xla/layout_util.h"
20 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
21 #include "tensorflow/compiler/xla/util.h"
22 #include "tensorflow/core/lib/core/errors.h"
23 #include "tensorflow/core/lib/gtl/cleanup.h"
24 #include "tensorflow/core/lib/io/path.h"
25 #include "tensorflow/core/platform/cuda_libdevice_path.h"
26 #include "tensorflow/core/platform/regexp.h"
27 #include "tensorflow/core/platform/subprocess.h"
28 #include "tensorflow/core/platform/tracing.h"
29 #include "tensorflow/core/profiler/lib/traceme.h"
30 #include "tensorflow/core/util/determinism.h"
31 #include "tensorflow/core/util/env_var.h"
32 #include "tensorflow/core/util/proto/proto_utils.h"
33 #include "tensorflow/stream_executor/gpu/gpu_asm_opts.h"
34 #include "tensorflow/stream_executor/kernel_spec.h"
35
36 namespace xla {
37 namespace gpu {
38
39 namespace {
40
41 using se::dnn::DataLayout;
42 using se::dnn::DataLayoutString;
43 using se::dnn::FilterLayout;
44 using se::dnn::FilterLayoutString;
45 using tensorflow::AutotuneResult;
46
47 // Returns the smallest integer >= 0 that's not in the given set of numbers.
48 //
49 // For example, FindMissingDnum({1, 0, 3, 4}) returns 2.
50 //
51 // This is useful for handling DataLayout::kBatchDepthYX4, which repesents a
52 // layout [N, C/k, H, W, k] for some constant k, usually 4 or 32.
53 // ConvolutionDimensionNumbers doesn't explicitly say which dimension is `k`,
54 // but we can infer it by finding the first dnum that isn't otherwise mentioned
55 // in the dnums.
FindMissingDnum(absl::Span<const int64> vals)56 int64 FindMissingDnum(absl::Span<const int64> vals) {
57 for (int i = 0; i < vals.size(); i++) {
58 if (!absl::c_linear_search(vals, i)) {
59 return i;
60 }
61 }
62 return vals.size();
63 }
64
65 } // anonymous namespace
66
67 StatusOr<std::tuple<Layout, Layout, Layout>>
StreamExecutorConvLayoutsToXlaLayouts(const ConvolutionDimensionNumbers & dnums,DataLayout input,FilterLayout filter,DataLayout output)68 StreamExecutorConvLayoutsToXlaLayouts(const ConvolutionDimensionNumbers& dnums,
69 DataLayout input, FilterLayout filter,
70 DataLayout output) {
71 std::vector<int64> input_layout;
72 switch (input) {
73 case DataLayout::kBatchDepthYX: // NCHW
74 input_layout.push_back(dnums.input_batch_dimension());
75 input_layout.push_back(dnums.input_feature_dimension());
76 input_layout.insert(input_layout.end(),
77 dnums.input_spatial_dimensions().begin(),
78 dnums.input_spatial_dimensions().end());
79 break;
80 case DataLayout::kBatchDepthYX4: // NCHW_VECT_C
81 case DataLayout::kBatchDepthYX32: // NCHW_VECT_C
82 input_layout.push_back(dnums.input_batch_dimension());
83 input_layout.push_back(dnums.input_feature_dimension());
84 input_layout.insert(input_layout.end(),
85 dnums.input_spatial_dimensions().begin(),
86 dnums.input_spatial_dimensions().end());
87 input_layout.push_back(FindMissingDnum(input_layout));
88 break;
89 case DataLayout::kBatchYXDepth: // NHWC
90 input_layout.push_back(dnums.input_batch_dimension());
91 input_layout.insert(input_layout.end(),
92 dnums.input_spatial_dimensions().begin(),
93 dnums.input_spatial_dimensions().end());
94 input_layout.push_back(dnums.input_feature_dimension());
95 break;
96 default:
97 return InternalError("Invalid input layout %s for conv with dnums %s",
98 DataLayoutString(input),
99 ConvolutionDimensionNumbersToString(dnums));
100 }
101
102 std::vector<int64> filter_layout;
103 switch (filter) {
104 case FilterLayout::kOutputInputYX: // OIHW
105 filter_layout.push_back(dnums.kernel_output_feature_dimension());
106 filter_layout.push_back(dnums.kernel_input_feature_dimension());
107 filter_layout.insert(filter_layout.end(),
108 dnums.kernel_spatial_dimensions().begin(),
109 dnums.kernel_spatial_dimensions().end());
110 break;
111 case FilterLayout::kOutputInputYX4: // OIHW_VECT_C
112 case FilterLayout::kOutputInputYX32: // OIHW_VECT_C
113 filter_layout.push_back(dnums.kernel_output_feature_dimension());
114 filter_layout.push_back(dnums.kernel_input_feature_dimension());
115 filter_layout.insert(filter_layout.end(),
116 dnums.kernel_spatial_dimensions().begin(),
117 dnums.kernel_spatial_dimensions().end());
118 filter_layout.push_back(FindMissingDnum(filter_layout));
119 break;
120 case FilterLayout::kOutputYXInput: // OHWI
121 filter_layout.push_back(dnums.kernel_output_feature_dimension());
122 filter_layout.insert(filter_layout.end(),
123 dnums.kernel_spatial_dimensions().begin(),
124 dnums.kernel_spatial_dimensions().end());
125 filter_layout.push_back(dnums.kernel_input_feature_dimension());
126 break;
127 default:
128 return InternalError("Invalid filter layout %s for conv with dnums %s",
129 FilterLayoutString(filter),
130 ConvolutionDimensionNumbersToString(dnums));
131 }
132
133 std::vector<int64> output_layout;
134 switch (output) {
135 case DataLayout::kBatchDepthYX: // NCHW
136 output_layout.push_back(dnums.output_batch_dimension());
137 output_layout.push_back(dnums.output_feature_dimension());
138 output_layout.insert(output_layout.end(),
139 dnums.output_spatial_dimensions().begin(),
140 dnums.output_spatial_dimensions().end());
141 break;
142 case DataLayout::kBatchDepthYX4: // NCHW_VECT_C
143 case DataLayout::kBatchDepthYX32: // NCHW_VECT_C
144 output_layout.push_back(dnums.output_batch_dimension());
145 output_layout.push_back(dnums.output_feature_dimension());
146 output_layout.insert(output_layout.end(),
147 dnums.output_spatial_dimensions().begin(),
148 dnums.output_spatial_dimensions().end());
149 output_layout.push_back(FindMissingDnum(output_layout));
150 break;
151 case DataLayout::kBatchYXDepth: // NHWC
152 output_layout.push_back(dnums.output_batch_dimension());
153 output_layout.insert(output_layout.end(),
154 dnums.output_spatial_dimensions().begin(),
155 dnums.output_spatial_dimensions().end());
156 output_layout.push_back(dnums.output_feature_dimension());
157 break;
158 default:
159 return InternalError("Invalid output layout %s for conv with dnums %s",
160 DataLayoutString(output),
161 ConvolutionDimensionNumbersToString(dnums));
162 }
163
164 return std::make_tuple(LayoutUtil::MakeLayoutFromMajorToMinor(input_layout),
165 LayoutUtil::MakeLayoutFromMajorToMinor(filter_layout),
166 LayoutUtil::MakeLayoutFromMajorToMinor(output_layout));
167 }
168
169 StatusOr<std::tuple<DataLayout, FilterLayout, DataLayout>>
XlaConvShapesToStreamExecutorLayouts(const ConvolutionDimensionNumbers & dnums,const Shape & input,const Shape & filter,const Shape & output)170 XlaConvShapesToStreamExecutorLayouts(const ConvolutionDimensionNumbers& dnums,
171 const Shape& input, const Shape& filter,
172 const Shape& output) {
173 CHECK(input.has_layout());
174 CHECK(filter.has_layout());
175 CHECK(output.has_layout());
176
177 Layout nchw_input, nchw_filter, nchw_output;
178 std::tie(nchw_input, nchw_filter, nchw_output) =
179 StreamExecutorConvLayoutsToXlaLayouts(dnums, DataLayout::kBatchDepthYX,
180 FilterLayout::kOutputInputYX,
181 DataLayout::kBatchDepthYX)
182 .ConsumeValueOrDie();
183
184 // NCHW4 and NCHW32 have the same Layout; we disambiguate them below.
185 Layout nchw_vect_input, nchw_vect_filter, nchw_vect_output;
186 std::tie(nchw_vect_input, nchw_vect_filter, nchw_vect_output) =
187 StreamExecutorConvLayoutsToXlaLayouts(dnums, DataLayout::kBatchDepthYX4,
188 FilterLayout::kOutputInputYX4,
189 DataLayout::kBatchDepthYX4)
190 .ConsumeValueOrDie();
191
192 Layout nhwc_input, nhwc_filter, nhwc_output;
193 std::tie(nhwc_input, nhwc_filter, nhwc_output) =
194 StreamExecutorConvLayoutsToXlaLayouts(dnums, DataLayout::kBatchYXDepth,
195 FilterLayout::kOutputYXInput,
196 DataLayout::kBatchYXDepth)
197 .ConsumeValueOrDie();
198
199 DataLayout input_layout;
200 if (LayoutUtil::Equal(input.layout(), nchw_input)) {
201 input_layout = DataLayout::kBatchDepthYX;
202 } else if (LayoutUtil::Equal(input.layout(), nchw_vect_input)) {
203 // Differentiate between VECT_4 and VECT_32 by looking at the input shape.
204 int64_t vect_size = input.dimensions(input.layout().minor_to_major(0));
205 if (vect_size == 4) {
206 input_layout = DataLayout::kBatchDepthYX4;
207 } else if (vect_size == 32) {
208 input_layout = DataLayout::kBatchDepthYX32;
209 } else {
210 return InternalError(
211 "Invalid input shape %s for conv with dnums %s. Most-minor dim "
212 "should be 4 or 32, but was %d.",
213 ShapeUtil::HumanStringWithLayout(input),
214 ConvolutionDimensionNumbersToString(dnums), vect_size);
215 }
216 } else if (LayoutUtil::Equal(input.layout(), nhwc_input)) {
217 input_layout = DataLayout::kBatchYXDepth;
218 } else {
219 return InternalError("Invalid input layout %s for conv with dnums %s",
220 LayoutUtil::HumanString(input.layout()),
221 ConvolutionDimensionNumbersToString(dnums));
222 }
223
224 FilterLayout filter_layout;
225 if (LayoutUtil::Equal(filter.layout(), nchw_filter)) {
226 filter_layout = FilterLayout::kOutputInputYX;
227 } else if (LayoutUtil::Equal(filter.layout(), nchw_vect_filter)) {
228 int64_t vect_size = filter.dimensions(filter.layout().minor_to_major(0));
229 if (vect_size == 4) {
230 filter_layout = FilterLayout::kOutputInputYX4;
231 } else if (vect_size == 32) {
232 filter_layout = FilterLayout::kOutputInputYX32;
233 } else {
234 return InternalError(
235 "Invalid filter shape %s for conv with dnums %s. Most-minor dim "
236 "should be 4 or 32, but was %d.",
237 ShapeUtil::HumanStringWithLayout(filter),
238 ConvolutionDimensionNumbersToString(dnums), vect_size);
239 }
240 } else if (LayoutUtil::Equal(filter.layout(), nhwc_filter)) {
241 filter_layout = FilterLayout::kOutputYXInput;
242 } else {
243 return InternalError("Invalid filter layout %s for conv with dnums %s",
244 LayoutUtil::HumanString(filter.layout()),
245 ConvolutionDimensionNumbersToString(dnums));
246 }
247
248 DataLayout output_layout;
249 if (LayoutUtil::Equal(output.layout(), nchw_output)) {
250 output_layout = DataLayout::kBatchDepthYX;
251 } else if (LayoutUtil::Equal(output.layout(), nchw_vect_output)) {
252 int64_t vect_size = output.dimensions(output.layout().minor_to_major(0));
253 if (vect_size == 4) {
254 output_layout = DataLayout::kBatchDepthYX4;
255 } else if (vect_size == 32) {
256 output_layout = DataLayout::kBatchDepthYX32;
257 } else {
258 return InternalError(
259 "Invalid output shape %s for conv with dnums %s. Most-minor dim "
260 "should be 4 or 32, but was %d.",
261 ShapeUtil::HumanStringWithLayout(output),
262 ConvolutionDimensionNumbersToString(dnums), vect_size);
263 }
264 } else if (LayoutUtil::Equal(output.layout(), nhwc_output)) {
265 output_layout = DataLayout::kBatchYXDepth;
266 } else {
267 return InternalError("Invalid output layout %s for conv with dnums %s",
268 LayoutUtil::HumanString(output.layout()),
269 ConvolutionDimensionNumbersToString(dnums));
270 }
271
272 return std::make_tuple(input_layout, filter_layout, output_layout);
273 }
274
275 // Given unique integers D = {d0, d1, ds...}, finds the first integer less than
276 // `rank` which is not in D. If there is no such number (because all the values
277 // in [0, rank) appear), returns nullopt.
278 //
279 // When D is the set of dimensions in a ConvolutionDimensionNumbers, this finds
280 // the dimension number that corresponds to the vectorized-features dimension in
281 // the convolution.
FindVectorizedDim(int64_t rank,int64_t d0,int64_t d1,absl::Span<const int64> ds)282 static absl::optional<int64> FindVectorizedDim(int64_t rank, int64_t d0,
283 int64_t d1,
284 absl::Span<const int64> ds) {
285 for (int64_t i = 0; i < rank; i++) {
286 if (i == d0 || i == d1 || absl::c_linear_search(ds, i)) {
287 continue;
288 }
289 return i;
290 }
291 return absl::nullopt;
292 }
293
294 std::tuple<absl::optional<int64>, absl::optional<int64>, absl::optional<int64>>
FindVectorizedFeatureDims(const ConvolutionDimensionNumbers & dnums,const Shape & input,const Shape & filter,const Shape & output)295 FindVectorizedFeatureDims(const ConvolutionDimensionNumbers& dnums,
296 const Shape& input, const Shape& filter,
297 const Shape& output) {
298 return {
299 FindVectorizedDim(input.dimensions_size(), dnums.input_batch_dimension(),
300 dnums.input_feature_dimension(),
301 dnums.input_spatial_dimensions()),
302 FindVectorizedDim(filter.dimensions_size(),
303 dnums.kernel_input_feature_dimension(),
304 dnums.kernel_output_feature_dimension(),
305 dnums.kernel_spatial_dimensions()),
306 FindVectorizedDim(
307 output.dimensions_size(), dnums.output_batch_dimension(),
308 dnums.output_feature_dimension(), dnums.output_spatial_dimensions()),
309 };
310 }
311
LockGpu(const se::StreamExecutor * stream_exec)312 tensorflow::mutex_lock LockGpu(const se::StreamExecutor* stream_exec) {
313 static tensorflow::mutex mu(tensorflow::LINKER_INITIALIZED);
314 // se::Platform*s are global singletons guaranteed to live forever.
315 static auto* mutexes =
316 new std::map<std::pair<const se::Platform*, /*device_ordinal*/ int64>,
317 tensorflow::mutex>();
318
319 tensorflow::mutex_lock global_lock(mu);
320 auto it = mutexes
321 ->emplace(std::piecewise_construct,
322 std::make_tuple(stream_exec->platform(),
323 stream_exec->device_ordinal()),
324 std::make_tuple())
325 .first;
326 return tensorflow::mutex_lock{it->second};
327 }
328
CreateKernel(absl::string_view kernel_name,uint64 num_args,absl::string_view ptx,absl::Span<const uint8> cubin_data,se::StreamExecutor * stream_exec)329 StatusOr<std::unique_ptr<se::KernelBase>> CreateKernel(
330 absl::string_view kernel_name, uint64 num_args, absl::string_view ptx,
331 absl::Span<const uint8> cubin_data, se::StreamExecutor* stream_exec) {
332 se::MultiKernelLoaderSpec loader_spec(num_args);
333 loader_spec.AddCudaPtxInMemory(ptx, kernel_name);
334
335 if (!cubin_data.empty()) {
336 loader_spec.AddCudaCubinInMemory(
337 reinterpret_cast<const char*>(cubin_data.data()), kernel_name);
338 }
339
340 auto kernel_base = absl::make_unique<se::KernelBase>(stream_exec);
341 TF_RETURN_IF_ERROR(stream_exec->GetKernel(loader_spec, kernel_base.get()));
342 return std::move(kernel_base);
343 }
344
ExecuteKernelOnStream(const se::KernelBase & kernel,absl::Span<const se::DeviceMemoryBase> args,const LaunchDimensions & dims,se::Stream * stream)345 Status ExecuteKernelOnStream(const se::KernelBase& kernel,
346 absl::Span<const se::DeviceMemoryBase> args,
347 const LaunchDimensions& dims, se::Stream* stream) {
348 static constexpr int kKernelArgsLimit = 1024;
349 auto kernel_args = absl::make_unique<se::KernelArgsArray<kKernelArgsLimit>>();
350 for (const se::DeviceMemoryBase& buf : args) {
351 kernel_args->add_device_memory_argument(buf);
352 }
353 LaunchDimensions::Dim3D thread_counts = dims.thread_counts_per_block();
354 LaunchDimensions::Dim3D block_counts = dims.block_counts();
355 return stream->parent()->Launch(
356 stream, se::ThreadDim(thread_counts.x, thread_counts.y, thread_counts.z),
357 se::BlockDim(block_counts.x, block_counts.y, block_counts.z), kernel,
358 *kernel_args);
359 }
360
PtxOptsFromConfig(const HloModuleConfig & hlo_module_config)361 se::GpuAsmOpts PtxOptsFromConfig(const HloModuleConfig& hlo_module_config) {
362 string extra_string =
363 hlo_module_config.debug_options().xla_gpu_asm_extra_flags();
364 std::vector<std::string> extra_flags;
365 extra_flags = absl::StrSplit(extra_string, ',', absl::SkipEmpty());
366 return se::GpuAsmOpts(
367 hlo_module_config.debug_options().xla_gpu_disable_gpuasm_optimizations(),
368 hlo_module_config.debug_options().xla_gpu_cuda_data_dir(), extra_flags);
369 }
370
371 // Unimplemented for integers yet.
372 template <typename T, typename Generator>
373 typename std::enable_if<std::is_integral<T>::value,
374 T>::type static UniformDistribution(T lhs, T rhs,
375 Generator* gen) =
376 delete;
377
378 template <typename T, typename Generator>
379 typename std::enable_if<std::is_floating_point<T>::value,
UniformDistribution(T lhs,T rhs,Generator * gen)380 T>::type static UniformDistribution(T lhs, T rhs,
381 Generator* gen) {
382 return std::uniform_real_distribution<T>(lhs, rhs)(*gen);
383 }
384
385 template <typename T>
InitializeTypedBuffer(se::Stream * stream,se::DeviceMemoryBase buffer,int64 * rng_state)386 static void InitializeTypedBuffer(se::Stream* stream,
387 se::DeviceMemoryBase buffer,
388 int64* rng_state) {
389 // Accesses to static variables are not locked, since the caller is already
390 // in a critical section.
391 static std::vector<T>* host_buffer = [] {
392 // Use a large prime number to fragment the accesses.
393 auto* ret = new std::vector<T>(10069);
394 // Default-seeded random numbers.
395 std::mt19937 gen;
396 for (auto& element : *ret) {
397 // Only double gets random values in double. Other data types get random
398 // values in float then cast them to the target data types.
399 using RandomFloatingPointType =
400 typename std::conditional<std::is_same<T, Eigen::half>::value, float,
401 T>::type;
402 using RandomType =
403 typename std::conditional<std::is_integral<T>::value, float,
404 RandomFloatingPointType>::type;
405 // Scale down the values for fp16 to have less overflows.
406 auto upper_bound =
407 RandomType(std::is_same<T, Eigen::half>::value ? 0.1 : 1.0);
408 auto rand_val = UniformDistribution(RandomType(0), upper_bound, &gen);
409 // For float or double, it is between [0,1].
410 // For fp16, it ranges between [0, 0.1].
411 // For integer types, element is either 0 or 1 for less overflows
412 // especially for int8.
413 element = T(std::is_integral<T>::value ? rand_val + 0.5 : rand_val);
414 }
415 return ret;
416 }();
417
418 int64& host_index = *rng_state;
419
420 char* current_addr = static_cast<char*>(buffer.opaque());
421 CHECK_EQ(0, buffer.size() % sizeof(T));
422 int64_t elements_left = buffer.size() / sizeof(T);
423 while (elements_left > 0) {
424 CHECK_LE(host_index, host_buffer->size());
425 if (host_buffer->size() == host_index) {
426 host_index = 0;
427 }
428 int64_t elements_copied =
429 std::min<int64>(host_buffer->size() - host_index, elements_left);
430 se::DeviceMemoryBase mem(current_addr, elements_copied * sizeof(T));
431 stream->ThenMemcpy(&mem, host_buffer->data() + host_index,
432 elements_copied * sizeof(T));
433 current_addr += elements_copied * sizeof(T);
434 elements_left -= elements_copied;
435 host_index += elements_copied;
436 }
437 }
438
InitializeBuffer(se::Stream * stream,PrimitiveType buffer_type,int64 * rng_state,se::DeviceMemoryBase buffer)439 void InitializeBuffer(se::Stream* stream, PrimitiveType buffer_type,
440 int64* rng_state, se::DeviceMemoryBase buffer) {
441 switch (buffer_type) {
442 case xla::F16:
443 case xla::BF16:
444 // Using F16 for BF16 initialization: it's fine since we only need some
445 // random number there, and random generator is not working for BF16 (not
446 // all required overloads are there).
447 return InitializeTypedBuffer<Eigen::half>(stream, buffer, rng_state);
448 case xla::F32:
449 case xla::C64:
450 return InitializeTypedBuffer<float>(stream, buffer, rng_state);
451 case xla::F64:
452 case xla::C128:
453 return InitializeTypedBuffer<double>(stream, buffer, rng_state);
454 case xla::S8:
455 return InitializeTypedBuffer<int8>(stream, buffer, rng_state);
456 case xla::S32:
457 return InitializeTypedBuffer<int32>(stream, buffer, rng_state);
458 default:
459 LOG(FATAL) << "Unexpected type: "
460 << primitive_util::LowercasePrimitiveTypeName(buffer_type);
461 }
462 }
463
GetDNNConvKindFromCudnnConvKind(CudnnConvKind kind)464 StatusOr<se::dnn::ConvolutionKind> GetDNNConvKindFromCudnnConvKind(
465 CudnnConvKind kind) {
466 switch (kind) {
467 case CudnnConvKind::kBackwardFilter:
468 return se::dnn::BACKWARD_FILTER;
469 case CudnnConvKind::kBackwardInput:
470 return se::dnn::BACKWARD_DATA;
471 case CudnnConvKind::kForward:
472 return se::dnn::FORWARD;
473 default:
474 break;
475 }
476 return InternalError("Unexpected convolution kind");
477 }
478
GetDNNDataTypeFromPrimitiveType(PrimitiveType type)479 StatusOr<se::dnn::DataType> GetDNNDataTypeFromPrimitiveType(
480 PrimitiveType type) {
481 switch (type) {
482 case F16:
483 return se::dnn::ToDataType<Eigen::half>::value;
484 case F32:
485 return se::dnn::ToDataType<float>::value;
486 case F64:
487 return se::dnn::ToDataType<double>::value;
488 default:
489 break;
490 }
491 return InternalError("Unsupported convolution datatype");
492 }
493
RequireDeterminism(const HloModuleConfig & config)494 bool RequireDeterminism(const HloModuleConfig& config) {
495 static bool require_cudnn_determinism = [] {
496 // TODO(reedwm): Remove the TF_CUDNN_DETERMINISTIC env var.
497 bool cudnn_deterministic = false;
498 TF_CHECK_OK(tensorflow::ReadBoolFromEnvVar("TF_CUDNN_DETERMINISTIC",
499 /*default_val=*/false,
500 &cudnn_deterministic));
501 return cudnn_deterministic;
502 }();
503 return tensorflow::OpDeterminismRequired() || require_cudnn_determinism ||
504 config.debug_options().xla_gpu_deterministic_ops();
505 }
506
PickBestResult(absl::Span<AutotuneResult const> profile_results,const HloInstruction & instr)507 StatusOr<AutotuneResult> PickBestResult(
508 absl::Span<AutotuneResult const> profile_results,
509 const HloInstruction& instr) {
510 std::vector<AutotuneResult> filtered_results;
511
512 // For now, we ignore WRONG_RESULT failures because false-positives are
513 // possible (e.g. perhaps the reference algorithm is the one that's
514 // incorrect!). But we don't ignore REDZONE_MODIFIED failures because they're
515 // quite severe and can be detected with high accuracy.
516 absl::c_copy_if(
517 profile_results, std::back_inserter(filtered_results),
518 [](const AutotuneResult& r) {
519 return !(r.has_failure() &&
520 r.failure().kind() != AutotuneResult::WRONG_RESULT);
521 });
522
523 if (filtered_results.empty()) {
524 return InternalError(
525 "All algorithms tried for %s failed. Falling back to "
526 "default algorithm. ",
527 instr.ToString());
528 }
529
530 auto selected_result = filtered_results.begin();
531 if (!RequireDeterminism(instr.parent()->parent()->config())) {
532 selected_result = absl::c_min_element(
533 filtered_results,
534 [](const AutotuneResult& lhs, const AutotuneResult& rhs) {
535 return tensorflow::proto_utils::FromDurationProto(lhs.run_time()) <
536 tensorflow::proto_utils::FromDurationProto(rhs.run_time());
537 });
538 }
539 return *selected_result;
540 }
541
542 } // namespace gpu
543 } // namespace xla
544