1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h"
17
18 #include <stddef.h>
19 #include <string.h>
20
21 #include <map>
22 #include <memory>
23 #include <string>
24 #include <unordered_map>
25 #include <utility>
26 #include <vector>
27
28 // IWYU pragma: no_include "llvm/Config/Disassemblers.def.inc"
29 // IWYU pragma: no_include "llvm/Config/Targets.def.inc"
30 #include "absl/base/call_once.h"
31 #include "absl/memory/memory.h"
32 #include "absl/strings/str_cat.h"
33 #include "llvm/ADT/StringRef.h"
34 #include "llvm/ADT/Triple.h"
35 #include "llvm/IR/Function.h"
36 #include "llvm/IR/LLVMContext.h"
37 #include "llvm/IR/Mangler.h"
38 #include "llvm/IR/Module.h"
39 #include "llvm/IR/Verifier.h"
40 #include "llvm/Object/ObjectFile.h"
41 #include "llvm/Support/CommandLine.h"
42 #include "llvm/Support/Error.h"
43 #include "llvm/Support/TargetRegistry.h"
44 #include "llvm/Support/TargetSelect.h"
45 #include "llvm/Target/TargetMachine.h"
46 #include "llvm/Target/TargetOptions.h"
47 #include "mlir/Dialect/Affine/IR/AffineOps.h" // from @llvm-project
48 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" // from @llvm-project
49 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" // from @llvm-project
50 #include "mlir/Dialect/SCF/SCF.h" // from @llvm-project
51 #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
52 #include "mlir/Dialect/Vector/VectorOps.h" // from @llvm-project
53 #include "mlir/InitAllDialects.h" // from @llvm-project
54 #include "tensorflow/compiler/xla/cpu_function_runtime.h"
55 #include "tensorflow/compiler/xla/literal.h"
56 #include "tensorflow/compiler/xla/map_util.h"
57 #include "tensorflow/compiler/xla/protobuf_util.h"
58 #include "tensorflow/compiler/xla/service/algebraic_simplifier.h"
59 #include "tensorflow/compiler/xla/service/all_gather_decomposer.h"
60 #include "tensorflow/compiler/xla/service/all_to_all_decomposer.h"
61 #include "tensorflow/compiler/xla/service/batch_dot_simplification.h"
62 #include "tensorflow/compiler/xla/service/batchnorm_expander.h"
63 #include "tensorflow/compiler/xla/service/buffer_assignment.h"
64 #include "tensorflow/compiler/xla/service/call_inliner.h"
65 #include "tensorflow/compiler/xla/service/cholesky_expander.h"
66 #include "tensorflow/compiler/xla/service/comparison_expander.h"
67 #include "tensorflow/compiler/xla/service/conditional_canonicalizer.h"
68 #include "tensorflow/compiler/xla/service/conditional_simplifier.h"
69 #include "tensorflow/compiler/xla/service/conditional_to_select.h"
70 #include "tensorflow/compiler/xla/service/convolution_group_converter.h"
71 #include "tensorflow/compiler/xla/service/copy_insertion.h"
72 #include "tensorflow/compiler/xla/service/cpu/buffer_info_util.h"
73 #include "tensorflow/compiler/xla/service/cpu/compiler_functor.h"
74 #include "tensorflow/compiler/xla/service/cpu/conv_canonicalization.h"
75 #include "tensorflow/compiler/xla/service/cpu/cpu_executable.h"
76 #include "tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.h"
77 #include "tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.h"
78 #include "tensorflow/compiler/xla/service/cpu/cpu_options.h"
79 #include "tensorflow/compiler/xla/service/cpu/dot_op_emitter.h"
80 #include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h"
81 #include "tensorflow/compiler/xla/service/cpu/ir_emitter.h"
82 #include "tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h"
83 #include "tensorflow/compiler/xla/service/cpu/simple_orc_jit.h"
84 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
85 #include "tensorflow/compiler/xla/service/dot_decomposer.h"
86 #include "tensorflow/compiler/xla/service/dump.h"
87 #include "tensorflow/compiler/xla/service/dynamic_index_splitter.h"
88 #include "tensorflow/compiler/xla/service/dynamic_padder.h"
89 #include "tensorflow/compiler/xla/service/flatten_call_graph.h"
90 #include "tensorflow/compiler/xla/service/gather_expander.h"
91 #include "tensorflow/compiler/xla/service/hlo.pb.h"
92 #include "tensorflow/compiler/xla/service/hlo_computation.h"
93 #include "tensorflow/compiler/xla/service/hlo_constant_folding.h"
94 #include "tensorflow/compiler/xla/service/hlo_cse.h"
95 #include "tensorflow/compiler/xla/service/hlo_dce.h"
96 #include "tensorflow/compiler/xla/service/hlo_element_type_converter.h"
97 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
98 #include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h"
99 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
100 #include "tensorflow/compiler/xla/service/hlo_ordering.h"
101 #include "tensorflow/compiler/xla/service/hlo_pass_fix.h"
102 #include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h"
103 #include "tensorflow/compiler/xla/service/hlo_proto_util.h"
104 #include "tensorflow/compiler/xla/service/hlo_subcomputation_unification.h"
105 #include "tensorflow/compiler/xla/service/hlo_verifier.h"
106 #include "tensorflow/compiler/xla/service/indexed_array_analysis.h"
107 #include "tensorflow/compiler/xla/service/llvm_compiler.h"
108 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
109 #include "tensorflow/compiler/xla/service/logistic_expander.h"
110 #include "tensorflow/compiler/xla/service/map_inliner.h"
111 #include "tensorflow/compiler/xla/service/operand_upcaster.h"
112 #include "tensorflow/compiler/xla/service/qr_expander.h"
113 #include "tensorflow/compiler/xla/service/reshape_mover.h"
114 #include "tensorflow/compiler/xla/service/rng_bit_generator_expander.h"
115 #include "tensorflow/compiler/xla/service/rng_expander.h"
116 #include "tensorflow/compiler/xla/service/scatter_expander.h"
117 #include "tensorflow/compiler/xla/service/slice_sinker.h"
118 #include "tensorflow/compiler/xla/service/slow_operation_alarm.h"
119 #include "tensorflow/compiler/xla/service/sort_simplifier.h"
120 #include "tensorflow/compiler/xla/service/topk_rewriter.h"
121 #include "tensorflow/compiler/xla/service/transpose_folding.h"
122 #include "tensorflow/compiler/xla/service/tree_reduction_rewriter.h"
123 #include "tensorflow/compiler/xla/service/triangular_solve_expander.h"
124 #include "tensorflow/compiler/xla/service/tuple_simplifier.h"
125 #include "tensorflow/compiler/xla/service/while_loop_constant_sinking.h"
126 #include "tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h"
127 #include "tensorflow/compiler/xla/service/while_loop_simplifier.h"
128 #include "tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h"
129 #include "tensorflow/compiler/xla/status_macros.h"
130 #include "tensorflow/compiler/xla/statusor.h"
131 #include "tensorflow/compiler/xla/types.h"
132 #include "tensorflow/compiler/xla/util.h"
133 #include "tensorflow/compiler/xla/xla_data.pb.h"
134 #include "tensorflow/core/platform/dynamic_annotations.h"
135
136 namespace {
137
138 // We need to explicitly load all the dialects we will involved in emitting the
139 // IR. This is only needed because of how MLIR is bolted into XLA and does not
140 // make use of the MLIR infrastructure (like using a proper pass pipeline).
141 // Hopefully this will all go away at some point in favor of a better
142 // integration.
LoadMLIRDialects(mlir::MLIRContext & context)143 void LoadMLIRDialects(mlir::MLIRContext& context) {
144 context.loadDialect<mlir::linalg::LinalgDialect, mlir::scf::SCFDialect,
145 mlir::vector::VectorDialect, mlir::StandardOpsDialect,
146 mlir::AffineDialect>();
147 }
148
149 } // namespace
150
151 namespace xla {
152 namespace cpu {
153 using BufferInfo = cpu_function_runtime::BufferInfo;
154
CpuAotCompilationOptions(string triple,string cpu_name,string features,string entry_point_name,RelocationModel relocation_model)155 CpuAotCompilationOptions::CpuAotCompilationOptions(
156 string triple, string cpu_name, string features, string entry_point_name,
157 RelocationModel relocation_model)
158 : triple_(std::move(triple)),
159 cpu_name_(std::move(cpu_name)),
160 features_(std::move(features)),
161 entry_point_name_(std::move(entry_point_name)),
162 relocation_model_(relocation_model) {}
163
164 CpuAotCompilationOptions::~CpuAotCompilationOptions() = default;
165
PlatformId() const166 se::Platform::Id CpuAotCompilationOptions::PlatformId() const {
167 return se::host::kHostPlatformId;
168 }
169
CpuAotCompilationResult(ObjectFileData object_file_data,std::vector<BufferInfo> buffer_infos,int64 result_buffer_index,std::unique_ptr<HloProfilePrinterData> hlo_profile_printer_data)170 CpuAotCompilationResult::CpuAotCompilationResult(
171 ObjectFileData object_file_data, std::vector<BufferInfo> buffer_infos,
172 int64 result_buffer_index,
173 std::unique_ptr<HloProfilePrinterData> hlo_profile_printer_data)
174 : object_file_data_(std::move(object_file_data)),
175 buffer_infos_(std::move(buffer_infos)),
176 result_buffer_index_(result_buffer_index),
177 hlo_profile_printer_data_(std::move(hlo_profile_printer_data)) {}
178
179 CpuAotCompilationResult::~CpuAotCompilationResult() = default;
180
CpuCompiler()181 CpuCompiler::CpuCompiler() {
182 // Initialize LLVM the first time the CpuCompiler is initialized.
183 static bool llvm_initialized = []() {
184 InitializeLLVMTarget();
185 return true;
186 }();
187 (void)llvm_initialized;
188 }
189
Compile(std::unique_ptr<HloModuleGroup> module_group,std::vector<std::vector<se::StreamExecutor * >> stream_execs,const CompileOptions & options)190 StatusOr<std::vector<std::unique_ptr<Executable>>> CpuCompiler::Compile(
191 std::unique_ptr<HloModuleGroup> module_group,
192 std::vector<std::vector<se::StreamExecutor*>> stream_execs,
193 const CompileOptions& options) {
194 for (const std::vector<se::StreamExecutor*>& se_vector : stream_execs) {
195 if (se_vector.size() != 1) {
196 return Unimplemented(
197 "Model partitioning not implemented for the CPU compiler");
198 }
199 }
200 return LLVMCompiler::Compile(std::move(module_group), stream_execs, options);
201 }
202
InitializeLLVMTarget()203 /* static */ void CpuCompiler::InitializeLLVMTarget() {
204 // Initialize LLVM's MC layer for the native target.
205 llvm::InitializeNativeTarget();
206 llvm::InitializeNativeTargetAsmPrinter();
207 }
208
209 namespace {
210
211 // LLVM makes certain options configurable only through its command-line
212 // options; it provide the ParseCommandLineOptions function that lets us set
213 // flags at runtime. However, since these flags are global we want to avoid
214 // multiple invocations of the LLVM compilation pipeline with a different set of
215 // flags. Therefore, we only pass command-line flags to LLVM once, before the
216 // first module is compiled.
217 absl::once_flag llvm_command_line_options_initialized;
218
219 // This visitor records which HLO instructions should have profiling information
220 // recorded.
221 class CollectProfileCandidates : public DfsHloVisitorWithDefault {
222 public:
223 static StatusOr<std::unordered_map<const HloInstruction*, int64>>
GetCandidatesForComputation(const HloComputation & computation,const std::unordered_map<const HloInstruction *,int64> & assigned_indices)224 GetCandidatesForComputation(
225 const HloComputation& computation,
226 const std::unordered_map<const HloInstruction*, int64>&
227 assigned_indices) {
228 std::unordered_map<const HloInstruction*, int64> hlo_to_profile_idx;
229 CollectProfileCandidates profile_candidates_for_computation(
230 &hlo_to_profile_idx, assigned_indices);
231 TF_RETURN_IF_ERROR(computation.Accept(&profile_candidates_for_computation));
232 return hlo_to_profile_idx;
233 }
234
235 private:
CollectProfileCandidates(std::unordered_map<const HloInstruction *,int64> * hlo_to_profile_idx,const std::unordered_map<const HloInstruction *,int64> & assigned_indices)236 CollectProfileCandidates(
237 std::unordered_map<const HloInstruction*, int64>* hlo_to_profile_idx,
238 const std::unordered_map<const HloInstruction*, int64>& assigned_indices)
239 : hlo_to_profile_idx_(hlo_to_profile_idx),
240 assigned_indices_(assigned_indices) {}
241
DefaultAction(HloInstruction * hlo_instruction)242 Status DefaultAction(HloInstruction* hlo_instruction) override {
243 hlo_to_profile_idx_->insert(
244 {hlo_instruction, FindOrDie(assigned_indices_, hlo_instruction)});
245 return Status::OK();
246 }
247
HandleCall(HloInstruction * call)248 Status HandleCall(HloInstruction* call) override {
249 TF_RETURN_IF_ERROR(DefaultAction(call));
250 CollectProfileCandidates candidates_for_call(hlo_to_profile_idx_,
251 assigned_indices_);
252 TF_RETURN_IF_ERROR(call->to_apply()->Accept(&candidates_for_call));
253 return Status::OK();
254 }
255
256 // Skip constants, there is nothing to profile.
HandleConstant(HloInstruction *)257 Status HandleConstant(HloInstruction*) override { return Status::OK(); }
258 // Skip parameters, they are a simple load.
HandleParameter(HloInstruction *)259 Status HandleParameter(HloInstruction*) override { return Status::OK(); }
260 // It is important to recurse for "while" or else we risk overly coarse
261 // profiling information.
HandleWhile(HloInstruction * xla_while)262 Status HandleWhile(HloInstruction* xla_while) override {
263 TF_RETURN_IF_ERROR(DefaultAction(xla_while));
264
265 CollectProfileCandidates candidates_for_condition(hlo_to_profile_idx_,
266 assigned_indices_);
267 TF_RETURN_IF_ERROR(
268 xla_while->while_condition()->Accept(&candidates_for_condition));
269
270 CollectProfileCandidates candidates_for_body(hlo_to_profile_idx_,
271 assigned_indices_);
272 TF_RETURN_IF_ERROR(xla_while->while_body()->Accept(&candidates_for_body));
273
274 return Status::OK();
275 }
276
277 std::unordered_map<const HloInstruction*, int64>* hlo_to_profile_idx_;
278 const std::unordered_map<const HloInstruction*, int64>& assigned_indices_;
279 };
280
281 } // namespace
282
RunHloPassesThroughLayoutAssn(HloModule * module,bool,LLVMTargetMachineFeatures * target_machine_features)283 Status CpuCompiler::RunHloPassesThroughLayoutAssn(
284 HloModule* module, bool /*is_aot_compile*/,
285 LLVMTargetMachineFeatures* target_machine_features) {
286 HloPassPipeline pipeline("HLO passes through layout assignment");
287 pipeline.AddInvariantChecker<HloVerifier>(/*layout_sensitive=*/false,
288 /*allow_mixed_precision=*/false);
289
290 pipeline.AddPass<OperandUpcaster>();
291
292 // Expand random number generation.
293 pipeline.AddPass<RngExpander>();
294 pipeline.AddPass<RngBitGeneratorExpander>(RandomAlgorithm::RNG_PHILOX);
295
296 // Remove zero-sized HLO from the input so that other passes don't have to
297 // handle it.
298 pipeline.AddPass<ZeroSizedHloElimination>();
299
300 pipeline.AddPass<DynamicIndexSplitter>();
301
302 pipeline.AddPass<ConditionalToSelect>();
303 pipeline.AddPass<MapInliner>();
304
305 pipeline.AddPass<ComparisonExpander>();
306 pipeline.AddPass<CholeskyExpander>();
307 pipeline.AddPass<QrExpander>();
308 pipeline.AddPass<TriangularSolveExpander>();
309 pipeline.AddPass<AllGatherDecomposer>();
310 pipeline.AddPass<AllToAllDecomposer>();
311
312 // Inline computations with a single call site.
313 pipeline.AddPass<CallInliner>(/*single_call_site=*/true);
314 pipeline.AddPass<BatchDotSimplification>();
315 pipeline.AddPass<DotDecomposer>();
316 // Convert BF16 operations to F32 operations so that the CPU backend can
317 // support BF16 operations without directly implementing a BF16 lowering for
318 // most ops.
319 pipeline.AddPass<HloElementTypeConverter>(BF16, F32);
320 // After canonicalization, there may be more batch dots that can be
321 // simplified.
322 pipeline.AddPass<BatchDotSimplification>();
323 auto cost_model = [](HloInstruction* conv) {
324 // We need a cost model for CPUs. Currently, do nothing.
325 return false;
326 };
327 pipeline.AddPass<ConvolutionGroupConverter>(
328 cost_model,
329 /*convert_batch_groups_only=*/true);
330 pipeline.AddPass<ConvolutionGroupConverter>(
331 cost_model,
332 /*convert_batch_groups_only=*/false);
333 pipeline.AddPass<BatchNormExpander>(
334 /*rewrite_training_op=*/true,
335 /*rewrite_inference_op=*/true,
336 /*rewrite_grad_op=*/true);
337 pipeline.AddPass<LogisticExpander>(
338 /*expansion_type=*/LogisticExpansionType::kExp);
339 pipeline.AddPass<ConditionalCanonicalizer>();
340 pipeline.AddPass<DynamicPadder>();
341 pipeline.AddPass<ScatterExpander>(ScatterExpander::kEliminateAllScatters);
342 pipeline.AddPass<ConvCanonicalization>(target_machine_features);
343 {
344 auto& pass =
345 pipeline.AddPass<HloPassFix<HloPassPipeline>>("simplification");
346 pass.AddInvariantCheckerDebug<HloVerifier>(/*layout_sensitive=*/false,
347 /*allow_mixed_precision=*/false);
348
349 pass.AddPass<TreeReductionRewriter>();
350 AlgebraicSimplifierOptions options;
351 options.set_enable_dot_strength_reduction(false);
352 pass.AddPass<AlgebraicSimplifier>(options);
353 pass.AddPass<SortSimplifier>();
354 pass.AddPass<HloDCE>();
355 pass.AddPass<GatherExpander>(GatherExpander::kEliminateSimpleGathers);
356
357 // BatchNormExpander can create zero-sized ops, so zero-sized HLO
358 // elimination has to come after that pass.
359 pass.AddPass<ZeroSizedHloElimination>();
360
361 pass.AddPass<WhileLoopInvariantCodeMotion>();
362 pass.AddPass<TupleSimplifier>();
363 pass.AddPass<WhileLoopConstantSinking>();
364 pass.AddPass<WhileLoopSimplifier>();
365
366 // TODO(b/134075051): Re-enable after b/134075051 is fixed.
367 // pass.AddPass<SliceSinker>();
368
369 pass.AddPass<HloDCE>();
370 pass.AddPass<ReshapeMover>();
371 pass.AddPass<HloConstantFolding>();
372 pass.AddPass<ConditionalSimplifier>();
373 }
374 pipeline.AddPass<TopkRewriter>([](const HloSortInstruction* sort, int64) {
375 return sort->operand(0)->shape().element_type() == F32;
376 });
377 pipeline.AddPass<IndexedArrayAnalysisPrinterPass>();
378 pipeline.AddPass<TransposeFolding>(
379 [&](const HloInstruction& dot,
380 const TransposeFolding::OperandIndices& candidate_operands) {
381 return DotImplementationCanHandleTranspose(dot,
382 *target_machine_features)
383 ? candidate_operands
384 : TransposeFolding::OperandIndices{};
385 },
386 TransposeFolding::NeverFoldTranspose);
387 pipeline.AddPass<HloCSE>(/*is_layout_sensitive=*/false);
388
389 // Layout assignment uses alias analysis, which requires the call graph to be
390 // flattened.
391 pipeline.AddPass<FlattenCallGraph>();
392 pipeline.AddPass<CpuLayoutAssignment>(
393 module->mutable_entry_computation_layout(),
394 LayoutAssignment::InstructionCanChangeLayout, target_machine_features);
395
396 pipeline.AddPass<CpuInstructionFusion>();
397
398 return pipeline.Run(module).status();
399 }
400
RunHloPassesAfterLayoutAssn(HloModule * module,bool is_aot_compile,LLVMTargetMachineFeatures * target_machine_features)401 Status CpuCompiler::RunHloPassesAfterLayoutAssn(
402 HloModule* module, bool is_aot_compile,
403 LLVMTargetMachineFeatures* target_machine_features) {
404 HloPassPipeline pipeline("HLO passes after layout assignment");
405 // After layout assignment, use a layout-sensitive verifier.
406
407 pipeline.AddPass<HloPassPipeline>("after layout assignment")
408 .AddInvariantCheckerDebug<HloVerifier>(
409 /*layout_sensitive=*/true,
410 /*allow_mixed_precision=*/false);
411
412 // The LayoutAssignment pass may leave behind kCopy instructions which are
413 // duplicate or NOPs, so remove them with algebraic simplification and CSE.
414 {
415 auto& pass = pipeline.AddPass<HloPassFix<HloPassPipeline>>(
416 "simplification after layout assignment");
417 pass.AddInvariantCheckerDebug<HloVerifier>(
418 /*layout_sensitive=*/true,
419 /*allow_mixed_precision=*/false,
420 LayoutAssignment::InstructionCanChangeLayout);
421 AlgebraicSimplifierOptions options;
422 options.set_is_layout_sensitive(true);
423 options.set_enable_dot_strength_reduction(false);
424 pass.AddPass<HloPassFix<AlgebraicSimplifier>>(options);
425 pass.AddPass<HloDCE>();
426 pass.AddPass<HloCSE>(/*is_layout_sensitive=*/true);
427 }
428
429 // Outline ops in the entry computation into calls to subcomputations.
430 const int max_parallelism =
431 module->config().intra_op_parallelism_threads() > 0
432 ? module->config().intra_op_parallelism_threads()
433 : tensorflow::port::NumSchedulableCPUs();
434 if (!is_aot_compile) {
435 // Run ParallelTaskAssigner to assign parallel tasks to HLOs in module.
436 // Note this is not run for AOT because it would bring in thread pool
437 // and thread synchronization dependencies which would likely increase
438 // binary size (and most AOT applications are single-threaded).
439 // TODO(b/29630486) Support multi-threaded AOT.
440 pipeline.AddPass<ParallelTaskAssigner>(
441 max_parallelism, ShapeSizeBytesFunction(), target_machine_features);
442 }
443 // Copy insertion should be performed immediately before IR emission to
444 // avoid inserting unnecessary copies (later pass adds an instruction which
445 // materializes the value) or missing a necessary copy (later pass removes
446 // an instruction which materializes a value). DCE must be run immediately
447 // before (and sometime after) copy insertion, to avoid dead code from
448 // interfering with the rewrites.
449 pipeline.AddPass<HloDCE>();
450 pipeline.AddPass<CopyInsertion>();
451 pipeline.AddPass<HloDCE>();
452 return pipeline.Run(module).status();
453 }
454
RunHloPasses(HloModule * module,bool is_aot_compile,llvm::TargetMachine * target_machine)455 Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile,
456 llvm::TargetMachine* target_machine) {
457 LLVMTargetMachineFeatures target_machine_features(target_machine);
458 TF_RETURN_IF_ERROR(RunHloPassesThroughLayoutAssn(module, is_aot_compile,
459 &target_machine_features));
460 return RunHloPassesAfterLayoutAssn(module, is_aot_compile,
461 &target_machine_features);
462 }
463
464 namespace {
465
466 // Align buffers to 16-byte boundaries.
memory_alignment(LogicalBuffer::Color)467 int64 memory_alignment(LogicalBuffer::Color) {
468 return cpu_function_runtime::kMinAlign;
469 }
470
CompilerTargetOptions(const HloModuleConfig & module_config)471 llvm::TargetOptions CompilerTargetOptions(
472 const HloModuleConfig& module_config) {
473 llvm::TargetOptions target_options;
474 // Always allow FMA fusion. This increases precision instead of decreasing it.
475 target_options.AllowFPOpFusion = llvm::FPOpFusion::Fast;
476 return target_options;
477 }
478
CodeGenOptLevel(const HloModuleConfig & module_config)479 llvm::CodeGenOpt::Level CodeGenOptLevel(const HloModuleConfig& module_config) {
480 VLOG(2) << "backend_optimization_level: "
481 << module_config.debug_options().xla_backend_optimization_level();
482 switch (module_config.debug_options().xla_backend_optimization_level()) {
483 case 1:
484 return llvm::CodeGenOpt::Less;
485 case 2:
486 return llvm::CodeGenOpt::Default;
487 case 3:
488 return llvm::CodeGenOpt::Aggressive;
489 default:
490 return llvm::CodeGenOpt::None;
491 }
492 }
493
GetIRModuleHooks(const HloModule & hlo_module,const LLVMCompiler::ModuleHook & user_pre_optimization_hook,const LLVMCompiler::ModuleHook & user_post_optimization_hook)494 std::pair<LLVMCompiler::ModuleHook, LLVMCompiler::ModuleHook> GetIRModuleHooks(
495 const HloModule& hlo_module,
496 const LLVMCompiler::ModuleHook& user_pre_optimization_hook,
497 const LLVMCompiler::ModuleHook& user_post_optimization_hook) {
498 // Create the IR hooks. If applicable, each IR hook does the following:
499 //
500 // * Calls the user supplied module hook.
501 // * Writes out the IR to a file in the output directory designated by
502 // --xla_dump_to
503 const HloModule* hlo_module_ptr = &hlo_module;
504 auto hook = [user_pre_optimization_hook, user_post_optimization_hook,
505 hlo_module_ptr](bool optimized,
506 const llvm::Module& llvm_module) {
507 const auto& user_hook =
508 !optimized ? user_pre_optimization_hook : user_post_optimization_hook;
509 if (user_hook) {
510 user_hook(llvm_module);
511 }
512 llvm_ir::DumpIrIfEnabled(*hlo_module_ptr, llvm_module, optimized);
513 };
514 return {[hook](const llvm::Module& llvm_module) {
515 return hook(/*optimized=*/false, llvm_module);
516 },
517 [hook](const llvm::Module& llvm_module) {
518 return hook(/*optimized=*/true, llvm_module);
519 }};
520 }
521
VerifyLlvmModule(const llvm::Module & llvm_module)522 Status VerifyLlvmModule(const llvm::Module& llvm_module) {
523 XLA_SCOPED_LOGGING_TIMER("CpuCompiler - Running LLVM verifier");
524
525 std::string err;
526 llvm::raw_string_ostream err_stream(err);
527
528 // verifyModule() returns true if the module is broken.
529 TF_RET_CHECK(!llvm::verifyModule(llvm_module, &err_stream))
530 << "Invalid LLVM IR before optimizations:\n"
531 << err_stream.str()
532 << "\nThis probably indicates a bug in the HLO -> LLVM IR lowering. "
533 "Rerun with --xla_dump_to to get the IR. ";
534 return Status::OK();
535 }
536
CreateHloProfilingArtifacts(const HloModule & module,std::unordered_map<const HloInstruction *,int64> * instruction_to_profile_idx,std::unordered_map<const HloComputation *,int64> * computation_to_profile_idx,std::unique_ptr<HloProfileIndexMap> * hlo_profile_index_map,std::unique_ptr<HloProfilePrinterData> * hlo_profile_printer_data)537 Status CreateHloProfilingArtifacts(
538 const HloModule& module,
539 std::unordered_map<const HloInstruction*, int64>*
540 instruction_to_profile_idx,
541 std::unordered_map<const HloComputation*, int64>*
542 computation_to_profile_idx,
543 std::unique_ptr<HloProfileIndexMap>* hlo_profile_index_map,
544 std::unique_ptr<HloProfilePrinterData>* hlo_profile_printer_data) {
545 *hlo_profile_index_map = absl::make_unique<HloProfileIndexMap>(module);
546 const HloComputation& entry_computation = *module.entry_computation();
547
548 TF_ASSIGN_OR_RETURN(
549 *instruction_to_profile_idx,
550 CollectProfileCandidates::GetCandidatesForComputation(
551 entry_computation,
552 (*hlo_profile_index_map)->instruction_to_profile_idx()));
553
554 auto shape_size_bytes = [](const Shape& shape) {
555 // On the cpu, opaques are pointers.
556 if (shape.IsOpaque()) {
557 return static_cast<int64>(sizeof(void*));
558 }
559 return ShapeUtil::ByteSizeOf(shape, sizeof(void*));
560 };
561
562 HloCostAnalysis cost_analysis(shape_size_bytes);
563 TF_RETURN_IF_ERROR(entry_computation.Accept(&cost_analysis));
564 *hlo_profile_printer_data = CreateHloProfilePrinterData(
565 **hlo_profile_index_map, cost_analysis, entry_computation.name());
566 *computation_to_profile_idx =
567 (*hlo_profile_index_map)->computation_to_profile_idx();
568
569 return Status::OK();
570 }
571
572 } // namespace
573
RunHloPasses(std::unique_ptr<HloModule> module,se::StreamExecutor *,const CompileOptions &)574 StatusOr<std::unique_ptr<HloModule>> CpuCompiler::RunHloPasses(
575 std::unique_ptr<HloModule> module, se::StreamExecutor* /*stream_exec*/,
576 const CompileOptions& /*options*/) {
577 std::unique_ptr<llvm::TargetMachine> jit_target_machine =
578 SimpleOrcJIT::InferTargetMachineForJIT(
579 CompilerTargetOptions(module->config()),
580 CodeGenOptLevel(module->config()));
581
582 TF_RETURN_IF_ERROR(RunHloPasses(module.get(), /*is_aot_compile=*/false,
583 jit_target_machine.get()));
584 return std::move(module);
585 }
586
587 StatusOr<
588 std::tuple<std::unique_ptr<HloModule>, std::unique_ptr<BufferAssignment>>>
RunHloPassesAndBufferAssignement(std::unique_ptr<HloModule> module,se::StreamExecutor * executor,bool optimize,const CompileOptions & options)589 CpuCompiler::RunHloPassesAndBufferAssignement(std::unique_ptr<HloModule> module,
590 se::StreamExecutor* executor,
591 bool optimize,
592 const CompileOptions& options) {
593 if (optimize) {
594 TF_ASSIGN_OR_RETURN(module,
595 RunHloPasses(std::move(module), executor, options));
596 }
597
598 // Select an order for emitting the HLO instructions for each computation.
599 // Using this sequence enables tighter buffer liveness analysis and reduced
600 // memory usage (as compared to using DependencyHloOrdering).
601 TF_ASSIGN_OR_RETURN(HloSchedule schedule,
602 ScheduleModule(module.get(), BufferSizeBytesFunction(),
603 ComputationSchedulerToModuleScheduler(
604 DFSMemoryScheduler)));
605
606 // Run buffer allocation on the HLO graph.
607 TF_ASSIGN_OR_RETURN(
608 std::unique_ptr<BufferAssignment> assignment,
609 BufferAssigner::Run(module.get(),
610 absl::make_unique<SequentialHloOrdering>(schedule),
611 BufferSizeBytesFunction(), memory_alignment,
612 /*allocate_buffers_for_constants=*/true));
613
614 return std::make_tuple(std::move(module), std::move(assignment));
615 }
616
617 namespace {
618
619 // Post-compilation callback functor for use by SimpleOrcJIT.
620 //
621 // Dumps machine code if dumping is enabled for the module.
622 struct OrcJITPostCompilationHook {
623 // Gets an std::function that implements this hook.
Createxla::cpu::__anon31d794070c11::OrcJITPostCompilationHook624 static std::function<void(const llvm::object::ObjectFile& obj_file)> Create(
625 const HloModule* module) {
626 // This struct is not copyable, but std::functions must be. So to create an
627 // std::function out of this struct, we have to wrap it in a shared_ptr.
628 auto wrapped = std::make_shared<OrcJITPostCompilationHook>(module);
629 return [wrapped](const llvm::object::ObjectFile& obj_file) {
630 (*wrapped)(obj_file);
631 };
632 }
633
634 // Constructor can't be private because we want to call it from
635 // std::make_shared, but users should call Create() instead.
OrcJITPostCompilationHookxla::cpu::__anon31d794070c11::OrcJITPostCompilationHook636 explicit OrcJITPostCompilationHook(const HloModule* module)
637 : module(module) {}
638
639 private:
operator ()xla::cpu::__anon31d794070c11::OrcJITPostCompilationHook640 void operator()(const llvm::object::ObjectFile& obj_file) {
641 if (!DumpingEnabledForHloModule(*module)) {
642 return;
643 }
644 DumpToFileInDir(*module, /*file_prefix=*/"", /*file_suffix=*/"o",
645 absl::string_view(obj_file.getData().data(),
646 obj_file.getData().size()));
647 }
648
649 const HloModule* module;
650 };
651
652 } // namespace
653
RunBackend(std::unique_ptr<HloModule> module,se::StreamExecutor * stream_exec,const CompileOptions & options)654 StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend(
655 std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
656 const CompileOptions& options) {
657 VLOG(1) << "Compiling: " << module->name();
658 XLA_SCOPED_LOGGING_TIMER(
659 absl::StrFormat("Compiling [%s] for CPU using JIT", module->name()));
660 std::string slow_compilation_msg =
661 absl::StrCat("Compiling module ", module->name());
662 auto slow_compile_alarm = SlowCompilationAlarm(slow_compilation_msg);
663
664 TF_RET_CHECK(stream_exec != nullptr);
665 absl::call_once(llvm_command_line_options_initialized,
666 &llvm_ir::InitializeLLVMCommandLineOptions, module->config());
667
668 ModuleHook pre_optimization_ir_hook;
669 ModuleHook post_optimization_ir_hook;
670 std::tie(pre_optimization_ir_hook, post_optimization_ir_hook) =
671 GetIRModuleHooks(*module, user_pre_optimization_hook_,
672 user_post_optimization_hook_);
673
674 // Compile must be thread-safe so create a new LLVM context for the module.
675 mlir::MLIRContext mlir_context;
676 LoadMLIRDialects(mlir_context);
677 auto llvm_context = std::make_unique<llvm::LLVMContext>();
678 auto llvm_module =
679 absl::make_unique<llvm::Module>("__compute_module", *llvm_context);
680
681 auto jit = SimpleOrcJIT::Create(
682 CompilerTargetOptions(module->config()),
683 CodeGenOptLevel(module->config()),
684 options::OptimizeForSizeRequested(module->config()),
685 module->config().debug_options().xla_llvm_disable_expensive_passes(),
686 llvm_ir::GetCpuFastMathFlags(module->config()), pre_optimization_ir_hook,
687 post_optimization_ir_hook,
688 OrcJITPostCompilationHook::Create(module.get()));
689 if (!jit) {
690 return InternalError("Creating JIT failed: %s",
691 llvm::toString(jit.takeError()));
692 }
693 llvm_module->setDataLayout((*jit)->data_layout());
694 llvm_module->setTargetTriple((*jit)->target_triple().getTriple());
695
696 HloComputation* entry_computation = module->entry_computation();
697 std::unordered_map<const HloInstruction*, int64> instruction_to_profile_idx;
698 std::unordered_map<const HloComputation*, int64> computation_to_profile_idx;
699 std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map;
700 std::unique_ptr<HloProfilePrinterData> hlo_profile_printer_data;
701 if (module->config().hlo_profiling_enabled()) {
702 TF_RETURN_IF_ERROR(CreateHloProfilingArtifacts(
703 *module, &instruction_to_profile_idx, &computation_to_profile_idx,
704 &hlo_profile_index_map, &hlo_profile_printer_data));
705 }
706
707 std::unique_ptr<Executable> cpu_executable;
708
709 // Cache these flags here since we'll want to access them after the module's
710 // ownership is std::moved.
711 const bool embed_ir_in_executable =
712 module->config().debug_options().xla_embed_ir_in_executable();
713
714 // Select an order for emitting the HLO instructions for each
715 // computation. Using this sequence enables tighter buffer liveness analysis
716 // and reduced memory usage (as compared to using DependencyHloOrdering).
717 TF_ASSIGN_OR_RETURN(HloSchedule schedule,
718 ScheduleModule(module.get(), BufferSizeBytesFunction(),
719 ComputationSchedulerToModuleScheduler(
720 DFSMemoryScheduler)));
721
722 // Run buffer allocation on the HLO graph.
723 TF_ASSIGN_OR_RETURN(
724 std::unique_ptr<BufferAssignment> assignment,
725 BufferAssigner::Run(module.get(),
726 absl::make_unique<SequentialHloOrdering>(schedule),
727 BufferSizeBytesFunction(), memory_alignment,
728 /*allocate_buffers_for_constants=*/true));
729 DumpHloModuleIfEnabled(*module, *assignment, "after_optimizations");
730
731 // Each computation is a single function. Emit all embedded computations
732 // before the entry computation. The order of computations returned from
733 // GetEmbeddedComputations guarantees that a called computation occurs
734 // before a caller computation.
735
736 LLVMTargetMachineFeatures target_machine_features((*jit)->target_machine());
737 IrEmitter ir_emitter(&mlir_context, *module, *assignment, llvm_module.get(),
738 std::move(instruction_to_profile_idx),
739 std::move(computation_to_profile_idx),
740 &target_machine_features,
741 #ifdef MEMORY_SANITIZER
742 /*emit_code_for_msan=*/true
743 #else
744 /*emit_code_for_msan=*/false
745 #endif
746 );
747
748 TF_RETURN_IF_ERROR(ir_emitter.EmitConstantGlobals());
749
750 for (auto embedded_computation :
751 entry_computation->MakeEmbeddedComputationsList()) {
752 if (embedded_computation->IsFusionComputation()) {
753 continue;
754 }
755 TF_RETURN_IF_ERROR(
756 ir_emitter
757 .EmitComputation(
758 embedded_computation, embedded_computation->name(),
759 /*is_top_level_computation=*/false,
760 schedule.sequence(embedded_computation).instructions())
761 .status());
762 }
763 string function_name_prefix = entry_computation->name().empty()
764 ? "__compute"
765 : entry_computation->name();
766 TF_ASSIGN_OR_RETURN(llvm::Function * entry_function,
767 ir_emitter.EmitComputation(
768 entry_computation, function_name_prefix,
769 /*is_top_level_computation=*/true,
770 schedule.sequence(entry_computation).instructions()));
771
772 string function_name = [&]() {
773 llvm::SmallVector<char, 40> function_name_vector;
774 llvm::Mangler::getNameWithPrefix(
775 function_name_vector, entry_function->getName(), (*jit)->data_layout());
776 return string(function_name_vector.begin(), function_name_vector.end());
777 }();
778
779 string ir_module_string;
780 if (embed_ir_in_executable) {
781 ir_module_string = llvm_ir::DumpModuleToString(*llvm_module);
782 }
783
784 TF_RETURN_IF_ERROR(VerifyLlvmModule(*llvm_module));
785
786 // JIT compile the LLVM IR module to in-memory machine code.
787 llvm::orc::ThreadSafeModule thread_safe_module(std::move(llvm_module),
788 std::move(llvm_context));
789 cantFail((*jit)->AddModule(std::move(thread_safe_module)));
790 cpu_executable.reset(new CpuExecutable(
791 std::move(*jit), std::move(assignment), std::move(module), function_name,
792 std::move(hlo_profile_printer_data), std::move(hlo_profile_index_map)));
793
794 if (embed_ir_in_executable) {
795 static_cast<CpuExecutable&>(*cpu_executable)
796 .set_ir_module_string(ir_module_string);
797 }
798
799 VLOG(1) << "Compilation finished";
800 return std::move(cpu_executable);
801 }
802
803 StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group,const AotCompilationOptions & aot_options)804 CpuCompiler::CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group,
805 const AotCompilationOptions& aot_options) {
806 TF_RET_CHECK(!module_group->empty());
807 std::vector<std::unique_ptr<HloModule>> modules =
808 module_group->ConsumeModules();
809
810 absl::call_once(llvm_command_line_options_initialized,
811 &llvm_ir::InitializeLLVMCommandLineOptions,
812 modules[0]->config());
813
814 // We can pass just one llvm::TargetOptions when we compile the LLVM module,
815 // so we bail if the configs have conflicting flags. At the moment, the only
816 // flags that need to be consistent are for fast-math.
817 for (const auto& fn_and_name :
818 {std::make_pair(&DebugOptions::xla_cpu_enable_fast_math,
819 "xla_cpu_enable_fast_math"),
820 std::make_pair(&DebugOptions::xla_cpu_fast_math_honor_infs,
821 "xla_cpu_fast_math_honor_infs"),
822 std::make_pair(&DebugOptions::xla_cpu_fast_math_honor_nans,
823 "xla_cpu_fast_math_honor_nans")}) {
824 // This only works because each of the method pointers above returns a bool.
825 // Otherwise we'd have to do some template magic.
826 const auto& field_method_ptr = fn_and_name.first;
827 const auto& field_name = fn_and_name.second;
828 bool first_module_val =
829 (modules[0]->config().debug_options().*field_method_ptr)();
830 for (int64 i = 0; i < modules.size(); ++i) {
831 bool cur_module_val =
832 (modules[i]->config().debug_options().*field_method_ptr)();
833 if (first_module_val != cur_module_val) {
834 return InvalidArgument(
835 "All HLO module configs must have the same value for %s, but "
836 "module 0 and %d have different values (%d vs %d).",
837 field_name, i, first_module_val, cur_module_val);
838 }
839 }
840 }
841
842 if (aot_options.PlatformId() != se::host::kHostPlatformId) {
843 return InvalidArgument("Incompatible AOT compilation platform");
844 }
845 const CpuAotCompilationOptions& options =
846 static_cast<const CpuAotCompilationOptions&>(aot_options);
847 llvm::Triple triple(llvm::Triple::normalize(options.triple()));
848 std::string error;
849 const llvm::Target* target =
850 llvm::TargetRegistry::lookupTarget(triple.getTriple(), error);
851 if (target == nullptr) {
852 return InternalError("TargetRegistry::lookupTarget failed: %s", error);
853 }
854
855 llvm::Reloc::Model reloc_model = llvm::Reloc::Static;
856 llvm::PICLevel::Level pic_level = llvm::PICLevel::NotPIC;
857 llvm::PIELevel::Level pie_level = llvm::PIELevel::Default;
858 switch (options.relocation_model()) {
859 case CpuAotCompilationOptions::RelocationModel::Static:
860 reloc_model = llvm::Reloc::Static;
861 pic_level = llvm::PICLevel::NotPIC;
862 pie_level = llvm::PIELevel::Default;
863 break;
864 case CpuAotCompilationOptions::RelocationModel::SmallPic:
865 reloc_model = llvm::Reloc::PIC_;
866 pic_level = llvm::PICLevel::SmallPIC;
867 pie_level = llvm::PIELevel::Default;
868 break;
869 case CpuAotCompilationOptions::RelocationModel::BigPic:
870 reloc_model = llvm::Reloc::PIC_;
871 pic_level = llvm::PICLevel::BigPIC;
872 pie_level = llvm::PIELevel::Default;
873 break;
874 case CpuAotCompilationOptions::RelocationModel::SmallPie:
875 reloc_model = llvm::Reloc::PIC_;
876 pic_level = llvm::PICLevel::SmallPIC;
877 pie_level = llvm::PIELevel::Small;
878 break;
879 case CpuAotCompilationOptions::RelocationModel::BigPie:
880 reloc_model = llvm::Reloc::PIC_;
881 pic_level = llvm::PICLevel::BigPIC;
882 pie_level = llvm::PIELevel::Large;
883 break;
884 }
885 llvm::CodeGenOpt::Level opt_level = CodeGenOptLevel(modules[0]->config());
886 std::unique_ptr<llvm::TargetMachine> target_machine =
887 absl::WrapUnique(target->createTargetMachine(
888 triple.getTriple(), options.cpu_name(), options.features(),
889 CompilerTargetOptions(modules[0]->config()), reloc_model, llvm::None,
890 opt_level));
891
892 // Compile must be thread-safe so create a new LLVM context for the module.
893 mlir::MLIRContext mlir_context;
894 LoadMLIRDialects(mlir_context);
895 llvm::LLVMContext llvm_context;
896 llvm::Module llvm_module("__compute_module", llvm_context);
897 llvm_module.setDataLayout(target_machine->createDataLayout());
898 llvm_module.setTargetTriple(triple.getTriple());
899 if (pic_level != llvm::PICLevel::NotPIC) {
900 llvm_module.setPICLevel(pic_level);
901 }
902 if (pie_level != llvm::PIELevel::Default) {
903 llvm_module.setPIELevel(pie_level);
904 }
905
906 std::vector<std::unique_ptr<AotCompilationResult>> results;
907 for (size_t i = 0; i < modules.size(); ++i) {
908 HloModule* module = modules[i].get();
909 VLOG(1) << "Compiling ahead-of-time: " << module->name();
910
911 TF_RETURN_IF_ERROR(
912 RunHloPasses(module, /*is_aot_compile=*/true, target_machine.get()));
913
914 TF_ASSIGN_OR_RETURN(HloSchedule schedule,
915 ScheduleModule(module, BufferSizeBytesFunction()));
916
917 // Run buffer analysis on the HLO graph. This analysis figures out which
918 // temporary buffers are required to run the computation.
919 TF_ASSIGN_OR_RETURN(
920 std::unique_ptr<BufferAssignment> assignment,
921 BufferAssigner::Run(module,
922 absl::make_unique<SequentialHloOrdering>(schedule),
923 BufferSizeBytesFunction(), memory_alignment,
924 /*allocate_buffers_for_constants=*/true));
925 // BufferAssignment::ToString() includes a header, so no need for us to
926 // print one ourselves.
927 if (DumpingEnabledForHloModule(*module)) {
928 DumpToFileInDirOrStdout(*module, "", "buffer_assignment",
929 assignment->ToString());
930 }
931 DumpHloModuleIfEnabled(*module, *assignment, "after_optimizations");
932
933 std::unordered_map<const HloInstruction*, int64> instruction_to_profile_idx;
934 std::unordered_map<const HloComputation*, int64> computation_to_profile_idx;
935 std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map;
936 std::unique_ptr<HloProfilePrinterData> hlo_profile_printer_data;
937
938 if (module->config().hlo_profiling_enabled()) {
939 TF_RETURN_IF_ERROR(CreateHloProfilingArtifacts(
940 *module, &instruction_to_profile_idx, &computation_to_profile_idx,
941 &hlo_profile_index_map, &hlo_profile_printer_data));
942 }
943
944 LLVMTargetMachineFeatures target_machine_features(target_machine.get());
945 IrEmitter ir_emitter(&mlir_context, *module, *assignment, &llvm_module,
946 std::move(instruction_to_profile_idx),
947 std::move(computation_to_profile_idx),
948 &target_machine_features,
949 // TODO(b/66051036): Run full msan for AOT.
950 /*emit_code_for_msan=*/false);
951
952 TF_RETURN_IF_ERROR(ir_emitter.EmitConstantGlobals());
953
954 HloComputation* computation = module->entry_computation();
955 for (auto embedded_computation :
956 computation->MakeEmbeddedComputationsList()) {
957 if (embedded_computation->IsFusionComputation()) {
958 continue;
959 }
960 TF_RETURN_IF_ERROR(
961 ir_emitter
962 .EmitComputation(
963 embedded_computation, embedded_computation->name(),
964 /*is_top_level_computation=*/false,
965 schedule.sequence(embedded_computation).instructions())
966 .status());
967 }
968 const string& entry_point_name = options.entry_point_name();
969 TF_ASSIGN_OR_RETURN(llvm::Function * entry_function,
970 ir_emitter.EmitComputation(
971 computation, entry_point_name,
972 /*is_top_level_computation=*/true,
973 schedule.sequence(computation).instructions()));
974
975 CHECK(entry_function->getName() == entry_point_name);
976
977 ModuleHook pre_optimization_ir_hook;
978 ModuleHook post_optimization_ir_hook;
979 std::tie(pre_optimization_ir_hook, post_optimization_ir_hook) =
980 GetIRModuleHooks(*module, user_pre_optimization_hook_,
981 user_post_optimization_hook_);
982
983 // Run the LLVM verifier over the unoptimized LLVM IR. If it fails, run the
984 // pre-optimization IR dump hook before returning.
985 {
986 Status verify_status = VerifyLlvmModule(llvm_module);
987 if (!verify_status.ok() && pre_optimization_ir_hook) {
988 pre_optimization_ir_hook(llvm_module);
989 }
990 TF_RETURN_IF_ERROR(verify_status);
991 }
992
993 auto post_codegen_hook = [&](const llvm::object::ObjectFile& obj_file) {
994 if (!DumpingEnabledForHloModule(*module)) {
995 return;
996 }
997 DumpToFileInDir(*module, /*file_prefix=*/"", /*file_suffix=*/"o",
998 absl::string_view(obj_file.getData().data(),
999 obj_file.getData().size()));
1000 };
1001
1002 CompilerFunctor compiler_functor(
1003 target_machine.get(), opt_level,
1004 options::OptimizeForSizeRequested(module->config()),
1005 module->config().debug_options().xla_llvm_disable_expensive_passes(),
1006 llvm_ir::GetCpuFastMathFlags(module->config()),
1007 pre_optimization_ir_hook, post_optimization_ir_hook, post_codegen_hook);
1008 std::unique_ptr<llvm::MemoryBuffer> object_file =
1009 cantFail(compiler_functor(llvm_module));
1010 ObjectFileData object_file_data(object_file->getBufferStart(),
1011 object_file->getBufferEnd());
1012
1013 std::vector<BufferInfo> buffer_infos =
1014 CreateBufferInfosFromBufferAssignment(*assignment);
1015
1016 TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice result_slice,
1017 assignment->GetUniqueTopLevelOutputSlice());
1018
1019 results.emplace_back(absl::make_unique<CpuAotCompilationResult>(
1020 std::move(object_file_data), std::move(buffer_infos),
1021 result_slice.index(), std::move(hlo_profile_printer_data)));
1022 }
1023
1024 VLOG(1) << "Compilation finished";
1025 return std::move(results);
1026 }
1027
PlatformId() const1028 se::Platform::Id CpuCompiler::PlatformId() const {
1029 return se::host::kHostPlatformId;
1030 }
1031
ShapeSizeBytesFunction() const1032 HloCostAnalysis::ShapeSizeFunction CpuCompiler::ShapeSizeBytesFunction() const {
1033 return CpuExecutable::ShapeSizeBytes;
1034 }
1035
1036 } // namespace cpu
1037 } // namespace xla
1038
InitModule()1039 static bool InitModule() {
1040 xla::Compiler::RegisterCompilerFactory(
1041 stream_executor::host::kHostPlatformId,
1042 []() { return absl::make_unique<xla::cpu::CpuCompiler>(); });
1043 return true;
1044 }
1045 static bool module_initialized = InitModule();
1046