• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h"
17 
18 #include <stddef.h>
19 #include <string.h>
20 
21 #include <map>
22 #include <memory>
23 #include <string>
24 #include <unordered_map>
25 #include <utility>
26 #include <vector>
27 
28 // IWYU pragma: no_include "llvm/Config/Disassemblers.def.inc"
29 // IWYU pragma: no_include "llvm/Config/Targets.def.inc"
30 #include "absl/base/call_once.h"
31 #include "absl/memory/memory.h"
32 #include "absl/strings/str_cat.h"
33 #include "llvm/ADT/StringRef.h"
34 #include "llvm/ADT/Triple.h"
35 #include "llvm/IR/Function.h"
36 #include "llvm/IR/LLVMContext.h"
37 #include "llvm/IR/Mangler.h"
38 #include "llvm/IR/Module.h"
39 #include "llvm/IR/Verifier.h"
40 #include "llvm/Object/ObjectFile.h"
41 #include "llvm/Support/CommandLine.h"
42 #include "llvm/Support/Error.h"
43 #include "llvm/Support/TargetRegistry.h"
44 #include "llvm/Support/TargetSelect.h"
45 #include "llvm/Target/TargetMachine.h"
46 #include "llvm/Target/TargetOptions.h"
47 #include "mlir/Dialect/Affine/IR/AffineOps.h"  // from @llvm-project
48 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"  // from @llvm-project
49 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h"  // from @llvm-project
50 #include "mlir/Dialect/SCF/SCF.h"  // from @llvm-project
51 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
52 #include "mlir/Dialect/Vector/VectorOps.h"  // from @llvm-project
53 #include "mlir/InitAllDialects.h"  // from @llvm-project
54 #include "tensorflow/compiler/xla/cpu_function_runtime.h"
55 #include "tensorflow/compiler/xla/literal.h"
56 #include "tensorflow/compiler/xla/map_util.h"
57 #include "tensorflow/compiler/xla/protobuf_util.h"
58 #include "tensorflow/compiler/xla/service/algebraic_simplifier.h"
59 #include "tensorflow/compiler/xla/service/all_gather_decomposer.h"
60 #include "tensorflow/compiler/xla/service/all_to_all_decomposer.h"
61 #include "tensorflow/compiler/xla/service/batch_dot_simplification.h"
62 #include "tensorflow/compiler/xla/service/batchnorm_expander.h"
63 #include "tensorflow/compiler/xla/service/buffer_assignment.h"
64 #include "tensorflow/compiler/xla/service/call_inliner.h"
65 #include "tensorflow/compiler/xla/service/cholesky_expander.h"
66 #include "tensorflow/compiler/xla/service/comparison_expander.h"
67 #include "tensorflow/compiler/xla/service/conditional_canonicalizer.h"
68 #include "tensorflow/compiler/xla/service/conditional_simplifier.h"
69 #include "tensorflow/compiler/xla/service/conditional_to_select.h"
70 #include "tensorflow/compiler/xla/service/convolution_group_converter.h"
71 #include "tensorflow/compiler/xla/service/copy_insertion.h"
72 #include "tensorflow/compiler/xla/service/cpu/buffer_info_util.h"
73 #include "tensorflow/compiler/xla/service/cpu/compiler_functor.h"
74 #include "tensorflow/compiler/xla/service/cpu/conv_canonicalization.h"
75 #include "tensorflow/compiler/xla/service/cpu/cpu_executable.h"
76 #include "tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.h"
77 #include "tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.h"
78 #include "tensorflow/compiler/xla/service/cpu/cpu_options.h"
79 #include "tensorflow/compiler/xla/service/cpu/dot_op_emitter.h"
80 #include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h"
81 #include "tensorflow/compiler/xla/service/cpu/ir_emitter.h"
82 #include "tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h"
83 #include "tensorflow/compiler/xla/service/cpu/simple_orc_jit.h"
84 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
85 #include "tensorflow/compiler/xla/service/dot_decomposer.h"
86 #include "tensorflow/compiler/xla/service/dump.h"
87 #include "tensorflow/compiler/xla/service/dynamic_index_splitter.h"
88 #include "tensorflow/compiler/xla/service/dynamic_padder.h"
89 #include "tensorflow/compiler/xla/service/flatten_call_graph.h"
90 #include "tensorflow/compiler/xla/service/gather_expander.h"
91 #include "tensorflow/compiler/xla/service/hlo.pb.h"
92 #include "tensorflow/compiler/xla/service/hlo_computation.h"
93 #include "tensorflow/compiler/xla/service/hlo_constant_folding.h"
94 #include "tensorflow/compiler/xla/service/hlo_cse.h"
95 #include "tensorflow/compiler/xla/service/hlo_dce.h"
96 #include "tensorflow/compiler/xla/service/hlo_element_type_converter.h"
97 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
98 #include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h"
99 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
100 #include "tensorflow/compiler/xla/service/hlo_ordering.h"
101 #include "tensorflow/compiler/xla/service/hlo_pass_fix.h"
102 #include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h"
103 #include "tensorflow/compiler/xla/service/hlo_proto_util.h"
104 #include "tensorflow/compiler/xla/service/hlo_subcomputation_unification.h"
105 #include "tensorflow/compiler/xla/service/hlo_verifier.h"
106 #include "tensorflow/compiler/xla/service/indexed_array_analysis.h"
107 #include "tensorflow/compiler/xla/service/llvm_compiler.h"
108 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
109 #include "tensorflow/compiler/xla/service/logistic_expander.h"
110 #include "tensorflow/compiler/xla/service/map_inliner.h"
111 #include "tensorflow/compiler/xla/service/operand_upcaster.h"
112 #include "tensorflow/compiler/xla/service/qr_expander.h"
113 #include "tensorflow/compiler/xla/service/reshape_mover.h"
114 #include "tensorflow/compiler/xla/service/rng_bit_generator_expander.h"
115 #include "tensorflow/compiler/xla/service/rng_expander.h"
116 #include "tensorflow/compiler/xla/service/scatter_expander.h"
117 #include "tensorflow/compiler/xla/service/slice_sinker.h"
118 #include "tensorflow/compiler/xla/service/slow_operation_alarm.h"
119 #include "tensorflow/compiler/xla/service/sort_simplifier.h"
120 #include "tensorflow/compiler/xla/service/topk_rewriter.h"
121 #include "tensorflow/compiler/xla/service/transpose_folding.h"
122 #include "tensorflow/compiler/xla/service/tree_reduction_rewriter.h"
123 #include "tensorflow/compiler/xla/service/triangular_solve_expander.h"
124 #include "tensorflow/compiler/xla/service/tuple_simplifier.h"
125 #include "tensorflow/compiler/xla/service/while_loop_constant_sinking.h"
126 #include "tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h"
127 #include "tensorflow/compiler/xla/service/while_loop_simplifier.h"
128 #include "tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h"
129 #include "tensorflow/compiler/xla/status_macros.h"
130 #include "tensorflow/compiler/xla/statusor.h"
131 #include "tensorflow/compiler/xla/types.h"
132 #include "tensorflow/compiler/xla/util.h"
133 #include "tensorflow/compiler/xla/xla_data.pb.h"
134 #include "tensorflow/core/platform/dynamic_annotations.h"
135 
136 namespace {
137 
138 // We need to explicitly load all the dialects we will involved in emitting the
139 // IR. This is only needed because of how MLIR is bolted into XLA and does not
140 // make use of the MLIR infrastructure (like using a proper pass pipeline).
141 // Hopefully this will all go away at some point in favor of a better
142 // integration.
LoadMLIRDialects(mlir::MLIRContext & context)143 void LoadMLIRDialects(mlir::MLIRContext& context) {
144   context.loadDialect<mlir::linalg::LinalgDialect, mlir::scf::SCFDialect,
145                       mlir::vector::VectorDialect, mlir::StandardOpsDialect,
146                       mlir::AffineDialect>();
147 }
148 
149 }  // namespace
150 
151 namespace xla {
152 namespace cpu {
153 using BufferInfo = cpu_function_runtime::BufferInfo;
154 
CpuAotCompilationOptions(string triple,string cpu_name,string features,string entry_point_name,RelocationModel relocation_model)155 CpuAotCompilationOptions::CpuAotCompilationOptions(
156     string triple, string cpu_name, string features, string entry_point_name,
157     RelocationModel relocation_model)
158     : triple_(std::move(triple)),
159       cpu_name_(std::move(cpu_name)),
160       features_(std::move(features)),
161       entry_point_name_(std::move(entry_point_name)),
162       relocation_model_(relocation_model) {}
163 
164 CpuAotCompilationOptions::~CpuAotCompilationOptions() = default;
165 
PlatformId() const166 se::Platform::Id CpuAotCompilationOptions::PlatformId() const {
167   return se::host::kHostPlatformId;
168 }
169 
CpuAotCompilationResult(ObjectFileData object_file_data,std::vector<BufferInfo> buffer_infos,int64 result_buffer_index,std::unique_ptr<HloProfilePrinterData> hlo_profile_printer_data)170 CpuAotCompilationResult::CpuAotCompilationResult(
171     ObjectFileData object_file_data, std::vector<BufferInfo> buffer_infos,
172     int64 result_buffer_index,
173     std::unique_ptr<HloProfilePrinterData> hlo_profile_printer_data)
174     : object_file_data_(std::move(object_file_data)),
175       buffer_infos_(std::move(buffer_infos)),
176       result_buffer_index_(result_buffer_index),
177       hlo_profile_printer_data_(std::move(hlo_profile_printer_data)) {}
178 
179 CpuAotCompilationResult::~CpuAotCompilationResult() = default;
180 
CpuCompiler()181 CpuCompiler::CpuCompiler() {
182   // Initialize LLVM the first time the CpuCompiler is initialized.
183   static bool llvm_initialized = []() {
184     InitializeLLVMTarget();
185     return true;
186   }();
187   (void)llvm_initialized;
188 }
189 
Compile(std::unique_ptr<HloModuleGroup> module_group,std::vector<std::vector<se::StreamExecutor * >> stream_execs,const CompileOptions & options)190 StatusOr<std::vector<std::unique_ptr<Executable>>> CpuCompiler::Compile(
191     std::unique_ptr<HloModuleGroup> module_group,
192     std::vector<std::vector<se::StreamExecutor*>> stream_execs,
193     const CompileOptions& options) {
194   for (const std::vector<se::StreamExecutor*>& se_vector : stream_execs) {
195     if (se_vector.size() != 1) {
196       return Unimplemented(
197           "Model partitioning not implemented for the CPU compiler");
198     }
199   }
200   return LLVMCompiler::Compile(std::move(module_group), stream_execs, options);
201 }
202 
InitializeLLVMTarget()203 /* static */ void CpuCompiler::InitializeLLVMTarget() {
204   // Initialize LLVM's MC layer for the native target.
205   llvm::InitializeNativeTarget();
206   llvm::InitializeNativeTargetAsmPrinter();
207 }
208 
209 namespace {
210 
211 // LLVM makes certain options configurable only through its command-line
212 // options; it provide the ParseCommandLineOptions function that lets us set
213 // flags at runtime. However, since these flags are global we want to avoid
214 // multiple invocations of the LLVM compilation pipeline with a different set of
215 // flags. Therefore, we only pass command-line flags to LLVM once, before the
216 // first module is compiled.
217 absl::once_flag llvm_command_line_options_initialized;
218 
219 // This visitor records which HLO instructions should have profiling information
220 // recorded.
221 class CollectProfileCandidates : public DfsHloVisitorWithDefault {
222  public:
223   static StatusOr<std::unordered_map<const HloInstruction*, int64>>
GetCandidatesForComputation(const HloComputation & computation,const std::unordered_map<const HloInstruction *,int64> & assigned_indices)224   GetCandidatesForComputation(
225       const HloComputation& computation,
226       const std::unordered_map<const HloInstruction*, int64>&
227           assigned_indices) {
228     std::unordered_map<const HloInstruction*, int64> hlo_to_profile_idx;
229     CollectProfileCandidates profile_candidates_for_computation(
230         &hlo_to_profile_idx, assigned_indices);
231     TF_RETURN_IF_ERROR(computation.Accept(&profile_candidates_for_computation));
232     return hlo_to_profile_idx;
233   }
234 
235  private:
CollectProfileCandidates(std::unordered_map<const HloInstruction *,int64> * hlo_to_profile_idx,const std::unordered_map<const HloInstruction *,int64> & assigned_indices)236   CollectProfileCandidates(
237       std::unordered_map<const HloInstruction*, int64>* hlo_to_profile_idx,
238       const std::unordered_map<const HloInstruction*, int64>& assigned_indices)
239       : hlo_to_profile_idx_(hlo_to_profile_idx),
240         assigned_indices_(assigned_indices) {}
241 
DefaultAction(HloInstruction * hlo_instruction)242   Status DefaultAction(HloInstruction* hlo_instruction) override {
243     hlo_to_profile_idx_->insert(
244         {hlo_instruction, FindOrDie(assigned_indices_, hlo_instruction)});
245     return Status::OK();
246   }
247 
HandleCall(HloInstruction * call)248   Status HandleCall(HloInstruction* call) override {
249     TF_RETURN_IF_ERROR(DefaultAction(call));
250     CollectProfileCandidates candidates_for_call(hlo_to_profile_idx_,
251                                                  assigned_indices_);
252     TF_RETURN_IF_ERROR(call->to_apply()->Accept(&candidates_for_call));
253     return Status::OK();
254   }
255 
256   // Skip constants, there is nothing to profile.
HandleConstant(HloInstruction *)257   Status HandleConstant(HloInstruction*) override { return Status::OK(); }
258   // Skip parameters, they are a simple load.
HandleParameter(HloInstruction *)259   Status HandleParameter(HloInstruction*) override { return Status::OK(); }
260   // It is important to recurse for "while" or else we risk overly coarse
261   // profiling information.
HandleWhile(HloInstruction * xla_while)262   Status HandleWhile(HloInstruction* xla_while) override {
263     TF_RETURN_IF_ERROR(DefaultAction(xla_while));
264 
265     CollectProfileCandidates candidates_for_condition(hlo_to_profile_idx_,
266                                                       assigned_indices_);
267     TF_RETURN_IF_ERROR(
268         xla_while->while_condition()->Accept(&candidates_for_condition));
269 
270     CollectProfileCandidates candidates_for_body(hlo_to_profile_idx_,
271                                                  assigned_indices_);
272     TF_RETURN_IF_ERROR(xla_while->while_body()->Accept(&candidates_for_body));
273 
274     return Status::OK();
275   }
276 
277   std::unordered_map<const HloInstruction*, int64>* hlo_to_profile_idx_;
278   const std::unordered_map<const HloInstruction*, int64>& assigned_indices_;
279 };
280 
281 }  // namespace
282 
RunHloPassesThroughLayoutAssn(HloModule * module,bool,LLVMTargetMachineFeatures * target_machine_features)283 Status CpuCompiler::RunHloPassesThroughLayoutAssn(
284     HloModule* module, bool /*is_aot_compile*/,
285     LLVMTargetMachineFeatures* target_machine_features) {
286   HloPassPipeline pipeline("HLO passes through layout assignment");
287   pipeline.AddInvariantChecker<HloVerifier>(/*layout_sensitive=*/false,
288                                             /*allow_mixed_precision=*/false);
289 
290   pipeline.AddPass<OperandUpcaster>();
291 
292   // Expand random number generation.
293   pipeline.AddPass<RngExpander>();
294   pipeline.AddPass<RngBitGeneratorExpander>(RandomAlgorithm::RNG_PHILOX);
295 
296   // Remove zero-sized HLO from the input so that other passes don't have to
297   // handle it.
298   pipeline.AddPass<ZeroSizedHloElimination>();
299 
300   pipeline.AddPass<DynamicIndexSplitter>();
301 
302   pipeline.AddPass<ConditionalToSelect>();
303   pipeline.AddPass<MapInliner>();
304 
305   pipeline.AddPass<ComparisonExpander>();
306   pipeline.AddPass<CholeskyExpander>();
307   pipeline.AddPass<QrExpander>();
308   pipeline.AddPass<TriangularSolveExpander>();
309   pipeline.AddPass<AllGatherDecomposer>();
310   pipeline.AddPass<AllToAllDecomposer>();
311 
312   // Inline computations with a single call site.
313   pipeline.AddPass<CallInliner>(/*single_call_site=*/true);
314   pipeline.AddPass<BatchDotSimplification>();
315   pipeline.AddPass<DotDecomposer>();
316   // Convert BF16 operations to F32 operations so that the CPU backend can
317   // support BF16 operations without directly implementing a BF16 lowering for
318   // most ops.
319   pipeline.AddPass<HloElementTypeConverter>(BF16, F32);
320   // After canonicalization, there may be more batch dots that can be
321   // simplified.
322   pipeline.AddPass<BatchDotSimplification>();
323   auto cost_model = [](HloInstruction* conv) {
324     // We need a cost model for CPUs. Currently, do nothing.
325     return false;
326   };
327   pipeline.AddPass<ConvolutionGroupConverter>(
328       cost_model,
329       /*convert_batch_groups_only=*/true);
330   pipeline.AddPass<ConvolutionGroupConverter>(
331       cost_model,
332       /*convert_batch_groups_only=*/false);
333   pipeline.AddPass<BatchNormExpander>(
334       /*rewrite_training_op=*/true,
335       /*rewrite_inference_op=*/true,
336       /*rewrite_grad_op=*/true);
337   pipeline.AddPass<LogisticExpander>(
338       /*expansion_type=*/LogisticExpansionType::kExp);
339   pipeline.AddPass<ConditionalCanonicalizer>();
340   pipeline.AddPass<DynamicPadder>();
341   pipeline.AddPass<ScatterExpander>(ScatterExpander::kEliminateAllScatters);
342   pipeline.AddPass<ConvCanonicalization>(target_machine_features);
343   {
344     auto& pass =
345         pipeline.AddPass<HloPassFix<HloPassPipeline>>("simplification");
346     pass.AddInvariantCheckerDebug<HloVerifier>(/*layout_sensitive=*/false,
347                                                /*allow_mixed_precision=*/false);
348 
349     pass.AddPass<TreeReductionRewriter>();
350     AlgebraicSimplifierOptions options;
351     options.set_enable_dot_strength_reduction(false);
352     pass.AddPass<AlgebraicSimplifier>(options);
353     pass.AddPass<SortSimplifier>();
354     pass.AddPass<HloDCE>();
355     pass.AddPass<GatherExpander>(GatherExpander::kEliminateSimpleGathers);
356 
357     // BatchNormExpander can create zero-sized ops, so zero-sized HLO
358     // elimination has to come after that pass.
359     pass.AddPass<ZeroSizedHloElimination>();
360 
361     pass.AddPass<WhileLoopInvariantCodeMotion>();
362     pass.AddPass<TupleSimplifier>();
363     pass.AddPass<WhileLoopConstantSinking>();
364     pass.AddPass<WhileLoopSimplifier>();
365 
366     // TODO(b/134075051): Re-enable after b/134075051 is fixed.
367     // pass.AddPass<SliceSinker>();
368 
369     pass.AddPass<HloDCE>();
370     pass.AddPass<ReshapeMover>();
371     pass.AddPass<HloConstantFolding>();
372     pass.AddPass<ConditionalSimplifier>();
373   }
374   pipeline.AddPass<TopkRewriter>([](const HloSortInstruction* sort, int64) {
375     return sort->operand(0)->shape().element_type() == F32;
376   });
377   pipeline.AddPass<IndexedArrayAnalysisPrinterPass>();
378   pipeline.AddPass<TransposeFolding>(
379       [&](const HloInstruction& dot,
380           const TransposeFolding::OperandIndices& candidate_operands) {
381         return DotImplementationCanHandleTranspose(dot,
382                                                    *target_machine_features)
383                    ? candidate_operands
384                    : TransposeFolding::OperandIndices{};
385       },
386       TransposeFolding::NeverFoldTranspose);
387   pipeline.AddPass<HloCSE>(/*is_layout_sensitive=*/false);
388 
389   // Layout assignment uses alias analysis, which requires the call graph to be
390   // flattened.
391   pipeline.AddPass<FlattenCallGraph>();
392   pipeline.AddPass<CpuLayoutAssignment>(
393       module->mutable_entry_computation_layout(),
394       LayoutAssignment::InstructionCanChangeLayout, target_machine_features);
395 
396   pipeline.AddPass<CpuInstructionFusion>();
397 
398   return pipeline.Run(module).status();
399 }
400 
RunHloPassesAfterLayoutAssn(HloModule * module,bool is_aot_compile,LLVMTargetMachineFeatures * target_machine_features)401 Status CpuCompiler::RunHloPassesAfterLayoutAssn(
402     HloModule* module, bool is_aot_compile,
403     LLVMTargetMachineFeatures* target_machine_features) {
404   HloPassPipeline pipeline("HLO passes after layout assignment");
405   // After layout assignment, use a layout-sensitive verifier.
406 
407   pipeline.AddPass<HloPassPipeline>("after layout assignment")
408       .AddInvariantCheckerDebug<HloVerifier>(
409           /*layout_sensitive=*/true,
410           /*allow_mixed_precision=*/false);
411 
412   // The LayoutAssignment pass may leave behind kCopy instructions which are
413   // duplicate or NOPs, so remove them with algebraic simplification and CSE.
414   {
415     auto& pass = pipeline.AddPass<HloPassFix<HloPassPipeline>>(
416         "simplification after layout assignment");
417     pass.AddInvariantCheckerDebug<HloVerifier>(
418         /*layout_sensitive=*/true,
419         /*allow_mixed_precision=*/false,
420         LayoutAssignment::InstructionCanChangeLayout);
421     AlgebraicSimplifierOptions options;
422     options.set_is_layout_sensitive(true);
423     options.set_enable_dot_strength_reduction(false);
424     pass.AddPass<HloPassFix<AlgebraicSimplifier>>(options);
425     pass.AddPass<HloDCE>();
426     pass.AddPass<HloCSE>(/*is_layout_sensitive=*/true);
427   }
428 
429   // Outline ops in the entry computation into calls to subcomputations.
430   const int max_parallelism =
431       module->config().intra_op_parallelism_threads() > 0
432           ? module->config().intra_op_parallelism_threads()
433           : tensorflow::port::NumSchedulableCPUs();
434   if (!is_aot_compile) {
435     // Run ParallelTaskAssigner to assign parallel tasks to HLOs in module.
436     // Note this is not run for AOT because it would bring in thread pool
437     // and thread synchronization dependencies which would likely increase
438     // binary size (and most AOT applications are single-threaded).
439     // TODO(b/29630486) Support multi-threaded AOT.
440     pipeline.AddPass<ParallelTaskAssigner>(
441         max_parallelism, ShapeSizeBytesFunction(), target_machine_features);
442   }
443   // Copy insertion should be performed immediately before IR emission to
444   // avoid inserting unnecessary copies (later pass adds an instruction which
445   // materializes the value) or missing a necessary copy (later pass removes
446   // an instruction which materializes a value). DCE must be run immediately
447   // before (and sometime after) copy insertion, to avoid dead code from
448   // interfering with the rewrites.
449   pipeline.AddPass<HloDCE>();
450   pipeline.AddPass<CopyInsertion>();
451   pipeline.AddPass<HloDCE>();
452   return pipeline.Run(module).status();
453 }
454 
RunHloPasses(HloModule * module,bool is_aot_compile,llvm::TargetMachine * target_machine)455 Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile,
456                                  llvm::TargetMachine* target_machine) {
457   LLVMTargetMachineFeatures target_machine_features(target_machine);
458   TF_RETURN_IF_ERROR(RunHloPassesThroughLayoutAssn(module, is_aot_compile,
459                                                    &target_machine_features));
460   return RunHloPassesAfterLayoutAssn(module, is_aot_compile,
461                                      &target_machine_features);
462 }
463 
464 namespace {
465 
466 // Align buffers to 16-byte boundaries.
memory_alignment(LogicalBuffer::Color)467 int64 memory_alignment(LogicalBuffer::Color) {
468   return cpu_function_runtime::kMinAlign;
469 }
470 
CompilerTargetOptions(const HloModuleConfig & module_config)471 llvm::TargetOptions CompilerTargetOptions(
472     const HloModuleConfig& module_config) {
473   llvm::TargetOptions target_options;
474   // Always allow FMA fusion. This increases precision instead of decreasing it.
475   target_options.AllowFPOpFusion = llvm::FPOpFusion::Fast;
476   return target_options;
477 }
478 
CodeGenOptLevel(const HloModuleConfig & module_config)479 llvm::CodeGenOpt::Level CodeGenOptLevel(const HloModuleConfig& module_config) {
480   VLOG(2) << "backend_optimization_level: "
481           << module_config.debug_options().xla_backend_optimization_level();
482   switch (module_config.debug_options().xla_backend_optimization_level()) {
483     case 1:
484       return llvm::CodeGenOpt::Less;
485     case 2:
486       return llvm::CodeGenOpt::Default;
487     case 3:
488       return llvm::CodeGenOpt::Aggressive;
489     default:
490       return llvm::CodeGenOpt::None;
491   }
492 }
493 
GetIRModuleHooks(const HloModule & hlo_module,const LLVMCompiler::ModuleHook & user_pre_optimization_hook,const LLVMCompiler::ModuleHook & user_post_optimization_hook)494 std::pair<LLVMCompiler::ModuleHook, LLVMCompiler::ModuleHook> GetIRModuleHooks(
495     const HloModule& hlo_module,
496     const LLVMCompiler::ModuleHook& user_pre_optimization_hook,
497     const LLVMCompiler::ModuleHook& user_post_optimization_hook) {
498   // Create the IR hooks. If applicable, each IR hook does the following:
499   //
500   //  * Calls the user supplied module hook.
501   //  * Writes out the IR to a file in the output directory designated by
502   //    --xla_dump_to
503   const HloModule* hlo_module_ptr = &hlo_module;
504   auto hook = [user_pre_optimization_hook, user_post_optimization_hook,
505                hlo_module_ptr](bool optimized,
506                                const llvm::Module& llvm_module) {
507     const auto& user_hook =
508         !optimized ? user_pre_optimization_hook : user_post_optimization_hook;
509     if (user_hook) {
510       user_hook(llvm_module);
511     }
512     llvm_ir::DumpIrIfEnabled(*hlo_module_ptr, llvm_module, optimized);
513   };
514   return {[hook](const llvm::Module& llvm_module) {
515             return hook(/*optimized=*/false, llvm_module);
516           },
517           [hook](const llvm::Module& llvm_module) {
518             return hook(/*optimized=*/true, llvm_module);
519           }};
520 }
521 
VerifyLlvmModule(const llvm::Module & llvm_module)522 Status VerifyLlvmModule(const llvm::Module& llvm_module) {
523   XLA_SCOPED_LOGGING_TIMER("CpuCompiler - Running LLVM verifier");
524 
525   std::string err;
526   llvm::raw_string_ostream err_stream(err);
527 
528   // verifyModule() returns true if the module is broken.
529   TF_RET_CHECK(!llvm::verifyModule(llvm_module, &err_stream))
530       << "Invalid LLVM IR before optimizations:\n"
531       << err_stream.str()
532       << "\nThis probably indicates a bug in the HLO -> LLVM IR lowering. "
533          "Rerun with --xla_dump_to to get the IR. ";
534   return Status::OK();
535 }
536 
CreateHloProfilingArtifacts(const HloModule & module,std::unordered_map<const HloInstruction *,int64> * instruction_to_profile_idx,std::unordered_map<const HloComputation *,int64> * computation_to_profile_idx,std::unique_ptr<HloProfileIndexMap> * hlo_profile_index_map,std::unique_ptr<HloProfilePrinterData> * hlo_profile_printer_data)537 Status CreateHloProfilingArtifacts(
538     const HloModule& module,
539     std::unordered_map<const HloInstruction*, int64>*
540         instruction_to_profile_idx,
541     std::unordered_map<const HloComputation*, int64>*
542         computation_to_profile_idx,
543     std::unique_ptr<HloProfileIndexMap>* hlo_profile_index_map,
544     std::unique_ptr<HloProfilePrinterData>* hlo_profile_printer_data) {
545   *hlo_profile_index_map = absl::make_unique<HloProfileIndexMap>(module);
546   const HloComputation& entry_computation = *module.entry_computation();
547 
548   TF_ASSIGN_OR_RETURN(
549       *instruction_to_profile_idx,
550       CollectProfileCandidates::GetCandidatesForComputation(
551           entry_computation,
552           (*hlo_profile_index_map)->instruction_to_profile_idx()));
553 
554   auto shape_size_bytes = [](const Shape& shape) {
555     // On the cpu, opaques are pointers.
556     if (shape.IsOpaque()) {
557       return static_cast<int64>(sizeof(void*));
558     }
559     return ShapeUtil::ByteSizeOf(shape, sizeof(void*));
560   };
561 
562   HloCostAnalysis cost_analysis(shape_size_bytes);
563   TF_RETURN_IF_ERROR(entry_computation.Accept(&cost_analysis));
564   *hlo_profile_printer_data = CreateHloProfilePrinterData(
565       **hlo_profile_index_map, cost_analysis, entry_computation.name());
566   *computation_to_profile_idx =
567       (*hlo_profile_index_map)->computation_to_profile_idx();
568 
569   return Status::OK();
570 }
571 
572 }  // namespace
573 
RunHloPasses(std::unique_ptr<HloModule> module,se::StreamExecutor *,const CompileOptions &)574 StatusOr<std::unique_ptr<HloModule>> CpuCompiler::RunHloPasses(
575     std::unique_ptr<HloModule> module, se::StreamExecutor* /*stream_exec*/,
576     const CompileOptions& /*options*/) {
577   std::unique_ptr<llvm::TargetMachine> jit_target_machine =
578       SimpleOrcJIT::InferTargetMachineForJIT(
579           CompilerTargetOptions(module->config()),
580           CodeGenOptLevel(module->config()));
581 
582   TF_RETURN_IF_ERROR(RunHloPasses(module.get(), /*is_aot_compile=*/false,
583                                   jit_target_machine.get()));
584   return std::move(module);
585 }
586 
587 StatusOr<
588     std::tuple<std::unique_ptr<HloModule>, std::unique_ptr<BufferAssignment>>>
RunHloPassesAndBufferAssignement(std::unique_ptr<HloModule> module,se::StreamExecutor * executor,bool optimize,const CompileOptions & options)589 CpuCompiler::RunHloPassesAndBufferAssignement(std::unique_ptr<HloModule> module,
590                                               se::StreamExecutor* executor,
591                                               bool optimize,
592                                               const CompileOptions& options) {
593   if (optimize) {
594     TF_ASSIGN_OR_RETURN(module,
595                         RunHloPasses(std::move(module), executor, options));
596   }
597 
598   // Select an order for emitting the HLO instructions for each computation.
599   // Using this sequence enables tighter buffer liveness analysis and reduced
600   // memory usage (as compared to using DependencyHloOrdering).
601   TF_ASSIGN_OR_RETURN(HloSchedule schedule,
602                       ScheduleModule(module.get(), BufferSizeBytesFunction(),
603                                      ComputationSchedulerToModuleScheduler(
604                                          DFSMemoryScheduler)));
605 
606   // Run buffer allocation on the HLO graph.
607   TF_ASSIGN_OR_RETURN(
608       std::unique_ptr<BufferAssignment> assignment,
609       BufferAssigner::Run(module.get(),
610                           absl::make_unique<SequentialHloOrdering>(schedule),
611                           BufferSizeBytesFunction(), memory_alignment,
612                           /*allocate_buffers_for_constants=*/true));
613 
614   return std::make_tuple(std::move(module), std::move(assignment));
615 }
616 
617 namespace {
618 
619 // Post-compilation callback functor for use by SimpleOrcJIT.
620 //
621 // Dumps machine code if dumping is enabled for the module.
622 struct OrcJITPostCompilationHook {
623   // Gets an std::function that implements this hook.
Createxla::cpu::__anon31d794070c11::OrcJITPostCompilationHook624   static std::function<void(const llvm::object::ObjectFile& obj_file)> Create(
625       const HloModule* module) {
626     // This struct is not copyable, but std::functions must be.  So to create an
627     // std::function out of this struct, we have to wrap it in a shared_ptr.
628     auto wrapped = std::make_shared<OrcJITPostCompilationHook>(module);
629     return [wrapped](const llvm::object::ObjectFile& obj_file) {
630       (*wrapped)(obj_file);
631     };
632   }
633 
634   // Constructor can't be private because we want to call it from
635   // std::make_shared, but users should call Create() instead.
OrcJITPostCompilationHookxla::cpu::__anon31d794070c11::OrcJITPostCompilationHook636   explicit OrcJITPostCompilationHook(const HloModule* module)
637       : module(module) {}
638 
639  private:
operator ()xla::cpu::__anon31d794070c11::OrcJITPostCompilationHook640   void operator()(const llvm::object::ObjectFile& obj_file) {
641     if (!DumpingEnabledForHloModule(*module)) {
642       return;
643     }
644     DumpToFileInDir(*module, /*file_prefix=*/"", /*file_suffix=*/"o",
645                     absl::string_view(obj_file.getData().data(),
646                                       obj_file.getData().size()));
647   }
648 
649   const HloModule* module;
650 };
651 
652 }  // namespace
653 
RunBackend(std::unique_ptr<HloModule> module,se::StreamExecutor * stream_exec,const CompileOptions & options)654 StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend(
655     std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
656     const CompileOptions& options) {
657   VLOG(1) << "Compiling: " << module->name();
658   XLA_SCOPED_LOGGING_TIMER(
659       absl::StrFormat("Compiling [%s] for CPU using JIT", module->name()));
660   std::string slow_compilation_msg =
661       absl::StrCat("Compiling module ", module->name());
662   auto slow_compile_alarm = SlowCompilationAlarm(slow_compilation_msg);
663 
664   TF_RET_CHECK(stream_exec != nullptr);
665   absl::call_once(llvm_command_line_options_initialized,
666                   &llvm_ir::InitializeLLVMCommandLineOptions, module->config());
667 
668   ModuleHook pre_optimization_ir_hook;
669   ModuleHook post_optimization_ir_hook;
670   std::tie(pre_optimization_ir_hook, post_optimization_ir_hook) =
671       GetIRModuleHooks(*module, user_pre_optimization_hook_,
672                        user_post_optimization_hook_);
673 
674   // Compile must be thread-safe so create a new LLVM context for the module.
675   mlir::MLIRContext mlir_context;
676   LoadMLIRDialects(mlir_context);
677   auto llvm_context = std::make_unique<llvm::LLVMContext>();
678   auto llvm_module =
679       absl::make_unique<llvm::Module>("__compute_module", *llvm_context);
680 
681   auto jit = SimpleOrcJIT::Create(
682       CompilerTargetOptions(module->config()),
683       CodeGenOptLevel(module->config()),
684       options::OptimizeForSizeRequested(module->config()),
685       module->config().debug_options().xla_llvm_disable_expensive_passes(),
686       llvm_ir::GetCpuFastMathFlags(module->config()), pre_optimization_ir_hook,
687       post_optimization_ir_hook,
688       OrcJITPostCompilationHook::Create(module.get()));
689   if (!jit) {
690     return InternalError("Creating JIT failed: %s",
691                          llvm::toString(jit.takeError()));
692   }
693   llvm_module->setDataLayout((*jit)->data_layout());
694   llvm_module->setTargetTriple((*jit)->target_triple().getTriple());
695 
696   HloComputation* entry_computation = module->entry_computation();
697   std::unordered_map<const HloInstruction*, int64> instruction_to_profile_idx;
698   std::unordered_map<const HloComputation*, int64> computation_to_profile_idx;
699   std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map;
700   std::unique_ptr<HloProfilePrinterData> hlo_profile_printer_data;
701   if (module->config().hlo_profiling_enabled()) {
702     TF_RETURN_IF_ERROR(CreateHloProfilingArtifacts(
703         *module, &instruction_to_profile_idx, &computation_to_profile_idx,
704         &hlo_profile_index_map, &hlo_profile_printer_data));
705   }
706 
707   std::unique_ptr<Executable> cpu_executable;
708 
709   // Cache these flags here since we'll want to access them after the module's
710   // ownership is std::moved.
711   const bool embed_ir_in_executable =
712       module->config().debug_options().xla_embed_ir_in_executable();
713 
714   // Select an order for emitting the HLO instructions for each
715   // computation. Using this sequence enables tighter buffer liveness analysis
716   // and reduced memory usage (as compared to using DependencyHloOrdering).
717   TF_ASSIGN_OR_RETURN(HloSchedule schedule,
718                       ScheduleModule(module.get(), BufferSizeBytesFunction(),
719                                      ComputationSchedulerToModuleScheduler(
720                                          DFSMemoryScheduler)));
721 
722   // Run buffer allocation on the HLO graph.
723   TF_ASSIGN_OR_RETURN(
724       std::unique_ptr<BufferAssignment> assignment,
725       BufferAssigner::Run(module.get(),
726                           absl::make_unique<SequentialHloOrdering>(schedule),
727                           BufferSizeBytesFunction(), memory_alignment,
728                           /*allocate_buffers_for_constants=*/true));
729   DumpHloModuleIfEnabled(*module, *assignment, "after_optimizations");
730 
731   // Each computation is a single function.  Emit all embedded computations
732   // before the entry computation. The order of computations returned from
733   // GetEmbeddedComputations guarantees that a called computation occurs
734   // before a caller computation.
735 
736   LLVMTargetMachineFeatures target_machine_features((*jit)->target_machine());
737   IrEmitter ir_emitter(&mlir_context, *module, *assignment, llvm_module.get(),
738                        std::move(instruction_to_profile_idx),
739                        std::move(computation_to_profile_idx),
740                        &target_machine_features,
741 #ifdef MEMORY_SANITIZER
742                        /*emit_code_for_msan=*/true
743 #else
744                        /*emit_code_for_msan=*/false
745 #endif
746   );
747 
748   TF_RETURN_IF_ERROR(ir_emitter.EmitConstantGlobals());
749 
750   for (auto embedded_computation :
751        entry_computation->MakeEmbeddedComputationsList()) {
752     if (embedded_computation->IsFusionComputation()) {
753       continue;
754     }
755     TF_RETURN_IF_ERROR(
756         ir_emitter
757             .EmitComputation(
758                 embedded_computation, embedded_computation->name(),
759                 /*is_top_level_computation=*/false,
760                 schedule.sequence(embedded_computation).instructions())
761             .status());
762   }
763   string function_name_prefix = entry_computation->name().empty()
764                                     ? "__compute"
765                                     : entry_computation->name();
766   TF_ASSIGN_OR_RETURN(llvm::Function * entry_function,
767                       ir_emitter.EmitComputation(
768                           entry_computation, function_name_prefix,
769                           /*is_top_level_computation=*/true,
770                           schedule.sequence(entry_computation).instructions()));
771 
772   string function_name = [&]() {
773     llvm::SmallVector<char, 40> function_name_vector;
774     llvm::Mangler::getNameWithPrefix(
775         function_name_vector, entry_function->getName(), (*jit)->data_layout());
776     return string(function_name_vector.begin(), function_name_vector.end());
777   }();
778 
779   string ir_module_string;
780   if (embed_ir_in_executable) {
781     ir_module_string = llvm_ir::DumpModuleToString(*llvm_module);
782   }
783 
784   TF_RETURN_IF_ERROR(VerifyLlvmModule(*llvm_module));
785 
786   // JIT compile the LLVM IR module to in-memory machine code.
787   llvm::orc::ThreadSafeModule thread_safe_module(std::move(llvm_module),
788                                                  std::move(llvm_context));
789   cantFail((*jit)->AddModule(std::move(thread_safe_module)));
790   cpu_executable.reset(new CpuExecutable(
791       std::move(*jit), std::move(assignment), std::move(module), function_name,
792       std::move(hlo_profile_printer_data), std::move(hlo_profile_index_map)));
793 
794   if (embed_ir_in_executable) {
795     static_cast<CpuExecutable&>(*cpu_executable)
796         .set_ir_module_string(ir_module_string);
797   }
798 
799   VLOG(1) << "Compilation finished";
800   return std::move(cpu_executable);
801 }
802 
803 StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group,const AotCompilationOptions & aot_options)804 CpuCompiler::CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group,
805                                 const AotCompilationOptions& aot_options) {
806   TF_RET_CHECK(!module_group->empty());
807   std::vector<std::unique_ptr<HloModule>> modules =
808       module_group->ConsumeModules();
809 
810   absl::call_once(llvm_command_line_options_initialized,
811                   &llvm_ir::InitializeLLVMCommandLineOptions,
812                   modules[0]->config());
813 
814   // We can pass just one llvm::TargetOptions when we compile the LLVM module,
815   // so we bail if the configs have conflicting flags. At the moment, the only
816   // flags that need to be consistent are for fast-math.
817   for (const auto& fn_and_name :
818        {std::make_pair(&DebugOptions::xla_cpu_enable_fast_math,
819                        "xla_cpu_enable_fast_math"),
820         std::make_pair(&DebugOptions::xla_cpu_fast_math_honor_infs,
821                        "xla_cpu_fast_math_honor_infs"),
822         std::make_pair(&DebugOptions::xla_cpu_fast_math_honor_nans,
823                        "xla_cpu_fast_math_honor_nans")}) {
824     // This only works because each of the method pointers above returns a bool.
825     // Otherwise we'd have to do some template magic.
826     const auto& field_method_ptr = fn_and_name.first;
827     const auto& field_name = fn_and_name.second;
828     bool first_module_val =
829         (modules[0]->config().debug_options().*field_method_ptr)();
830     for (int64 i = 0; i < modules.size(); ++i) {
831       bool cur_module_val =
832           (modules[i]->config().debug_options().*field_method_ptr)();
833       if (first_module_val != cur_module_val) {
834         return InvalidArgument(
835             "All HLO module configs must have the same value for %s, but "
836             "module 0 and %d have different values (%d vs %d).",
837             field_name, i, first_module_val, cur_module_val);
838       }
839     }
840   }
841 
842   if (aot_options.PlatformId() != se::host::kHostPlatformId) {
843     return InvalidArgument("Incompatible AOT compilation platform");
844   }
845   const CpuAotCompilationOptions& options =
846       static_cast<const CpuAotCompilationOptions&>(aot_options);
847   llvm::Triple triple(llvm::Triple::normalize(options.triple()));
848   std::string error;
849   const llvm::Target* target =
850       llvm::TargetRegistry::lookupTarget(triple.getTriple(), error);
851   if (target == nullptr) {
852     return InternalError("TargetRegistry::lookupTarget failed: %s", error);
853   }
854 
855   llvm::Reloc::Model reloc_model = llvm::Reloc::Static;
856   llvm::PICLevel::Level pic_level = llvm::PICLevel::NotPIC;
857   llvm::PIELevel::Level pie_level = llvm::PIELevel::Default;
858   switch (options.relocation_model()) {
859     case CpuAotCompilationOptions::RelocationModel::Static:
860       reloc_model = llvm::Reloc::Static;
861       pic_level = llvm::PICLevel::NotPIC;
862       pie_level = llvm::PIELevel::Default;
863       break;
864     case CpuAotCompilationOptions::RelocationModel::SmallPic:
865       reloc_model = llvm::Reloc::PIC_;
866       pic_level = llvm::PICLevel::SmallPIC;
867       pie_level = llvm::PIELevel::Default;
868       break;
869     case CpuAotCompilationOptions::RelocationModel::BigPic:
870       reloc_model = llvm::Reloc::PIC_;
871       pic_level = llvm::PICLevel::BigPIC;
872       pie_level = llvm::PIELevel::Default;
873       break;
874     case CpuAotCompilationOptions::RelocationModel::SmallPie:
875       reloc_model = llvm::Reloc::PIC_;
876       pic_level = llvm::PICLevel::SmallPIC;
877       pie_level = llvm::PIELevel::Small;
878       break;
879     case CpuAotCompilationOptions::RelocationModel::BigPie:
880       reloc_model = llvm::Reloc::PIC_;
881       pic_level = llvm::PICLevel::BigPIC;
882       pie_level = llvm::PIELevel::Large;
883       break;
884   }
885   llvm::CodeGenOpt::Level opt_level = CodeGenOptLevel(modules[0]->config());
886   std::unique_ptr<llvm::TargetMachine> target_machine =
887       absl::WrapUnique(target->createTargetMachine(
888           triple.getTriple(), options.cpu_name(), options.features(),
889           CompilerTargetOptions(modules[0]->config()), reloc_model, llvm::None,
890           opt_level));
891 
892   // Compile must be thread-safe so create a new LLVM context for the module.
893   mlir::MLIRContext mlir_context;
894   LoadMLIRDialects(mlir_context);
895   llvm::LLVMContext llvm_context;
896   llvm::Module llvm_module("__compute_module", llvm_context);
897   llvm_module.setDataLayout(target_machine->createDataLayout());
898   llvm_module.setTargetTriple(triple.getTriple());
899   if (pic_level != llvm::PICLevel::NotPIC) {
900     llvm_module.setPICLevel(pic_level);
901   }
902   if (pie_level != llvm::PIELevel::Default) {
903     llvm_module.setPIELevel(pie_level);
904   }
905 
906   std::vector<std::unique_ptr<AotCompilationResult>> results;
907   for (size_t i = 0; i < modules.size(); ++i) {
908     HloModule* module = modules[i].get();
909     VLOG(1) << "Compiling ahead-of-time: " << module->name();
910 
911     TF_RETURN_IF_ERROR(
912         RunHloPasses(module, /*is_aot_compile=*/true, target_machine.get()));
913 
914     TF_ASSIGN_OR_RETURN(HloSchedule schedule,
915                         ScheduleModule(module, BufferSizeBytesFunction()));
916 
917     // Run buffer analysis on the HLO graph. This analysis figures out which
918     // temporary buffers are required to run the computation.
919     TF_ASSIGN_OR_RETURN(
920         std::unique_ptr<BufferAssignment> assignment,
921         BufferAssigner::Run(module,
922                             absl::make_unique<SequentialHloOrdering>(schedule),
923                             BufferSizeBytesFunction(), memory_alignment,
924                             /*allocate_buffers_for_constants=*/true));
925     // BufferAssignment::ToString() includes a header, so no need for us to
926     // print one ourselves.
927     if (DumpingEnabledForHloModule(*module)) {
928       DumpToFileInDirOrStdout(*module, "", "buffer_assignment",
929                               assignment->ToString());
930     }
931     DumpHloModuleIfEnabled(*module, *assignment, "after_optimizations");
932 
933     std::unordered_map<const HloInstruction*, int64> instruction_to_profile_idx;
934     std::unordered_map<const HloComputation*, int64> computation_to_profile_idx;
935     std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map;
936     std::unique_ptr<HloProfilePrinterData> hlo_profile_printer_data;
937 
938     if (module->config().hlo_profiling_enabled()) {
939       TF_RETURN_IF_ERROR(CreateHloProfilingArtifacts(
940           *module, &instruction_to_profile_idx, &computation_to_profile_idx,
941           &hlo_profile_index_map, &hlo_profile_printer_data));
942     }
943 
944     LLVMTargetMachineFeatures target_machine_features(target_machine.get());
945     IrEmitter ir_emitter(&mlir_context, *module, *assignment, &llvm_module,
946                          std::move(instruction_to_profile_idx),
947                          std::move(computation_to_profile_idx),
948                          &target_machine_features,
949                          // TODO(b/66051036): Run full msan for AOT.
950                          /*emit_code_for_msan=*/false);
951 
952     TF_RETURN_IF_ERROR(ir_emitter.EmitConstantGlobals());
953 
954     HloComputation* computation = module->entry_computation();
955     for (auto embedded_computation :
956          computation->MakeEmbeddedComputationsList()) {
957       if (embedded_computation->IsFusionComputation()) {
958         continue;
959       }
960       TF_RETURN_IF_ERROR(
961           ir_emitter
962               .EmitComputation(
963                   embedded_computation, embedded_computation->name(),
964                   /*is_top_level_computation=*/false,
965                   schedule.sequence(embedded_computation).instructions())
966               .status());
967     }
968     const string& entry_point_name = options.entry_point_name();
969     TF_ASSIGN_OR_RETURN(llvm::Function * entry_function,
970                         ir_emitter.EmitComputation(
971                             computation, entry_point_name,
972                             /*is_top_level_computation=*/true,
973                             schedule.sequence(computation).instructions()));
974 
975     CHECK(entry_function->getName() == entry_point_name);
976 
977     ModuleHook pre_optimization_ir_hook;
978     ModuleHook post_optimization_ir_hook;
979     std::tie(pre_optimization_ir_hook, post_optimization_ir_hook) =
980         GetIRModuleHooks(*module, user_pre_optimization_hook_,
981                          user_post_optimization_hook_);
982 
983     // Run the LLVM verifier over the unoptimized LLVM IR.  If it fails, run the
984     // pre-optimization IR dump hook before returning.
985     {
986       Status verify_status = VerifyLlvmModule(llvm_module);
987       if (!verify_status.ok() && pre_optimization_ir_hook) {
988         pre_optimization_ir_hook(llvm_module);
989       }
990       TF_RETURN_IF_ERROR(verify_status);
991     }
992 
993     auto post_codegen_hook = [&](const llvm::object::ObjectFile& obj_file) {
994       if (!DumpingEnabledForHloModule(*module)) {
995         return;
996       }
997       DumpToFileInDir(*module, /*file_prefix=*/"", /*file_suffix=*/"o",
998                       absl::string_view(obj_file.getData().data(),
999                                         obj_file.getData().size()));
1000     };
1001 
1002     CompilerFunctor compiler_functor(
1003         target_machine.get(), opt_level,
1004         options::OptimizeForSizeRequested(module->config()),
1005         module->config().debug_options().xla_llvm_disable_expensive_passes(),
1006         llvm_ir::GetCpuFastMathFlags(module->config()),
1007         pre_optimization_ir_hook, post_optimization_ir_hook, post_codegen_hook);
1008     std::unique_ptr<llvm::MemoryBuffer> object_file =
1009         cantFail(compiler_functor(llvm_module));
1010     ObjectFileData object_file_data(object_file->getBufferStart(),
1011                                     object_file->getBufferEnd());
1012 
1013     std::vector<BufferInfo> buffer_infos =
1014         CreateBufferInfosFromBufferAssignment(*assignment);
1015 
1016     TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice result_slice,
1017                         assignment->GetUniqueTopLevelOutputSlice());
1018 
1019     results.emplace_back(absl::make_unique<CpuAotCompilationResult>(
1020         std::move(object_file_data), std::move(buffer_infos),
1021         result_slice.index(), std::move(hlo_profile_printer_data)));
1022   }
1023 
1024   VLOG(1) << "Compilation finished";
1025   return std::move(results);
1026 }
1027 
PlatformId() const1028 se::Platform::Id CpuCompiler::PlatformId() const {
1029   return se::host::kHostPlatformId;
1030 }
1031 
ShapeSizeBytesFunction() const1032 HloCostAnalysis::ShapeSizeFunction CpuCompiler::ShapeSizeBytesFunction() const {
1033   return CpuExecutable::ShapeSizeBytes;
1034 }
1035 
1036 }  // namespace cpu
1037 }  // namespace xla
1038 
InitModule()1039 static bool InitModule() {
1040   xla::Compiler::RegisterCompilerFactory(
1041       stream_executor::host::kHostPlatformId,
1042       []() { return absl::make_unique<xla::cpu::CpuCompiler>(); });
1043   return true;
1044 }
1045 static bool module_initialized = InitModule();
1046