• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h"
17 
18 #include <stddef.h>
19 #include <string.h>
20 
21 #include <map>
22 #include <memory>
23 #include <string>
24 #include <unordered_map>
25 #include <utility>
26 #include <vector>
27 
28 // IWYU pragma: no_include "llvm/Config/Disassemblers.def.inc"
29 // IWYU pragma: no_include "llvm/Config/Targets.def.inc"
30 #include "absl/base/call_once.h"
31 #include "absl/memory/memory.h"
32 #include "absl/strings/str_cat.h"
33 #include "llvm/ADT/StringRef.h"
34 #include "llvm/ADT/Triple.h"
35 #include "llvm/IR/Function.h"
36 #include "llvm/IR/LLVMContext.h"
37 #include "llvm/IR/Mangler.h"
38 #include "llvm/IR/Module.h"
39 #include "llvm/IR/Verifier.h"
40 #include "llvm/Object/ObjectFile.h"
41 #include "llvm/Support/CommandLine.h"
42 #include "llvm/Support/Error.h"
43 #include "llvm/Support/TargetRegistry.h"
44 #include "llvm/Support/TargetSelect.h"
45 #include "llvm/Target/TargetMachine.h"
46 #include "llvm/Target/TargetOptions.h"
47 #include "mlir/Dialect/Affine/IR/AffineOps.h"  // from @llvm-project
48 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"  // from @llvm-project
49 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h"  // from @llvm-project
50 #include "mlir/Dialect/SCF/SCF.h"  // from @llvm-project
51 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
52 #include "mlir/Dialect/Vector/VectorOps.h"  // from @llvm-project
53 #include "mlir/InitAllDialects.h"  // from @llvm-project
54 #include "tensorflow/compiler/xla/cpu_function_runtime.h"
55 #include "tensorflow/compiler/xla/literal.h"
56 #include "tensorflow/compiler/xla/map_util.h"
57 #include "tensorflow/compiler/xla/protobuf_util.h"
58 #include "tensorflow/compiler/xla/service/algebraic_simplifier.h"
59 #include "tensorflow/compiler/xla/service/all_gather_decomposer.h"
60 #include "tensorflow/compiler/xla/service/all_to_all_decomposer.h"
61 #include "tensorflow/compiler/xla/service/batch_dot_simplification.h"
62 #include "tensorflow/compiler/xla/service/batchnorm_expander.h"
63 #include "tensorflow/compiler/xla/service/bfloat16_normalization.h"
64 #include "tensorflow/compiler/xla/service/buffer_assignment.h"
65 #include "tensorflow/compiler/xla/service/call_inliner.h"
66 #include "tensorflow/compiler/xla/service/cholesky_expander.h"
67 #include "tensorflow/compiler/xla/service/comparison_expander.h"
68 #include "tensorflow/compiler/xla/service/conditional_canonicalizer.h"
69 #include "tensorflow/compiler/xla/service/conditional_simplifier.h"
70 #include "tensorflow/compiler/xla/service/conditional_to_select.h"
71 #include "tensorflow/compiler/xla/service/convolution_group_converter.h"
72 #include "tensorflow/compiler/xla/service/copy_insertion.h"
73 #include "tensorflow/compiler/xla/service/cpu/buffer_info_util.h"
74 #include "tensorflow/compiler/xla/service/cpu/compiler_functor.h"
75 #include "tensorflow/compiler/xla/service/cpu/conv_canonicalization.h"
76 #include "tensorflow/compiler/xla/service/cpu/cpu_executable.h"
77 #include "tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.h"
78 #include "tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.h"
79 #include "tensorflow/compiler/xla/service/cpu/cpu_options.h"
80 #include "tensorflow/compiler/xla/service/cpu/dot_op_emitter.h"
81 #include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h"
82 #include "tensorflow/compiler/xla/service/cpu/ir_emitter.h"
83 #include "tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h"
84 #include "tensorflow/compiler/xla/service/cpu/simple_orc_jit.h"
85 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
86 #include "tensorflow/compiler/xla/service/dot_decomposer.h"
87 #include "tensorflow/compiler/xla/service/dump.h"
88 #include "tensorflow/compiler/xla/service/dynamic_index_splitter.h"
89 #include "tensorflow/compiler/xla/service/dynamic_padder.h"
90 #include "tensorflow/compiler/xla/service/eigh_expander.h"
91 #include "tensorflow/compiler/xla/service/flatten_call_graph.h"
92 #include "tensorflow/compiler/xla/service/gather_expander.h"
93 #include "tensorflow/compiler/xla/service/hlo.pb.h"
94 #include "tensorflow/compiler/xla/service/hlo_computation.h"
95 #include "tensorflow/compiler/xla/service/hlo_constant_folding.h"
96 #include "tensorflow/compiler/xla/service/hlo_cse.h"
97 #include "tensorflow/compiler/xla/service/hlo_dce.h"
98 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
99 #include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h"
100 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
101 #include "tensorflow/compiler/xla/service/hlo_ordering.h"
102 #include "tensorflow/compiler/xla/service/hlo_pass_fix.h"
103 #include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h"
104 #include "tensorflow/compiler/xla/service/hlo_proto_util.h"
105 #include "tensorflow/compiler/xla/service/hlo_subcomputation_unification.h"
106 #include "tensorflow/compiler/xla/service/hlo_verifier.h"
107 #include "tensorflow/compiler/xla/service/indexed_array_analysis.h"
108 #include "tensorflow/compiler/xla/service/llvm_compiler.h"
109 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
110 #include "tensorflow/compiler/xla/service/logistic_expander.h"
111 #include "tensorflow/compiler/xla/service/map_inliner.h"
112 #include "tensorflow/compiler/xla/service/operand_upcaster.h"
113 #include "tensorflow/compiler/xla/service/qr_expander.h"
114 #include "tensorflow/compiler/xla/service/reduce_scatter_decomposer.h"
115 #include "tensorflow/compiler/xla/service/reshape_mover.h"
116 #include "tensorflow/compiler/xla/service/result_caster.h"
117 #include "tensorflow/compiler/xla/service/rng_bit_generator_expander.h"
118 #include "tensorflow/compiler/xla/service/rng_expander.h"
119 #include "tensorflow/compiler/xla/service/scatter_expander.h"
120 #include "tensorflow/compiler/xla/service/slice_sinker.h"
121 #include "tensorflow/compiler/xla/service/slow_operation_alarm.h"
122 #include "tensorflow/compiler/xla/service/sort_simplifier.h"
123 #include "tensorflow/compiler/xla/service/topk_rewriter.h"
124 #include "tensorflow/compiler/xla/service/transpose_folding.h"
125 #include "tensorflow/compiler/xla/service/tree_reduction_rewriter.h"
126 #include "tensorflow/compiler/xla/service/triangular_solve_expander.h"
127 #include "tensorflow/compiler/xla/service/tuple_simplifier.h"
128 #include "tensorflow/compiler/xla/service/while_loop_constant_sinking.h"
129 #include "tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h"
130 #include "tensorflow/compiler/xla/service/while_loop_simplifier.h"
131 #include "tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h"
132 #include "tensorflow/compiler/xla/status_macros.h"
133 #include "tensorflow/compiler/xla/statusor.h"
134 #include "tensorflow/compiler/xla/types.h"
135 #include "tensorflow/compiler/xla/util.h"
136 #include "tensorflow/compiler/xla/xla_data.pb.h"
137 #include "tensorflow/core/platform/dynamic_annotations.h"
138 
139 namespace {
140 
141 // We need to explicitly load all the dialects we will involved in emitting the
142 // IR. This is only needed because of how MLIR is bolted into XLA and does not
143 // make use of the MLIR infrastructure (like using a proper pass pipeline).
144 // Hopefully this will all go away at some point in favor of a better
145 // integration.
LoadMLIRDialects(mlir::MLIRContext & context)146 void LoadMLIRDialects(mlir::MLIRContext& context) {
147   context.loadDialect<mlir::linalg::LinalgDialect, mlir::scf::SCFDialect,
148                       mlir::vector::VectorDialect, mlir::StandardOpsDialect,
149                       mlir::AffineDialect>();
150 }
151 
152 }  // namespace
153 
154 namespace xla {
155 namespace cpu {
156 using BufferInfo = cpu_function_runtime::BufferInfo;
157 
CpuAotCompilationOptions(string triple,string cpu_name,string features,string entry_point_name,RelocationModel relocation_model)158 CpuAotCompilationOptions::CpuAotCompilationOptions(
159     string triple, string cpu_name, string features, string entry_point_name,
160     RelocationModel relocation_model)
161     : triple_(std::move(triple)),
162       cpu_name_(std::move(cpu_name)),
163       features_(std::move(features)),
164       entry_point_name_(std::move(entry_point_name)),
165       relocation_model_(relocation_model) {}
166 
167 CpuAotCompilationOptions::~CpuAotCompilationOptions() = default;
168 
PlatformId() const169 se::Platform::Id CpuAotCompilationOptions::PlatformId() const {
170   return se::host::kHostPlatformId;
171 }
172 
CpuAotCompilationResult(ObjectFileData object_file_data,std::vector<BufferInfo> buffer_infos,int64_t result_buffer_index,std::unique_ptr<HloProfilePrinterData> hlo_profile_printer_data)173 CpuAotCompilationResult::CpuAotCompilationResult(
174     ObjectFileData object_file_data, std::vector<BufferInfo> buffer_infos,
175     int64_t result_buffer_index,
176     std::unique_ptr<HloProfilePrinterData> hlo_profile_printer_data)
177     : object_file_data_(std::move(object_file_data)),
178       buffer_infos_(std::move(buffer_infos)),
179       result_buffer_index_(result_buffer_index),
180       hlo_profile_printer_data_(std::move(hlo_profile_printer_data)) {}
181 
182 CpuAotCompilationResult::~CpuAotCompilationResult() = default;
183 
CpuCompiler()184 CpuCompiler::CpuCompiler() {
185   // Initialize LLVM the first time the CpuCompiler is initialized.
186   static bool llvm_initialized = []() {
187     InitializeLLVMTarget();
188     return true;
189   }();
190   (void)llvm_initialized;
191 }
192 
Compile(std::unique_ptr<HloModuleGroup> module_group,std::vector<std::vector<se::StreamExecutor * >> stream_execs,const CompileOptions & options)193 StatusOr<std::vector<std::unique_ptr<Executable>>> CpuCompiler::Compile(
194     std::unique_ptr<HloModuleGroup> module_group,
195     std::vector<std::vector<se::StreamExecutor*>> stream_execs,
196     const CompileOptions& options) {
197   for (const std::vector<se::StreamExecutor*>& se_vector : stream_execs) {
198     if (se_vector.size() != 1) {
199       return Unimplemented(
200           "Model partitioning not implemented for the CPU compiler");
201     }
202   }
203   return LLVMCompiler::Compile(std::move(module_group), stream_execs, options);
204 }
205 
InitializeLLVMTarget()206 /* static */ void CpuCompiler::InitializeLLVMTarget() {
207   // Initialize LLVM's MC layer for the native target.
208   llvm::InitializeNativeTarget();
209   llvm::InitializeNativeTargetAsmPrinter();
210 }
211 
212 namespace {
213 
214 // LLVM makes certain options configurable only through its command-line
215 // options; it provide the ParseCommandLineOptions function that lets us set
216 // flags at runtime. However, since these flags are global we want to avoid
217 // multiple invocations of the LLVM compilation pipeline with a different set of
218 // flags. Therefore, we only pass command-line flags to LLVM once, before the
219 // first module is compiled.
220 absl::once_flag llvm_command_line_options_initialized;
221 
222 // This visitor records which HLO instructions should have profiling information
223 // recorded.
224 class CollectProfileCandidates : public DfsHloVisitorWithDefault {
225  public:
226   static StatusOr<std::unordered_map<const HloInstruction*, int64>>
GetCandidatesForComputation(const HloComputation & computation,const std::unordered_map<const HloInstruction *,int64> & assigned_indices)227   GetCandidatesForComputation(
228       const HloComputation& computation,
229       const std::unordered_map<const HloInstruction*, int64>&
230           assigned_indices) {
231     std::unordered_map<const HloInstruction*, int64> hlo_to_profile_idx;
232     CollectProfileCandidates profile_candidates_for_computation(
233         &hlo_to_profile_idx, assigned_indices);
234     TF_RETURN_IF_ERROR(computation.Accept(&profile_candidates_for_computation));
235     return hlo_to_profile_idx;
236   }
237 
238  private:
CollectProfileCandidates(std::unordered_map<const HloInstruction *,int64> * hlo_to_profile_idx,const std::unordered_map<const HloInstruction *,int64> & assigned_indices)239   CollectProfileCandidates(
240       std::unordered_map<const HloInstruction*, int64>* hlo_to_profile_idx,
241       const std::unordered_map<const HloInstruction*, int64>& assigned_indices)
242       : hlo_to_profile_idx_(hlo_to_profile_idx),
243         assigned_indices_(assigned_indices) {}
244 
DefaultAction(HloInstruction * hlo_instruction)245   Status DefaultAction(HloInstruction* hlo_instruction) override {
246     hlo_to_profile_idx_->insert(
247         {hlo_instruction, FindOrDie(assigned_indices_, hlo_instruction)});
248     return Status::OK();
249   }
250 
HandleCall(HloInstruction * call)251   Status HandleCall(HloInstruction* call) override {
252     TF_RETURN_IF_ERROR(DefaultAction(call));
253     CollectProfileCandidates candidates_for_call(hlo_to_profile_idx_,
254                                                  assigned_indices_);
255     TF_RETURN_IF_ERROR(call->to_apply()->Accept(&candidates_for_call));
256     return Status::OK();
257   }
258   // Recurse into "conditional" so we can profile inside of it.
HandleConditional(HloInstruction * conditional)259   Status HandleConditional(HloInstruction* conditional) override {
260     TF_RETURN_IF_ERROR(DefaultAction(conditional));
261 
262     CollectProfileCandidates candidates_for_true(hlo_to_profile_idx_,
263                                                  assigned_indices_);
264     TF_RETURN_IF_ERROR(
265         conditional->true_computation()->Accept(&candidates_for_true));
266 
267     CollectProfileCandidates candidates_for_false(hlo_to_profile_idx_,
268                                                   assigned_indices_);
269     TF_RETURN_IF_ERROR(
270         conditional->false_computation()->Accept(&candidates_for_false));
271 
272     return Status::OK();
273   }
274 
275   // Skip constants, there is nothing to profile.
HandleConstant(HloInstruction *)276   Status HandleConstant(HloInstruction*) override { return Status::OK(); }
277   // Skip parameters, they are a simple load.
HandleParameter(HloInstruction *)278   Status HandleParameter(HloInstruction*) override { return Status::OK(); }
279   // It is important to recurse for "while" or else we risk overly coarse
280   // profiling information.
HandleWhile(HloInstruction * xla_while)281   Status HandleWhile(HloInstruction* xla_while) override {
282     TF_RETURN_IF_ERROR(DefaultAction(xla_while));
283 
284     CollectProfileCandidates candidates_for_condition(hlo_to_profile_idx_,
285                                                       assigned_indices_);
286     TF_RETURN_IF_ERROR(
287         xla_while->while_condition()->Accept(&candidates_for_condition));
288 
289     CollectProfileCandidates candidates_for_body(hlo_to_profile_idx_,
290                                                  assigned_indices_);
291     TF_RETURN_IF_ERROR(xla_while->while_body()->Accept(&candidates_for_body));
292 
293     return Status::OK();
294   }
295 
296   std::unordered_map<const HloInstruction*, int64>* hlo_to_profile_idx_;
297   const std::unordered_map<const HloInstruction*, int64>& assigned_indices_;
298 };
299 
300 }  // namespace
301 
RunHloPassesThroughLayoutAssn(HloModule * module,bool,LLVMTargetMachineFeatures * target_machine_features)302 Status CpuCompiler::RunHloPassesThroughLayoutAssn(
303     HloModule* module, bool /*is_aot_compile*/,
304     LLVMTargetMachineFeatures* target_machine_features) {
305   HloPassPipeline pipeline("HLO passes through layout assignment");
306   pipeline.AddInvariantChecker<HloVerifier>(/*layout_sensitive=*/false,
307                                             /*allow_mixed_precision=*/false);
308 
309   pipeline.AddPass<OperandUpcaster>();
310   pipeline.AddPass<ResultCaster>();
311 
312   // Expand random number generation.
313   pipeline.AddPass<RngExpander>();
314   pipeline.AddPass<RngBitGeneratorExpander>(RandomAlgorithm::RNG_PHILOX);
315 
316   // Remove zero-sized HLO from the input so that other passes don't have to
317   // handle it.
318   pipeline.AddPass<ZeroSizedHloElimination>();
319 
320   pipeline.AddPass<DynamicIndexSplitter>();
321 
322   pipeline.AddPass<ConditionalToSelect>();
323   pipeline.AddPass<MapInliner>();
324 
325   pipeline.AddPass<ComparisonExpander>();
326   pipeline.AddPass<CholeskyExpander>();
327   pipeline.AddPass<QrExpander>();
328   pipeline.AddPass<EighExpander>();
329   pipeline.AddPass<TriangularSolveExpander>();
330   pipeline.AddPass<AllGatherDecomposer>();
331   pipeline.AddPass<AllToAllDecomposer>();
332   pipeline.AddPass<ReduceScatterDecomposer>();
333 
334   // Inline computations with a single call site.
335   pipeline.AddPass<CallInliner>(/*single_call_site=*/true);
336   pipeline.AddPass<BatchDotSimplification>();
337   pipeline.AddPass<DotDecomposer>();
338   // Convert BF16 operations to F32 operations so that the CPU backend can
339   // support BF16 operations without directly implementing a BF16 lowering for
340   // most ops.
341   BFloat16Support bf16;
342   pipeline.AddPass<BFloat16Normalization>(&bf16);
343   // After canonicalization, there may be more batch dots that can be
344   // simplified.
345   pipeline.AddPass<BatchDotSimplification>();
346   auto cost_model = [](HloInstruction* conv) {
347     // We need a cost model for CPUs. Currently, do nothing.
348     return false;
349   };
350   pipeline.AddPass<ConvolutionGroupConverter>(
351       cost_model,
352       /*convert_batch_groups_only=*/true);
353   pipeline.AddPass<ConvolutionGroupConverter>(
354       cost_model,
355       /*convert_batch_groups_only=*/false);
356   pipeline.AddPass<BatchNormExpander>(
357       /*rewrite_training_op=*/true,
358       /*rewrite_inference_op=*/true,
359       /*rewrite_grad_op=*/true);
360   pipeline.AddPass<LogisticExpander>(
361       /*expansion_type=*/LogisticExpansionType::kExp);
362   pipeline.AddPass<ConditionalCanonicalizer>();
363   pipeline.AddPass<DynamicPadder>();
364   pipeline.AddPass<ScatterExpander>(ScatterExpander::kEliminateAllScatters);
365   pipeline.AddPass<ConvCanonicalization>(target_machine_features);
366 
367   // Run the following passes to a fixed point.
368   [&pipeline =
369        pipeline.AddPass<HloPassFix<HloPassPipeline>>("simplification")] {
370     pipeline.AddInvariantCheckerDebug<HloVerifier>(
371         /*layout_sensitive=*/false,
372         /*allow_mixed_precision=*/false);
373 
374     pipeline.AddPass<TreeReductionRewriter>();
375     AlgebraicSimplifierOptions options;
376     options.set_enable_dot_strength_reduction(false);
377     pipeline.AddPass<AlgebraicSimplifier>(options);
378     pipeline.AddPass<SortSimplifier>();
379     pipeline.AddPass<HloDCE>();
380     pipeline.AddPass<GatherExpander>(GatherExpander::kEliminateSimpleGathers);
381 
382     // BatchNormExpander can create zero-sized ops, so zero-sized HLO
383     // elimination has to come after that pass.
384     pipeline.AddPass<ZeroSizedHloElimination>();
385 
386     pipeline.AddPass<WhileLoopInvariantCodeMotion>();
387     pipeline.AddPass<TupleSimplifier>();
388     pipeline.AddPass<WhileLoopConstantSinking>();
389     pipeline.AddPass<WhileLoopSimplifier>();
390 
391     // TODO(b/134075051): Re-enable after b/134075051 is fixed.
392     // pipeline.AddPass<SliceSinker>();
393 
394     pipeline.AddPass<HloDCE>();
395     pipeline.AddPass<ReshapeMover>();
396     pipeline.AddPass<HloConstantFolding>();
397     pipeline.AddPass<ConditionalSimplifier>();
398   }();
399 
400   pipeline.AddPass<TopkRewriter>([](const HloSortInstruction* sort, int64_t) {
401     return sort->operand(0)->shape().element_type() == F32;
402   });
403   pipeline.AddPass<IndexedArrayAnalysisPrinterPass>();
404   pipeline.AddPass<TransposeFolding>(
405       [&](const HloInstruction& dot,
406           const TransposeFolding::OperandIndices& candidate_operands) {
407         return DotImplementationCanHandleTranspose(dot,
408                                                    *target_machine_features)
409                    ? candidate_operands
410                    : TransposeFolding::OperandIndices{};
411       },
412       TransposeFolding::NeverFoldTranspose);
413   pipeline.AddPass<HloCSE>(/*is_layout_sensitive=*/false);
414 
415   // Layout assignment uses alias analysis, which requires the call graph to be
416   // flattened.
417   pipeline.AddPass<FlattenCallGraph>();
418   pipeline.AddPass<CpuLayoutAssignment>(
419       module->mutable_entry_computation_layout(), target_machine_features);
420 
421   pipeline.AddPass<CpuInstructionFusion>();
422 
423   return pipeline.Run(module).status();
424 }
425 
RunHloPassesAfterLayoutAssn(HloModule * module,bool is_aot_compile,LLVMTargetMachineFeatures * target_machine_features)426 Status CpuCompiler::RunHloPassesAfterLayoutAssn(
427     HloModule* module, bool is_aot_compile,
428     LLVMTargetMachineFeatures* target_machine_features) {
429   HloPassPipeline pipeline("HLO passes after layout assignment");
430   // After layout assignment, use a layout-sensitive verifier.
431 
432   pipeline.AddPass<HloPassPipeline>("after layout assignment")
433       .AddInvariantCheckerDebug<HloVerifier>(
434           /*layout_sensitive=*/true,
435           /*allow_mixed_precision=*/false);
436 
437   // The LayoutAssignment pass may leave behind kCopy instructions which are
438   // duplicate or NOPs, so remove them with algebraic simplification and CSE.
439   // Run this to a fixed point.
440   [&pipeline = pipeline.AddPass<HloPassFix<HloPassPipeline>>(
441        "simplification after layout assignment")] {
442     pipeline.AddInvariantCheckerDebug<HloVerifier>(
443         /*layout_sensitive=*/true,
444         /*allow_mixed_precision=*/false,
445         LayoutAssignment::InstructionCanChangeLayout);
446     AlgebraicSimplifierOptions options;
447     options.set_is_layout_sensitive(true);
448     options.set_enable_dot_strength_reduction(false);
449     pipeline.AddPass<AlgebraicSimplifier>(options);
450     pipeline.AddPass<HloDCE>();
451     pipeline.AddPass<HloCSE>(/*is_layout_sensitive=*/true);
452   }();
453 
454   // Outline ops in the entry computation into calls to subcomputations.
455   const int max_parallelism =
456       module->config().intra_op_parallelism_threads() > 0
457           ? module->config().intra_op_parallelism_threads()
458           : tensorflow::port::NumSchedulableCPUs();
459   if (!is_aot_compile) {
460     // Run ParallelTaskAssigner to assign parallel tasks to HLOs in module.
461     // Note this is not run for AOT because it would bring in thread pool
462     // and thread synchronization dependencies which would likely increase
463     // binary size (and most AOT applications are single-threaded).
464     // TODO(b/29630486) Support multi-threaded AOT.
465     pipeline.AddPass<ParallelTaskAssigner>(
466         max_parallelism, ShapeSizeBytesFunction(), target_machine_features);
467   }
468   // Copy insertion should be performed immediately before IR emission to
469   // avoid inserting unnecessary copies (later pass adds an instruction which
470   // materializes the value) or missing a necessary copy (later pass removes
471   // an instruction which materializes a value). DCE must be run immediately
472   // before (and sometime after) copy insertion, to avoid dead code from
473   // interfering with the rewrites.
474   pipeline.AddPass<HloDCE>();
475   pipeline.AddPass<CopyInsertion>();
476   pipeline.AddPass<HloDCE>();
477   return pipeline.Run(module).status();
478 }
479 
RunHloPasses(HloModule * module,bool is_aot_compile,llvm::TargetMachine * target_machine)480 Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile,
481                                  llvm::TargetMachine* target_machine) {
482   LLVMTargetMachineFeatures target_machine_features(target_machine);
483   TF_RETURN_IF_ERROR(RunHloPassesThroughLayoutAssn(module, is_aot_compile,
484                                                    &target_machine_features));
485   return RunHloPassesAfterLayoutAssn(module, is_aot_compile,
486                                      &target_machine_features);
487 }
488 
489 namespace {
490 
491 // Align buffers to 16-byte boundaries.
memory_alignment(LogicalBuffer::Color)492 int64 memory_alignment(LogicalBuffer::Color) {
493   return cpu_function_runtime::kMinAlign;
494 }
495 
CompilerTargetOptions(const HloModuleConfig & module_config)496 llvm::TargetOptions CompilerTargetOptions(
497     const HloModuleConfig& module_config) {
498   llvm::TargetOptions target_options;
499   // Always allow FMA fusion. This increases precision instead of decreasing it.
500   target_options.AllowFPOpFusion = llvm::FPOpFusion::Fast;
501   return target_options;
502 }
503 
CodeGenOptLevel(const HloModuleConfig & module_config)504 llvm::CodeGenOpt::Level CodeGenOptLevel(const HloModuleConfig& module_config) {
505   VLOG(2) << "backend_optimization_level: "
506           << module_config.debug_options().xla_backend_optimization_level();
507   switch (module_config.debug_options().xla_backend_optimization_level()) {
508     case 1:
509       return llvm::CodeGenOpt::Less;
510     case 2:
511       return llvm::CodeGenOpt::Default;
512     case 3:
513       return llvm::CodeGenOpt::Aggressive;
514     default:
515       return llvm::CodeGenOpt::None;
516   }
517 }
518 
GetIRModuleHooks(const HloModule & hlo_module,const LLVMCompiler::ModuleHook & user_pre_optimization_hook,const LLVMCompiler::ModuleHook & user_post_optimization_hook)519 std::pair<LLVMCompiler::ModuleHook, LLVMCompiler::ModuleHook> GetIRModuleHooks(
520     const HloModule& hlo_module,
521     const LLVMCompiler::ModuleHook& user_pre_optimization_hook,
522     const LLVMCompiler::ModuleHook& user_post_optimization_hook) {
523   // Create the IR hooks. If applicable, each IR hook does the following:
524   //
525   //  * Calls the user supplied module hook.
526   //  * Writes out the IR to a file in the output directory designated by
527   //    --xla_dump_to
528   const HloModule* hlo_module_ptr = &hlo_module;
529   auto hook = [user_pre_optimization_hook, user_post_optimization_hook,
530                hlo_module_ptr](bool optimized,
531                                const llvm::Module& llvm_module) {
532     const auto& user_hook =
533         !optimized ? user_pre_optimization_hook : user_post_optimization_hook;
534     if (user_hook) {
535       user_hook(llvm_module);
536     }
537     llvm_ir::DumpIrIfEnabled(*hlo_module_ptr, llvm_module, optimized);
538   };
539   return {[hook](const llvm::Module& llvm_module) {
540             return hook(/*optimized=*/false, llvm_module);
541           },
542           [hook](const llvm::Module& llvm_module) {
543             return hook(/*optimized=*/true, llvm_module);
544           }};
545 }
546 
VerifyLlvmModule(const llvm::Module & llvm_module)547 Status VerifyLlvmModule(const llvm::Module& llvm_module) {
548   XLA_SCOPED_LOGGING_TIMER("CpuCompiler - Running LLVM verifier");
549 
550   std::string err;
551   llvm::raw_string_ostream err_stream(err);
552 
553   // verifyModule() returns true if the module is broken.
554   TF_RET_CHECK(!llvm::verifyModule(llvm_module, &err_stream))
555       << "Invalid LLVM IR before optimizations:\n"
556       << err_stream.str()
557       << "\nThis probably indicates a bug in the HLO -> LLVM IR lowering. "
558          "Rerun with --xla_dump_to to get the IR. ";
559   return Status::OK();
560 }
561 
CreateHloProfilingArtifacts(const HloModule & module,std::unordered_map<const HloInstruction *,int64> * instruction_to_profile_idx,std::unordered_map<const HloComputation *,int64> * computation_to_profile_idx,std::unique_ptr<HloProfileIndexMap> * hlo_profile_index_map,std::unique_ptr<HloProfilePrinterData> * hlo_profile_printer_data)562 Status CreateHloProfilingArtifacts(
563     const HloModule& module,
564     std::unordered_map<const HloInstruction*, int64>*
565         instruction_to_profile_idx,
566     std::unordered_map<const HloComputation*, int64>*
567         computation_to_profile_idx,
568     std::unique_ptr<HloProfileIndexMap>* hlo_profile_index_map,
569     std::unique_ptr<HloProfilePrinterData>* hlo_profile_printer_data) {
570   *hlo_profile_index_map = absl::make_unique<HloProfileIndexMap>(module);
571   const HloComputation& entry_computation = *module.entry_computation();
572 
573   TF_ASSIGN_OR_RETURN(
574       *instruction_to_profile_idx,
575       CollectProfileCandidates::GetCandidatesForComputation(
576           entry_computation,
577           (*hlo_profile_index_map)->instruction_to_profile_idx()));
578 
579   auto shape_size_bytes = [](const Shape& shape) {
580     // On the cpu, opaques are pointers.
581     if (shape.IsOpaque()) {
582       return static_cast<int64>(sizeof(void*));
583     }
584     return ShapeUtil::ByteSizeOf(shape, sizeof(void*));
585   };
586 
587   HloCostAnalysis cost_analysis(shape_size_bytes);
588   TF_RETURN_IF_ERROR(entry_computation.Accept(&cost_analysis));
589   *hlo_profile_printer_data = CreateHloProfilePrinterData(
590       **hlo_profile_index_map, cost_analysis, entry_computation.name());
591   *computation_to_profile_idx =
592       (*hlo_profile_index_map)->computation_to_profile_idx();
593 
594   return Status::OK();
595 }
596 
597 }  // namespace
598 
RunHloPasses(std::unique_ptr<HloModule> module,se::StreamExecutor *,const CompileOptions &)599 StatusOr<std::unique_ptr<HloModule>> CpuCompiler::RunHloPasses(
600     std::unique_ptr<HloModule> module, se::StreamExecutor* /*stream_exec*/,
601     const CompileOptions& /*options*/) {
602   std::unique_ptr<llvm::TargetMachine> jit_target_machine =
603       SimpleOrcJIT::InferTargetMachineForJIT(
604           CompilerTargetOptions(module->config()),
605           CodeGenOptLevel(module->config()));
606 
607   TF_RETURN_IF_ERROR(RunHloPasses(module.get(), /*is_aot_compile=*/false,
608                                   jit_target_machine.get()));
609   return std::move(module);
610 }
611 
612 StatusOr<
613     std::tuple<std::unique_ptr<HloModule>, std::unique_ptr<BufferAssignment>>>
RunHloPassesAndBufferAssignement(std::unique_ptr<HloModule> module,se::StreamExecutor * executor,bool optimize,const CompileOptions & options)614 CpuCompiler::RunHloPassesAndBufferAssignement(std::unique_ptr<HloModule> module,
615                                               se::StreamExecutor* executor,
616                                               bool optimize,
617                                               const CompileOptions& options) {
618   if (optimize) {
619     TF_ASSIGN_OR_RETURN(module,
620                         RunHloPasses(std::move(module), executor, options));
621   }
622 
623   // Select an order for emitting the HLO instructions for each computation.
624   // Using this sequence enables tighter buffer liveness analysis and reduced
625   // memory usage (as compared to using DependencyHloOrdering).
626   TF_ASSIGN_OR_RETURN(HloSchedule schedule,
627                       ScheduleModule(module.get(), BufferSizeBytesFunction(),
628                                      ComputationSchedulerToModuleScheduler(
629                                          DFSMemoryScheduler)));
630 
631   // Run buffer allocation on the HLO graph.
632   TF_ASSIGN_OR_RETURN(
633       std::unique_ptr<BufferAssignment> assignment,
634       BufferAssigner::Run(module.get(),
635                           absl::make_unique<SequentialHloOrdering>(schedule),
636                           BufferSizeBytesFunction(), memory_alignment,
637                           /*allocate_buffers_for_constants=*/true));
638 
639   return std::make_tuple(std::move(module), std::move(assignment));
640 }
641 
642 namespace {
643 
644 // Post-compilation callback functor for use by SimpleOrcJIT.
645 //
646 // Dumps machine code if dumping is enabled for the module.
647 struct OrcJITPostCompilationHook {
648   // Gets an std::function that implements this hook.
Createxla::cpu::__anond5cc82ac0e11::OrcJITPostCompilationHook649   static std::function<void(const llvm::object::ObjectFile& obj_file)> Create(
650       const HloModule* module) {
651     // This struct is not copyable, but std::functions must be.  So to create an
652     // std::function out of this struct, we have to wrap it in a shared_ptr.
653     auto wrapped = std::make_shared<OrcJITPostCompilationHook>(module);
654     return [wrapped](const llvm::object::ObjectFile& obj_file) {
655       (*wrapped)(obj_file);
656     };
657   }
658 
659   // Constructor can't be private because we want to call it from
660   // std::make_shared, but users should call Create() instead.
OrcJITPostCompilationHookxla::cpu::__anond5cc82ac0e11::OrcJITPostCompilationHook661   explicit OrcJITPostCompilationHook(const HloModule* module)
662       : module(module) {}
663 
664  private:
operator ()xla::cpu::__anond5cc82ac0e11::OrcJITPostCompilationHook665   void operator()(const llvm::object::ObjectFile& obj_file) {
666     if (!DumpingEnabledForHloModule(*module)) {
667       return;
668     }
669     DumpToFileInDir(*module, /*file_prefix=*/"", /*file_suffix=*/"o",
670                     absl::string_view(obj_file.getData().data(),
671                                       obj_file.getData().size()));
672   }
673 
674   const HloModule* module;
675 };
676 
677 }  // namespace
678 
RunBackend(std::unique_ptr<HloModule> module,se::StreamExecutor * stream_exec,const CompileOptions & options)679 StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend(
680     std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
681     const CompileOptions& options) {
682   VLOG(1) << "Compiling: " << module->name();
683   XLA_SCOPED_LOGGING_TIMER(
684       absl::StrFormat("Compiling [%s] for CPU using JIT", module->name()));
685   std::string slow_compilation_msg =
686       absl::StrCat("Compiling module ", module->name());
687   auto slow_compile_alarm = SlowCompilationAlarm(slow_compilation_msg);
688 
689   absl::call_once(llvm_command_line_options_initialized,
690                   &llvm_ir::InitializeLLVMCommandLineOptions, module->config());
691 
692   ModuleHook pre_optimization_ir_hook;
693   ModuleHook post_optimization_ir_hook;
694   std::tie(pre_optimization_ir_hook, post_optimization_ir_hook) =
695       GetIRModuleHooks(*module, user_pre_optimization_hook_,
696                        user_post_optimization_hook_);
697 
698   // Compile must be thread-safe so create a new LLVM context for the module.
699   mlir::MLIRContext mlir_context;
700   LoadMLIRDialects(mlir_context);
701   auto llvm_context = std::make_unique<llvm::LLVMContext>();
702   auto llvm_module =
703       absl::make_unique<llvm::Module>("__compute_module", *llvm_context);
704 
705   auto jit = SimpleOrcJIT::Create(
706       CompilerTargetOptions(module->config()),
707       CodeGenOptLevel(module->config()),
708       options::OptimizeForSizeRequested(module->config()),
709       module->config().debug_options().xla_llvm_disable_expensive_passes(),
710       llvm_ir::GetCpuFastMathFlags(module->config()), pre_optimization_ir_hook,
711       post_optimization_ir_hook,
712       OrcJITPostCompilationHook::Create(module.get()));
713   if (!jit) {
714     return InternalError("Creating JIT failed: %s",
715                          llvm::toString(jit.takeError()));
716   }
717   llvm_module->setDataLayout((*jit)->data_layout());
718   llvm_module->setTargetTriple((*jit)->target_triple().getTriple());
719 
720   HloComputation* entry_computation = module->entry_computation();
721   std::unordered_map<const HloInstruction*, int64> instruction_to_profile_idx;
722   std::unordered_map<const HloComputation*, int64> computation_to_profile_idx;
723   std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map;
724   std::unique_ptr<HloProfilePrinterData> hlo_profile_printer_data;
725   if (module->config().hlo_profiling_enabled()) {
726     TF_RETURN_IF_ERROR(CreateHloProfilingArtifacts(
727         *module, &instruction_to_profile_idx, &computation_to_profile_idx,
728         &hlo_profile_index_map, &hlo_profile_printer_data));
729   }
730 
731   // Cache these flags here since we'll want to access them after the module's
732   // ownership is std::moved.
733   const bool embed_ir_in_executable =
734       module->config().debug_options().xla_embed_ir_in_executable();
735 
736   // Select an order for emitting the HLO instructions for each
737   // computation. Using this sequence enables tighter buffer liveness analysis
738   // and reduced memory usage (as compared to using DependencyHloOrdering).
739   TF_ASSIGN_OR_RETURN(HloSchedule schedule,
740                       ScheduleModule(module.get(), BufferSizeBytesFunction(),
741                                      ComputationSchedulerToModuleScheduler(
742                                          DFSMemoryScheduler)));
743 
744   // Run buffer allocation on the HLO graph.
745   TF_ASSIGN_OR_RETURN(
746       std::unique_ptr<BufferAssignment> assignment,
747       BufferAssigner::Run(module.get(),
748                           absl::make_unique<SequentialHloOrdering>(schedule),
749                           BufferSizeBytesFunction(), memory_alignment,
750                           /*allocate_buffers_for_constants=*/true));
751   DumpHloModuleIfEnabled(*module, *assignment, "cpu_after_optimizations");
752 
753   // Each computation is a single function.  Emit all embedded computations
754   // before the entry computation. The order of computations returned from
755   // GetEmbeddedComputations guarantees that a called computation occurs
756   // before a caller computation.
757 
758   LLVMTargetMachineFeatures target_machine_features((*jit)->target_machine());
759   IrEmitter ir_emitter(&mlir_context, *module, *assignment, llvm_module.get(),
760                        std::move(instruction_to_profile_idx),
761                        std::move(computation_to_profile_idx),
762                        &target_machine_features,
763 #ifdef MEMORY_SANITIZER
764                        /*emit_code_for_msan=*/true
765 #else
766                        /*emit_code_for_msan=*/false
767 #endif
768   );
769 
770   TF_RETURN_IF_ERROR(ir_emitter.EmitConstantGlobals());
771 
772   for (auto embedded_computation :
773        entry_computation->MakeEmbeddedComputationsList()) {
774     if (embedded_computation->IsFusionComputation()) {
775       continue;
776     }
777     TF_RETURN_IF_ERROR(
778         ir_emitter
779             .EmitComputation(
780                 embedded_computation, embedded_computation->name(),
781                 /*is_top_level_computation=*/false,
782                 schedule.sequence(embedded_computation).instructions())
783             .status());
784   }
785   string function_name_prefix = entry_computation->name().empty()
786                                     ? "__compute"
787                                     : entry_computation->name();
788   TF_ASSIGN_OR_RETURN(llvm::Function * entry_function,
789                       ir_emitter.EmitComputation(
790                           entry_computation, function_name_prefix,
791                           /*is_top_level_computation=*/true,
792                           schedule.sequence(entry_computation).instructions()));
793 
794   string function_name = [&]() {
795     llvm::SmallVector<char, 40> function_name_vector;
796     llvm::Mangler::getNameWithPrefix(
797         function_name_vector, entry_function->getName(), (*jit)->data_layout());
798     return string(function_name_vector.begin(), function_name_vector.end());
799   }();
800 
801   string ir_module_string;
802   if (embed_ir_in_executable) {
803     ir_module_string = llvm_ir::DumpModuleToString(*llvm_module);
804   }
805 
806   TF_RETURN_IF_ERROR(VerifyLlvmModule(*llvm_module));
807 
808   // JIT compile the LLVM IR module to in-memory machine code.
809   llvm::orc::ThreadSafeModule thread_safe_module(std::move(llvm_module),
810                                                  std::move(llvm_context));
811   cantFail((*jit)->AddModule(std::move(thread_safe_module)));
812 
813   auto cpu_executable = absl::make_unique<CpuExecutable>(
814       std::move(*jit), std::move(assignment), std::move(module), function_name,
815       std::move(hlo_profile_printer_data), std::move(hlo_profile_index_map));
816 
817   if (embed_ir_in_executable) {
818     cpu_executable->set_ir_module_string(ir_module_string);
819   }
820 
821   // Dump computation proto state and buffer assignment for debug and test, if
822   // dump is enabled.
823   if (DumpingEnabledForHloModule(cpu_executable->module())) {
824     auto hlo_proto = absl::make_unique<HloProto>();
825     *hlo_proto->mutable_hlo_module() = cpu_executable->module().ToProto();
826     *hlo_proto->mutable_buffer_assignment() =
827         cpu_executable->buffer_assignment().ToProto();
828     cpu_executable->set_hlo_proto(std::move(hlo_proto));
829   }
830 
831   cpu_executable->set_debug_info(
832       cpu_executable->buffer_assignment().GetStats().ToString());
833   VLOG(1) << "Compilation finished";
834   return std::unique_ptr<Executable>(std::move(cpu_executable));
835 }
836 
837 StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group,const AotCompilationOptions & aot_options)838 CpuCompiler::CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group,
839                                 const AotCompilationOptions& aot_options) {
840   TF_RET_CHECK(!module_group->empty());
841   std::vector<std::unique_ptr<HloModule>> modules =
842       module_group->ConsumeModules();
843 
844   absl::call_once(llvm_command_line_options_initialized,
845                   &llvm_ir::InitializeLLVMCommandLineOptions,
846                   modules[0]->config());
847 
848   // We can pass just one llvm::TargetOptions when we compile the LLVM module,
849   // so we bail if the configs have conflicting flags. At the moment, the only
850   // flags that need to be consistent are for fast-math.
851   for (const auto& fn_and_name :
852        {std::make_pair(&DebugOptions::xla_cpu_enable_fast_math,
853                        "xla_cpu_enable_fast_math"),
854         std::make_pair(&DebugOptions::xla_cpu_fast_math_honor_infs,
855                        "xla_cpu_fast_math_honor_infs"),
856         std::make_pair(&DebugOptions::xla_cpu_fast_math_honor_nans,
857                        "xla_cpu_fast_math_honor_nans")}) {
858     // This only works because each of the method pointers above returns a bool.
859     // Otherwise we'd have to do some template magic.
860     const auto& field_method_ptr = fn_and_name.first;
861     const auto& field_name = fn_and_name.second;
862     bool first_module_val =
863         (modules[0]->config().debug_options().*field_method_ptr)();
864     for (int64_t i = 0; i < modules.size(); ++i) {
865       bool cur_module_val =
866           (modules[i]->config().debug_options().*field_method_ptr)();
867       if (first_module_val != cur_module_val) {
868         return InvalidArgument(
869             "All HLO module configs must have the same value for %s, but "
870             "module 0 and %d have different values (%d vs %d).",
871             field_name, i, first_module_val, cur_module_val);
872       }
873     }
874   }
875 
876   if (aot_options.PlatformId() != se::host::kHostPlatformId) {
877     return InvalidArgument("Incompatible AOT compilation platform");
878   }
879   const CpuAotCompilationOptions& options =
880       static_cast<const CpuAotCompilationOptions&>(aot_options);
881   llvm::Triple triple(llvm::Triple::normalize(options.triple()));
882   std::string error;
883   const llvm::Target* target =
884       llvm::TargetRegistry::lookupTarget(triple.getTriple(), error);
885   if (target == nullptr) {
886     return InternalError("TargetRegistry::lookupTarget failed: %s", error);
887   }
888 
889   llvm::Reloc::Model reloc_model = llvm::Reloc::Static;
890   llvm::PICLevel::Level pic_level = llvm::PICLevel::NotPIC;
891   llvm::PIELevel::Level pie_level = llvm::PIELevel::Default;
892   switch (options.relocation_model()) {
893     case CpuAotCompilationOptions::RelocationModel::Static:
894       reloc_model = llvm::Reloc::Static;
895       pic_level = llvm::PICLevel::NotPIC;
896       pie_level = llvm::PIELevel::Default;
897       break;
898     case CpuAotCompilationOptions::RelocationModel::SmallPic:
899       reloc_model = llvm::Reloc::PIC_;
900       pic_level = llvm::PICLevel::SmallPIC;
901       pie_level = llvm::PIELevel::Default;
902       break;
903     case CpuAotCompilationOptions::RelocationModel::BigPic:
904       reloc_model = llvm::Reloc::PIC_;
905       pic_level = llvm::PICLevel::BigPIC;
906       pie_level = llvm::PIELevel::Default;
907       break;
908     case CpuAotCompilationOptions::RelocationModel::SmallPie:
909       reloc_model = llvm::Reloc::PIC_;
910       pic_level = llvm::PICLevel::SmallPIC;
911       pie_level = llvm::PIELevel::Small;
912       break;
913     case CpuAotCompilationOptions::RelocationModel::BigPie:
914       reloc_model = llvm::Reloc::PIC_;
915       pic_level = llvm::PICLevel::BigPIC;
916       pie_level = llvm::PIELevel::Large;
917       break;
918   }
919   llvm::CodeGenOpt::Level opt_level = CodeGenOptLevel(modules[0]->config());
920   std::unique_ptr<llvm::TargetMachine> target_machine =
921       absl::WrapUnique(target->createTargetMachine(
922           triple.getTriple(), options.cpu_name(), options.features(),
923           CompilerTargetOptions(modules[0]->config()), reloc_model, llvm::None,
924           opt_level));
925 
926   // Compile must be thread-safe so create a new LLVM context for the module.
927   mlir::MLIRContext mlir_context;
928   LoadMLIRDialects(mlir_context);
929   llvm::LLVMContext llvm_context;
930   llvm::Module llvm_module("__compute_module", llvm_context);
931   llvm_module.setDataLayout(target_machine->createDataLayout());
932   llvm_module.setTargetTriple(triple.getTriple());
933   if (pic_level != llvm::PICLevel::NotPIC) {
934     llvm_module.setPICLevel(pic_level);
935   }
936   if (pie_level != llvm::PIELevel::Default) {
937     llvm_module.setPIELevel(pie_level);
938   }
939 
940   std::vector<std::unique_ptr<AotCompilationResult>> results;
941   for (size_t i = 0; i < modules.size(); ++i) {
942     HloModule* module = modules[i].get();
943     VLOG(1) << "Compiling ahead-of-time: " << module->name();
944 
945     TF_RETURN_IF_ERROR(
946         RunHloPasses(module, /*is_aot_compile=*/true, target_machine.get()));
947 
948     TF_ASSIGN_OR_RETURN(HloSchedule schedule,
949                         ScheduleModule(module, BufferSizeBytesFunction()));
950 
951     // Run buffer analysis on the HLO graph. This analysis figures out which
952     // temporary buffers are required to run the computation.
953     TF_ASSIGN_OR_RETURN(
954         std::unique_ptr<BufferAssignment> assignment,
955         BufferAssigner::Run(module,
956                             absl::make_unique<SequentialHloOrdering>(schedule),
957                             BufferSizeBytesFunction(), memory_alignment,
958                             /*allocate_buffers_for_constants=*/true));
959     // BufferAssignment::ToString() includes a header, so no need for us to
960     // print one ourselves.
961     if (DumpingEnabledForHloModule(*module)) {
962       DumpToFileInDirOrStdout(*module, "", "buffer_assignment",
963                               assignment->ToString());
964     }
965     DumpHloModuleIfEnabled(*module, *assignment, "cpu_after_optimizations");
966 
967     std::unordered_map<const HloInstruction*, int64> instruction_to_profile_idx;
968     std::unordered_map<const HloComputation*, int64> computation_to_profile_idx;
969     std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map;
970     std::unique_ptr<HloProfilePrinterData> hlo_profile_printer_data;
971 
972     if (module->config().hlo_profiling_enabled()) {
973       TF_RETURN_IF_ERROR(CreateHloProfilingArtifacts(
974           *module, &instruction_to_profile_idx, &computation_to_profile_idx,
975           &hlo_profile_index_map, &hlo_profile_printer_data));
976     }
977 
978     LLVMTargetMachineFeatures target_machine_features(target_machine.get());
979     IrEmitter ir_emitter(&mlir_context, *module, *assignment, &llvm_module,
980                          std::move(instruction_to_profile_idx),
981                          std::move(computation_to_profile_idx),
982                          &target_machine_features,
983                          // TODO(b/66051036): Run full msan for AOT.
984                          /*emit_code_for_msan=*/false);
985 
986     TF_RETURN_IF_ERROR(ir_emitter.EmitConstantGlobals());
987 
988     HloComputation* computation = module->entry_computation();
989     for (auto embedded_computation :
990          computation->MakeEmbeddedComputationsList()) {
991       if (embedded_computation->IsFusionComputation()) {
992         continue;
993       }
994       TF_RETURN_IF_ERROR(
995           ir_emitter
996               .EmitComputation(
997                   embedded_computation, embedded_computation->name(),
998                   /*is_top_level_computation=*/false,
999                   schedule.sequence(embedded_computation).instructions())
1000               .status());
1001     }
1002     const string& entry_point_name = options.entry_point_name();
1003     TF_ASSIGN_OR_RETURN(llvm::Function * entry_function,
1004                         ir_emitter.EmitComputation(
1005                             computation, entry_point_name,
1006                             /*is_top_level_computation=*/true,
1007                             schedule.sequence(computation).instructions()));
1008 
1009     CHECK(entry_function->getName() == entry_point_name);
1010 
1011     ModuleHook pre_optimization_ir_hook;
1012     ModuleHook post_optimization_ir_hook;
1013     std::tie(pre_optimization_ir_hook, post_optimization_ir_hook) =
1014         GetIRModuleHooks(*module, user_pre_optimization_hook_,
1015                          user_post_optimization_hook_);
1016 
1017     // Run the LLVM verifier over the unoptimized LLVM IR.  If it fails, run the
1018     // pre-optimization IR dump hook before returning.
1019     {
1020       Status verify_status = VerifyLlvmModule(llvm_module);
1021       if (!verify_status.ok() && pre_optimization_ir_hook) {
1022         pre_optimization_ir_hook(llvm_module);
1023       }
1024       TF_RETURN_IF_ERROR(verify_status);
1025     }
1026 
1027     auto post_codegen_hook = [&](const llvm::object::ObjectFile& obj_file) {
1028       if (!DumpingEnabledForHloModule(*module)) {
1029         return;
1030       }
1031       DumpToFileInDir(*module, /*file_prefix=*/"", /*file_suffix=*/"o",
1032                       absl::string_view(obj_file.getData().data(),
1033                                         obj_file.getData().size()));
1034     };
1035 
1036     CompilerFunctor compiler_functor(
1037         target_machine.get(), opt_level,
1038         options::OptimizeForSizeRequested(module->config()),
1039         module->config().debug_options().xla_llvm_disable_expensive_passes(),
1040         llvm_ir::GetCpuFastMathFlags(module->config()),
1041         pre_optimization_ir_hook, post_optimization_ir_hook, post_codegen_hook);
1042     std::unique_ptr<llvm::MemoryBuffer> object_file =
1043         cantFail(compiler_functor(llvm_module));
1044     ObjectFileData object_file_data(object_file->getBufferStart(),
1045                                     object_file->getBufferEnd());
1046 
1047     std::vector<BufferInfo> buffer_infos =
1048         CreateBufferInfosFromBufferAssignment(*assignment);
1049 
1050     TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice result_slice,
1051                         assignment->GetUniqueTopLevelOutputSlice());
1052 
1053     results.emplace_back(absl::make_unique<CpuAotCompilationResult>(
1054         std::move(object_file_data), std::move(buffer_infos),
1055         result_slice.index(), std::move(hlo_profile_printer_data)));
1056   }
1057 
1058   VLOG(1) << "Compilation finished";
1059   return std::move(results);
1060 }
1061 
PlatformId() const1062 se::Platform::Id CpuCompiler::PlatformId() const {
1063   return se::host::kHostPlatformId;
1064 }
1065 
ShapeSizeBytesFunction() const1066 HloCostAnalysis::ShapeSizeFunction CpuCompiler::ShapeSizeBytesFunction() const {
1067   return CpuExecutable::ShapeSizeBytes;
1068 }
1069 
1070 }  // namespace cpu
1071 }  // namespace xla
1072 
InitModule()1073 static bool InitModule() {
1074   xla::Compiler::RegisterCompilerFactory(
1075       stream_executor::host::kHostPlatformId,
1076       []() { return absl::make_unique<xla::cpu::CpuCompiler>(); });
1077   return true;
1078 }
1079 static bool module_initialized = InitModule();
1080