• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h"
17 
18 #include <stddef.h>
19 #include <string.h>
20 #include <map>
21 #include <mutex>  // NOLINT(build/c++11): only using std::call_once, not mutex.
22 #include <string>
23 #include <unordered_map>
24 #include <utility>
25 #include <vector>
26 
27 // IWYU pragma: no_include "llvm/Config/Disassemblers.def.inc"
28 // IWYU pragma: no_include "llvm/Config/Targets.def.inc"
29 #include "absl/memory/memory.h"
30 #include "absl/strings/str_cat.h"
31 #include "llvm/ADT/StringRef.h"
32 #include "llvm/ADT/Triple.h"
33 #include "llvm/IR/Function.h"
34 #include "llvm/IR/LLVMContext.h"
35 #include "llvm/IR/Mangler.h"
36 #include "llvm/IR/Module.h"
37 #include "llvm/IR/Verifier.h"
38 #include "llvm/Object/ObjectFile.h"
39 #include "llvm/Support/CommandLine.h"
40 #include "llvm/Support/TargetRegistry.h"
41 #include "llvm/Support/TargetSelect.h"
42 #include "llvm/Target/TargetMachine.h"
43 #include "llvm/Target/TargetOptions.h"
44 #include "tensorflow/compiler/xla/literal.h"
45 #include "tensorflow/compiler/xla/map_util.h"
46 #include "tensorflow/compiler/xla/protobuf_util.h"
47 #include "tensorflow/compiler/xla/service/algebraic_simplifier.h"
48 #include "tensorflow/compiler/xla/service/batch_dot_simplification.h"
49 #include "tensorflow/compiler/xla/service/batchnorm_expander.h"
50 #include "tensorflow/compiler/xla/service/buffer_assignment.h"
51 #include "tensorflow/compiler/xla/service/buffer_liveness.h"
52 #include "tensorflow/compiler/xla/service/call_inliner.h"
53 #include "tensorflow/compiler/xla/service/cholesky_expander.h"
54 #include "tensorflow/compiler/xla/service/conditional_simplifier.h"
55 #include "tensorflow/compiler/xla/service/convolution_group_converter.h"
56 #include "tensorflow/compiler/xla/service/cpu/buffer_info_util.h"
57 #include "tensorflow/compiler/xla/service/cpu/compiler_functor.h"
58 #include "tensorflow/compiler/xla/service/cpu/conv_canonicalization.h"
59 #include "tensorflow/compiler/xla/service/cpu/cpu_copy_insertion.h"
60 #include "tensorflow/compiler/xla/service/cpu/cpu_executable.h"
61 #include "tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h"
62 #include "tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.h"
63 #include "tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.h"
64 #include "tensorflow/compiler/xla/service/cpu/cpu_options.h"
65 #include "tensorflow/compiler/xla/service/cpu/disassembler.h"
66 #include "tensorflow/compiler/xla/service/cpu/dot_op_emitter.h"
67 #include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h"
68 #include "tensorflow/compiler/xla/service/cpu/ir_emitter.h"
69 #include "tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h"
70 #include "tensorflow/compiler/xla/service/cpu/simple_orc_jit.h"
71 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
72 #include "tensorflow/compiler/xla/service/dot_decomposer.h"
73 #include "tensorflow/compiler/xla/service/dump.h"
74 #include "tensorflow/compiler/xla/service/dynamic_index_splitter.h"
75 #include "tensorflow/compiler/xla/service/flatten_call_graph.h"
76 #include "tensorflow/compiler/xla/service/hlo.pb.h"
77 #include "tensorflow/compiler/xla/service/hlo_computation.h"
78 #include "tensorflow/compiler/xla/service/hlo_constant_folding.h"
79 #include "tensorflow/compiler/xla/service/hlo_cse.h"
80 #include "tensorflow/compiler/xla/service/hlo_dce.h"
81 #include "tensorflow/compiler/xla/service/hlo_element_type_converter.h"
82 #include "tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter.h"
83 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
84 #include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h"
85 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
86 #include "tensorflow/compiler/xla/service/hlo_ordering.h"
87 #include "tensorflow/compiler/xla/service/hlo_pass_fix.h"
88 #include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h"
89 #include "tensorflow/compiler/xla/service/hlo_proto_util.h"
90 #include "tensorflow/compiler/xla/service/hlo_subcomputation_unification.h"
91 #include "tensorflow/compiler/xla/service/hlo_verifier.h"
92 #include "tensorflow/compiler/xla/service/indexed_array_analysis.h"
93 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
94 #include "tensorflow/compiler/xla/service/map_inliner.h"
95 #include "tensorflow/compiler/xla/service/reduce_precision_insertion.h"
96 #include "tensorflow/compiler/xla/service/reshape_mover.h"
97 #include "tensorflow/compiler/xla/service/scatter_expander.h"
98 #include "tensorflow/compiler/xla/service/sort_simplifier.h"
99 #include "tensorflow/compiler/xla/service/transpose_folding.h"
100 #include "tensorflow/compiler/xla/service/triangular_solve_expander.h"
101 #include "tensorflow/compiler/xla/service/tuple_simplifier.h"
102 #include "tensorflow/compiler/xla/service/while_loop_constant_sinking.h"
103 #include "tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h"
104 #include "tensorflow/compiler/xla/service/while_loop_simplifier.h"
105 #include "tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h"
106 #include "tensorflow/compiler/xla/status_macros.h"
107 #include "tensorflow/compiler/xla/statusor.h"
108 #include "tensorflow/compiler/xla/types.h"
109 #include "tensorflow/compiler/xla/util.h"
110 #include "tensorflow/compiler/xla/xla_data.pb.h"
111 #include "tensorflow/core/platform/dynamic_annotations.h"
112 
113 namespace xla {
114 namespace cpu {
115 using BufferInfo = ::tensorflow::cpu_function_runtime::BufferInfo;
116 
CpuAotCompilationOptions(string triple,string cpu_name,string features,string entry_point_name,RelocationModel relocation_model)117 CpuAotCompilationOptions::CpuAotCompilationOptions(
118     string triple, string cpu_name, string features, string entry_point_name,
119     RelocationModel relocation_model)
120     : triple_(std::move(triple)),
121       cpu_name_(std::move(cpu_name)),
122       features_(std::move(features)),
123       entry_point_name_(std::move(entry_point_name)),
124       relocation_model_(relocation_model) {}
125 
126 CpuAotCompilationOptions::~CpuAotCompilationOptions() = default;
127 
PlatformId() const128 se::Platform::Id CpuAotCompilationOptions::PlatformId() const {
129   return se::host::kHostPlatformId;
130 }
131 
CpuAotCompilationResult(ObjectFileData object_file_data,std::vector<BufferInfo> buffer_infos,int64 result_buffer_index,std::unique_ptr<HloProfilePrinterData> hlo_profile_printer_data)132 CpuAotCompilationResult::CpuAotCompilationResult(
133     ObjectFileData object_file_data, std::vector<BufferInfo> buffer_infos,
134     int64 result_buffer_index,
135     std::unique_ptr<HloProfilePrinterData> hlo_profile_printer_data)
136     : object_file_data_(std::move(object_file_data)),
137       buffer_infos_(std::move(buffer_infos)),
138       result_buffer_index_(result_buffer_index),
139       hlo_profile_printer_data_(std::move(hlo_profile_printer_data)) {}
140 
141 CpuAotCompilationResult::~CpuAotCompilationResult() = default;
142 
CpuCompiler()143 CpuCompiler::CpuCompiler() {
144   // Initialize LLVM the first time the CpuCompiler is initialized.
145   static bool llvm_initialized = []() {
146     InitializeLLVMTarget();
147     return true;
148   }();
149   (void)llvm_initialized;
150 }
151 
InitializeLLVMTarget()152 /* static */ void CpuCompiler::InitializeLLVMTarget() {
153   // Initialize LLVM's MC layer for the native target.
154   llvm::InitializeNativeTarget();
155   llvm::InitializeNativeTargetAsmPrinter();
156   LLVMInitializeX86Target();
157   LLVMInitializeX86TargetInfo();
158   LLVMInitializeX86TargetMC();
159   LLVMInitializeX86AsmPrinter();
160   LLVMInitializeX86Disassembler();
161   LLVMInitializeARMTarget();
162   LLVMInitializeARMTargetInfo();
163   LLVMInitializeARMTargetMC();
164   LLVMInitializeARMAsmPrinter();
165   LLVMInitializeARMDisassembler();
166   LLVMInitializeAArch64Target();
167   LLVMInitializeAArch64TargetInfo();
168   LLVMInitializeAArch64TargetMC();
169   LLVMInitializeAArch64AsmPrinter();
170   LLVMInitializeAArch64Disassembler();
171 }
172 
173 namespace {
174 
175 // LLVM makes certain options configurable only through its command-line
176 // options; it provide the ParseCommandLineOptions function that lets us set
177 // flags at runtime. However, since these flags are global we want to avoid
178 // multiple invocations of the LLVM compilation pipeline with a different set of
179 // flags. Therefore, we only pass command-line flags to LLVM once, before the
180 // first module is compiled.
181 std::once_flag llvm_command_line_options_initialized;
182 
183 // This visitor records which HLO instructions should have profiling information
184 // recorded.
185 class CollectProfileCandidates : public DfsHloVisitorWithDefault {
186  public:
187   static StatusOr<std::unordered_map<const HloInstruction*, int64>>
GetCandidatesForComputation(const HloComputation & computation,const std::unordered_map<const HloInstruction *,int64> & assigned_indices)188   GetCandidatesForComputation(
189       const HloComputation& computation,
190       const std::unordered_map<const HloInstruction*, int64>&
191           assigned_indices) {
192     std::unordered_map<const HloInstruction*, int64> hlo_to_profile_idx;
193     CollectProfileCandidates profile_candidates_for_computation(
194         &hlo_to_profile_idx, assigned_indices);
195     TF_RETURN_IF_ERROR(computation.Accept(&profile_candidates_for_computation));
196     return hlo_to_profile_idx;
197   }
198 
199  private:
CollectProfileCandidates(std::unordered_map<const HloInstruction *,int64> * hlo_to_profile_idx,const std::unordered_map<const HloInstruction *,int64> & assigned_indices)200   CollectProfileCandidates(
201       std::unordered_map<const HloInstruction*, int64>* hlo_to_profile_idx,
202       const std::unordered_map<const HloInstruction*, int64>& assigned_indices)
203       : hlo_to_profile_idx_(hlo_to_profile_idx),
204         assigned_indices_(assigned_indices) {}
205 
DefaultAction(HloInstruction * hlo_instruction)206   Status DefaultAction(HloInstruction* hlo_instruction) override {
207     hlo_to_profile_idx_->insert(
208         {hlo_instruction, FindOrDie(assigned_indices_, hlo_instruction)});
209     return Status::OK();
210   }
211 
HandleCall(HloInstruction * call)212   Status HandleCall(HloInstruction* call) override {
213     TF_RETURN_IF_ERROR(DefaultAction(call));
214     CollectProfileCandidates candidates_for_call(hlo_to_profile_idx_,
215                                                  assigned_indices_);
216     TF_RETURN_IF_ERROR(call->to_apply()->Accept(&candidates_for_call));
217     return Status::OK();
218   }
219 
220   // Skip constants, there is nothing to profile.
HandleConstant(HloInstruction *)221   Status HandleConstant(HloInstruction*) override { return Status::OK(); }
222   // Skip parameters, they are a simple load.
HandleParameter(HloInstruction *)223   Status HandleParameter(HloInstruction*) override { return Status::OK(); }
224   // It is important to recurse for "while" or else we risk overly coarse
225   // profiling information.
HandleWhile(HloInstruction * xla_while)226   Status HandleWhile(HloInstruction* xla_while) override {
227     TF_RETURN_IF_ERROR(DefaultAction(xla_while));
228 
229     CollectProfileCandidates candidates_for_condition(hlo_to_profile_idx_,
230                                                       assigned_indices_);
231     TF_RETURN_IF_ERROR(
232         xla_while->while_condition()->Accept(&candidates_for_condition));
233 
234     CollectProfileCandidates candidates_for_body(hlo_to_profile_idx_,
235                                                  assigned_indices_);
236     TF_RETURN_IF_ERROR(xla_while->while_body()->Accept(&candidates_for_body));
237 
238     return Status::OK();
239   }
240 
241   std::unordered_map<const HloInstruction*, int64>* hlo_to_profile_idx_;
242   const std::unordered_map<const HloInstruction*, int64>& assigned_indices_;
243 };
244 
245 }  // namespace
246 
RunHloPassesThroughLayoutAssn(HloModule * module,bool,LLVMTargetMachineFeatures * target_machine_features)247 Status CpuCompiler::RunHloPassesThroughLayoutAssn(
248     HloModule* module, bool /*is_aot_compile*/,
249     LLVMTargetMachineFeatures* target_machine_features) {
250   HloPassPipeline pipeline("HLO passes through layout assignment");
251   pipeline.AddInvariantChecker<HloVerifier>(/*layout_sensitive=*/false,
252                                             /*allow_mixed_precision=*/false);
253   pipeline.AddPass<DynamicIndexSplitter>();
254   pipeline.AddPass<CpuHloSupportChecker>();
255 
256   ReducePrecisionInsertion::AddPasses(
257       &pipeline, module->config().debug_options(),
258       ReducePrecisionInsertion::PassTiming::BEFORE_OPTIMIZATION);
259 
260   pipeline.AddPass<MapInliner>();
261 
262   pipeline.AddPass<CholeskyExpander>();
263   pipeline.AddPass<TriangularSolveExpander>();
264 
265   // TODO(b/65775800): Fix wrong output bug in Call and remove the CallInliner
266   // pass.
267   pipeline.AddPass<CallInliner>();
268   pipeline.AddPass<BatchDotSimplification>();
269   pipeline.AddPass<DotDecomposer>(/*decompose_batch_dot=*/false);
270   auto cost_model = [](HloInstruction* conv) {
271     // We need a cost model for CPUs. Currently, do nothing.
272     return false;
273   };
274   pipeline.AddPass<ConvolutionGroupConverter>(
275       cost_model,
276       /*convert_batch_groups_only=*/true);
277   pipeline.AddPass<ConvolutionGroupConverter>(
278       cost_model,
279       /*convert_batch_groups_only=*/false);
280   pipeline.AddPass<ConvCanonicalization>(target_machine_features);
281   {
282     auto& pass =
283         pipeline.AddPass<HloPassFix<HloPassPipeline>>("simplification");
284     pass.AddInvariantChecker<HloVerifier>(/*layout_sensitive=*/false,
285                                           /*allow_mixed_precision=*/false);
286 
287     pass.AddPass<BatchNormExpander>(
288         /*rewrite_training_op=*/true,
289         /*rewrite_inference_op=*/true,
290         /*rewrite_grad_op=*/true);
291     pipeline.AddPass<HloGetDimensionSizeRewriter>();
292     AlgebraicSimplifierOptions options;
293     options.set_enable_dot_strength_reduction(false);
294     pass.AddPass<AlgebraicSimplifier>(options);
295     pass.AddPass<SortSimplifier>();
296     pass.AddPass<HloDCE>();
297 
298     // BatchNormExpander can create zero-sized ops, so zero-sized HLO
299     // elimination has to come after that pass.
300     pass.AddPass<ZeroSizedHloElimination>();
301 
302     pass.AddPass<WhileLoopInvariantCodeMotion>();
303     pass.AddPass<TupleSimplifier>();
304     pass.AddPass<WhileLoopConstantSinking>();
305     pass.AddPass<WhileLoopSimplifier>();
306     pass.AddPass<HloDCE>();
307     pass.AddPass<ReshapeMover>();
308     pass.AddPass<HloConstantFolding>();
309     pass.AddPass<ConditionalSimplifier>();
310   }
311   pipeline.AddPass<IndexedArrayAnalysisPrinterPass>();
312   pipeline.AddPass<TransposeFolding>(
313       [&](const HloInstruction& dot,
314           const TransposeFolding::OperandIndices& candidate_operands) {
315         return DotImplementationCanHandleTranspose(dot,
316                                                    *target_machine_features)
317                    ? candidate_operands
318                    : TransposeFolding::OperandIndices{};
319       },
320       TransposeFolding::NeverFoldTranspose);
321   pipeline.AddPass<HloCSE>(/*is_layout_sensitive=*/false);
322 
323   pipeline.AddPass<CpuLayoutAssignment>(
324       module->mutable_entry_computation_layout(),
325       LayoutAssignment::InstructionCanChangeLayout, target_machine_features);
326 
327   pipeline.AddPass<CpuInstructionFusion>();
328 
329   pipeline.AddPass<ScatterExpander>();
330 
331   ReducePrecisionInsertion::AddPasses(
332       &pipeline, module->config().debug_options(),
333       ReducePrecisionInsertion::PassTiming::AFTER_FUSION);
334   return pipeline.Run(module).status();
335 }
336 
RunHloPassesAfterLayoutAssn(HloModule * module,bool is_aot_compile,LLVMTargetMachineFeatures * target_machine_features)337 Status CpuCompiler::RunHloPassesAfterLayoutAssn(
338     HloModule* module, bool is_aot_compile,
339     LLVMTargetMachineFeatures* target_machine_features) {
340   HloPassPipeline pipeline("HLO passes after layout assignment");
341   // After layout assignment, use a layout-sensitive verifier.
342   auto& after_layout_assn =
343       pipeline.AddPass<HloPassPipeline>("after layout assignment");
344   after_layout_assn.AddInvariantChecker<HloVerifier>(
345       /*layout_sensitive=*/true,
346       /*allow_mixed_precision=*/false);
347 
348   // The LayoutAssignment pass may leave behind kCopy instructions which are
349   // duplicate or NOPs, so remove them with algebraic simplification and CSE.
350   {
351     auto& pass = pipeline.AddPass<HloPassFix<HloPassPipeline>>(
352         "simplification after layout assignement");
353     pass.AddInvariantChecker<HloVerifier>(
354         /*layout_sensitive=*/true,
355         /*allow_mixed_precision=*/false,
356         LayoutAssignment::InstructionCanChangeLayout);
357     AlgebraicSimplifierOptions options;
358     options.set_is_layout_sensitive(true);
359     options.set_enable_dot_strength_reduction(false);
360     pass.AddPass<HloPassFix<AlgebraicSimplifier>>(options);
361     pass.AddPass<HloDCE>();
362     pass.AddPass<HloCSE>(/*is_layout_sensitive=*/true);
363   }
364 
365   pipeline.AddPass<HloElementTypeConverter>(BF16, F32);
366 
367   // Outline ops in the entry computation into calls to subcomputations.
368   const int max_parallelism =
369       module->config().intra_op_parallelism_threads() > 0
370           ? module->config().intra_op_parallelism_threads()
371           : tensorflow::port::NumSchedulableCPUs();
372   if (!is_aot_compile) {
373     // Run ParallelTaskAssigner to assign parallel tasks to HLOs in module.
374     // Note this is not run for AOT because it would bring in thread pool
375     // and thread synchronization dependencies which would likely increase
376     // binary size (and most AOT applications are single-threaded).
377     // TODO(b/29630486) Support multi-threaded AOT.
378     pipeline.AddPass<ParallelTaskAssigner>(
379         max_parallelism, ShapeSizeBytesFunction(), target_machine_features);
380   }
381   // Copy insertion should be performed immediately before IR emission to
382   // avoid inserting unnecessary copies (later pass adds an instruction which
383   // materializes the value) or missing a necessary copy (later pass removes
384   // an instruction which materializes a value). DCE must be run immediately
385   // before (and sometime after) copy insertion, to avoid dead code from
386   // interfering with the rewrites.
387   pipeline.AddPass<HloDCE>();
388   pipeline.AddPass<FlattenCallGraph>();
389   pipeline.AddPass<CpuCopyInsertion>();
390   pipeline.AddPass<HloDCE>();
391   return pipeline.Run(module).status();
392 }
393 
RunHloPasses(HloModule * module,bool is_aot_compile,llvm::TargetMachine * target_machine)394 Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile,
395                                  llvm::TargetMachine* target_machine) {
396   LLVMTargetMachineFeatures target_machine_features(target_machine);
397   TF_RETURN_IF_ERROR(RunHloPassesThroughLayoutAssn(module, is_aot_compile,
398                                                    &target_machine_features));
399   return RunHloPassesAfterLayoutAssn(module, is_aot_compile,
400                                      &target_machine_features);
401 }
402 
403 namespace {
404 
405 // Align buffers to 16-byte boundaries.
406 constexpr int64 kMemoryAlignment = 16;
__anonaf93c4290602(LogicalBuffer::Color) 407 auto memory_alignment = [](LogicalBuffer::Color) { return kMemoryAlignment; };
408 
CompilerTargetOptions(const HloModuleConfig & module_config)409 llvm::TargetOptions CompilerTargetOptions(
410     const HloModuleConfig& module_config) {
411   llvm::TargetOptions target_options;
412   // In LLVM backend flags, UnsafeFPMath does not explicitly imply NoInfs, etc.
413   if (module_config.debug_options().xla_cpu_enable_fast_math()) {
414     target_options.UnsafeFPMath = true;
415     target_options.NoInfsFPMath =
416         module_config.debug_options().xla_cpu_fast_math_honor_infs();
417     target_options.NoNaNsFPMath =
418         module_config.debug_options().xla_cpu_fast_math_honor_nans();
419     target_options.NoSignedZerosFPMath = true;
420   } else {
421     target_options.UnsafeFPMath = false;
422     target_options.NoInfsFPMath = false;
423     target_options.NoNaNsFPMath = false;
424     target_options.NoSignedZerosFPMath = false;
425   }
426   return target_options;
427 }
428 
CodeGenOptLevel(const HloModuleConfig & module_config)429 llvm::CodeGenOpt::Level CodeGenOptLevel(const HloModuleConfig& module_config) {
430   VLOG(2) << "backend_optimization_level: "
431           << module_config.debug_options().xla_backend_optimization_level();
432   switch (module_config.debug_options().xla_backend_optimization_level()) {
433     case 1:
434       return llvm::CodeGenOpt::Less;
435     case 2:
436       return llvm::CodeGenOpt::Default;
437     case 3:
438       return llvm::CodeGenOpt::Aggressive;
439     default:
440       return llvm::CodeGenOpt::None;
441   }
442 }
443 
GetIRModuleHooks(const HloModule & hlo_module,const LLVMCompiler::ModuleHook & user_pre_optimization_hook,const LLVMCompiler::ModuleHook & user_post_optimization_hook)444 std::pair<LLVMCompiler::ModuleHook, LLVMCompiler::ModuleHook> GetIRModuleHooks(
445     const HloModule& hlo_module,
446     const LLVMCompiler::ModuleHook& user_pre_optimization_hook,
447     const LLVMCompiler::ModuleHook& user_post_optimization_hook) {
448   // Create the IR hooks. If applicable, each IR hook does the following:
449   //
450   //  * Calls the user supplied module hook.
451   //  * Writes out the IR to a file in the output directory designated by
452   //    --xla_dump_to
453   const HloModule* hlo_module_ptr = &hlo_module;
454   auto hook = [user_pre_optimization_hook, user_post_optimization_hook,
455                hlo_module_ptr](bool optimized,
456                                const llvm::Module& llvm_module) {
457     const auto& user_hook =
458         !optimized ? user_pre_optimization_hook : user_post_optimization_hook;
459     if (user_hook) {
460       user_hook(llvm_module);
461     }
462     llvm_ir::DumpIrIfEnabled(*hlo_module_ptr, llvm_module, optimized);
463   };
464   return {[hook](const llvm::Module& llvm_module) {
465             return hook(/*optimized=*/false, llvm_module);
466           },
467           [hook](const llvm::Module& llvm_module) {
468             return hook(/*optimized=*/true, llvm_module);
469           }};
470 }
471 
VerifyLlvmModule(const llvm::Module & llvm_module)472 Status VerifyLlvmModule(const llvm::Module& llvm_module) {
473   XLA_SCOPED_LOGGING_TIMER("CpuCompiler - Running LLVM verifier");
474 
475   std::string err;
476   llvm::raw_string_ostream err_stream(err);
477 
478   // verifyModule() returns true if the module is broken.
479   TF_RET_CHECK(!llvm::verifyModule(llvm_module, &err_stream))
480       << "Invalid LLVM IR before optimizations:\n"
481       << err_stream.str()
482       << "\nThis probably indicates a bug in the HLO -> LLVM IR lowering. "
483          "Rerun with --xla_dump_to to get the IR. ";
484   return Status::OK();
485 }
486 
CreateHloProfilingArtifacts(const HloModule & module,std::unordered_map<const HloInstruction *,int64> * instruction_to_profile_idx,std::unordered_map<const HloComputation *,int64> * computation_to_profile_idx,std::unique_ptr<HloProfileIndexMap> * hlo_profile_index_map,std::unique_ptr<HloProfilePrinterData> * hlo_profile_printer_data)487 Status CreateHloProfilingArtifacts(
488     const HloModule& module,
489     std::unordered_map<const HloInstruction*, int64>*
490         instruction_to_profile_idx,
491     std::unordered_map<const HloComputation*, int64>*
492         computation_to_profile_idx,
493     std::unique_ptr<HloProfileIndexMap>* hlo_profile_index_map,
494     std::unique_ptr<HloProfilePrinterData>* hlo_profile_printer_data) {
495   *hlo_profile_index_map = absl::make_unique<HloProfileIndexMap>(module);
496   const HloComputation& entry_computation = *module.entry_computation();
497 
498   TF_ASSIGN_OR_RETURN(
499       *instruction_to_profile_idx,
500       CollectProfileCandidates::GetCandidatesForComputation(
501           entry_computation,
502           (*hlo_profile_index_map)->instruction_to_profile_idx()));
503 
504   auto shape_size_bytes = [](const Shape& shape) {
505     // On the cpu, opaques are pointers.
506     if (shape.IsOpaque()) {
507       return static_cast<int64>(sizeof(void*));
508     }
509     return ShapeUtil::ByteSizeOf(shape, sizeof(void*));
510   };
511 
512   HloCostAnalysis cost_analysis(shape_size_bytes);
513   TF_RETURN_IF_ERROR(entry_computation.Accept(&cost_analysis));
514   *hlo_profile_printer_data = CreateHloProfilePrinterData(
515       **hlo_profile_index_map, cost_analysis, entry_computation.name());
516   *computation_to_profile_idx =
517       (*hlo_profile_index_map)->computation_to_profile_idx();
518 
519   return Status::OK();
520 }
521 
522 }  // namespace
523 
RunHloPasses(std::unique_ptr<HloModule> module,se::StreamExecutor *,DeviceMemoryAllocator *)524 StatusOr<std::unique_ptr<HloModule>> CpuCompiler::RunHloPasses(
525     std::unique_ptr<HloModule> module, se::StreamExecutor* /*stream_exec*/,
526     DeviceMemoryAllocator* /*device_allocator*/) {
527   std::unique_ptr<llvm::TargetMachine> jit_target_machine =
528       SimpleOrcJIT::InferTargetMachineForJIT(
529           CompilerTargetOptions(module->config()),
530           CodeGenOptLevel(module->config()));
531 
532   TF_RETURN_IF_ERROR(RunHloPasses(module.get(), /*is_aot_compile=*/false,
533                                   jit_target_machine.get()));
534   return std::move(module);
535 }
536 
537 namespace {
538 
539 // Post-compilation callback functor for use by SimpleOrcJIT.
540 //
541 // Dumps disassembled machine code if dumping is enabled for the module.
542 struct OrcJITPostCompilationHook {
543   // Gets an std::function that implements this hook.
Createxla::cpu::__anonaf93c4290b11::OrcJITPostCompilationHook544   static std::function<void(const llvm::object::ObjectFile& obj_file)> Create(
545       const HloModule* module) {
546     // This struct is not copyable, but std::functions must be.  So to create an
547     // std::function out of this struct, we have to wrap it in a shared_ptr.
548     auto wrapped = std::make_shared<OrcJITPostCompilationHook>(module);
549     return [wrapped](const llvm::object::ObjectFile& obj_file) {
550       (*wrapped)(obj_file);
551     };
552   }
553 
554   // Constructor can't be private because we want to call it from
555   // std::make_shared, but users should call Create() instead.
OrcJITPostCompilationHookxla::cpu::__anonaf93c4290b11::OrcJITPostCompilationHook556   explicit OrcJITPostCompilationHook(const HloModule* module)
557       : module(module),
558         target_machine(SimpleOrcJIT::InferTargetMachineForJIT(
559             CompilerTargetOptions(module->config()),
560             CodeGenOptLevel(module->config()))),
561         disassembler(*target_machine) {}
562 
563  private:
operator ()xla::cpu::__anonaf93c4290b11::OrcJITPostCompilationHook564   void operator()(const llvm::object::ObjectFile& obj_file) {
565     if (!DumpingEnabledForHloModule(*module)) {
566       return;
567     }
568     StatusOr<DisassemblerResult> disasm_or =
569         disassembler.DisassembleObjectFile(obj_file);
570     string text = disasm_or.ok() ? std::move(disasm_or).ValueOrDie().text
571                                  : absl::StrCat("Error disassembling: ",
572                                                 disasm_or.status().ToString());
573     DumpToFileInDirOrStdout(*module, /*file_suffix=*/"s", text);
574   }
575 
576   const HloModule* module;
577   // disassembler keeps references to data inside of target_machine.
578   std::unique_ptr<llvm::TargetMachine> target_machine;
579   Disassembler disassembler;
580 };
581 
582 }  // namespace
583 
RunBackend(std::unique_ptr<HloModule> module,se::StreamExecutor * stream_exec,DeviceMemoryAllocator *)584 StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend(
585     std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
586     DeviceMemoryAllocator* /*device_allocator*/) {
587   VLOG(1) << "Compiling: " << module->name();
588   XLA_SCOPED_LOGGING_TIMER(
589       absl::StrFormat("Compiling [%s] for CPU using JIT", module->name()));
590 
591   TF_RET_CHECK(stream_exec != nullptr);
592   std::call_once(llvm_command_line_options_initialized,
593                  &llvm_ir::InitializeLLVMCommandLineOptions, module->config());
594 
595   ModuleHook pre_optimization_ir_hook;
596   ModuleHook post_optimization_ir_hook;
597   std::tie(pre_optimization_ir_hook, post_optimization_ir_hook) =
598       GetIRModuleHooks(*module, user_pre_optimization_hook_,
599                        user_post_optimization_hook_);
600 
601   // Compile must be thread-safe so create a new LLVM context for the module.
602   auto llvm_context = absl::make_unique<llvm::LLVMContext>();
603   auto llvm_module =
604       absl::make_unique<llvm::Module>("__compute_module", *llvm_context);
605 
606   auto jit = absl::make_unique<SimpleOrcJIT>(
607       CompilerTargetOptions(module->config()),
608       CodeGenOptLevel(module->config()),
609       options::OptimizeForSizeRequested(module->config()),
610       module->config().debug_options().xla_cpu_enable_fast_math(),
611       module->config().debug_options().xla_llvm_disable_expensive_passes(),
612       pre_optimization_ir_hook, post_optimization_ir_hook,
613       OrcJITPostCompilationHook::Create(module.get()));
614   llvm_module->setDataLayout(jit->data_layout());
615   llvm_module->setTargetTriple(jit->target_triple().getTriple());
616 
617   HloComputation* entry_computation = module->entry_computation();
618   std::unordered_map<const HloInstruction*, int64> instruction_to_profile_idx;
619   std::unordered_map<const HloComputation*, int64> computation_to_profile_idx;
620   std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map;
621   std::unique_ptr<HloProfilePrinterData> hlo_profile_printer_data;
622   if (module->config().hlo_profiling_enabled()) {
623     TF_RETURN_IF_ERROR(CreateHloProfilingArtifacts(
624         *module, &instruction_to_profile_idx, &computation_to_profile_idx,
625         &hlo_profile_index_map, &hlo_profile_printer_data));
626   }
627 
628   std::unique_ptr<Executable> cpu_executable;
629 
630   // Cache these flags here since we'll want to access them after the module's
631   // ownership is std::moved.
632   const bool embed_ir_in_executable =
633       module->config().debug_options().xla_embed_ir_in_executable();
634 
635   // Select an order for emitting the HLO instructions for each
636   // computation. Using this sequence enables tighter buffer liveness analysis
637   // and reduced memory usage (as compared to using DependencyHloOrdering).
638   TF_ASSIGN_OR_RETURN(HloSchedule schedule,
639                       ScheduleModule(module.get(), BufferSizeBytesFunction(),
640                                      DFSMemoryScheduler));
641 
642   // Run buffer allocation on the HLO graph.
643   TF_ASSIGN_OR_RETURN(
644       std::unique_ptr<BufferAssignment> assignment,
645       BufferAssigner::Run(module.get(),
646                           absl::make_unique<SequentialHloOrdering>(schedule),
647                           BufferSizeBytesFunction(), memory_alignment,
648                           /*allow_input_output_aliasing=*/false,
649                           /*allocate_buffers_for_constants=*/true));
650   // BufferAssignment::ToString() includes a header, so no need for us to
651   // print one ourselves.
652   if (DumpingEnabledForHloModule(*module)) {
653     DumpToFileInDirOrStdout(*module, "buffer_assignment",
654                             assignment->ToString());
655   }
656   DumpHloModuleIfEnabled(*module, *assignment, "after_optimizations");
657 
658   // Each computation is a single function.  Emit all embedded computations
659   // before the entry computation. The order of computations returned from
660   // GetEmbeddedComputations guarantees that a called computation occurs
661   // before a caller computation.
662 
663   LLVMTargetMachineFeatures target_machine_features(jit->target_machine());
664   IrEmitter ir_emitter(*module, *assignment, llvm_module.get(),
665                        std::move(instruction_to_profile_idx),
666                        std::move(computation_to_profile_idx),
667                        &target_machine_features,
668 #ifdef MEMORY_SANITIZER
669                        /*emit_code_for_msan=*/true
670 #else
671                        /*emit_code_for_msan=*/false
672 #endif
673   );
674 
675   TF_RETURN_IF_ERROR(ir_emitter.EmitConstantGlobals());
676 
677   for (auto embedded_computation :
678        entry_computation->MakeEmbeddedComputationsList()) {
679     if (embedded_computation->IsFusionComputation()) {
680       continue;
681     }
682     TF_RETURN_IF_ERROR(
683         ir_emitter
684             .EmitComputation(
685                 embedded_computation, embedded_computation->name(),
686                 /*is_top_level_computation=*/false,
687                 schedule.sequence(embedded_computation).instructions())
688             .status());
689   }
690   string function_name_prefix = entry_computation->name().empty()
691                                     ? "__compute"
692                                     : entry_computation->name();
693   TF_ASSIGN_OR_RETURN(llvm::Function * entry_function,
694                       ir_emitter.EmitComputation(
695                           entry_computation, function_name_prefix,
696                           /*is_top_level_computation=*/true,
697                           schedule.sequence(entry_computation).instructions()));
698 
699   string function_name = [&]() {
700     llvm::SmallVector<char, 40> function_name_vector;
701     llvm::Mangler::getNameWithPrefix(
702         function_name_vector, entry_function->getName(), jit->data_layout());
703     return string(function_name_vector.begin(), function_name_vector.end());
704   }();
705 
706   string ir_module_string;
707   if (embed_ir_in_executable) {
708     ir_module_string = llvm_ir::DumpModuleToString(*llvm_module);
709   }
710 
711   TF_RETURN_IF_ERROR(VerifyLlvmModule(*llvm_module));
712 
713   // JIT compile the LLVM IR module to in-memory machine code.
714   jit->AddModule(std::move(llvm_module));
715   cpu_executable.reset(new CpuExecutable(
716       std::move(jit), std::move(assignment), std::move(module), function_name,
717       std::move(hlo_profile_printer_data), std::move(hlo_profile_index_map)));
718 
719   if (embed_ir_in_executable) {
720     static_cast<CpuExecutable&>(*cpu_executable)
721         .set_ir_module_string(ir_module_string);
722   }
723 
724   VLOG(1) << "Compilation finished";
725   return std::move(cpu_executable);
726 }
727 
728 StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group,const AotCompilationOptions & aot_options)729 CpuCompiler::CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group,
730                                 const AotCompilationOptions& aot_options) {
731   TF_RET_CHECK(!module_group->empty());
732   std::vector<std::unique_ptr<HloModule>> modules =
733       module_group->ConsumeModules();
734 
735   std::call_once(llvm_command_line_options_initialized,
736                  &llvm_ir::InitializeLLVMCommandLineOptions,
737                  modules[0]->config());
738 
739   // We can pass just one llvm::TargetOptions when we compile the LLVM module,
740   // so we bail if the configs have conflicting flags. At the moment, the only
741   // flag that needs to be consistent is fast-math.
742   const bool fast_math_enabled =
743       modules[0]->config().debug_options().xla_cpu_enable_fast_math();
744   for (const auto& module : modules) {
745     if (module->config().debug_options().xla_cpu_enable_fast_math() !=
746         fast_math_enabled) {
747       return InvalidArgument(
748           "All HLO module configs must have the same value for "
749           "xla_enable_fast_math.");
750     }
751   }
752 
753   if (aot_options.PlatformId() != se::host::kHostPlatformId) {
754     return InvalidArgument("Incompatible AOT compilation platform");
755   }
756   const CpuAotCompilationOptions& options =
757       static_cast<const CpuAotCompilationOptions&>(aot_options);
758   llvm::Triple triple(llvm::Triple::normalize(options.triple()));
759   std::string error;
760   const llvm::Target* target =
761       llvm::TargetRegistry::lookupTarget(triple.getTriple(), error);
762   if (target == nullptr) {
763     return InternalError("TargetRegistry::lookupTarget failed: %s", error);
764   }
765 
766   llvm::Reloc::Model reloc_model = llvm::Reloc::Static;
767   llvm::PICLevel::Level pic_level = llvm::PICLevel::NotPIC;
768   llvm::PIELevel::Level pie_level = llvm::PIELevel::Default;
769   switch (options.relocation_model()) {
770     case CpuAotCompilationOptions::RelocationModel::Static:
771       reloc_model = llvm::Reloc::Static;
772       pic_level = llvm::PICLevel::NotPIC;
773       pie_level = llvm::PIELevel::Default;
774       break;
775     case CpuAotCompilationOptions::RelocationModel::SmallPic:
776       reloc_model = llvm::Reloc::PIC_;
777       pic_level = llvm::PICLevel::SmallPIC;
778       pie_level = llvm::PIELevel::Default;
779       break;
780     case CpuAotCompilationOptions::RelocationModel::BigPic:
781       reloc_model = llvm::Reloc::PIC_;
782       pic_level = llvm::PICLevel::BigPIC;
783       pie_level = llvm::PIELevel::Default;
784       break;
785     case CpuAotCompilationOptions::RelocationModel::SmallPie:
786       reloc_model = llvm::Reloc::PIC_;
787       pic_level = llvm::PICLevel::SmallPIC;
788       pie_level = llvm::PIELevel::Small;
789       break;
790     case CpuAotCompilationOptions::RelocationModel::BigPie:
791       reloc_model = llvm::Reloc::PIC_;
792       pic_level = llvm::PICLevel::BigPIC;
793       pie_level = llvm::PIELevel::Large;
794       break;
795   }
796   llvm::CodeGenOpt::Level opt_level = CodeGenOptLevel(modules[0]->config());
797   std::unique_ptr<llvm::TargetMachine> target_machine =
798       absl::WrapUnique(target->createTargetMachine(
799           triple.getTriple(), options.cpu_name(), options.features(),
800           CompilerTargetOptions(modules[0]->config()), reloc_model, llvm::None,
801           opt_level));
802 
803   // Compile must be thread-safe so create a new LLVM context for the module.
804   llvm::LLVMContext llvm_context;
805   llvm::Module llvm_module("__compute_module", llvm_context);
806   llvm_module.setDataLayout(target_machine->createDataLayout());
807   llvm_module.setTargetTriple(triple.getTriple());
808   if (pic_level != llvm::PICLevel::NotPIC) {
809     llvm_module.setPICLevel(pic_level);
810   }
811   if (pie_level != llvm::PIELevel::Default) {
812     llvm_module.setPIELevel(pie_level);
813   }
814 
815   std::vector<std::unique_ptr<AotCompilationResult>> results;
816   for (size_t i = 0; i < modules.size(); ++i) {
817     HloModule* module = modules[i].get();
818     VLOG(1) << "Compiling ahead-of-time: " << module->name();
819 
820     TF_RETURN_IF_ERROR(
821         RunHloPasses(module, /*is_aot_compile=*/true, target_machine.get()));
822 
823     TF_ASSIGN_OR_RETURN(HloSchedule schedule,
824                         ScheduleModule(module, BufferSizeBytesFunction()));
825 
826     // Run buffer analysis on the HLO graph. This analysis figures out which
827     // temporary buffers are required to run the computation.
828     TF_ASSIGN_OR_RETURN(
829         std::unique_ptr<BufferAssignment> assignment,
830         BufferAssigner::Run(module,
831                             absl::make_unique<SequentialHloOrdering>(schedule),
832                             BufferSizeBytesFunction(), memory_alignment,
833                             /*allow_input_output_aliasing=*/false,
834                             /*allocate_buffers_for_constants=*/true));
835     // BufferAssignment::ToString() includes a header, so no need for us to
836     // print one ourselves.
837     if (DumpingEnabledForHloModule(*module)) {
838       DumpToFileInDirOrStdout(*module, "buffer_assignment",
839                               assignment->ToString());
840     }
841     DumpHloModuleIfEnabled(*module, *assignment, "after_optimizations");
842 
843     std::unordered_map<const HloInstruction*, int64> instruction_to_profile_idx;
844     std::unordered_map<const HloComputation*, int64> computation_to_profile_idx;
845     std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map;
846     std::unique_ptr<HloProfilePrinterData> hlo_profile_printer_data;
847 
848     if (module->config().hlo_profiling_enabled()) {
849       TF_RETURN_IF_ERROR(CreateHloProfilingArtifacts(
850           *module, &instruction_to_profile_idx, &computation_to_profile_idx,
851           &hlo_profile_index_map, &hlo_profile_printer_data));
852     }
853 
854     LLVMTargetMachineFeatures target_machine_features(target_machine.get());
855     IrEmitter ir_emitter(*module, *assignment, &llvm_module,
856                          std::move(instruction_to_profile_idx),
857                          std::move(computation_to_profile_idx),
858                          &target_machine_features,
859                          // TODO(b/66051036): Run full msan for AOT.
860                          /*emit_code_for_msan=*/false);
861 
862     TF_RETURN_IF_ERROR(ir_emitter.EmitConstantGlobals());
863 
864     HloComputation* computation = module->entry_computation();
865     for (auto embedded_computation :
866          computation->MakeEmbeddedComputationsList()) {
867       if (embedded_computation->IsFusionComputation()) {
868         continue;
869       }
870       TF_RETURN_IF_ERROR(
871           ir_emitter
872               .EmitComputation(
873                   embedded_computation, embedded_computation->name(),
874                   /*is_top_level_computation=*/false,
875                   schedule.sequence(embedded_computation).instructions())
876               .status());
877     }
878     const string& entry_point_name = options.entry_point_name();
879     TF_ASSIGN_OR_RETURN(llvm::Function * entry_function,
880                         ir_emitter.EmitComputation(
881                             computation, entry_point_name,
882                             /*is_top_level_computation=*/true,
883                             schedule.sequence(computation).instructions()));
884 
885     CHECK(entry_function->getName() == entry_point_name);
886 
887     ModuleHook pre_optimization_ir_hook;
888     ModuleHook post_optimization_ir_hook;
889     std::tie(pre_optimization_ir_hook, post_optimization_ir_hook) =
890         GetIRModuleHooks(*module, user_pre_optimization_hook_,
891                          user_post_optimization_hook_);
892 
893     // Run the LLVM verifier over the unoptimized LLVM IR.  If it fails, run the
894     // pre-optimization IR dump hook before returning.
895     {
896       Status verify_status = VerifyLlvmModule(llvm_module);
897       if (!verify_status.ok() && pre_optimization_ir_hook) {
898         pre_optimization_ir_hook(llvm_module);
899       }
900       TF_RETURN_IF_ERROR(verify_status);
901     }
902 
903     auto post_codegen_hook = [&](const llvm::object::ObjectFile& obj_file) {
904       if (!DumpingEnabledForHloModule(*module)) {
905         return;
906       }
907       StatusOr<DisassemblerResult> disasm_or =
908           Disassembler(*target_machine).DisassembleObjectFile(obj_file);
909       string text = disasm_or.ok()
910                         ? std::move(disasm_or).ValueOrDie().text
911                         : absl::StrCat("Error disassembling: ",
912                                        disasm_or.status().ToString());
913       DumpToFileInDirOrStdout(*module, /*file_suffix=*/"s", text);
914     };
915 
916     CompilerFunctor compiler_functor(
917         target_machine.get(), opt_level,
918         options::OptimizeForSizeRequested(module->config()),
919         module->config().debug_options().xla_cpu_enable_fast_math(),
920         module->config().debug_options().xla_llvm_disable_expensive_passes(),
921         pre_optimization_ir_hook, post_optimization_ir_hook, post_codegen_hook);
922     std::unique_ptr<llvm::MemoryBuffer> object_file =
923         compiler_functor(llvm_module);
924     ObjectFileData object_file_data(object_file->getBufferStart(),
925                                     object_file->getBufferEnd());
926 
927     std::vector<BufferInfo> buffer_infos =
928         CreateBufferInfosFromBufferAssignment(*assignment);
929 
930     TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice result_slice,
931                         assignment->GetUniqueTopLevelOutputSlice());
932 
933     results.emplace_back(absl::make_unique<CpuAotCompilationResult>(
934         std::move(object_file_data), std::move(buffer_infos),
935         result_slice.index(), std::move(hlo_profile_printer_data)));
936   }
937 
938   VLOG(1) << "Compilation finished";
939   return std::move(results);
940 }
941 
PlatformId() const942 se::Platform::Id CpuCompiler::PlatformId() const {
943   return se::host::kHostPlatformId;
944 }
945 
ShapeSizeBytesFunction() const946 HloCostAnalysis::ShapeSizeFunction CpuCompiler::ShapeSizeBytesFunction() const {
947   return CpuExecutable::ShapeSizeBytes;
948 }
949 
950 }  // namespace cpu
951 }  // namespace xla
952 
InitModule()953 static bool InitModule() {
954   xla::Compiler::RegisterCompilerFactory(
955       stream_executor::host::kHostPlatformId,
956       []() { return absl::make_unique<xla::cpu::CpuCompiler>(); });
957   return true;
958 }
959 static bool module_initialized = InitModule();
960