• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h"
17 
18 #include <fstream>
19 #include <map>
20 #include <memory>
21 #include <string>
22 #include <utility>
23 
24 #include "absl/base/call_once.h"
25 #include "absl/memory/memory.h"
26 #include "absl/strings/str_cat.h"
27 #include "absl/strings/string_view.h"
28 #include "llvm/ADT/STLExtras.h"
29 #include "llvm/ADT/StringMap.h"
30 #include "llvm/ADT/StringSet.h"
31 #include "llvm/Analysis/TargetLibraryInfo.h"
32 #include "llvm/Analysis/TargetTransformInfo.h"
33 #include "llvm/Bitcode/BitcodeReader.h"
34 #include "llvm/Bitcode/BitcodeWriter.h"
35 #include "llvm/CodeGen/CommandFlags.h"
36 #include "llvm/IR/LLVMContext.h"
37 #include "llvm/IR/LegacyPassManager.h"
38 #include "llvm/IR/Module.h"
39 #include "llvm/IR/Verifier.h"
40 #include "llvm/InitializePasses.h"
41 #include "llvm/Linker/Linker.h"
42 #include "llvm/PassRegistry.h"
43 #include "llvm/Support/CommandLine.h"
44 #include "llvm/Support/FileSystem.h"
45 #include "llvm/Support/FormattedStream.h"
46 #include "llvm/Support/Program.h"
47 #include "llvm/Support/TargetRegistry.h"
48 #include "llvm/Support/TargetSelect.h"
49 #include "llvm/Support/ToolOutputFile.h"
50 #include "llvm/Target/TargetMachine.h"
51 #include "llvm/Transforms/IPO.h"
52 #include "llvm/Transforms/IPO/AlwaysInliner.h"
53 #include "llvm/Transforms/IPO/Internalize.h"
54 #include "llvm/Transforms/IPO/PassManagerBuilder.h"
55 #include "llvm/Transforms/Scalar.h"
56 #include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/dump_ir_pass.h"
57 #include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/utils.h"
58 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
59 #include "tensorflow/compiler/xla/status_macros.h"
60 #include "tensorflow/compiler/xla/types.h"
61 #include "tensorflow/compiler/xla/util.h"
62 #include "tensorflow/core/lib/io/path.h"
63 #include "tensorflow/core/platform/env.h"
64 #include "tensorflow/core/platform/logging.h"
65 #include "tensorflow/core/platform/random.h"
66 #include "tensorflow/core/platform/tracing.h"
67 #include "tensorflow/core/profiler/lib/traceme.h"
68 #include "tensorflow/core/util/env_var.h"
69 
70 #if !defined(PLATFORM_GOOGLE) && TENSORFLOW_USE_ROCM
71 #include "rocm/rocm_config.h"
72 #endif
73 
74 namespace xla {
75 namespace gpu {
76 namespace {
77 
78 static llvm::codegen::RegisterCodeGenFlags CGF;
79 
80 // Inline threshold value to use in LLVM AMDGPU backend.
81 const int kAMDGPUInlineThreshold = 0x100000;
82 
83 // Default inline threshold value to use in llvm.
84 const int kDefaultInlineThreshold = 1100;
85 
86 // Gets the GPU name as it's known to LLVM for a given compute
87 // capability.  If we see an unrecognized compute capability, we
88 // return the highest one that is known and below the selected device.
GetSmName(std::pair<int,int> compute_capability)89 static string GetSmName(std::pair<int, int> compute_capability) {
90   int compute_capability_version =
91       compute_capability.first * 10 + compute_capability.second;
92   int sm_version = 30;
93   // If the current compute capability isn't known, fallback to the
94   // most recent version before it.
95   int supported_versions[] = {75, 72, 70, 62, 61, 60, 53,
96                               52, 50, 37, 35, 32, 30};
97   for (int v : supported_versions) {
98     if (v <= compute_capability_version) {
99       sm_version = v;
100       break;
101     }
102   }
103 
104   // If the current CC isn't supported by LLVM and it is newer then
105   // the max supported LLVM version, do not warn about it. The end
106   // user can't do anything about this. PTX compiled for SM75 will
107   // run on SM80 too.
108   if (sm_version != compute_capability_version &&
109       compute_capability_version < supported_versions[0]) {
110     LOG(WARNING) << "Unknown compute capability (" << compute_capability.first
111                  << ", " << compute_capability.second << ") ."
112                  << "Defaulting to telling LLVM that we're compiling for sm_"
113                  << sm_version;
114   }
115   return absl::StrCat("sm_", sm_version);
116 }
117 
118 // Convenience function for producing a name of a temporary compilation product
119 // from the input filename.
MakeNameForTempProduct(absl::string_view input_filename,absl::string_view extension)120 string MakeNameForTempProduct(absl::string_view input_filename,
121                               absl::string_view extension) {
122   return ReplaceFilenameExtension(tensorflow::io::Basename(input_filename),
123                                   extension);
124 }
125 
126 // Initializes LLVM passes. Uses the PassRegistry mechanism.
InitializePasses(llvm::PassRegistry * pass_registry)127 void InitializePasses(llvm::PassRegistry* pass_registry) {
128   llvm::initializeCore(*pass_registry);
129   llvm::initializeCodeGen(*pass_registry);
130   llvm::initializeScalarOpts(*pass_registry);
131   llvm::initializeObjCARCOpts(*pass_registry);
132   llvm::initializeVectorization(*pass_registry);
133   llvm::initializeIPO(*pass_registry);
134   llvm::initializeAnalysis(*pass_registry);
135   llvm::initializeTransformUtils(*pass_registry);
136   llvm::initializeInstCombine(*pass_registry);
137   llvm::initializeInstrumentation(*pass_registry);
138   llvm::initializeTarget(*pass_registry);
139   llvm::initializeCodeGenPreparePass(*pass_registry);
140 }
141 
142 // Returns the TargetMachine, given a triple.
GetTargetMachine(llvm::Triple triple,absl::string_view cpu_name,const HloModuleConfig & hlo_module_config,absl::string_view feature_str)143 std::unique_ptr<llvm::TargetMachine> GetTargetMachine(
144     llvm::Triple triple, absl::string_view cpu_name,
145     const HloModuleConfig& hlo_module_config, absl::string_view feature_str) {
146   std::string error;
147   const llvm::Target* target =
148       llvm::TargetRegistry::lookupTarget("", triple, error);
149   if (target == nullptr) {
150     LOG(FATAL) << "Unable to find Target for triple '" << triple.str() << "'"
151                << " -- " << error;
152     return nullptr;
153   }
154 
155   llvm::TargetOptions target_options =
156       llvm::codegen::InitTargetOptionsFromCodeGenFlags(llvm::Triple());
157 
158   // Set the verbose assembly options.
159   target_options.MCOptions.AsmVerbose = false;
160 
161   // The selection of codegen optimization level is copied from function
162   // GetCodeGenOptLevel in //third_party/llvm/llvm/tools/opt/opt.cpp.
163   llvm::CodeGenOpt::Level codegen_opt_level;
164   switch (hlo_module_config.debug_options().xla_backend_optimization_level()) {
165     case 1:
166       codegen_opt_level = llvm::CodeGenOpt::Less;
167       break;
168     case 2:
169       codegen_opt_level = llvm::CodeGenOpt::Default;
170       break;
171     case 3:
172       codegen_opt_level = llvm::CodeGenOpt::Aggressive;
173       break;
174     default:
175       codegen_opt_level = llvm::CodeGenOpt::None;
176   }
177   return absl::WrapUnique(target->createTargetMachine(
178       triple.str(), llvm_ir::AsStringRef(cpu_name),
179       llvm_ir::AsStringRef(feature_str), target_options,
180       llvm::codegen::getExplicitRelocModel(),
181       llvm::codegen::getExplicitCodeModel(), codegen_opt_level));
182 }
183 
184 // Adds the standard LLVM optimization passes, based on the speed optimization
185 // level (opt_level) and size optimization level (size_level). Both module
186 // and function-level passes are added, so two pass managers are passed in and
187 // modified by this function.
AddOptimizationPasses(unsigned opt_level,unsigned size_level,llvm::TargetMachine * target_machine,llvm::legacy::PassManagerBase * module_passes,llvm::legacy::FunctionPassManager * function_passes,int inline_threshold)188 void AddOptimizationPasses(unsigned opt_level, unsigned size_level,
189                            llvm::TargetMachine* target_machine,
190                            llvm::legacy::PassManagerBase* module_passes,
191                            llvm::legacy::FunctionPassManager* function_passes,
192                            int inline_threshold) {
193   llvm::PassManagerBuilder builder;
194   builder.OptLevel = opt_level;
195   builder.SizeLevel = size_level;
196 
197   if (opt_level > 1) {
198     builder.Inliner = llvm::createFunctionInliningPass(inline_threshold);
199   } else {
200     // Only inline functions marked with "alwaysinline".
201     builder.Inliner = llvm::createAlwaysInlinerLegacyPass();
202   }
203 
204   builder.DisableUnrollLoops = opt_level == 0;
205   builder.LoopVectorize = opt_level > 0;
206   builder.SLPVectorize = opt_level > 1 && size_level < 2;
207 
208   // NVPTX's early-as-possible passes include NVVM reflect.
209   target_machine->adjustPassManager(builder);
210 
211   builder.populateFunctionPassManager(*function_passes);
212   builder.populateModulePassManager(*module_passes);
213 }
214 
215 // Emits the given module to a bit code file.
EmitBitcodeToFile(const llvm::Module & module,absl::string_view filename)216 void EmitBitcodeToFile(const llvm::Module& module, absl::string_view filename) {
217   std::error_code error_code;
218   llvm::ToolOutputFile outfile(string(filename).c_str(), error_code,
219                                llvm::sys::fs::F_None);
220   if (error_code) {
221     LOG(FATAL) << "opening bitcode file for writing: " << error_code.message();
222   }
223 
224   llvm::WriteBitcodeToFile(module, outfile.os());
225   outfile.keep();
226 }
227 
228 // Emits the given module to PTX. target_machine is an initialized TargetMachine
229 // for the NVPTX target.
EmitModuleToPTX(llvm::Module * module,llvm::TargetMachine * target_machine)230 string EmitModuleToPTX(llvm::Module* module,
231                        llvm::TargetMachine* target_machine) {
232   std::string ptx;
233   {
234     llvm::raw_string_ostream stream(ptx);
235     llvm::buffer_ostream pstream(stream);
236     // The extension is stripped by IrDumpingPassManager, so we need to
237     // get creative to add a suffix.
238     IrDumpingPassManager codegen_passes(
239         MakeNameForTempProduct(module->getModuleIdentifier(), "-nvptx.dummy"),
240         "", false);
241     codegen_passes.add(new llvm::TargetLibraryInfoWrapperPass(
242         llvm::Triple(module->getTargetTriple())));
243 
244     target_machine->addPassesToEmitFile(codegen_passes, pstream, nullptr,
245                                         llvm::CGFT_AssemblyFile);
246     codegen_passes.run(*module);
247   }
248 
249   return ptx;
250 }
251 
252 // LLVM has an extensive flags mechanism of its own, which is only accessible
253 // through the command line. Internal libraries within LLVM register parsers for
254 // flags, with no other way to configure them except pass these flags.
255 // To do this programmatically, we invoke ParseCommandLineOptions manually with
256 // a "fake argv".
257 // Note: setting flags with this method is stateful, since flags are just
258 // static globals within LLVM libraries.
FeedLLVMWithFlags(const std::vector<string> & cl_opts)259 void FeedLLVMWithFlags(const std::vector<string>& cl_opts) {
260   std::vector<const char*> fake_argv = {""};
261   for (const string& cl_opt : cl_opts) {
262     fake_argv.push_back(cl_opt.c_str());
263   }
264   llvm::cl::ParseCommandLineOptions(fake_argv.size(), &fake_argv[0]);
265 }
266 
267 // Returns whether the module could use any device bitcode library functions.
CouldNeedDeviceBitcode(const llvm::Module & module)268 bool CouldNeedDeviceBitcode(const llvm::Module& module) {
269   for (const llvm::Function& function : module.functions()) {
270     // The list of prefixes should be in sync with library functions used in
271     // target_util.cc.
272     if (!function.isIntrinsic() && function.isDeclaration() &&
273         (function.getName().startswith("__nv_") ||
274          function.getName().startswith("__ocml_") ||
275          function.getName().startswith("__ockl_"))) {
276       return true;
277     }
278   }
279   return false;
280 }
281 
282 // Links the module with a vector of path to bitcode modules.
283 // The caller must guarantee that the paths exist.
LinkWithBitcodeVector(llvm::Module * module,const std::vector<string> & bitcode_path_vector)284 Status LinkWithBitcodeVector(llvm::Module* module,
285                              const std::vector<string>& bitcode_path_vector) {
286   llvm::Linker linker(*module);
287 
288   for (auto& bitcode_path : bitcode_path_vector) {
289     if (!tensorflow::Env::Default()->FileExists(bitcode_path).ok()) {
290       LOG(ERROR) << "bitcode module is required by this HLO module but was "
291                     "not found at "
292                  << bitcode_path;
293       return xla::InternalError("bitcode module not found at %s", bitcode_path);
294     }
295 
296     std::unique_ptr<llvm::Module> bitcode_module =
297         LoadIRModule(bitcode_path, &module->getContext());
298     // Ignore the data layout of the module we're importing. This avoids a
299     // warning from the linker.
300     bitcode_module->setDataLayout(module->getDataLayout());
301     if (linker.linkInModule(
302             std::move(bitcode_module), llvm::Linker::Flags::LinkOnlyNeeded,
303             [](llvm::Module& M, const llvm::StringSet<>& GVS) {
304               internalizeModule(M, [&GVS](const llvm::GlobalValue& GV) {
305                 return !GV.hasName() || (GVS.count(GV.getName()) == 0);
306               });
307             })) {
308       return xla::InternalError("Error linking bitcode module from %s",
309                                 bitcode_path);
310     }
311   }
312   return Status::OK();
313 }
314 
315 // Links libdevice into the given module if the module needs libdevice.
LinkLibdeviceIfNecessary(llvm::Module * module,std::pair<int,int> compute_capability,const string & libdevice_dir_path)316 Status LinkLibdeviceIfNecessary(llvm::Module* module,
317                                 std::pair<int, int> compute_capability,
318                                 const string& libdevice_dir_path) {
319   if (!CouldNeedDeviceBitcode(*module)) {
320     return Status::OK();
321   }
322 
323   // CUDA 9+ uses a single libdevice file for all devices, and we don't support
324   // older CUDAs.
325   string libdevice_path =
326       tensorflow::io::JoinPath(libdevice_dir_path, "libdevice.10.bc");
327   if (!tensorflow::Env::Default()->FileExists(libdevice_path).ok()) {
328     LOG(WARNING)
329         << "libdevice is required by this HLO module but was not found at "
330         << libdevice_path;
331     return xla::InternalError("libdevice not found at %s", libdevice_path);
332   }
333 
334   VLOG(1) << "Linking with libdevice from: " << libdevice_path;
335   return LinkWithBitcodeVector(module, {libdevice_path});
336 }
337 
NVPTXTargetModuleLinker(llvm::Module * module,GpuVersion gpu_version,const HloModuleConfig & hlo_module_config,const string & device_bitcode_dir_path)338 Status NVPTXTargetModuleLinker(llvm::Module* module, GpuVersion gpu_version,
339                                const HloModuleConfig& hlo_module_config,
340                                const string& device_bitcode_dir_path) {
341   // Link the input module with libdevice, to pull in implementations of some
342   // builtins.
343   auto compute_capability = absl::get_if<std::pair<int, int>>(&gpu_version);
344   if (!compute_capability) {
345     return xla::InternalError("Incompatible compute capability was specified.");
346   }
347   TF_RETURN_IF_ERROR(LinkLibdeviceIfNecessary(module, *compute_capability,
348                                               device_bitcode_dir_path));
349 
350   // Set the flush-denormals-to-zero flag on the module so the NVVM reflect pass
351   // can access it.
352   module->addModuleFlag(llvm::Module::Override, "nvvm-reflect-ftz",
353                         hlo_module_config.debug_options().xla_gpu_ftz());
354 
355   // If ftz is enabled, set it as an attribute on every function in the module.
356   if (hlo_module_config.debug_options().xla_gpu_ftz()) {
357     for (llvm::Function& fn : *module) {
358       fn.addFnAttr("denormal-fp-math-f32", "preserve-sign");
359     }
360   }
361 
362   return Status::OK();
363 }
364 
NVPTXGetTargetMachine(llvm::Triple target_triple,std::pair<int,int> compute_capability,const HloModuleConfig & hlo_module_config)365 std::unique_ptr<llvm::TargetMachine> NVPTXGetTargetMachine(
366     llvm::Triple target_triple, std::pair<int, int> compute_capability,
367     const HloModuleConfig& hlo_module_config) {
368   // Figure out the exact name of the processor as known to the NVPTX backend
369   // from the gpu_architecture flag.
370   return GetTargetMachine(target_triple, GetSmName(compute_capability),
371                           hlo_module_config, "+ptx60");
372 }
373 
374 using TargetModuleLinker = std::function<Status(
375     llvm::Module*, GpuVersion, const HloModuleConfig&, const string&)>;
376 
LinkAndOptimizeModule(llvm::Module * module,GpuVersion gpu_version,const HloModuleConfig & hlo_module_config,const string & device_bitcode_dir_path,TargetModuleLinker module_linker,llvm::Triple default_target_triple,llvm::TargetMachine * target_machine,int inline_threshold)377 Status LinkAndOptimizeModule(llvm::Module* module, GpuVersion gpu_version,
378                              const HloModuleConfig& hlo_module_config,
379                              const string& device_bitcode_dir_path,
380                              TargetModuleLinker module_linker,
381                              llvm::Triple default_target_triple,
382                              llvm::TargetMachine* target_machine,
383                              int inline_threshold) {
384   TF_RETURN_IF_ERROR(module_linker(module, gpu_version, hlo_module_config,
385                                    device_bitcode_dir_path));
386 
387   IrDumpingPassManager module_passes(module->getModuleIdentifier(), "", false);
388 
389   // Add an appropriate TargetLibraryInfo pass for the module's triple.
390   llvm::TargetLibraryInfoWrapperPass* tliwp =
391       new llvm::TargetLibraryInfoWrapperPass(
392           llvm::Triple(module->getTargetTriple()));
393   module_passes.add(tliwp);
394 
395   // Try to fetch the target triple from the module. If not present, set a
396   // default target triple.
397   llvm::Triple target_triple = llvm::Triple(module->getTargetTriple());
398   if (target_triple.getArch() == llvm::Triple::UnknownArch) {
399     LOG(WARNING) << "target triple not found in the module";
400     target_triple = default_target_triple;
401   }
402 
403   module_passes.add(llvm::createTargetTransformInfoWrapperPass(
404       target_machine->getTargetIRAnalysis()));
405 
406   // The LLVM IR verifier performs sanity checking on the IR. This helps
407   // discover problems and report them in a meaningful manner, rather than let
408   // later passes report obscure assertions because of unfulfilled invariants.
409   module_passes.add(llvm::createVerifierPass());
410 
411   // Create the function-level pass manager. It needs data layout information
412   // too.
413   llvm::legacy::FunctionPassManager function_passes(module);
414 
415   int32 opt_level =
416       hlo_module_config.debug_options().xla_backend_optimization_level();
417 
418   if (opt_level < 2) {
419     LOG(ERROR) << std::string(80, '*');
420     LOG(ERROR) << "The XLA GPU backend doesn't support unoptimized code "
421                   "generation but ";
422     LOG(ERROR) << "--xla_backend_optimization_level is set to " << opt_level
423                << "!";
424     LOG(ERROR) << "(Supported configuration is "
425                   "--xla_backend_optimization_level >= 2.)";
426     LOG(ERROR) << std::string(80, '*');
427   }
428 
429   // Add optimization passes, and set inliner threshold.
430   AddOptimizationPasses(opt_level,
431                         /*size_level=*/0, target_machine, &module_passes,
432                         &function_passes, inline_threshold);
433 
434   // Loop unrolling exposes more opportunities for SROA. Therefore, we run SROA
435   // again after the standard optimization passes [http://b/13329423].
436   // TODO(jingyue): SROA may further expose more optimization opportunities such
437   // as more precise alias analysis and more function inlining (SROA may change
438   // the inlining cost of a function). For now, running SROA already emits good
439   // enough code for the evaluated benchmarks. We may want to run more
440   // optimizations later.
441   if (opt_level > 0) {
442     // LLVM's optimizer turns on SROA when the optimization level is greater
443     // than 0. We mimic this behavior here.
444     module_passes.add(llvm::createSROAPass());
445   }
446 
447   // Verify that the module is well formed after optimizations ran.
448   module_passes.add(llvm::createVerifierPass());
449 
450   // Done populating the pass managers. Now run them.
451 
452   function_passes.doInitialization();
453   for (auto func = module->begin(); func != module->end(); ++func) {
454     function_passes.run(*func);
455   }
456   function_passes.doFinalization();
457   module_passes.run(*module);
458 
459   return Status::OK();
460 }
461 
462 // One-time module initializer.
463 // Must be called only once -- DO NOT CALL DIRECTLY.
NVPTXBackendInit(const HloModuleConfig & hlo_module_config)464 void NVPTXBackendInit(const HloModuleConfig& hlo_module_config) {
465   // Feed all customized flags here, so we can override them with llvm_cl_opts
466   // without redeploy the compiler for development purpose.
467 
468   // This flag tunes a threshold in branch folding. The default threshold, which
469   // is one, is not suitable for CUDA programs where branches are more expensive
470   // than for CPU programs. Setting the threshold to 2 improves the latency of
471   // TwoDPatchDotProductKernel_IND_3_ND_48 by over 5%, and does not affect the
472   // latency of other benchmarks so far.
473   //
474   // I also tried setting this threshold to other values:
475   // * 3-6 gives similar results as 2;
476   // * >6 start hurting the performance of at least dot product kernels.
477   //
478   // TODO(jingyue): The current threshold only considers the number of IR
479   // instructions which do not accurately reflect the true cost. We need a
480   // better cost model.
481   FeedLLVMWithFlags({"-bonus-inst-threshold=2"});
482   // Increase limit when scanning memory dependencies.  This helps to reduce
483   // more redundant load instructions.
484   //
485   // The specific value is currently large enough for s3d in shoc benchmark,
486   // which contains a lot of load instructions and many arithmetic instructions
487   // between those loads.
488   FeedLLVMWithFlags({"-memdep-block-scan-limit=500"});
489 
490   // Use div.full -- it matters for some float-division heavy benchmarks.
491   // Using div.approx produces incorrect result for float32(max)/float32(max).
492   FeedLLVMWithFlags({"-nvptx-prec-divf32=1"});
493 
494   llvm_ir::InitializeLLVMCommandLineOptions(hlo_module_config);
495 
496   // Initialize the NVPTX target; it's the only target we link with, so call its
497   // specific initialization functions instead of the catch-all InitializeAll*.
498   LLVMInitializeNVPTXTarget();
499   LLVMInitializeNVPTXTargetInfo();
500   LLVMInitializeNVPTXTargetMC();
501   LLVMInitializeNVPTXAsmPrinter();
502 
503   // Initialize the LLVM optimization passes.
504   llvm::PassRegistry* registry = llvm::PassRegistry::getPassRegistry();
505   InitializePasses(registry);
506 }
507 
508 }  // namespace
509 
510 namespace nvptx {
511 
CompileToPtx(llvm::Module * module,GpuVersion gpu_version,const HloModuleConfig & hlo_module_config,const string & libdevice_dir_path,std::function<void (llvm::TargetMachine *)> configure_target)512 StatusOr<string> CompileToPtx(
513     llvm::Module* module, GpuVersion gpu_version,
514     const HloModuleConfig& hlo_module_config, const string& libdevice_dir_path,
515     std::function<void(llvm::TargetMachine*)> configure_target) {
516   static absl::once_flag backend_init_flag;
517   absl::call_once(backend_init_flag, NVPTXBackendInit, hlo_module_config);
518 
519   string ptx;
520   std::unique_ptr<llvm::TargetMachine> target_machine;
521   {
522     tensorflow::profiler::TraceMe activity(
523         [&] { return absl::StrCat("Compiling IR:", module->getName().str()); },
524         tensorflow::profiler::TraceMeLevel::kInfo);
525     XLA_SCOPED_LOGGING_TIMER("Compile module " + module->getName().str());
526 
527     // If the module has no functions or globals, there's nothing to compile.
528     // Just return an empty string.
529     if (module->empty() && module->global_empty()) {
530       VLOG(2) << "Module '" << module->getName().str()
531               << "' is empty. Skipping compilation.";
532       return string();
533     }
534 
535     auto compute_capability = absl::get_if<std::pair<int, int>>(&gpu_version);
536     if (!compute_capability) {
537       return xla::InternalError(
538           "Incompatible compute capability was specified.");
539     }
540 
541     llvm::Triple default_target_triple("nvptx64-unknown-unknown");
542     // Construct LLVM TargetMachine for NVPTX.
543     std::unique_ptr<llvm::TargetMachine> target_machine = NVPTXGetTargetMachine(
544         default_target_triple, *compute_capability, hlo_module_config);
545 
546     // Apply target machine configuration from call-back if available.
547     if (configure_target) {
548       configure_target(target_machine.get());
549     }
550 
551     // Link with libdevice, and optimize the LLVM module.
552     TF_RETURN_IF_ERROR(LinkAndOptimizeModule(
553         module, gpu_version, hlo_module_config, libdevice_dir_path,
554         NVPTXTargetModuleLinker, default_target_triple, target_machine.get(),
555         kDefaultInlineThreshold));
556 
557     // Lower optimized LLVM module to PTX.
558     ptx = EmitModuleToPTX(module, target_machine.get());
559   }
560   return ptx;
561 }
562 
563 }  // namespace nvptx
564 
565 namespace {
566 
567 // Gets the ROCm-Device-Libs filenames for a particular AMDGPU version.
GetROCDLPaths(int amdgpu_version,const string & rocdl_dir_path)568 static std::vector<string> GetROCDLPaths(int amdgpu_version,
569                                          const string& rocdl_dir_path) {
570   // AMDGPU version-neutral bitcodes.
571 #if TF_ROCM_VERSION >= 30900
572   static std::vector<string>* rocdl_filenames = new std::vector<string>(
573       {"hc.bc", "opencl.bc", "ocml.bc", "ockl.bc", "oclc_finite_only_off.bc",
574        "oclc_daz_opt_off.bc", "oclc_correctly_rounded_sqrt_on.bc",
575        "oclc_unsafe_math_off.bc", "oclc_wavefrontsize64_on.bc"});
576 #else
577   static std::vector<string>* rocdl_filenames = new std::vector<string>(
578       {"hc.amdgcn.bc", "opencl.amdgcn.bc", "ocml.amdgcn.bc", "ockl.amdgcn.bc",
579        "oclc_finite_only_off.amdgcn.bc", "oclc_daz_opt_off.amdgcn.bc",
580        "oclc_correctly_rounded_sqrt_on.amdgcn.bc",
581        "oclc_unsafe_math_off.amdgcn.bc", "oclc_wavefrontsize64_on.amdgcn.bc"});
582 #endif
583 
584   // Construct full path to ROCDL bitcode libraries.
585   std::vector<string> result;
586   for (auto& filename : *rocdl_filenames) {
587     result.push_back(tensorflow::io::JoinPath(rocdl_dir_path, filename));
588   }
589 
590   // Add AMDGPU version-specific bitcodes.
591   result.push_back(tensorflow::io::JoinPath(
592       rocdl_dir_path,
593 #if TF_ROCM_VERSION >= 30900
594       absl::StrCat("oclc_isa_version_", amdgpu_version, ".bc")));
595 #else
596       absl::StrCat("oclc_isa_version_", amdgpu_version, ".amdgcn.bc")));
597 #endif
598   return result;
599 }
600 
601 struct HsacoCacheEntry {
602   uint64 hash;
603   std::string ir;
604   std::string gfx;
605   std::vector<uint8> hsaco;
606 };
607 
608 struct HsacoCache {
609  protected:
610   std::vector<HsacoCacheEntry> cache;
611   std::mutex m_mutex;
612   int request_count = 0;
613   int hit_count = 0;
614 
615  public:
616   static bool Find(const std::string& ir, uint64_t& hash,
617                    const std::string& gfx, std::vector<uint8>& hsaco);
618   static void Add(const std::string& ir, uint64_t hash, const std::string& gfx,
619                   const std::vector<uint8>& hsaco);
620 };
621 
622 static HsacoCache g_hsacoCache;
623 
Find(const std::string & ir,uint64_t & hash,const std::string & gfx,std::vector<uint8> & hsaco)624 bool HsacoCache::Find(const std::string& ir, uint64_t& hash,
625                       const std::string& gfx, std::vector<uint8>& hsaco) {
626   std::lock_guard<std::mutex> lg(g_hsacoCache.m_mutex);
627   hash = std::hash<std::string>{}(ir);
628   bool hit = false;
629   for (auto& x : g_hsacoCache.cache) {
630     if (x.hash != hash) continue;
631     if (x.gfx != gfx) continue;
632     if (x.ir != ir) continue;
633     hsaco = x.hsaco;
634     hit = true;
635     break;
636   }
637   g_hsacoCache.request_count++;
638   if (hit) g_hsacoCache.hit_count++;
639   if (!(g_hsacoCache.request_count % 50))
640     VLOG(1) << "HSACO cache: " << g_hsacoCache.request_count << " requests, "
641             << g_hsacoCache.hit_count << " hits";
642   return hit;
643 }
644 
Add(const std::string & ir,uint64_t hash,const std::string & gfx,const std::vector<uint8> & hsaco)645 void HsacoCache::Add(const std::string& ir, uint64_t hash,
646                      const std::string& gfx, const std::vector<uint8>& hsaco) {
647   std::lock_guard<std::mutex> lg(g_hsacoCache.m_mutex);
648   g_hsacoCache.cache.resize(g_hsacoCache.cache.size() + 1);
649   g_hsacoCache.cache.back().ir = ir;
650   g_hsacoCache.cache.back().hash = hash;
651   g_hsacoCache.cache.back().gfx = gfx;
652   g_hsacoCache.cache.back().hsaco = hsaco;
653 }
654 
655 // Emits the given module to HSA Code Object. target_machine is an initialized
656 // TargetMachine for the AMDGPU target.
EmitModuleToHsaco(llvm::Module * module,llvm::TargetMachine * target_machine)657 StatusOr<std::vector<uint8>> EmitModuleToHsaco(
658     llvm::Module* module, llvm::TargetMachine* target_machine) {
659   auto* env = tensorflow::Env::Default();
660   std::vector<std::string> tempdir_vector;
661   env->GetLocalTempDirectories(&tempdir_vector);
662   if (tempdir_vector.empty()) {
663     return xla::InternalError(
664         "Unable to locate a temporary directory for compile-time artifacts.");
665   }
666   std::string tempdir_name = tempdir_vector.front();
667   VLOG(1) << "Compile-time artifacts located at: " << tempdir_name;
668 
669   bool keep_tempfiles = false;
670   TF_CHECK_OK(tensorflow::ReadBoolFromEnvVar("TF_ROCM_KEEP_XLA_TEMPFILES",
671                                              /*default_val=*/false,
672                                              &keep_tempfiles));
673   // Prepare filenames for all stages of compilation:
674   // IR, binary ISA, and HSACO.
675   std::string random_number = std::to_string(tensorflow::random::New64());
676   std::string ir_filename =
677       absl::StrCat(module->getModuleIdentifier(), random_number + ".ll");
678   std::string ir_path = tensorflow::io::JoinPath(tempdir_name, ir_filename);
679 
680   std::string ir_opt_filename =
681       absl::StrCat(module->getModuleIdentifier(), random_number + "_opt.ll");
682   std::string ir_opt_path =
683       tensorflow::io::JoinPath(tempdir_name, ir_opt_filename);
684 
685   std::string isabin_filename =
686       absl::StrCat(module->getModuleIdentifier(), random_number + ".o");
687   std::string isabin_path =
688       tensorflow::io::JoinPath(tempdir_name, isabin_filename);
689 
690   std::string hsaco_filename =
691       absl::StrCat(module->getModuleIdentifier(), random_number + ".hsaco");
692   std::string hsaco_path =
693       tensorflow::io::JoinPath(tempdir_name, hsaco_filename);
694 
695   std::error_code ec;
696 
697   // Dump LLVM IR.
698   std::unique_ptr<llvm::raw_fd_ostream> ir_fs(
699       new llvm::raw_fd_ostream(ir_path, ec, llvm::sys::fs::F_None));
700   module->print(*ir_fs, nullptr);
701   ir_fs->flush();
702 
703   // Emit GCN ISA binary.
704   // The extension is stripped by IrDumpingPassManager, so we need to
705   // get creative to add a suffix.
706   std::string module_id = module->getModuleIdentifier();
707   IrDumpingPassManager codegen_passes(
708       ReplaceFilenameExtension(tensorflow::io::Basename(module_id),
709                                random_number + "-amdgpu.dummy"),
710       "", false);
711   codegen_passes.add(new llvm::TargetLibraryInfoWrapperPass(
712       llvm::Triple(module->getTargetTriple())));
713   llvm::SmallVector<char, 0> stream;
714   llvm::raw_svector_ostream pstream(stream);
715   std::unique_ptr<llvm::raw_fd_ostream> isabin_fs(
716       new llvm::raw_fd_ostream(isabin_path, ec, llvm::sys::fs::F_Text));
717   module->setDataLayout(target_machine->createDataLayout());
718   target_machine->addPassesToEmitFile(codegen_passes, *isabin_fs, nullptr,
719                                       llvm::CGFT_ObjectFile);
720   codegen_passes.run(*module);
721   isabin_fs->flush();
722 
723   if (keep_tempfiles) {
724     std::unique_ptr<llvm::raw_fd_ostream> ir_fs(
725         new llvm::raw_fd_ostream(ir_opt_path, ec, llvm::sys::fs::F_None));
726     module->print(*ir_fs, nullptr);
727     ir_fs->flush();
728   }
729   // Locate lld.
730   // TODO(whchung@gmail.com): change to tensorflow::ROCmRoot() after
731   // ROCm-Device-Libs PR.
732   std::string lld_path_1 = tensorflow::io::JoinPath("/opt/rocm", "hcc/bin");
733   std::string lld_path_2 = tensorflow::io::JoinPath("/opt/rocm", "llvm/bin");
734   auto lld_program =
735       llvm::sys::findProgramByName("ld.lld", {lld_path_1, lld_path_2});
736   if (!lld_program) {
737     return xla::InternalError("unable to find ld.lld in PATH: %s",
738                               lld_program.getError().message());
739   }
740   std::vector<llvm::StringRef> lld_args{
741       llvm_ir::AsStringRef("ld.lld"),
742       llvm_ir::AsStringRef("-flavor"),
743       llvm_ir::AsStringRef("gnu"),
744       llvm_ir::AsStringRef("-shared"),
745       llvm_ir::AsStringRef(isabin_path),
746       llvm_ir::AsStringRef("-o"),
747       llvm_ir::AsStringRef(hsaco_path),
748   };
749 
750   std::string error_message;
751   int lld_result =
752       llvm::sys::ExecuteAndWait(*lld_program, llvm_ir::AsArrayRef(lld_args),
753                                 llvm::None, {}, 0, 0, &error_message);
754   if (lld_result) {
755     return xla::InternalError("ld.lld execute fail: %s, error code %d",
756                               error_message, lld_result);
757   }
758 
759   // Read HSACO.
760   std::ifstream hsaco_file(hsaco_path, std::ios::binary | std::ios::ate);
761   std::ifstream::pos_type hsaco_file_size = hsaco_file.tellg();
762 
763   std::vector<uint8> hsaco(hsaco_file_size);
764   hsaco_file.seekg(0, std::ios::beg);
765   hsaco_file.read(reinterpret_cast<char*>(&hsaco[0]), hsaco_file_size);
766   hsaco_file.close();
767   if (!keep_tempfiles) {
768     remove(ir_path.c_str());
769     remove(isabin_path.c_str());
770     remove(hsaco_path.c_str());
771   }
772   return hsaco;
773 }
774 
775 // Links ROCm-Device-Libs into the given module if the module needs it.
LinkROCDLIfNecessary(llvm::Module * module,int amdgpu_version,const string & rocdl_dir_path)776 Status LinkROCDLIfNecessary(llvm::Module* module, int amdgpu_version,
777                             const string& rocdl_dir_path) {
778   if (!CouldNeedDeviceBitcode(*module)) {
779     return Status::OK();
780   }
781 
782   return LinkWithBitcodeVector(module,
783                                GetROCDLPaths(amdgpu_version, rocdl_dir_path));
784 }
785 
AMDGPUTargetModuleLinker(llvm::Module * module,GpuVersion gpu_version,const HloModuleConfig & hlo_module_config,const string & device_bitcode_dir_path)786 Status AMDGPUTargetModuleLinker(llvm::Module* module, GpuVersion gpu_version,
787                                 const HloModuleConfig& hlo_module_config,
788                                 const string& device_bitcode_dir_path) {
789   // Link the input module with ROCDL.
790   auto amdgpu_version = absl::get_if<std::pair<int, std::string>>(&gpu_version);
791   if (!amdgpu_version) {
792     return xla::InternalError(
793         "Incompatible AMD GCN ISA version was specified.");
794   }
795   TF_RETURN_IF_ERROR(LinkROCDLIfNecessary(module, amdgpu_version->first,
796                                           device_bitcode_dir_path));
797 
798   return Status::OK();
799 }
800 
801 // The following routine maps a feature token extracted from the
802 // hipDeviceProp_t::gcnArchName string, and maps it to a valid feature_str
803 // to be used for creating the AMDGPUTarget.
804 // This mapping is currently in a state of flux because TF XLA uses its
805 // own copy of LLVM, which is different from the LLVM version used by
806 // hipcc/runtime in the ROCm install. Ordinarily this is not a problem,
807 // but right now, the LLVM version used by hipcc/runtime has "targetID"
808 // related changes which have not yet been upstreamed (to the LLVM repo)
809 // When that upstreaming happens (and TF LLVM pointer moves past the
810 // upstream commit), the following mapping will need to change
MapGCNArchNameTokenToFeatureStr(const std::string & token)811 std::string MapGCNArchNameTokenToFeatureStr(const std::string& token) {
812   if (token == "sramecc+") {
813     return "+sram-ecc";
814   } else if (token == "sramecc-") {
815     return "-sram-ecc";
816   } else if (token == "xnack+") {
817     return "+xnack";
818   } else if (token == "xnack-") {
819     return "-xnack";
820   }
821   return "";
822 }
823 
GetFeatureStrFromGCNArchName(const std::string & gcn_arch_name)824 std::string GetFeatureStrFromGCNArchName(const std::string& gcn_arch_name) {
825   std::string feature_str;
826 
827 #if TF_ROCM_VERSION < 30900
828   // For ROCm versions older than 3.9, hardcode it to "+code-object-v3"
829   // This is simply to preserve how things were...nohing else
830   feature_str = "+code-object-v3";
831 #elif TF_ROCM_VERSION < 40000
832   // For ROCM versions 3.9 and 3.10, hardcode it to empty string
833   feature_str = "";
834 #else
835   // For ROCm versions 4.0 and greater, we need to specify the correct
836   // feature str, based on the underlying GPU HW to get max performance.
837   std::vector<std::string> tokens = absl::StrSplit(gcn_arch_name, ':');
838   std::vector<std::string> mapped_tokens;
839   for (auto it = tokens.begin(); it != tokens.end(); it++) {
840     // Skip the first token, that is the gfxNNN str
841     // The rest of the tokens are the feature/targetid strings
842     if (it != tokens.begin()) {
843       std::string token(*it);
844       std::string mapped_token = MapGCNArchNameTokenToFeatureStr(token);
845       mapped_tokens.push_back(mapped_token);
846     }
847   }
848   feature_str = absl::StrJoin(mapped_tokens, ",");
849 #endif
850 
851   return feature_str;
852 }
853 
AMDGPUGetTargetMachine(llvm::Triple target_triple,GpuVersion gpu_version,const HloModuleConfig & hlo_module_config)854 std::unique_ptr<llvm::TargetMachine> AMDGPUGetTargetMachine(
855     llvm::Triple target_triple, GpuVersion gpu_version,
856     const HloModuleConfig& hlo_module_config) {
857   auto amdgpu_version = absl::get_if<std::pair<int, std::string>>(&gpu_version);
858   int gcn_arch_value = amdgpu_version->first;
859   std::string gcn_arch_name = amdgpu_version->second;
860   std::string feature_str = GetFeatureStrFromGCNArchName(gcn_arch_name);
861   return GetTargetMachine(std::move(target_triple),
862                           absl::StrCat("gfx", gcn_arch_value),
863                           hlo_module_config, feature_str);
864 }
865 
AMDGPUBackendInit(const HloModuleConfig & hlo_module_config)866 void AMDGPUBackendInit(const HloModuleConfig& hlo_module_config) {
867   llvm_ir::InitializeLLVMCommandLineOptions(hlo_module_config);
868 
869   // Initialize the AMDGPU target; it's the only target we link with, so call
870   // its specific initialization functions instead of the catch-all
871   // InitializeAll*.
872 #if TENSORFLOW_USE_ROCM
873   LLVMInitializeAMDGPUTarget();
874   LLVMInitializeAMDGPUTargetInfo();
875   LLVMInitializeAMDGPUTargetMC();
876   LLVMInitializeAMDGPUAsmPrinter();
877 #endif
878 
879   llvm::PassRegistry* registry = llvm::PassRegistry::getPassRegistry();
880   InitializePasses(registry);
881 }
882 
883 }  // namespace
884 
885 namespace amdgpu {
CompileToHsaco(llvm::Module * module,GpuVersion gpu_version,const HloModuleConfig & hlo_module_config,const string & rocdl_dir_path)886 StatusOr<std::vector<uint8>> CompileToHsaco(
887     llvm::Module* module, GpuVersion gpu_version,
888     const HloModuleConfig& hlo_module_config, const string& rocdl_dir_path) {
889   static absl::once_flag backend_init_flag;
890   absl::call_once(backend_init_flag, AMDGPUBackendInit, hlo_module_config);
891 
892   std::vector<uint8> hsaco;
893   std::unique_ptr<llvm::TargetMachine> target_machine;
894   std::string str;
895   llvm::raw_string_ostream stream(str);
896   stream << *module;
897   // Delete the first two lines, since they usually vary even when the rest of
898   // the code is the same (but verify that they are what we expect).
899   if (str.size() >= 13 && str.substr(0, 13) == "; ModuleID = ") {
900     auto pos = str.find('\n');
901     if (pos != std::string::npos) str = str.substr(pos + 1);
902   }
903   if (str.size() >= 18 && str.substr(0, 18) == "source_filename = ") {
904     auto pos = str.find('\n');
905     if (pos != std::string::npos) str = str.substr(pos + 1);
906   }
907   str += hlo_module_config.compilation_cache_key();
908   {
909     tensorflow::profiler::TraceMe activity(
910         [&] { return absl::StrCat("Compiling IR", module->getName().str()); },
911         tensorflow::profiler::TraceMeLevel::kInfo);
912     XLA_SCOPED_LOGGING_TIMER("Compile module " + module->getName().str());
913 
914     auto amdgpu_version =
915         absl::get_if<std::pair<int, std::string>>(&gpu_version);
916     if (!amdgpu_version) {
917       return xla::InternalError(
918           "Incompatible AMD GCN ISA version was specified.");
919     }
920     uint64_t hash;
921     if (HsacoCache::Find(str, hash, amdgpu_version->second, hsaco)) {
922       VLOG(1) << "HSACO cache hit";
923       return hsaco;
924     }
925     VLOG(1) << "HSACO cache miss";
926     bool dump_lls = false;
927     if (dump_lls) {
928       static int hsaco_count = 0;
929       std::string name = "/tmp/" + std::to_string(hsaco_count) + ".ll";
930       hsaco_count++;
931       std::ofstream ofs(name);
932       ofs << str;
933       ofs.close();
934     }
935 
936     llvm::Triple default_target_triple("amdgcn--amdhsa-amdgiz");
937     // Construct LLVM TargetMachine for AMDGPU.
938     std::unique_ptr<llvm::TargetMachine> target_machine =
939         AMDGPUGetTargetMachine(default_target_triple, gpu_version,
940                                hlo_module_config);
941 
942     // Link with ROCm-Device-Libs, and optimize the LLVM module.
943     TF_RETURN_IF_ERROR(LinkAndOptimizeModule(
944         module, gpu_version, hlo_module_config, rocdl_dir_path,
945         AMDGPUTargetModuleLinker, default_target_triple, target_machine.get(),
946         kAMDGPUInlineThreshold));
947 
948     // Lower optimized LLVM module to HSA code object.
949     TF_ASSIGN_OR_RETURN(hsaco, EmitModuleToHsaco(module, target_machine.get()));
950     HsacoCache::Add(str, hash, amdgpu_version->second, hsaco);
951   }
952   return hsaco;
953 }
954 
955 }  // namespace amdgpu
956 
957 }  // namespace gpu
958 }  // namespace xla
959