• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/kernels/gpu_utils.h"
17 
18 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
19 
20 #include <iterator>
21 
22 #include "google/protobuf/any.pb.h"
23 #include "absl/algorithm/container.h"
24 #include "absl/base/call_once.h"
25 #include "tensorflow/core/platform/logger.h"
26 #include "tensorflow/core/protobuf/autotuning.pb.h"
27 #include "tensorflow/core/protobuf/conv_autotuning.pb.h"
28 #include "tensorflow/core/util/env_var.h"
29 #include "tensorflow/core/util/proto/proto_utils.h"
30 #include "tensorflow/stream_executor/gpu/asm_compiler.h"
31 #include "tensorflow/stream_executor/gpu/redzone_allocator.h"
32 
33 namespace tensorflow {
34 
RedzoneCheckDisabled()35 bool RedzoneCheckDisabled() {
36   const char* disable_rz_str = std::getenv("TF_DISABLE_RZ_CHECK");
37   return disable_rz_str != nullptr && std::strcmp(disable_rz_str, "1") == 0;
38 }
39 
WrapRedzoneBestEffort(se::RedzoneAllocator * rz_allocator,se::DeviceMemoryBase buffer)40 se::DeviceMemoryBase WrapRedzoneBestEffort(se::RedzoneAllocator* rz_allocator,
41                                            se::DeviceMemoryBase buffer) {
42   if (RedzoneCheckDisabled()) {
43     return buffer;
44   }
45   auto output_rz_or = rz_allocator->AllocateBytes(buffer.size());
46   if (!output_rz_or.ok()) {
47     static absl::once_flag rz_allocation_failure_logged;
48     absl::call_once(rz_allocation_failure_logged, []() {
49       LOG(WARNING) << "Failed to allocate memory for convolution redzone "
50                    << "checking; skipping this check. This is benign and only "
51                    << "means that we won't check cudnn for out-of-bounds reads "
52                    << "and writes. This message will only be printed once.";
53     });
54     return buffer;
55   }
56   return se::DeviceMemoryBase(output_rz_or.ValueOrDie());
57 }
58 
CheckRedzones(const se::RedzoneAllocator & rz_allocator,tensorflow::AutotuneResult * autotune_result)59 void CheckRedzones(const se::RedzoneAllocator& rz_allocator,
60                    tensorflow::AutotuneResult* autotune_result) {
61   if (RedzoneCheckDisabled()) {
62     return;
63   }
64   se::port::StatusOr<se::RedzoneAllocator::RedzoneCheckStatus> rz_status =
65       rz_allocator.CheckRedzones();
66   if (!rz_status.ok()) {
67     static absl::once_flag failure_logged;
68     absl::call_once(failure_logged, [&]() {
69       LOG(WARNING) << "Failed to check cudnn convolutions for out-of-bounds "
70                    << "reads and writes with an error message: '"
71                    << rz_status.status().error_message()
72                    << "'; skipping this check. This only means that we won't "
73                    << "check cudnn for out-of-bounds reads and writes. This "
74                    << "message will only be printed once.";
75     });
76     return;
77   }
78   auto rz_check_status = rz_status.ValueOrDie();
79   if (!rz_check_status.ok()) {
80     auto* fail = autotune_result->mutable_failure();
81     fail->set_msg(rz_check_status.RedzoneFailureMsg());
82     fail->set_kind(AutotuneResult::REDZONE_MODIFIED);
83     fail->set_buffer_address(
84         reinterpret_cast<uint64>(rz_check_status.user_buffer_address));
85     LOG(ERROR)
86         << "Detected cudnn out-of-bounds write in convolution buffer! This is "
87            "likely a cudnn bug. We will skip this algorithm in the future, but "
88            "your GPU state may already be corrupted, leading to incorrect "
89            "results. Within Google, no action is needed on your part. Outside "
90            "of Google, please ensure you're running the latest version of "
91            "cudnn. If that doesn't fix the problem, please file a bug with "
92            "this full error message and we'll contact nvidia.";
93     LOG(ERROR) << rz_check_status.RedzoneFailureMsg();
94   }
95 }
96 
97 namespace {
98 
GetCudnnVersion(se::StreamExecutor * stream_executor)99 tensorflow::CudnnVersion GetCudnnVersion(se::StreamExecutor* stream_executor) {
100   tensorflow::CudnnVersion cudnn_version;
101   if (auto* dnn = stream_executor->AsDnn()) {
102     se::port::StatusOr<se::dnn::VersionInfo> version_or = dnn->GetVersion();
103     if (version_or.ok()) {
104       const auto& version = version_or.ValueOrDie();
105       cudnn_version.set_major(version.major_version());
106       cudnn_version.set_minor(version.minor_version());
107       cudnn_version.set_patch(version.patch());
108     }
109   }
110   return cudnn_version;
111 }
112 
GetComputeCapability(se::StreamExecutor * stream_executor)113 tensorflow::ComputeCapability GetComputeCapability(
114     se::StreamExecutor* stream_executor) {
115   tensorflow::ComputeCapability cc;
116   int cc_major, cc_minor;
117   stream_executor->GetDeviceDescription().cuda_compute_capability(&cc_major,
118                                                                   &cc_minor);
119   cc.set_major(cc_major);
120   cc.set_minor(cc_minor);
121   return cc;
122 }
123 
124 }  // namespace
125 
LogConvAutotuneResults(se::dnn::ConvolutionKind kind,se::dnn::DataType element_type,se::DeviceMemoryBase input_buffer,se::DeviceMemoryBase filter_buffer,se::DeviceMemoryBase output_buffer,const se::dnn::BatchDescriptor & input_desc,const se::dnn::FilterDescriptor & filter_desc,const se::dnn::BatchDescriptor & output_desc,const se::dnn::ConvolutionDescriptor & conv_desc,se::StreamExecutor * stream_exec,absl::Span<const AutotuneResult> results)126 void LogConvAutotuneResults(se::dnn::ConvolutionKind kind,
127                             se::dnn::DataType element_type,
128                             se::DeviceMemoryBase input_buffer,
129                             se::DeviceMemoryBase filter_buffer,
130                             se::DeviceMemoryBase output_buffer,
131                             const se::dnn::BatchDescriptor& input_desc,
132                             const se::dnn::FilterDescriptor& filter_desc,
133                             const se::dnn::BatchDescriptor& output_desc,
134                             const se::dnn::ConvolutionDescriptor& conv_desc,
135                             se::StreamExecutor* stream_exec,
136                             absl::Span<const AutotuneResult> results) {
137   AutotuningLog log;
138   {
139     ConvolutionProto instr;
140     instr.set_kind(kind);
141     *instr.mutable_input() = input_desc.ToProto(element_type);
142     *instr.mutable_filter() = filter_desc.ToProto(element_type);
143     *instr.mutable_output() = output_desc.ToProto(element_type);
144     *instr.mutable_conv_desc() = conv_desc.ToProto();
145     instr.set_conv_scale(1);
146     instr.set_side_value_scale(0);
147     instr.set_input_address(reinterpret_cast<uint64>(input_buffer.opaque()));
148     instr.set_filter_address(reinterpret_cast<uint64>(filter_buffer.opaque()));
149     instr.set_output_address(reinterpret_cast<uint64>(output_buffer.opaque()));
150     log.mutable_instr()->PackFrom(std::move(instr));
151   }
152   *log.mutable_cudnn_version() = GetCudnnVersion(stream_exec);
153   *log.mutable_compute_capability() = GetComputeCapability(stream_exec);
154   log.set_device_pci_bus_id(stream_exec->GetDeviceDescription().pci_bus_id());
155   {
156     string blas_version;
157     if (auto* blas = stream_exec->AsBlas()) {
158       if (blas->GetVersion(&blas_version).ok()) {
159         log.set_blas_version(blas_version);
160       }
161     }
162   }
163   for (const auto& result : results) {
164     *log.add_results() = result;
165   }
166   Logger::GetSingleton()->LogProto(log);
167 }
168 
LogFusedConvForwardAutotuneResults(se::dnn::DataType element_type,se::DeviceMemoryBase input_buffer,se::DeviceMemoryBase filter_buffer,se::DeviceMemoryBase output_buffer,se::DeviceMemoryBase bias_buffer,se::DeviceMemoryBase side_input_buffer,const se::dnn::BatchDescriptor & input_desc,const se::dnn::FilterDescriptor & filter_desc,const se::dnn::BatchDescriptor & output_desc,const se::dnn::ConvolutionDescriptor & conv_desc,double conv_scale,double side_value_scale,se::dnn::ActivationMode activation_mode,se::StreamExecutor * stream_exec,absl::Span<const AutotuneResult> results)169 void LogFusedConvForwardAutotuneResults(
170     se::dnn::DataType element_type, se::DeviceMemoryBase input_buffer,
171     se::DeviceMemoryBase filter_buffer, se::DeviceMemoryBase output_buffer,
172     se::DeviceMemoryBase bias_buffer, se::DeviceMemoryBase side_input_buffer,
173     const se::dnn::BatchDescriptor& input_desc,
174     const se::dnn::FilterDescriptor& filter_desc,
175     const se::dnn::BatchDescriptor& output_desc,
176     const se::dnn::ConvolutionDescriptor& conv_desc, double conv_scale,
177     double side_value_scale, se::dnn::ActivationMode activation_mode,
178     se::StreamExecutor* stream_exec, absl::Span<const AutotuneResult> results) {
179   AutotuningLog log;
180   {
181     ConvolutionProto instr;
182     instr.set_kind(se::dnn::ConvolutionKind::FORWARD_BIAS_ACTIVATION);
183     *instr.mutable_input() = input_desc.ToProto(element_type);
184     *instr.mutable_filter() = filter_desc.ToProto(element_type);
185     *instr.mutable_output() = output_desc.ToProto(element_type);
186     *instr.mutable_conv_desc() = conv_desc.ToProto();
187     instr.set_conv_scale(conv_scale);
188     instr.set_side_value_scale(side_value_scale);
189     instr.set_activation(activation_mode);
190     instr.set_input_address(reinterpret_cast<uint64>(input_buffer.opaque()));
191     instr.set_filter_address(reinterpret_cast<uint64>(filter_buffer.opaque()));
192     instr.set_output_address(reinterpret_cast<uint64>(output_buffer.opaque()));
193     instr.set_bias_address(reinterpret_cast<uint64>(bias_buffer.opaque()));
194     instr.set_side_input_address(
195         reinterpret_cast<uint64>(side_input_buffer.opaque()));
196     log.mutable_instr()->PackFrom(std::move(instr));
197   }
198   *log.mutable_cudnn_version() = GetCudnnVersion(stream_exec);
199   *log.mutable_compute_capability() = GetComputeCapability(stream_exec);
200   log.set_device_pci_bus_id(stream_exec->GetDeviceDescription().pci_bus_id());
201   {
202     string blas_version;
203     if (auto* blas = stream_exec->AsBlas()) {
204       if (blas->GetVersion(&blas_version).ok()) {
205         log.set_blas_version(blas_version);
206       }
207     }
208   }
209   for (const auto& result : results) {
210     *log.add_results() = result;
211   }
212   Logger::GetSingleton()->LogProto(log);
213 }
214 
215 // The following function allows deterministic ops to be implemented relatively
216 // quickly using environment variables. It is intended to be temporary. The
217 // longer-term intention is to enable deterministic ops via tf.config and
218 // appropriate plumbing. See the discussion on PR 34951 for more information:
219 // https://github.com/tensorflow/tensorflow/pull/34951#discussion_r355682316
220 // This function and associated comment are replicated in the following three
221 // places:
222 //   1. tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.cc
223 //   2. tensorflow/core/kernels/gpu_utils.cc
224 //   3. tensorflow/stream_executor/cuda/cuda_dnn.cc
225 // When implementing the plumbing, you should also search for the use of
226 // TF_DETERMINISTIC_OPS on its own.
227 // TODO(duncanriach): move to an API that uses tf.config and implement the first
228 //                    phase of plumbing.
RequireCudnnDeterminism()229 bool RequireCudnnDeterminism() {
230   static bool require_cudnn_determinism = [] {
231     bool deterministic_ops = false;
232     TF_CHECK_OK(tensorflow::ReadBoolFromEnvVar("TF_DETERMINISTIC_OPS",
233                                                /*default_val=*/false,
234                                                &deterministic_ops));
235     bool cudnn_deterministic = false;
236     TF_CHECK_OK(tensorflow::ReadBoolFromEnvVar("TF_CUDNN_DETERMINISTIC",
237                                                /*default_val=*/false,
238                                                &cudnn_deterministic));
239     return deterministic_ops || cudnn_deterministic;
240   }();
241   return require_cudnn_determinism;
242 }
243 
BestCudnnConvAlgorithm(absl::Span<const AutotuneResult> results,se::dnn::AlgorithmConfig * algo)244 Status BestCudnnConvAlgorithm(absl::Span<const AutotuneResult> results,
245                               se::dnn::AlgorithmConfig* algo) {
246   std::vector<AutotuneResult> filtered_results;
247   absl::c_copy_if(
248       results, std::back_inserter(filtered_results),
249       [](const AutotuneResult& result) { return !result.has_failure(); });
250   if (filtered_results.empty()) {
251     return errors::NotFound("No algorithm worked!");
252   }
253   std::vector<AutotuneResult> filtered_results_no_scratch;
254   absl::c_copy_if(
255       filtered_results, std::back_inserter(filtered_results_no_scratch),
256       [](const AutotuneResult& result) { return result.scratch_bytes() == 0; });
257 
258   auto selected_result = filtered_results.begin();
259   auto selected_result_no_scratch = filtered_results_no_scratch.begin();
260   if (!RequireCudnnDeterminism()) {
261     auto compare_run_times = [](const AutotuneResult& lhs,
262                                 const AutotuneResult& rhs) {
263       return proto_utils::FromDurationProto(lhs.run_time()) <
264              proto_utils::FromDurationProto(rhs.run_time());
265     };
266     selected_result = absl::c_min_element(filtered_results, compare_run_times);
267     selected_result_no_scratch =
268         absl::c_min_element(filtered_results_no_scratch, compare_run_times);
269   }
270 
271   algo->set_algorithm({selected_result->conv().algorithm(),
272                        selected_result->conv().tensor_ops_enabled()});
273   if (selected_result_no_scratch != filtered_results_no_scratch.end()) {
274     algo->set_algorithm_no_scratch(
275         {selected_result_no_scratch->conv().algorithm(),
276          selected_result_no_scratch->conv().tensor_ops_enabled()});
277   }
278 
279   return Status::OK();
280 }
281 
282 }  // namespace tensorflow
283 
284 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
285