1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/gpu/nvptx_compiler.h"
17
18 #include <stdlib.h>
19
20 #include <fstream>
21
22 #include "absl/base/call_once.h"
23 #include "tensorflow/compiler/xla/service/algebraic_simplifier.h"
24 #include "tensorflow/compiler/xla/service/dump.h"
25 #include "tensorflow/compiler/xla/service/gpu/cublas_gemm_pad_for_tensor_cores.h"
26 #include "tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter.h"
27 #include "tensorflow/compiler/xla/service/gpu/cudnn_pad_for_convolutions.h"
28 #include "tensorflow/compiler/xla/service/gpu/cusolver_rewriter.h"
29 #include "tensorflow/compiler/xla/service/gpu/gemm_algorithm_picker.h"
30 #include "tensorflow/compiler/xla/service/gpu/gpu_conv_padding_legalization.h"
31 #include "tensorflow/compiler/xla/service/gpu/gpu_conv_rewriter.h"
32 #include "tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h"
33 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
34 #include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h"
35 #include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h"
36 #include "tensorflow/compiler/xla/service/gpu/target_constants.h"
37 #include "tensorflow/compiler/xla/service/hlo_constant_folding.h"
38 #include "tensorflow/compiler/xla/service/hlo_cse.h"
39 #include "tensorflow/compiler/xla/service/hlo_pass_fix.h"
40 #include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h"
41 #include "tensorflow/compiler/xla/service/hlo_verifier.h"
42 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
43 #include "tensorflow/compiler/xla/service/tuple_simplifier.h"
44 #include "tensorflow/compiler/xla/status_macros.h"
45 #include "tensorflow/compiler/xla/types.h"
46 #include "tensorflow/compiler/xla/util.h"
47 #include "tensorflow/core/lib/core/status.h"
48 #include "tensorflow/core/lib/io/path.h"
49 #include "tensorflow/core/platform/cuda_libdevice_path.h"
50 #include "tensorflow/core/platform/tracing.h"
51 #include "tensorflow/core/profiler/lib/traceme.h"
52 #include "tensorflow/stream_executor/cuda/cuda_diagnostics.h"
53 #include "tensorflow/stream_executor/gpu/asm_compiler.h"
54 #include "tensorflow/stream_executor/gpu/gpu_driver.h"
55
56 namespace xla {
57 namespace gpu {
58
59 namespace {
60
61 namespace tracing = tensorflow::tracing;
62
CandidateCudaRoots(const HloModuleConfig & config)63 static std::vector<std::string> CandidateCudaRoots(
64 const HloModuleConfig& config) {
65 return tensorflow::CandidateCudaRoots(
66 config.debug_options().xla_gpu_cuda_data_dir());
67 }
68
PrintCantFindCudaMessage(absl::string_view msg,const HloModuleConfig & hlo_module_config)69 void PrintCantFindCudaMessage(absl::string_view msg,
70 const HloModuleConfig& hlo_module_config) {
71 LOG(WARNING) << msg;
72 LOG(WARNING) << "Searched for CUDA in the following directories:";
73
74 for (const auto& dir : CandidateCudaRoots(hlo_module_config)) {
75 LOG(WARNING) << " " << dir;
76 }
77 LOG(WARNING)
78 << "You can choose the search directory by setting xla_gpu_cuda_data_dir "
79 "in HloModule's DebugOptions. For most apps, setting the environment "
80 "variable XLA_FLAGS=--xla_gpu_cuda_data_dir=/path/to/cuda will work.";
81 }
82
83 // Returns the directory containing nvvm libdevice files.
GetLibdeviceDir(const HloModuleConfig & hlo_module_config)84 string GetLibdeviceDir(const HloModuleConfig& hlo_module_config) {
85 for (const string& cuda_root : CandidateCudaRoots(hlo_module_config)) {
86 string libdevice_dir =
87 tensorflow::io::JoinPath(cuda_root, "nvvm", "libdevice");
88 VLOG(2) << "Looking for libdevice at " << libdevice_dir;
89 if (tensorflow::Env::Default()->IsDirectory(libdevice_dir).ok()) {
90 VLOG(2) << "Found libdevice dir " << libdevice_dir;
91 return libdevice_dir;
92 }
93 }
94 PrintCantFindCudaMessage(
95 "Can't find libdevice directory ${CUDA_DIR}/nvvm/libdevice. This may "
96 "result in compilation or runtime failures, if the program we try to run "
97 "uses routines from libdevice.",
98 hlo_module_config);
99
100 // GetCudaRootCandidates always includes ".", but if everything fails, we
101 // return it anyway. Better than returning the empty string.
102 return ".";
103 }
104
105 } // namespace
106
OptimizeHloConvolutionCanonicalization(HloModule * hlo_module,se::StreamExecutor * stream_exec,se::DeviceMemoryAllocator * device_allocator)107 Status NVPTXCompiler::OptimizeHloConvolutionCanonicalization(
108 HloModule* hlo_module, se::StreamExecutor* stream_exec,
109 se::DeviceMemoryAllocator* device_allocator) {
110 // Convert convolutions into CustomCalls to cudnn, then canonicalize them
111 // (GpuConvPaddingLegalization). Also expand cuSolver calls.
112 HloPassPipeline pipeline("conv_canonicalization");
113 pipeline.AddInvariantCheckerDebug<HloVerifier>(
114 /*layout_sensitive=*/false,
115 /*allow_mixed_precision=*/false);
116 pipeline.AddPass<CusolverRewriter>();
117 pipeline.AddPass<GpuConvRewriter>();
118 pipeline.AddPass<CudnnFusedConvRewriter>();
119 pipeline.AddPass<GpuConvPaddingLegalization>();
120 pipeline.AddPass<CudnnPadForConvolutions>(IsVoltaOrLater(*stream_exec));
121 // CudnnConvPadForIntegerConvolutions and CudnnConvPadForTensorCores leaves
122 // behind unnecessary tuple/get-tuple-element pairs that TupleSimplifier
123 // fixes.
124 pipeline.AddPass<TupleSimplifier>();
125
126 // tf2xla bridge, DepthwiseConvolutionConverter and GpuConvRewriter
127 // introduces reshapes and transposes that can be eliminated using
128 // AlgebraicSimplifier
129 {
130 auto& pass = pipeline.AddPass<HloPassFix<HloPassPipeline>>(
131 "algebraic_simplification_post_conv_rewriter");
132 pass.AddInvariantCheckerDebug<HloVerifier>(/*layout_sensitive=*/false,
133 /*allow_mixed_precision=*/false);
134
135 AlgebraicSimplifierOptions options;
136 // When transposes appear in a fusion node, we can easily adjust the
137 // multi-dimensional index to create the one needed for the operand. This
138 // is not as easy with bitcasts, because we don't have the information
139 // readily available which dimensions are permuted. In addition to that,
140 // if we have a transpose and a reshape next to each other, they will both
141 // be replaced by a bitcast, and we replace bitcast(bitcast) with one
142 // bitcast. This leads to having to linearize and then delinearize the
143 // index.
144 options.set_replace_transpose_with_bitcast(false);
145 options.set_enable_conv_operand_swap(false);
146 options.set_cudnn_batchnorm_forward_training_metadata(
147 kCudnnBatchNormForwardTrainingCallTarget);
148 pass.AddPass<AlgebraicSimplifier>(options);
149 }
150
151 // GpuConvRewriter, GpuConvPaddingLegalization and
152 // CudnnConvPadForTensorCores may add instructions which can be simplified
153 // by constant folding.
154 pipeline.AddPass<HloConstantFolding>();
155 TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status());
156
157 return Status::OK();
158 }
159
OptimizeHloPostLayoutAssignment(HloModule * hlo_module,se::StreamExecutor * stream_exec,se::DeviceMemoryAllocator * device_allocator)160 Status NVPTXCompiler::OptimizeHloPostLayoutAssignment(
161 HloModule* hlo_module, se::StreamExecutor* stream_exec,
162 se::DeviceMemoryAllocator* device_allocator) {
163 HloPassPipeline pre_pipeline("nvptx post-layout_assignment part 1");
164 // Pad the dimensions of matrices in dot operations to multiples of 8.
165 // This needs to run before GemmRewriter, which is part of
166 // OptimizeHloPostLayoutAssignment().
167 if (IsVoltaOrLater(*stream_exec)) {
168 pre_pipeline.AddPass<CublasGemmPadForTensorCores>();
169 }
170 TF_RETURN_IF_ERROR(pre_pipeline.Run(hlo_module).status());
171
172 TF_RETURN_IF_ERROR(GpuCompiler::OptimizeHloPostLayoutAssignment(
173 hlo_module, stream_exec, device_allocator));
174
175 HloPassPipeline post_pipeline("nvptx post-layout_assignment part 2");
176
177 // Find the fastest algorithm for GEMMs.
178 post_pipeline.AddPass<GemmAlgorithmPicker>(stream_exec, device_allocator);
179 TF_RETURN_IF_ERROR(post_pipeline.Run(hlo_module).status());
180
181 return Status::OK();
182 }
183
184 namespace {
CanShareBufferHint(const HloInstruction * user,const HloInstruction * operand,const ShapeIndex & user_index)185 absl::optional<bool> CanShareBufferHint(const HloInstruction* user,
186 const HloInstruction* operand,
187 const ShapeIndex& user_index) {
188 // Share the bias buffer with the parent instruction.
189 if (IsCublasGemm(*user)) {
190 if (user->operand_count() == 3 && user->operand(2) == operand) {
191 return true;
192 }
193 }
194 // The operand of cholesky can be shared with the first output.
195 if (user->opcode() == HloOpcode::kCustomCall &&
196 user->custom_call_target() == kCusolverCholeskyCallTarget) {
197 return user_index.size() == 1 && user_index[0] == 0;
198 }
199 return absl::nullopt;
200 }
201
202 // Try to load ptx from files defined in the FLAGS. If successful, return true.
MaybeLoadPtxFromFile(const HloModuleConfig module_config,const HloModule * module,std::string * ptx)203 bool MaybeLoadPtxFromFile(const HloModuleConfig module_config,
204 const HloModule* module, std::string* ptx) {
205 // If the xla_gpu_ptx_file options is set, be explicit when a file is used
206 // and warn when a file is not used to ease catching typo in filename.
207 std::string prefix = xla::FilenameFor(*module, "", *ptx);
208 std::string matched_filename;
209 for (const string& full_filename :
210 module_config.debug_options().xla_gpu_ptx_file()) {
211 // To ease comparing many PTX versions, accept different suffixes then
212 // the original filename.
213 auto filename = tensorflow::io::Basename(full_filename);
214 if (absl::StartsWith(filename, prefix)) {
215 matched_filename = full_filename;
216 VLOG(0) << "RunBackend() - Will load PTX from file: " << full_filename;
217 break;
218 }
219 }
220 if (!module_config.debug_options().xla_gpu_ptx_file().empty() &&
221 matched_filename.empty()) {
222 VLOG(0) << "RunBackend() - For module with prefix '" << prefix
223 << "', we did not found a PTX file to load.";
224 }
225
226 if (!matched_filename.empty()) {
227 std::ifstream ifs(matched_filename, std::ifstream::in);
228 *ptx = std::string(std::istreambuf_iterator<char>(ifs),
229 std::istreambuf_iterator<char>());
230 CHECK(!ptx->empty()) << "Empty or non existing PTX file: "
231 << matched_filename;
232 return true;
233 }
234 return false;
235 }
236
237 } // namespace
238
239 // Prints a warning if the ptx->sass JIT in the driver has known bugs.
240 //
241 // Using such a driver only a problem if we fail to use ptxas to compile our ptx
242 // and have to use the driver instead, so you should only call this function if
243 // we're going to use the driver JIT.
244 //
245 // Only prints a warning the first time it's called.
WarnIfBadDriverJITVersion()246 void WarnIfBadDriverJITVersion() {
247 static absl::once_flag run_once;
248 absl::call_once(run_once, [] {
249 auto version_or_status = se::cuda::Diagnostician::FindKernelDriverVersion();
250 if (!version_or_status.ok()) {
251 LOG(WARNING) << "Couldn't read CUDA driver version.";
252 return;
253 }
254 se::cuda::DriverVersion version = version_or_status.ValueOrDie();
255
256 // The following versions of the driver JIT miscompile some address
257 // calculations with large offsets (e.g. "load ptr + large_constant"),
258 // b/70245379:
259 //
260 // - 384.x before 384.108
261 // - 387.x before 387.40
262 // - 390.x before 390.10.
263 //
264 // In addition, only >= 396.20 contains ptxas >= 9.2.88, which contains the
265 // fix for the "large multioutput fusions" miscompile, b/111107644.
266 if (version < std::make_tuple(396, 20, 0)) {
267 LOG(WARNING)
268 << "*** WARNING *** Invoking the PTX->SASS JIT from driver version "
269 << se::cuda::DriverVersionToString(version)
270 << ", which is older than 396.20.0. These versions are known to "
271 "miscompile XLA code, leading to incorrect results or "
272 "invalid-address errors.\nXLA only uses the driver JIT if it "
273 "cannot find ptxas; you don't need to update your driver if "
274 "you can point XLA to ptxas 9.2.88 or newer.";
275 }
276 });
277 }
278
NVPTXCompiler()279 NVPTXCompiler::NVPTXCompiler()
280 : GpuCompiler(stream_executor::cuda::kCudaPlatformId, nvptx::kTargetTriple,
281 nvptx::kDataLayout) {}
282
GetCanShareBuffer()283 HloDataflowAnalysis::CanShareBuffer NVPTXCompiler::GetCanShareBuffer() {
284 return &CanShareBufferHint;
285 }
286
GetGpuVersion(se::StreamExecutor * stream_exec)287 GpuVersion NVPTXCompiler::GetGpuVersion(se::StreamExecutor* stream_exec) {
288 int cc_major, cc_minor;
289 if (!stream_exec->GetDeviceDescription().cuda_compute_capability(&cc_major,
290 &cc_minor)) {
291 LOG(WARNING)
292 << "Couldn't get compute capability for device; assuming sm_20.";
293 cc_major = 2;
294 cc_minor = 0;
295 }
296
297 return std::make_pair(cc_major, cc_minor);
298 }
299
300 StatusOr<std::pair<std::string, std::vector<uint8>>>
CompileTargetBinary(const HloModuleConfig & module_config,llvm::Module * llvm_module,GpuVersion gpu_version,se::StreamExecutor * stream_exec,bool relocatable,const HloModule * debug_module)301 NVPTXCompiler::CompileTargetBinary(const HloModuleConfig& module_config,
302 llvm::Module* llvm_module,
303 GpuVersion gpu_version,
304 se::StreamExecutor* stream_exec,
305 bool relocatable,
306 const HloModule* debug_module) {
307 std::pair<int, int> compute_capability =
308 absl::get<std::pair<int, int>>(gpu_version);
309
310 std::string libdevice_dir;
311 {
312 tensorflow::mutex_lock lock(mutex_);
313
314 // Find the directory containing libdevice. To avoid searching for it every
315 // time, we have a one-element cache, keyed on the module's config's
316 // cuda_data_dir.
317 if (cached_libdevice_dir_.empty()) {
318 cached_libdevice_dir_ = GetLibdeviceDir(module_config);
319 }
320 libdevice_dir = cached_libdevice_dir_;
321 }
322 VLOG(2) << "Libdevice dir = " << libdevice_dir << "\n";
323
324 string ptx;
325 if (!(debug_module &&
326 MaybeLoadPtxFromFile(module_config, debug_module, &ptx))) {
327 XLA_SCOPED_LOGGING_TIMER(
328 "NVPTXCompiler::CompileTargetBinary - CompileToPtx");
329 TF_ASSIGN_OR_RETURN(ptx, nvptx::CompileToPtx(llvm_module, gpu_version,
330 module_config, libdevice_dir));
331 }
332
333 std::vector<uint8> cubin = CompileGpuAsmOrGetCachedResult(
334 stream_exec, ptx, compute_capability.first, compute_capability.second,
335 module_config, relocatable);
336
337 return std::pair<std::string, std::vector<uint8>>(std::move(ptx),
338 std::move(cubin));
339 }
340
CompileGpuAsmOrGetCachedResult(se::StreamExecutor * stream_exec,const string & ptx,int cc_major,int cc_minor,const HloModuleConfig & hlo_module_config,bool relocatable)341 std::vector<uint8> NVPTXCompiler::CompileGpuAsmOrGetCachedResult(
342 se::StreamExecutor* stream_exec, const string& ptx, int cc_major,
343 int cc_minor, const HloModuleConfig& hlo_module_config, bool relocatable) {
344 XLA_SCOPED_LOGGING_TIMER("NVPTXCompiler::CompileGpuAsmOrGetCachedResult");
345 tensorflow::profiler::TraceMe activity(
346 "PTX->CUBIN", tensorflow::profiler::TraceMeLevel::kInfo);
347 bool inserted;
348 decltype(compilation_cache_.begin()) iter;
349 // Pointers into compilation_cache_ where the ptx and (optional) cubin are
350 // stored.
351 const string* cache_ptx = nullptr;
352 CompilationCacheValue* cache_value = nullptr;
353
354 {
355 tensorflow::mutex_lock lock(mutex_);
356 std::tie(iter, inserted) = compilation_cache_.emplace(
357 std::piecewise_construct,
358 std::forward_as_tuple(ptx, cc_major, cc_minor, relocatable),
359 std::forward_as_tuple());
360 cache_ptx = &iter->first.ptx;
361 cache_value = &iter->second;
362 }
363
364 // Compile the ptx if it wasn't in the cache before we called this function.
365 // Other threads asking for the same compilation key will block on
366 // cache_value->mutex_ until compilation is done.
367 {
368 tensorflow::mutex_lock lock(cache_value->mutex_);
369 if (inserted) {
370 CHECK(!cache_value->compilation_done);
371 if (!ptx.empty()) {
372 auto ptxas_config = PtxOptsFromConfig(hlo_module_config);
373 if (relocatable) {
374 ptxas_config.extra_flags.push_back("-c");
375 }
376 StatusOr<std::vector<uint8>> maybe_cubin = se::CompileGpuAsm(
377 stream_exec->device_ordinal(), cache_ptx->c_str(), ptxas_config);
378
379 if (maybe_cubin.ok()) {
380 cache_value->cubin_data = std::move(maybe_cubin).ValueOrDie();
381 VLOG(2) << "Compiled PTX size:" << ptx.size()
382 << " CUBIN size: " << cache_value->cubin_data.size();
383 } else {
384 if (maybe_cubin.status().code() ==
385 tensorflow::error::Code::NOT_FOUND) {
386 if (!hlo_module_config.debug_options()
387 .xla_gpu_unsafe_fallback_to_driver_on_ptxas_not_found()) {
388 PrintCantFindCudaMessage(
389 "Can't find ptxas binary in ${CUDA_DIR}/bin. Custom ptxas "
390 "location can be specified using $PATH.",
391 hlo_module_config);
392 LOG(FATAL)
393 << "Can't find ptxas binary. You can pass the flag "
394 "--xla_gpu_unsafe_fallback_to_driver_on_ptxas_not_found "
395 "to use the GPU driver for compiling ptx instead. However "
396 "this option is discouraged and can lead to increased "
397 "memory consumptions and other subtle runtime issues.";
398 }
399 // Missing ptxas is expected in some environments where CUDA SDK
400 // binaries are not available. We don't want to spam logs with
401 // identical warnings in this case.
402
403 // TODO(jlebar): we should implement a LOG_FIRST_N and LOG_EVERY_N
404 // for more general usage.
405 static std::atomic<bool> warning_done(false);
406 bool log_warning = !warning_done.exchange(true);
407 if (log_warning) {
408 PrintCantFindCudaMessage(
409 "Can't find ptxas binary in ${CUDA_DIR}/bin. Will back to "
410 "the GPU driver for PTX -> sass compilation. This is OK so "
411 "long as you don't see a warning below about an out-of-date "
412 "driver version. Custom ptxas location can be specified "
413 "using $PATH.",
414 hlo_module_config);
415 }
416 } else if (maybe_cubin.status().code() !=
417 tensorflow::error::Code::UNIMPLEMENTED) {
418 // If unimplemented is returned, we fallback to the driver.
419 LOG(FATAL) << "ptxas returned an error during compilation of ptx "
420 "to sass: '"
421 << maybe_cubin.status() << "' "
422 << "If the error message indicates that a file could "
423 "not be written, please verify that sufficient "
424 "filesystem space is provided.";
425 }
426
427 // We're going to use the driver to JIT our PTX->SASS, so warn if
428 // the JIT in the driver has known bugs.
429 WarnIfBadDriverJITVersion();
430 }
431 }
432 cache_value->compilation_done = true;
433 cache_value->compilation_done_cv_.notify_all();
434 } else {
435 while (!cache_value->compilation_done) {
436 cache_value->compilation_done_cv_.wait(lock);
437 }
438 }
439 }
440
441 CHECK(cache_value != nullptr);
442 CHECK(cache_value->compilation_done);
443 return cache_value->cubin_data;
444 }
445
LinkModules(se::StreamExecutor * stream_exec,std::vector<std::vector<uint8>> modules)446 StatusOr<std::vector<uint8>> NVPTXCompiler::LinkModules(
447 se::StreamExecutor* stream_exec, std::vector<std::vector<uint8>> modules) {
448 std::vector<stream_executor::CubinOrPTXImage> images;
449 images.reserve(modules.size());
450 for (auto& module : modules) {
451 images.push_back({"", std::move(module)});
452 }
453 return LinkGpuAsm(static_cast<se::gpu::GpuContext*>(
454 stream_exec->implementation()->GpuContextHack()),
455 images);
456 }
457
458 } // namespace gpu
459 } // namespace xla
460