• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h"
17 
18 #include <fstream>
19 #include <map>
20 #include <memory>
21 #include <string>
22 #include <utility>
23 
24 #include "absl/base/call_once.h"
25 #include "absl/memory/memory.h"
26 #include "absl/strings/str_cat.h"
27 #include "absl/strings/string_view.h"
28 #include "llvm/ADT/STLExtras.h"
29 #include "llvm/ADT/StringMap.h"
30 #include "llvm/ADT/StringSet.h"
31 #include "llvm/Analysis/TargetLibraryInfo.h"
32 #include "llvm/Analysis/TargetTransformInfo.h"
33 #include "llvm/Bitcode/BitcodeReader.h"
34 #include "llvm/Bitcode/BitcodeWriter.h"
35 #include "llvm/CodeGen/CommandFlags.h"
36 #include "llvm/IR/LLVMContext.h"
37 #include "llvm/IR/LegacyPassManager.h"
38 #include "llvm/IR/Module.h"
39 #include "llvm/IR/Verifier.h"
40 #include "llvm/InitializePasses.h"
41 #include "llvm/Linker/Linker.h"
42 #include "llvm/PassRegistry.h"
43 #include "llvm/Support/CommandLine.h"
44 #include "llvm/Support/FileSystem.h"
45 #include "llvm/Support/FormattedStream.h"
46 #include "llvm/Support/Program.h"
47 #include "llvm/Support/TargetRegistry.h"
48 #include "llvm/Support/TargetSelect.h"
49 #include "llvm/Support/ToolOutputFile.h"
50 #include "llvm/Target/TargetMachine.h"
51 #include "llvm/Transforms/IPO.h"
52 #include "llvm/Transforms/IPO/AlwaysInliner.h"
53 #include "llvm/Transforms/IPO/Internalize.h"
54 #include "llvm/Transforms/IPO/PassManagerBuilder.h"
55 #include "llvm/Transforms/Scalar.h"
56 #include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/dump_ir_pass.h"
57 #include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/utils.h"
58 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
59 #include "tensorflow/compiler/xla/status_macros.h"
60 #include "tensorflow/compiler/xla/types.h"
61 #include "tensorflow/compiler/xla/util.h"
62 #include "tensorflow/core/lib/io/path.h"
63 #include "tensorflow/core/platform/env.h"
64 #include "tensorflow/core/platform/logging.h"
65 #include "tensorflow/core/platform/path.h"
66 #include "tensorflow/core/platform/random.h"
67 #include "tensorflow/core/platform/tracing.h"
68 #include "tensorflow/core/profiler/lib/traceme.h"
69 #include "tensorflow/core/util/env_var.h"
70 
71 #if !defined(PLATFORM_GOOGLE) && TENSORFLOW_USE_ROCM
72 #include "rocm/rocm_config.h"
73 #endif
74 
75 namespace xla {
76 namespace gpu {
77 namespace {
78 
79 static llvm::codegen::RegisterCodeGenFlags CGF;
80 
81 // Inline threshold value to use in LLVM AMDGPU backend.
82 const int kAMDGPUInlineThreshold = 0x100000;
83 
84 // Default inline threshold value to use in llvm.
85 const int kDefaultInlineThreshold = 1100;
86 
87 // Gets the GPU name as it's known to LLVM for a given compute
88 // capability.  If we see an unrecognized compute capability, we
89 // return the highest one that is known and below the selected device.
GetSmName(se::CudaComputeCapability compute_capability)90 static string GetSmName(se::CudaComputeCapability compute_capability) {
91   int compute_capability_version =
92       compute_capability.major * 10 + compute_capability.minor;
93   int sm_version = 30;
94   // If the current compute capability isn't known, fallback to the
95   // most recent version before it.
96   int supported_versions[] = {75, 72, 70, 62, 61, 60, 53,
97                               52, 50, 37, 35, 32, 30};
98   for (int v : supported_versions) {
99     if (v <= compute_capability_version) {
100       sm_version = v;
101       break;
102     }
103   }
104 
105   // If the current CC isn't supported by LLVM and it is newer then
106   // the max supported LLVM version, do not warn about it. The end
107   // user can't do anything about this. PTX compiled for SM75 will
108   // run on SM80 too.
109   if (sm_version != compute_capability_version &&
110       compute_capability_version < supported_versions[0]) {
111     LOG(WARNING) << "Unknown compute capability "
112                  << compute_capability.ToString()
113                  << ". Defaulting to telling LLVM that we're compiling for sm_"
114                  << sm_version;
115   }
116   return absl::StrCat("sm_", sm_version);
117 }
118 
119 // Convenience function for producing a name of a temporary compilation product
120 // from the input filename.
MakeNameForTempProduct(absl::string_view input_filename,absl::string_view extension)121 string MakeNameForTempProduct(absl::string_view input_filename,
122                               absl::string_view extension) {
123   return ReplaceFilenameExtension(tensorflow::io::Basename(input_filename),
124                                   extension);
125 }
126 
127 // Initializes LLVM passes. Uses the PassRegistry mechanism.
InitializePasses(llvm::PassRegistry * pass_registry)128 void InitializePasses(llvm::PassRegistry* pass_registry) {
129   llvm::initializeCore(*pass_registry);
130   llvm::initializeCodeGen(*pass_registry);
131   llvm::initializeScalarOpts(*pass_registry);
132   llvm::initializeObjCARCOpts(*pass_registry);
133   llvm::initializeVectorization(*pass_registry);
134   llvm::initializeIPO(*pass_registry);
135   llvm::initializeAnalysis(*pass_registry);
136   llvm::initializeTransformUtils(*pass_registry);
137   llvm::initializeInstCombine(*pass_registry);
138   llvm::initializeInstrumentation(*pass_registry);
139   llvm::initializeTarget(*pass_registry);
140   llvm::initializeCodeGenPreparePass(*pass_registry);
141 }
142 
143 // Returns the TargetMachine, given a triple.
GetTargetMachine(llvm::Triple triple,absl::string_view cpu_name,const HloModuleConfig & hlo_module_config,absl::string_view feature_str)144 std::unique_ptr<llvm::TargetMachine> GetTargetMachine(
145     llvm::Triple triple, absl::string_view cpu_name,
146     const HloModuleConfig& hlo_module_config, absl::string_view feature_str) {
147   std::string error;
148   const llvm::Target* target =
149       llvm::TargetRegistry::lookupTarget("", triple, error);
150   if (target == nullptr) {
151     LOG(FATAL) << "Unable to find Target for triple '" << triple.str() << "'"
152                << " -- " << error;
153     return nullptr;
154   }
155 
156   llvm::TargetOptions target_options =
157       llvm::codegen::InitTargetOptionsFromCodeGenFlags(llvm::Triple());
158 
159   // Set the verbose assembly options.
160   target_options.MCOptions.AsmVerbose = false;
161 
162   // The selection of codegen optimization level is copied from function
163   // GetCodeGenOptLevel in //third_party/llvm/llvm/tools/opt/opt.cpp.
164   llvm::CodeGenOpt::Level codegen_opt_level;
165   switch (hlo_module_config.debug_options().xla_backend_optimization_level()) {
166     case 1:
167       codegen_opt_level = llvm::CodeGenOpt::Less;
168       break;
169     case 2:
170       codegen_opt_level = llvm::CodeGenOpt::Default;
171       break;
172     case 3:
173       codegen_opt_level = llvm::CodeGenOpt::Aggressive;
174       break;
175     default:
176       codegen_opt_level = llvm::CodeGenOpt::None;
177   }
178   return absl::WrapUnique(target->createTargetMachine(
179       triple.str(), llvm_ir::AsStringRef(cpu_name),
180       llvm_ir::AsStringRef(feature_str), target_options,
181       llvm::codegen::getExplicitRelocModel(),
182       llvm::codegen::getExplicitCodeModel(), codegen_opt_level));
183 }
184 
185 // Adds the standard LLVM optimization passes, based on the speed optimization
186 // level (opt_level) and size optimization level (size_level). Both module
187 // and function-level passes are added, so two pass managers are passed in and
188 // modified by this function.
AddOptimizationPasses(unsigned opt_level,unsigned size_level,llvm::TargetMachine * target_machine,llvm::legacy::PassManagerBase * module_passes,llvm::legacy::FunctionPassManager * function_passes,int inline_threshold)189 void AddOptimizationPasses(unsigned opt_level, unsigned size_level,
190                            llvm::TargetMachine* target_machine,
191                            llvm::legacy::PassManagerBase* module_passes,
192                            llvm::legacy::FunctionPassManager* function_passes,
193                            int inline_threshold) {
194   llvm::PassManagerBuilder builder;
195   builder.OptLevel = opt_level;
196   builder.SizeLevel = size_level;
197 
198   if (opt_level > 1) {
199     builder.Inliner = llvm::createFunctionInliningPass(inline_threshold);
200   } else {
201     // Only inline functions marked with "alwaysinline".
202     builder.Inliner = llvm::createAlwaysInlinerLegacyPass();
203   }
204 
205   builder.DisableUnrollLoops = opt_level == 0;
206   builder.LoopVectorize = opt_level > 0;
207   builder.SLPVectorize = opt_level > 1 && size_level < 2;
208 
209   // NVPTX's early-as-possible passes include NVVM reflect.
210   target_machine->adjustPassManager(builder);
211 
212   builder.populateFunctionPassManager(*function_passes);
213   builder.populateModulePassManager(*module_passes);
214 }
215 
216 // Emits the given module to a bit code file.
EmitBitcodeToFile(const llvm::Module & module,absl::string_view filename)217 void EmitBitcodeToFile(const llvm::Module& module, absl::string_view filename) {
218   std::error_code error_code;
219   llvm::ToolOutputFile outfile(string(filename).c_str(), error_code,
220                                llvm::sys::fs::OF_None);
221   if (error_code) {
222     LOG(FATAL) << "opening bitcode file for writing: " << error_code.message();
223   }
224 
225   llvm::WriteBitcodeToFile(module, outfile.os());
226   outfile.keep();
227 }
228 
229 // Emits the given module to PTX. target_machine is an initialized TargetMachine
230 // for the NVPTX target.
EmitModuleToPTX(llvm::Module * module,llvm::TargetMachine * target_machine)231 string EmitModuleToPTX(llvm::Module* module,
232                        llvm::TargetMachine* target_machine) {
233   std::string ptx;
234   {
235     llvm::raw_string_ostream stream(ptx);
236     llvm::buffer_ostream pstream(stream);
237     // The extension is stripped by IrDumpingPassManager, so we need to
238     // get creative to add a suffix.
239     IrDumpingPassManager codegen_passes(
240         MakeNameForTempProduct(module->getModuleIdentifier(), "-nvptx.dummy"),
241         "", false);
242     codegen_passes.add(new llvm::TargetLibraryInfoWrapperPass(
243         llvm::Triple(module->getTargetTriple())));
244 
245     target_machine->addPassesToEmitFile(codegen_passes, pstream, nullptr,
246                                         llvm::CGFT_AssemblyFile);
247     codegen_passes.run(*module);
248   }
249 
250   return ptx;
251 }
252 
253 // LLVM has an extensive flags mechanism of its own, which is only accessible
254 // through the command line. Internal libraries within LLVM register parsers for
255 // flags, with no other way to configure them except pass these flags.
256 // To do this programmatically, we invoke ParseCommandLineOptions manually with
257 // a "fake argv".
258 // Note: setting flags with this method is stateful, since flags are just
259 // static globals within LLVM libraries.
FeedLLVMWithFlags(const std::vector<string> & cl_opts)260 void FeedLLVMWithFlags(const std::vector<string>& cl_opts) {
261   std::vector<const char*> fake_argv = {""};
262   for (const string& cl_opt : cl_opts) {
263     fake_argv.push_back(cl_opt.c_str());
264   }
265   llvm::cl::ParseCommandLineOptions(fake_argv.size(), &fake_argv[0]);
266 }
267 
268 // Returns whether the module could use any device bitcode library functions.
CouldNeedDeviceBitcode(const llvm::Module & module)269 bool CouldNeedDeviceBitcode(const llvm::Module& module) {
270   for (const llvm::Function& function : module.functions()) {
271     // The list of prefixes should be in sync with library functions used in
272     // target_util.cc.
273     if (!function.isIntrinsic() && function.isDeclaration() &&
274         (function.getName().startswith("__nv_") ||
275          function.getName().startswith("__ocml_") ||
276          function.getName().startswith("__ockl_"))) {
277       return true;
278     }
279   }
280   return false;
281 }
282 
283 // Links the module with a vector of path to bitcode modules.
284 // The caller must guarantee that the paths exist.
LinkWithBitcodeVector(llvm::Module * module,const std::vector<string> & bitcode_path_vector)285 Status LinkWithBitcodeVector(llvm::Module* module,
286                              const std::vector<string>& bitcode_path_vector) {
287   llvm::Linker linker(*module);
288 
289   for (auto& bitcode_path : bitcode_path_vector) {
290     if (!tensorflow::Env::Default()->FileExists(bitcode_path).ok()) {
291       LOG(ERROR) << "bitcode module is required by this HLO module but was "
292                     "not found at "
293                  << bitcode_path;
294       return xla::InternalError("bitcode module not found at %s", bitcode_path);
295     }
296 
297     std::unique_ptr<llvm::Module> bitcode_module =
298         LoadIRModule(bitcode_path, &module->getContext());
299     // Ignore the data layout of the module we're importing. This avoids a
300     // warning from the linker.
301     bitcode_module->setDataLayout(module->getDataLayout());
302     if (linker.linkInModule(
303             std::move(bitcode_module), llvm::Linker::Flags::LinkOnlyNeeded,
304             [](llvm::Module& M, const llvm::StringSet<>& GVS) {
305               internalizeModule(M, [&GVS](const llvm::GlobalValue& GV) {
306                 return !GV.hasName() || (GVS.count(GV.getName()) == 0);
307               });
308             })) {
309       return xla::InternalError("Error linking bitcode module from %s",
310                                 bitcode_path);
311     }
312   }
313   return Status::OK();
314 }
315 
316 // Links libdevice into the given module if the module needs libdevice.
LinkLibdeviceIfNecessary(llvm::Module * module,const string & libdevice_dir_path)317 Status LinkLibdeviceIfNecessary(llvm::Module* module,
318                                 const string& libdevice_dir_path) {
319   if (!CouldNeedDeviceBitcode(*module)) {
320     return Status::OK();
321   }
322 
323   // CUDA 9+ uses a single libdevice file for all devices, and we don't support
324   // older CUDAs.
325   string libdevice_path =
326       tensorflow::io::JoinPath(libdevice_dir_path, "libdevice.10.bc");
327   if (!tensorflow::Env::Default()->FileExists(libdevice_path).ok()) {
328     LOG(WARNING)
329         << "libdevice is required by this HLO module but was not found at "
330         << libdevice_path;
331     return xla::InternalError("libdevice not found at %s", libdevice_path);
332   }
333 
334   VLOG(1) << "Linking with libdevice from: " << libdevice_path;
335   return LinkWithBitcodeVector(module, {libdevice_path});
336 }
337 
NVPTXTargetModuleLinker(llvm::Module * module,GpuVersion gpu_version,const HloModuleConfig & hlo_module_config,const string & device_bitcode_dir_path)338 Status NVPTXTargetModuleLinker(llvm::Module* module, GpuVersion gpu_version,
339                                const HloModuleConfig& hlo_module_config,
340                                const string& device_bitcode_dir_path) {
341   // Link the input module with libdevice, to pull in implementations of some
342   // builtins.
343   TF_RETURN_IF_ERROR(LinkLibdeviceIfNecessary(module, device_bitcode_dir_path));
344 
345   // Set the flush-denormals-to-zero flag on the module so the NVVM reflect pass
346   // can access it.
347   module->addModuleFlag(llvm::Module::Override, "nvvm-reflect-ftz",
348                         hlo_module_config.debug_options().xla_gpu_ftz());
349 
350   // If ftz is enabled, set it as an attribute on every function in the module.
351   if (hlo_module_config.debug_options().xla_gpu_ftz()) {
352     for (llvm::Function& fn : *module) {
353       fn.addFnAttr("denormal-fp-math-f32", "preserve-sign");
354     }
355   }
356 
357   return Status::OK();
358 }
359 
NVPTXGetTargetMachine(llvm::Triple target_triple,se::CudaComputeCapability compute_capability,const HloModuleConfig & hlo_module_config)360 std::unique_ptr<llvm::TargetMachine> NVPTXGetTargetMachine(
361     llvm::Triple target_triple, se::CudaComputeCapability compute_capability,
362     const HloModuleConfig& hlo_module_config) {
363   // Figure out the exact name of the processor as known to the NVPTX backend
364   // from the gpu_architecture flag.
365   return GetTargetMachine(target_triple, GetSmName(compute_capability),
366                           hlo_module_config, "+ptx60");
367 }
368 
369 using TargetModuleLinker = std::function<Status(
370     llvm::Module*, GpuVersion, const HloModuleConfig&, const string&)>;
371 
LinkAndOptimizeModule(llvm::Module * module,GpuVersion gpu_version,const HloModuleConfig & hlo_module_config,const string & device_bitcode_dir_path,TargetModuleLinker module_linker,llvm::Triple default_target_triple,llvm::TargetMachine * target_machine,int inline_threshold)372 Status LinkAndOptimizeModule(llvm::Module* module, GpuVersion gpu_version,
373                              const HloModuleConfig& hlo_module_config,
374                              const string& device_bitcode_dir_path,
375                              TargetModuleLinker module_linker,
376                              llvm::Triple default_target_triple,
377                              llvm::TargetMachine* target_machine,
378                              int inline_threshold) {
379   TF_RETURN_IF_ERROR(module_linker(module, gpu_version, hlo_module_config,
380                                    device_bitcode_dir_path));
381 
382   bool dump_ir = hlo_module_config.debug_options().xla_gpu_dump_llvmir();
383   std::string outputs_dir;
384   tensorflow::io::GetTestUndeclaredOutputsDir(&outputs_dir);
385   IrDumpingPassManager module_passes(module->getModuleIdentifier(), outputs_dir,
386                                      dump_ir);
387 
388   // Add an appropriate TargetLibraryInfo pass for the module's triple.
389   llvm::TargetLibraryInfoWrapperPass* tliwp =
390       new llvm::TargetLibraryInfoWrapperPass(
391           llvm::Triple(module->getTargetTriple()));
392   module_passes.add(tliwp);
393 
394   // Try to fetch the target triple from the module. If not present, set a
395   // default target triple.
396   llvm::Triple target_triple = llvm::Triple(module->getTargetTriple());
397   if (target_triple.getArch() == llvm::Triple::UnknownArch) {
398     LOG(WARNING) << "target triple not found in the module";
399     target_triple = default_target_triple;
400   }
401 
402   module_passes.add(llvm::createTargetTransformInfoWrapperPass(
403       target_machine->getTargetIRAnalysis()));
404 
405   // The LLVM IR verifier performs sanity checking on the IR. This helps
406   // discover problems and report them in a meaningful manner, rather than let
407   // later passes report obscure assertions because of unfulfilled invariants.
408   module_passes.add(llvm::createVerifierPass());
409 
410   // Create the function-level pass manager. It needs data layout information
411   // too.
412   llvm::legacy::FunctionPassManager function_passes(module);
413 
414   int32_t opt_level =
415       hlo_module_config.debug_options().xla_backend_optimization_level();
416 
417   if (opt_level < 2) {
418     LOG(ERROR) << std::string(80, '*');
419     LOG(ERROR) << "The XLA GPU backend doesn't support unoptimized code "
420                   "generation but ";
421     LOG(ERROR) << "--xla_backend_optimization_level is set to " << opt_level
422                << "!";
423     LOG(ERROR) << "(Supported configuration is "
424                   "--xla_backend_optimization_level >= 2.)";
425     LOG(ERROR) << std::string(80, '*');
426   }
427 
428   // Add optimization passes, and set inliner threshold.
429   AddOptimizationPasses(opt_level,
430                         /*size_level=*/0, target_machine, &module_passes,
431                         &function_passes, inline_threshold);
432 
433   // Loop unrolling exposes more opportunities for SROA. Therefore, we run SROA
434   // again after the standard optimization passes [http://b/13329423].
435   // TODO(jingyue): SROA may further expose more optimization opportunities such
436   // as more precise alias analysis and more function inlining (SROA may change
437   // the inlining cost of a function). For now, running SROA already emits good
438   // enough code for the evaluated benchmarks. We may want to run more
439   // optimizations later.
440   if (opt_level > 0) {
441     // LLVM's optimizer turns on SROA when the optimization level is greater
442     // than 0. We mimic this behavior here.
443     module_passes.add(llvm::createSROAPass());
444   }
445 
446   // Verify that the module is well formed after optimizations ran.
447   module_passes.add(llvm::createVerifierPass());
448 
449   // Done populating the pass managers. Now run them.
450 
451   function_passes.doInitialization();
452   for (auto func = module->begin(); func != module->end(); ++func) {
453     function_passes.run(*func);
454   }
455   function_passes.doFinalization();
456   module_passes.run(*module);
457 
458   return Status::OK();
459 }
460 
461 // One-time module initializer.
462 // Must be called only once -- DO NOT CALL DIRECTLY.
NVPTXBackendInit(const HloModuleConfig & hlo_module_config)463 void NVPTXBackendInit(const HloModuleConfig& hlo_module_config) {
464   // Feed all customized flags here, so we can override them with llvm_cl_opts
465   // without redeploy the compiler for development purpose.
466 
467   // This flag tunes a threshold in branch folding. The default threshold, which
468   // is one, is not suitable for CUDA programs where branches are more expensive
469   // than for CPU programs. Setting the threshold to 2 improves the latency of
470   // TwoDPatchDotProductKernel_IND_3_ND_48 by over 5%, and does not affect the
471   // latency of other benchmarks so far.
472   //
473   // I also tried setting this threshold to other values:
474   // * 3-6 gives similar results as 2;
475   // * >6 start hurting the performance of at least dot product kernels.
476   //
477   // TODO(jingyue): The current threshold only considers the number of IR
478   // instructions which do not accurately reflect the true cost. We need a
479   // better cost model.
480   FeedLLVMWithFlags({"-bonus-inst-threshold=2"});
481   // Increase limit when scanning memory dependencies.  This helps to reduce
482   // more redundant load instructions.
483   //
484   // The specific value is currently large enough for s3d in shoc benchmark,
485   // which contains a lot of load instructions and many arithmetic instructions
486   // between those loads.
487   FeedLLVMWithFlags({"-memdep-block-scan-limit=500"});
488 
489   // Use div.full -- it matters for some float-division heavy benchmarks.
490   // Using div.approx produces incorrect result for float32(max)/float32(max).
491   FeedLLVMWithFlags({"-nvptx-prec-divf32=1"});
492 
493   llvm_ir::InitializeLLVMCommandLineOptions(hlo_module_config);
494 
495   // Initialize the NVPTX target; it's the only target we link with, so call its
496   // specific initialization functions instead of the catch-all InitializeAll*.
497   LLVMInitializeNVPTXTarget();
498   LLVMInitializeNVPTXTargetInfo();
499   LLVMInitializeNVPTXTargetMC();
500   LLVMInitializeNVPTXAsmPrinter();
501 
502   // Initialize the LLVM optimization passes.
503   llvm::PassRegistry* registry = llvm::PassRegistry::getPassRegistry();
504   InitializePasses(registry);
505 }
506 
507 }  // namespace
508 
509 namespace nvptx {
510 
CompileToPtx(llvm::Module * module,GpuVersion gpu_version,const HloModuleConfig & hlo_module_config,const string & libdevice_dir_path,std::function<void (llvm::TargetMachine *)> configure_target)511 StatusOr<string> CompileToPtx(
512     llvm::Module* module, GpuVersion gpu_version,
513     const HloModuleConfig& hlo_module_config, const string& libdevice_dir_path,
514     std::function<void(llvm::TargetMachine*)> configure_target) {
515   static absl::once_flag backend_init_flag;
516   absl::call_once(backend_init_flag, NVPTXBackendInit, hlo_module_config);
517 
518   string ptx;
519   std::unique_ptr<llvm::TargetMachine> target_machine;
520   {
521     tensorflow::profiler::TraceMe activity(
522         [&] { return absl::StrCat("Compiling IR:", module->getName().str()); },
523         tensorflow::profiler::TraceMeLevel::kInfo);
524     XLA_SCOPED_LOGGING_TIMER("Compile module " + module->getName().str());
525 
526     // If the module has no functions or globals, there's nothing to compile.
527     // Just return an empty string.
528     if (module->empty() && module->global_empty()) {
529       VLOG(2) << "Module '" << module->getName().str()
530               << "' is empty. Skipping compilation.";
531       return string();
532     }
533 
534     auto compute_capability =
535         absl::get_if<se::CudaComputeCapability>(&gpu_version);
536     if (!compute_capability) {
537       return xla::InternalError(
538           "Incompatible compute capability was specified.");
539     }
540 
541     llvm::Triple default_target_triple("nvptx64-unknown-unknown");
542     // Construct LLVM TargetMachine for NVPTX.
543     std::unique_ptr<llvm::TargetMachine> target_machine = NVPTXGetTargetMachine(
544         default_target_triple, *compute_capability, hlo_module_config);
545 
546     // Apply target machine configuration from call-back if available.
547     if (configure_target) {
548       configure_target(target_machine.get());
549     }
550 
551     // Link with libdevice, and optimize the LLVM module.
552     TF_RETURN_IF_ERROR(LinkAndOptimizeModule(
553         module, gpu_version, hlo_module_config, libdevice_dir_path,
554         NVPTXTargetModuleLinker, default_target_triple, target_machine.get(),
555         kDefaultInlineThreshold));
556 
557     // Lower optimized LLVM module to PTX.
558     ptx = EmitModuleToPTX(module, target_machine.get());
559   }
560   return ptx;
561 }
562 
563 }  // namespace nvptx
564 
565 namespace {
566 
567 // Gets the ROCm-Device-Libs filenames for a particular AMDGPU version.
GetROCDLPaths(std::string amdgpu_version,const string & rocdl_dir_path)568 std::vector<string> GetROCDLPaths(std::string amdgpu_version,
569                                   const string& rocdl_dir_path) {
570   // AMDGPU version-neutral bitcodes.
571 #if TF_ROCM_VERSION >= 30900
572   static std::vector<string>* rocdl_filenames = new std::vector<string>(
573       {"hc.bc", "opencl.bc", "ocml.bc", "ockl.bc", "oclc_finite_only_off.bc",
574        "oclc_daz_opt_off.bc", "oclc_correctly_rounded_sqrt_on.bc",
575        "oclc_unsafe_math_off.bc", "oclc_wavefrontsize64_on.bc"});
576 #else
577   static std::vector<string>* rocdl_filenames = new std::vector<string>(
578       {"hc.amdgcn.bc", "opencl.amdgcn.bc", "ocml.amdgcn.bc", "ockl.amdgcn.bc",
579        "oclc_finite_only_off.amdgcn.bc", "oclc_daz_opt_off.amdgcn.bc",
580        "oclc_correctly_rounded_sqrt_on.amdgcn.bc",
581        "oclc_unsafe_math_off.amdgcn.bc", "oclc_wavefrontsize64_on.amdgcn.bc"});
582 #endif
583 
584   // Construct full path to ROCDL bitcode libraries.
585   std::vector<string> result;
586   for (auto& filename : *rocdl_filenames) {
587     result.push_back(tensorflow::io::JoinPath(rocdl_dir_path, filename));
588   }
589 
590   // Add AMDGPU version-specific bitcodes.
591   std::vector<std::string> tokens = absl::StrSplit(amdgpu_version, ':');
592   if (!tokens.empty() && tokens[0].size() >= 3) {
593     amdgpu_version = tokens[0].substr(3);
594   }
595   result.push_back(tensorflow::io::JoinPath(
596       rocdl_dir_path,
597 #if TF_ROCM_VERSION >= 30900
598       absl::StrCat("oclc_isa_version_", amdgpu_version, ".bc")));
599 #else
600       absl::StrCat("oclc_isa_version_", amdgpu_version, ".amdgcn.bc")));
601 #endif
602   return result;
603 }
604 
605 struct HsacoCacheEntry {
606   uint64 hash;
607   std::string ir;
608   std::string gfx;
609   std::vector<uint8> hsaco;
610 };
611 
612 struct HsacoCache {
613  protected:
614   std::vector<HsacoCacheEntry> cache;
615   std::mutex m_mutex;
616   int request_count = 0;
617   int hit_count = 0;
618 
619  public:
620   static bool Find(const std::string& ir, uint64_t& hash,
621                    const std::string& gfx, std::vector<uint8>& hsaco);
622   static void Add(const std::string& ir, uint64_t hash, const std::string& gfx,
623                   const std::vector<uint8>& hsaco);
624 };
625 
626 static HsacoCache g_hsacoCache;
627 
Find(const std::string & ir,uint64_t & hash,const std::string & gfx,std::vector<uint8> & hsaco)628 bool HsacoCache::Find(const std::string& ir, uint64_t& hash,
629                       const std::string& gfx, std::vector<uint8>& hsaco) {
630   std::lock_guard<std::mutex> lg(g_hsacoCache.m_mutex);
631   hash = std::hash<std::string>{}(ir);
632   bool hit = false;
633   for (auto& x : g_hsacoCache.cache) {
634     if (x.hash != hash) continue;
635     if (x.gfx != gfx) continue;
636     if (x.ir != ir) continue;
637     hsaco = x.hsaco;
638     hit = true;
639     break;
640   }
641   g_hsacoCache.request_count++;
642   if (hit) g_hsacoCache.hit_count++;
643   if (!(g_hsacoCache.request_count % 50))
644     VLOG(1) << "HSACO cache: " << g_hsacoCache.request_count << " requests, "
645             << g_hsacoCache.hit_count << " hits";
646   return hit;
647 }
648 
Add(const std::string & ir,uint64_t hash,const std::string & gfx,const std::vector<uint8> & hsaco)649 void HsacoCache::Add(const std::string& ir, uint64_t hash,
650                      const std::string& gfx, const std::vector<uint8>& hsaco) {
651   std::lock_guard<std::mutex> lg(g_hsacoCache.m_mutex);
652   g_hsacoCache.cache.resize(g_hsacoCache.cache.size() + 1);
653   g_hsacoCache.cache.back().ir = ir;
654   g_hsacoCache.cache.back().hash = hash;
655   g_hsacoCache.cache.back().gfx = gfx;
656   g_hsacoCache.cache.back().hsaco = hsaco;
657 }
658 
659 // Emits the given module to HSA Code Object. target_machine is an initialized
660 // TargetMachine for the AMDGPU target.
EmitModuleToHsaco(llvm::Module * module,llvm::TargetMachine * target_machine)661 StatusOr<std::vector<uint8>> EmitModuleToHsaco(
662     llvm::Module* module, llvm::TargetMachine* target_machine) {
663   auto* env = tensorflow::Env::Default();
664   std::vector<std::string> tempdir_vector;
665   env->GetLocalTempDirectories(&tempdir_vector);
666   if (tempdir_vector.empty()) {
667     return xla::InternalError(
668         "Unable to locate a temporary directory for compile-time artifacts.");
669   }
670   std::string tempdir_name = tempdir_vector.front();
671   VLOG(1) << "Compile-time artifacts located at: " << tempdir_name;
672 
673   bool keep_tempfiles = false;
674   TF_CHECK_OK(tensorflow::ReadBoolFromEnvVar("TF_ROCM_KEEP_XLA_TEMPFILES",
675                                              /*default_val=*/false,
676                                              &keep_tempfiles));
677   // Prepare filenames for all stages of compilation:
678   // IR, binary ISA, and HSACO.
679   std::string random_number = std::to_string(tensorflow::random::New64());
680   std::string ir_filename =
681       absl::StrCat(module->getModuleIdentifier(), random_number + ".ll");
682   std::string ir_path = tensorflow::io::JoinPath(tempdir_name, ir_filename);
683 
684   std::string ir_opt_filename =
685       absl::StrCat(module->getModuleIdentifier(), random_number + "_opt.ll");
686   std::string ir_opt_path =
687       tensorflow::io::JoinPath(tempdir_name, ir_opt_filename);
688 
689   std::string isabin_filename =
690       absl::StrCat(module->getModuleIdentifier(), random_number + ".o");
691   std::string isabin_path =
692       tensorflow::io::JoinPath(tempdir_name, isabin_filename);
693 
694   std::string hsaco_filename =
695       absl::StrCat(module->getModuleIdentifier(), random_number + ".hsaco");
696   std::string hsaco_path =
697       tensorflow::io::JoinPath(tempdir_name, hsaco_filename);
698 
699   std::error_code ec;
700 
701   // Dump LLVM IR.
702   std::unique_ptr<llvm::raw_fd_ostream> ir_fs(
703       new llvm::raw_fd_ostream(ir_path, ec, llvm::sys::fs::OF_None));
704   module->print(*ir_fs, nullptr);
705   ir_fs->flush();
706 
707   // Emit GCN ISA binary.
708   // The extension is stripped by IrDumpingPassManager, so we need to
709   // get creative to add a suffix.
710   std::string module_id = module->getModuleIdentifier();
711   IrDumpingPassManager codegen_passes(
712       ReplaceFilenameExtension(tensorflow::io::Basename(module_id),
713                                random_number + "-amdgpu.dummy"),
714       "", false);
715   codegen_passes.add(new llvm::TargetLibraryInfoWrapperPass(
716       llvm::Triple(module->getTargetTriple())));
717   llvm::SmallVector<char, 0> stream;
718   llvm::raw_svector_ostream pstream(stream);
719   std::unique_ptr<llvm::raw_fd_ostream> isabin_fs(
720       new llvm::raw_fd_ostream(isabin_path, ec, llvm::sys::fs::OF_Text));
721   module->setDataLayout(target_machine->createDataLayout());
722   target_machine->addPassesToEmitFile(codegen_passes, *isabin_fs, nullptr,
723                                       llvm::CGFT_ObjectFile);
724   codegen_passes.run(*module);
725   isabin_fs->flush();
726 
727   if (keep_tempfiles) {
728     std::unique_ptr<llvm::raw_fd_ostream> ir_fs(
729         new llvm::raw_fd_ostream(ir_opt_path, ec, llvm::sys::fs::OF_None));
730     module->print(*ir_fs, nullptr);
731     ir_fs->flush();
732   }
733   // Locate lld.
734   // TODO(whchung@gmail.com): change to tensorflow::ROCmRoot() after
735   // ROCm-Device-Libs PR.
736   std::string lld_path_1 = tensorflow::io::JoinPath("/opt/rocm", "hcc/bin");
737   std::string lld_path_2 = tensorflow::io::JoinPath("/opt/rocm", "llvm/bin");
738   auto lld_program =
739       llvm::sys::findProgramByName("ld.lld", {lld_path_1, lld_path_2});
740   if (!lld_program) {
741     return xla::InternalError("unable to find ld.lld in PATH: %s",
742                               lld_program.getError().message());
743   }
744   std::vector<llvm::StringRef> lld_args{
745       llvm_ir::AsStringRef("ld.lld"),
746       llvm_ir::AsStringRef("-flavor"),
747       llvm_ir::AsStringRef("gnu"),
748       llvm_ir::AsStringRef("-shared"),
749       llvm_ir::AsStringRef(isabin_path),
750       llvm_ir::AsStringRef("-o"),
751       llvm_ir::AsStringRef(hsaco_path),
752   };
753 
754   std::string error_message;
755   int lld_result =
756       llvm::sys::ExecuteAndWait(*lld_program, llvm_ir::AsArrayRef(lld_args),
757                                 llvm::None, {}, 0, 0, &error_message);
758   if (lld_result) {
759     return xla::InternalError("ld.lld execute fail: %s, error code %d",
760                               error_message, lld_result);
761   }
762 
763   // Read HSACO.
764   std::ifstream hsaco_file(hsaco_path, std::ios::binary | std::ios::ate);
765   std::ifstream::pos_type hsaco_file_size = hsaco_file.tellg();
766 
767   std::vector<uint8> hsaco(hsaco_file_size);
768   hsaco_file.seekg(0, std::ios::beg);
769   hsaco_file.read(reinterpret_cast<char*>(&hsaco[0]), hsaco_file_size);
770   hsaco_file.close();
771   if (!keep_tempfiles) {
772     remove(ir_path.c_str());
773     remove(isabin_path.c_str());
774     remove(hsaco_path.c_str());
775   }
776   return hsaco;
777 }
778 
779 // Links ROCm-Device-Libs into the given module if the module needs it.
LinkROCDLIfNecessary(llvm::Module * module,std::string amdgpu_version,const string & rocdl_dir_path)780 Status LinkROCDLIfNecessary(llvm::Module* module, std::string amdgpu_version,
781                             const string& rocdl_dir_path) {
782   if (!CouldNeedDeviceBitcode(*module)) {
783     return Status::OK();
784   }
785 
786   return LinkWithBitcodeVector(module,
787                                GetROCDLPaths(amdgpu_version, rocdl_dir_path));
788 }
789 
AMDGPUTargetModuleLinker(llvm::Module * module,GpuVersion gpu_version,const HloModuleConfig & hlo_module_config,const string & device_bitcode_dir_path)790 Status AMDGPUTargetModuleLinker(llvm::Module* module, GpuVersion gpu_version,
791                                 const HloModuleConfig& hlo_module_config,
792                                 const string& device_bitcode_dir_path) {
793   // Link the input module with ROCDL.
794   auto amdgpu_version = absl::get_if<std::string>(&gpu_version);
795   if (!amdgpu_version) {
796     return xla::InternalError(
797         "Incompatible AMD GCN ISA version was specified.");
798   }
799   TF_RETURN_IF_ERROR(
800       LinkROCDLIfNecessary(module, *amdgpu_version, device_bitcode_dir_path));
801 
802   // If ftz is enabled, set it as an attribute on every function in the module.
803   if (hlo_module_config.debug_options().xla_gpu_ftz()) {
804     for (llvm::Function& fn : *module) {
805       fn.addFnAttr("denormal-fp-math-f32", "preserve-sign");
806     }
807   }
808 
809   return Status::OK();
810 }
811 
812 // The following routine maps a feature token extracted from the
813 // hipDeviceProp_t::gcnArchName string, and maps it to a valid feature_str
814 // to be used for creating the AMDGPUTarget.
815 // This mapping is currently in a state of flux because TF XLA uses its
816 // own copy of LLVM, which is different from the LLVM version used by
817 // hipcc/runtime in the ROCm install. Ordinarily this is not a problem,
818 // but right now, the LLVM version used by hipcc/runtime has "targetID"
819 // related changes which have not yet been upstreamed (to the LLVM repo)
820 // When that upstreaming happens (and TF LLVM pointer moves past the
821 // upstream commit), the following mapping will need to change
MapGCNArchNameTokenToFeatureStr(const std::string & token)822 std::string MapGCNArchNameTokenToFeatureStr(const std::string& token) {
823   if (token == "sramecc+") {
824     return "+sramecc";
825   } else if (token == "sramecc-") {
826     return "-sramecc";
827   } else if (token == "xnack+") {
828     return "+xnack";
829   } else if (token == "xnack-") {
830     return "-xnack";
831   }
832   return "";
833 }
834 
GetFeatureStrFromGCNArchName(const std::string & gcn_arch_name)835 std::pair<std::string, std::string> GetFeatureStrFromGCNArchName(
836     const std::string& gcn_arch_name) {
837   std::string feature_str;
838 
839   std::string gfx = gcn_arch_name;
840 #if TF_ROCM_VERSION < 30900
841   // For ROCm versions older than 3.9, hardcode it to "+code-object-v3"
842   // This is simply to preserve how things were...nohing else
843   feature_str = "+code-object-v3";
844 #elif TF_ROCM_VERSION < 40000
845   // For ROCM versions 3.9 and 3.10, hardcode it to empty string
846   feature_str = "";
847 #else
848   // For ROCm versions 4.0 and greater, we need to specify the correct
849   // feature str, based on the underlying GPU HW to get max performance.
850   std::vector<std::string> tokens = absl::StrSplit(gcn_arch_name, ':');
851   std::vector<std::string> mapped_tokens;
852   if (tokens.size() > 0) gfx = tokens[0];
853   for (auto it = tokens.begin(); it != tokens.end(); it++) {
854     // Skip the first token, that is the gfxNNN str
855     // The rest of the tokens are the feature/targetid strings
856     if (it != tokens.begin()) {
857       std::string token(*it);
858       std::string mapped_token = MapGCNArchNameTokenToFeatureStr(token);
859       mapped_tokens.push_back(mapped_token);
860     }
861   }
862   feature_str = absl::StrJoin(mapped_tokens, ",");
863 #endif
864 
865   return std::make_pair(gfx, feature_str);
866 }
867 
AMDGPUGetTargetMachine(llvm::Triple target_triple,GpuVersion gpu_version,const HloModuleConfig & hlo_module_config)868 std::unique_ptr<llvm::TargetMachine> AMDGPUGetTargetMachine(
869     llvm::Triple target_triple, GpuVersion gpu_version,
870     const HloModuleConfig& hlo_module_config) {
871   auto amdgpu_version = absl::get_if<std::string>(&gpu_version);
872   std::string gcn_arch_name = *amdgpu_version;
873   auto arch = GetFeatureStrFromGCNArchName(gcn_arch_name);
874   return GetTargetMachine(std::move(target_triple), arch.first,
875                           hlo_module_config, arch.second);
876 }
877 
AMDGPUBackendInit(const HloModuleConfig & hlo_module_config)878 void AMDGPUBackendInit(const HloModuleConfig& hlo_module_config) {
879   llvm_ir::InitializeLLVMCommandLineOptions(hlo_module_config);
880 
881   // Initialize the AMDGPU target; it's the only target we link with, so call
882   // its specific initialization functions instead of the catch-all
883   // InitializeAll*.
884 #if TENSORFLOW_USE_ROCM
885   LLVMInitializeAMDGPUTarget();
886   LLVMInitializeAMDGPUTargetInfo();
887   LLVMInitializeAMDGPUTargetMC();
888   LLVMInitializeAMDGPUAsmPrinter();
889 
890 #if TF_ROCM_VERSION < 40100
891   // Use code-object-v3 for ROCm versions 4.0.1 and lower, since the
892   // HIP runtime for those ROCm versions expects the v3 HSACO objects
893   // Default is now v4 for newer LLVM versions (starting around 210326)
894   FeedLLVMWithFlags({"--amdhsa-code-object-version=3"});
895 #endif
896 
897 #endif
898 
899   llvm::PassRegistry* registry = llvm::PassRegistry::getPassRegistry();
900   InitializePasses(registry);
901 }
902 
903 }  // namespace
904 
905 namespace amdgpu {
CompileToHsaco(llvm::Module * module,GpuVersion gpu_version,const HloModuleConfig & hlo_module_config,const string & rocdl_dir_path)906 StatusOr<std::vector<uint8>> CompileToHsaco(
907     llvm::Module* module, GpuVersion gpu_version,
908     const HloModuleConfig& hlo_module_config, const string& rocdl_dir_path) {
909   static absl::once_flag backend_init_flag;
910   absl::call_once(backend_init_flag, AMDGPUBackendInit, hlo_module_config);
911 
912   std::vector<uint8> hsaco;
913   std::unique_ptr<llvm::TargetMachine> target_machine;
914   std::string str;
915   llvm::raw_string_ostream stream(str);
916   stream << *module;
917   // Delete the first two lines, since they usually vary even when the rest of
918   // the code is the same (but verify that they are what we expect).
919   if (str.size() >= 13 && str.substr(0, 13) == "; ModuleID = ") {
920     auto pos = str.find('\n');
921     if (pos != std::string::npos) str = str.substr(pos + 1);
922   }
923   if (str.size() >= 18 && str.substr(0, 18) == "source_filename = ") {
924     auto pos = str.find('\n');
925     if (pos != std::string::npos) str = str.substr(pos + 1);
926   }
927   str += hlo_module_config.compilation_cache_key();
928   {
929     tensorflow::profiler::TraceMe activity(
930         [&] { return absl::StrCat("Compiling IR", module->getName().str()); },
931         tensorflow::profiler::TraceMeLevel::kInfo);
932     XLA_SCOPED_LOGGING_TIMER("Compile module " + module->getName().str());
933 
934     auto amdgpu_version = absl::get_if<std::string>(&gpu_version);
935     if (!amdgpu_version) {
936       return xla::InternalError(
937           "Incompatible AMD GCN ISA version was specified.");
938     }
939     uint64_t hash;
940     if (HsacoCache::Find(str, hash, *amdgpu_version, hsaco)) {
941       VLOG(1) << "HSACO cache hit";
942       return hsaco;
943     }
944     VLOG(1) << "HSACO cache miss";
945     bool dump_lls = false;
946     if (dump_lls) {
947       static int hsaco_count = 0;
948       std::string name = "/tmp/" + std::to_string(hsaco_count) + ".ll";
949       hsaco_count++;
950       std::ofstream ofs(name);
951       ofs << str;
952       ofs.close();
953     }
954 
955     llvm::Triple default_target_triple("amdgcn--amdhsa-amdgiz");
956     // Construct LLVM TargetMachine for AMDGPU.
957     std::unique_ptr<llvm::TargetMachine> target_machine =
958         AMDGPUGetTargetMachine(default_target_triple, gpu_version,
959                                hlo_module_config);
960 
961     // Link with ROCm-Device-Libs, and optimize the LLVM module.
962     TF_RETURN_IF_ERROR(LinkAndOptimizeModule(
963         module, gpu_version, hlo_module_config, rocdl_dir_path,
964         AMDGPUTargetModuleLinker, default_target_triple, target_machine.get(),
965         kAMDGPUInlineThreshold));
966 
967     // Lower optimized LLVM module to HSA code object.
968     TF_ASSIGN_OR_RETURN(hsaco, EmitModuleToHsaco(module, target_machine.get()));
969     HsacoCache::Add(str, hash, *amdgpu_version, hsaco);
970   }
971   return hsaco;
972 }
973 
974 }  // namespace amdgpu
975 
976 }  // namespace gpu
977 }  // namespace xla
978