1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h"
17
18 #include <stddef.h>
19 #include <string.h>
20
21 #include <map>
22 #include <memory>
23 #include <string>
24 #include <unordered_map>
25 #include <utility>
26 #include <vector>
27
28 // IWYU pragma: no_include "llvm/Config/Disassemblers.def.inc"
29 // IWYU pragma: no_include "llvm/Config/Targets.def.inc"
30 #include "absl/base/call_once.h"
31 #include "absl/memory/memory.h"
32 #include "absl/strings/str_cat.h"
33 #include "llvm/ADT/StringRef.h"
34 #include "llvm/ADT/Triple.h"
35 #include "llvm/IR/Function.h"
36 #include "llvm/IR/LLVMContext.h"
37 #include "llvm/IR/Mangler.h"
38 #include "llvm/IR/Module.h"
39 #include "llvm/IR/Verifier.h"
40 #include "llvm/Object/ObjectFile.h"
41 #include "llvm/Support/CommandLine.h"
42 #include "llvm/Support/Error.h"
43 #include "llvm/Support/TargetRegistry.h"
44 #include "llvm/Support/TargetSelect.h"
45 #include "llvm/Target/TargetMachine.h"
46 #include "llvm/Target/TargetOptions.h"
47 #include "mlir/Dialect/Affine/IR/AffineOps.h" // from @llvm-project
48 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" // from @llvm-project
49 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" // from @llvm-project
50 #include "mlir/Dialect/SCF/SCF.h" // from @llvm-project
51 #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
52 #include "mlir/Dialect/Vector/VectorOps.h" // from @llvm-project
53 #include "mlir/InitAllDialects.h" // from @llvm-project
54 #include "tensorflow/compiler/xla/cpu_function_runtime.h"
55 #include "tensorflow/compiler/xla/literal.h"
56 #include "tensorflow/compiler/xla/map_util.h"
57 #include "tensorflow/compiler/xla/protobuf_util.h"
58 #include "tensorflow/compiler/xla/service/algebraic_simplifier.h"
59 #include "tensorflow/compiler/xla/service/all_gather_decomposer.h"
60 #include "tensorflow/compiler/xla/service/all_to_all_decomposer.h"
61 #include "tensorflow/compiler/xla/service/batch_dot_simplification.h"
62 #include "tensorflow/compiler/xla/service/batchnorm_expander.h"
63 #include "tensorflow/compiler/xla/service/bfloat16_normalization.h"
64 #include "tensorflow/compiler/xla/service/buffer_assignment.h"
65 #include "tensorflow/compiler/xla/service/call_inliner.h"
66 #include "tensorflow/compiler/xla/service/cholesky_expander.h"
67 #include "tensorflow/compiler/xla/service/comparison_expander.h"
68 #include "tensorflow/compiler/xla/service/conditional_canonicalizer.h"
69 #include "tensorflow/compiler/xla/service/conditional_simplifier.h"
70 #include "tensorflow/compiler/xla/service/conditional_to_select.h"
71 #include "tensorflow/compiler/xla/service/convolution_group_converter.h"
72 #include "tensorflow/compiler/xla/service/copy_insertion.h"
73 #include "tensorflow/compiler/xla/service/cpu/buffer_info_util.h"
74 #include "tensorflow/compiler/xla/service/cpu/compiler_functor.h"
75 #include "tensorflow/compiler/xla/service/cpu/conv_canonicalization.h"
76 #include "tensorflow/compiler/xla/service/cpu/cpu_executable.h"
77 #include "tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.h"
78 #include "tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.h"
79 #include "tensorflow/compiler/xla/service/cpu/cpu_options.h"
80 #include "tensorflow/compiler/xla/service/cpu/dot_op_emitter.h"
81 #include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h"
82 #include "tensorflow/compiler/xla/service/cpu/ir_emitter.h"
83 #include "tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h"
84 #include "tensorflow/compiler/xla/service/cpu/simple_orc_jit.h"
85 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
86 #include "tensorflow/compiler/xla/service/dot_decomposer.h"
87 #include "tensorflow/compiler/xla/service/dump.h"
88 #include "tensorflow/compiler/xla/service/dynamic_index_splitter.h"
89 #include "tensorflow/compiler/xla/service/dynamic_padder.h"
90 #include "tensorflow/compiler/xla/service/eigh_expander.h"
91 #include "tensorflow/compiler/xla/service/flatten_call_graph.h"
92 #include "tensorflow/compiler/xla/service/gather_expander.h"
93 #include "tensorflow/compiler/xla/service/hlo.pb.h"
94 #include "tensorflow/compiler/xla/service/hlo_computation.h"
95 #include "tensorflow/compiler/xla/service/hlo_constant_folding.h"
96 #include "tensorflow/compiler/xla/service/hlo_cse.h"
97 #include "tensorflow/compiler/xla/service/hlo_dce.h"
98 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
99 #include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h"
100 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
101 #include "tensorflow/compiler/xla/service/hlo_ordering.h"
102 #include "tensorflow/compiler/xla/service/hlo_pass_fix.h"
103 #include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h"
104 #include "tensorflow/compiler/xla/service/hlo_proto_util.h"
105 #include "tensorflow/compiler/xla/service/hlo_subcomputation_unification.h"
106 #include "tensorflow/compiler/xla/service/hlo_verifier.h"
107 #include "tensorflow/compiler/xla/service/indexed_array_analysis.h"
108 #include "tensorflow/compiler/xla/service/llvm_compiler.h"
109 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
110 #include "tensorflow/compiler/xla/service/logistic_expander.h"
111 #include "tensorflow/compiler/xla/service/map_inliner.h"
112 #include "tensorflow/compiler/xla/service/operand_upcaster.h"
113 #include "tensorflow/compiler/xla/service/qr_expander.h"
114 #include "tensorflow/compiler/xla/service/reduce_scatter_decomposer.h"
115 #include "tensorflow/compiler/xla/service/reshape_mover.h"
116 #include "tensorflow/compiler/xla/service/result_caster.h"
117 #include "tensorflow/compiler/xla/service/rng_bit_generator_expander.h"
118 #include "tensorflow/compiler/xla/service/rng_expander.h"
119 #include "tensorflow/compiler/xla/service/scatter_expander.h"
120 #include "tensorflow/compiler/xla/service/slice_sinker.h"
121 #include "tensorflow/compiler/xla/service/slow_operation_alarm.h"
122 #include "tensorflow/compiler/xla/service/sort_simplifier.h"
123 #include "tensorflow/compiler/xla/service/topk_rewriter.h"
124 #include "tensorflow/compiler/xla/service/transpose_folding.h"
125 #include "tensorflow/compiler/xla/service/tree_reduction_rewriter.h"
126 #include "tensorflow/compiler/xla/service/triangular_solve_expander.h"
127 #include "tensorflow/compiler/xla/service/tuple_simplifier.h"
128 #include "tensorflow/compiler/xla/service/while_loop_constant_sinking.h"
129 #include "tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h"
130 #include "tensorflow/compiler/xla/service/while_loop_simplifier.h"
131 #include "tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h"
132 #include "tensorflow/compiler/xla/status_macros.h"
133 #include "tensorflow/compiler/xla/statusor.h"
134 #include "tensorflow/compiler/xla/types.h"
135 #include "tensorflow/compiler/xla/util.h"
136 #include "tensorflow/compiler/xla/xla_data.pb.h"
137 #include "tensorflow/core/platform/dynamic_annotations.h"
138
139 namespace {
140
141 // We need to explicitly load all the dialects we will involved in emitting the
142 // IR. This is only needed because of how MLIR is bolted into XLA and does not
143 // make use of the MLIR infrastructure (like using a proper pass pipeline).
144 // Hopefully this will all go away at some point in favor of a better
145 // integration.
LoadMLIRDialects(mlir::MLIRContext & context)146 void LoadMLIRDialects(mlir::MLIRContext& context) {
147 context.loadDialect<mlir::linalg::LinalgDialect, mlir::scf::SCFDialect,
148 mlir::vector::VectorDialect, mlir::StandardOpsDialect,
149 mlir::AffineDialect>();
150 }
151
152 } // namespace
153
154 namespace xla {
155 namespace cpu {
156 using BufferInfo = cpu_function_runtime::BufferInfo;
157
CpuAotCompilationOptions(string triple,string cpu_name,string features,string entry_point_name,RelocationModel relocation_model)158 CpuAotCompilationOptions::CpuAotCompilationOptions(
159 string triple, string cpu_name, string features, string entry_point_name,
160 RelocationModel relocation_model)
161 : triple_(std::move(triple)),
162 cpu_name_(std::move(cpu_name)),
163 features_(std::move(features)),
164 entry_point_name_(std::move(entry_point_name)),
165 relocation_model_(relocation_model) {}
166
167 CpuAotCompilationOptions::~CpuAotCompilationOptions() = default;
168
PlatformId() const169 se::Platform::Id CpuAotCompilationOptions::PlatformId() const {
170 return se::host::kHostPlatformId;
171 }
172
CpuAotCompilationResult(ObjectFileData object_file_data,std::vector<BufferInfo> buffer_infos,int64_t result_buffer_index,std::unique_ptr<HloProfilePrinterData> hlo_profile_printer_data)173 CpuAotCompilationResult::CpuAotCompilationResult(
174 ObjectFileData object_file_data, std::vector<BufferInfo> buffer_infos,
175 int64_t result_buffer_index,
176 std::unique_ptr<HloProfilePrinterData> hlo_profile_printer_data)
177 : object_file_data_(std::move(object_file_data)),
178 buffer_infos_(std::move(buffer_infos)),
179 result_buffer_index_(result_buffer_index),
180 hlo_profile_printer_data_(std::move(hlo_profile_printer_data)) {}
181
182 CpuAotCompilationResult::~CpuAotCompilationResult() = default;
183
CpuCompiler()184 CpuCompiler::CpuCompiler() {
185 // Initialize LLVM the first time the CpuCompiler is initialized.
186 static bool llvm_initialized = []() {
187 InitializeLLVMTarget();
188 return true;
189 }();
190 (void)llvm_initialized;
191 }
192
Compile(std::unique_ptr<HloModuleGroup> module_group,std::vector<std::vector<se::StreamExecutor * >> stream_execs,const CompileOptions & options)193 StatusOr<std::vector<std::unique_ptr<Executable>>> CpuCompiler::Compile(
194 std::unique_ptr<HloModuleGroup> module_group,
195 std::vector<std::vector<se::StreamExecutor*>> stream_execs,
196 const CompileOptions& options) {
197 for (const std::vector<se::StreamExecutor*>& se_vector : stream_execs) {
198 if (se_vector.size() != 1) {
199 return Unimplemented(
200 "Model partitioning not implemented for the CPU compiler");
201 }
202 }
203 return LLVMCompiler::Compile(std::move(module_group), stream_execs, options);
204 }
205
InitializeLLVMTarget()206 /* static */ void CpuCompiler::InitializeLLVMTarget() {
207 // Initialize LLVM's MC layer for the native target.
208 llvm::InitializeNativeTarget();
209 llvm::InitializeNativeTargetAsmPrinter();
210 }
211
212 namespace {
213
214 // LLVM makes certain options configurable only through its command-line
215 // options; it provide the ParseCommandLineOptions function that lets us set
216 // flags at runtime. However, since these flags are global we want to avoid
217 // multiple invocations of the LLVM compilation pipeline with a different set of
218 // flags. Therefore, we only pass command-line flags to LLVM once, before the
219 // first module is compiled.
220 absl::once_flag llvm_command_line_options_initialized;
221
222 // This visitor records which HLO instructions should have profiling information
223 // recorded.
224 class CollectProfileCandidates : public DfsHloVisitorWithDefault {
225 public:
226 static StatusOr<std::unordered_map<const HloInstruction*, int64>>
GetCandidatesForComputation(const HloComputation & computation,const std::unordered_map<const HloInstruction *,int64> & assigned_indices)227 GetCandidatesForComputation(
228 const HloComputation& computation,
229 const std::unordered_map<const HloInstruction*, int64>&
230 assigned_indices) {
231 std::unordered_map<const HloInstruction*, int64> hlo_to_profile_idx;
232 CollectProfileCandidates profile_candidates_for_computation(
233 &hlo_to_profile_idx, assigned_indices);
234 TF_RETURN_IF_ERROR(computation.Accept(&profile_candidates_for_computation));
235 return hlo_to_profile_idx;
236 }
237
238 private:
CollectProfileCandidates(std::unordered_map<const HloInstruction *,int64> * hlo_to_profile_idx,const std::unordered_map<const HloInstruction *,int64> & assigned_indices)239 CollectProfileCandidates(
240 std::unordered_map<const HloInstruction*, int64>* hlo_to_profile_idx,
241 const std::unordered_map<const HloInstruction*, int64>& assigned_indices)
242 : hlo_to_profile_idx_(hlo_to_profile_idx),
243 assigned_indices_(assigned_indices) {}
244
DefaultAction(HloInstruction * hlo_instruction)245 Status DefaultAction(HloInstruction* hlo_instruction) override {
246 hlo_to_profile_idx_->insert(
247 {hlo_instruction, FindOrDie(assigned_indices_, hlo_instruction)});
248 return Status::OK();
249 }
250
HandleCall(HloInstruction * call)251 Status HandleCall(HloInstruction* call) override {
252 TF_RETURN_IF_ERROR(DefaultAction(call));
253 CollectProfileCandidates candidates_for_call(hlo_to_profile_idx_,
254 assigned_indices_);
255 TF_RETURN_IF_ERROR(call->to_apply()->Accept(&candidates_for_call));
256 return Status::OK();
257 }
258 // Recurse into "conditional" so we can profile inside of it.
HandleConditional(HloInstruction * conditional)259 Status HandleConditional(HloInstruction* conditional) override {
260 TF_RETURN_IF_ERROR(DefaultAction(conditional));
261
262 CollectProfileCandidates candidates_for_true(hlo_to_profile_idx_,
263 assigned_indices_);
264 TF_RETURN_IF_ERROR(
265 conditional->true_computation()->Accept(&candidates_for_true));
266
267 CollectProfileCandidates candidates_for_false(hlo_to_profile_idx_,
268 assigned_indices_);
269 TF_RETURN_IF_ERROR(
270 conditional->false_computation()->Accept(&candidates_for_false));
271
272 return Status::OK();
273 }
274
275 // Skip constants, there is nothing to profile.
HandleConstant(HloInstruction *)276 Status HandleConstant(HloInstruction*) override { return Status::OK(); }
277 // Skip parameters, they are a simple load.
HandleParameter(HloInstruction *)278 Status HandleParameter(HloInstruction*) override { return Status::OK(); }
279 // It is important to recurse for "while" or else we risk overly coarse
280 // profiling information.
HandleWhile(HloInstruction * xla_while)281 Status HandleWhile(HloInstruction* xla_while) override {
282 TF_RETURN_IF_ERROR(DefaultAction(xla_while));
283
284 CollectProfileCandidates candidates_for_condition(hlo_to_profile_idx_,
285 assigned_indices_);
286 TF_RETURN_IF_ERROR(
287 xla_while->while_condition()->Accept(&candidates_for_condition));
288
289 CollectProfileCandidates candidates_for_body(hlo_to_profile_idx_,
290 assigned_indices_);
291 TF_RETURN_IF_ERROR(xla_while->while_body()->Accept(&candidates_for_body));
292
293 return Status::OK();
294 }
295
296 std::unordered_map<const HloInstruction*, int64>* hlo_to_profile_idx_;
297 const std::unordered_map<const HloInstruction*, int64>& assigned_indices_;
298 };
299
300 } // namespace
301
RunHloPassesThroughLayoutAssn(HloModule * module,bool,LLVMTargetMachineFeatures * target_machine_features)302 Status CpuCompiler::RunHloPassesThroughLayoutAssn(
303 HloModule* module, bool /*is_aot_compile*/,
304 LLVMTargetMachineFeatures* target_machine_features) {
305 HloPassPipeline pipeline("HLO passes through layout assignment");
306 pipeline.AddInvariantChecker<HloVerifier>(/*layout_sensitive=*/false,
307 /*allow_mixed_precision=*/false);
308
309 pipeline.AddPass<OperandUpcaster>();
310 pipeline.AddPass<ResultCaster>();
311
312 // Expand random number generation.
313 pipeline.AddPass<RngExpander>();
314 pipeline.AddPass<RngBitGeneratorExpander>(RandomAlgorithm::RNG_PHILOX);
315
316 // Remove zero-sized HLO from the input so that other passes don't have to
317 // handle it.
318 pipeline.AddPass<ZeroSizedHloElimination>();
319
320 pipeline.AddPass<DynamicIndexSplitter>();
321
322 pipeline.AddPass<ConditionalToSelect>();
323 pipeline.AddPass<MapInliner>();
324
325 pipeline.AddPass<ComparisonExpander>();
326 pipeline.AddPass<CholeskyExpander>();
327 pipeline.AddPass<QrExpander>();
328 pipeline.AddPass<EighExpander>();
329 pipeline.AddPass<TriangularSolveExpander>();
330 pipeline.AddPass<AllGatherDecomposer>();
331 pipeline.AddPass<AllToAllDecomposer>();
332 pipeline.AddPass<ReduceScatterDecomposer>();
333
334 // Inline computations with a single call site.
335 pipeline.AddPass<CallInliner>(/*single_call_site=*/true);
336 pipeline.AddPass<BatchDotSimplification>();
337 pipeline.AddPass<DotDecomposer>();
338 // Convert BF16 operations to F32 operations so that the CPU backend can
339 // support BF16 operations without directly implementing a BF16 lowering for
340 // most ops.
341 BFloat16Support bf16;
342 pipeline.AddPass<BFloat16Normalization>(&bf16);
343 // After canonicalization, there may be more batch dots that can be
344 // simplified.
345 pipeline.AddPass<BatchDotSimplification>();
346 auto cost_model = [](HloInstruction* conv) {
347 // We need a cost model for CPUs. Currently, do nothing.
348 return false;
349 };
350 pipeline.AddPass<ConvolutionGroupConverter>(
351 cost_model,
352 /*convert_batch_groups_only=*/true);
353 pipeline.AddPass<ConvolutionGroupConverter>(
354 cost_model,
355 /*convert_batch_groups_only=*/false);
356 pipeline.AddPass<BatchNormExpander>(
357 /*rewrite_training_op=*/true,
358 /*rewrite_inference_op=*/true,
359 /*rewrite_grad_op=*/true);
360 pipeline.AddPass<LogisticExpander>(
361 /*expansion_type=*/LogisticExpansionType::kExp);
362 pipeline.AddPass<ConditionalCanonicalizer>();
363 pipeline.AddPass<DynamicPadder>();
364 pipeline.AddPass<ScatterExpander>(ScatterExpander::kEliminateAllScatters);
365 pipeline.AddPass<ConvCanonicalization>(target_machine_features);
366
367 // Run the following passes to a fixed point.
368 [&pipeline =
369 pipeline.AddPass<HloPassFix<HloPassPipeline>>("simplification")] {
370 pipeline.AddInvariantCheckerDebug<HloVerifier>(
371 /*layout_sensitive=*/false,
372 /*allow_mixed_precision=*/false);
373
374 pipeline.AddPass<TreeReductionRewriter>();
375 AlgebraicSimplifierOptions options;
376 options.set_enable_dot_strength_reduction(false);
377 pipeline.AddPass<AlgebraicSimplifier>(options);
378 pipeline.AddPass<SortSimplifier>();
379 pipeline.AddPass<HloDCE>();
380 pipeline.AddPass<GatherExpander>(GatherExpander::kEliminateSimpleGathers);
381
382 // BatchNormExpander can create zero-sized ops, so zero-sized HLO
383 // elimination has to come after that pass.
384 pipeline.AddPass<ZeroSizedHloElimination>();
385
386 pipeline.AddPass<WhileLoopInvariantCodeMotion>();
387 pipeline.AddPass<TupleSimplifier>();
388 pipeline.AddPass<WhileLoopConstantSinking>();
389 pipeline.AddPass<WhileLoopSimplifier>();
390
391 // TODO(b/134075051): Re-enable after b/134075051 is fixed.
392 // pipeline.AddPass<SliceSinker>();
393
394 pipeline.AddPass<HloDCE>();
395 pipeline.AddPass<ReshapeMover>();
396 pipeline.AddPass<HloConstantFolding>();
397 pipeline.AddPass<ConditionalSimplifier>();
398 }();
399
400 pipeline.AddPass<TopkRewriter>([](const HloSortInstruction* sort, int64_t) {
401 return sort->operand(0)->shape().element_type() == F32;
402 });
403 pipeline.AddPass<IndexedArrayAnalysisPrinterPass>();
404 pipeline.AddPass<TransposeFolding>(
405 [&](const HloInstruction& dot,
406 const TransposeFolding::OperandIndices& candidate_operands) {
407 return DotImplementationCanHandleTranspose(dot,
408 *target_machine_features)
409 ? candidate_operands
410 : TransposeFolding::OperandIndices{};
411 },
412 TransposeFolding::NeverFoldTranspose);
413 pipeline.AddPass<HloCSE>(/*is_layout_sensitive=*/false);
414
415 // Layout assignment uses alias analysis, which requires the call graph to be
416 // flattened.
417 pipeline.AddPass<FlattenCallGraph>();
418 pipeline.AddPass<CpuLayoutAssignment>(
419 module->mutable_entry_computation_layout(), target_machine_features);
420
421 pipeline.AddPass<CpuInstructionFusion>();
422
423 return pipeline.Run(module).status();
424 }
425
RunHloPassesAfterLayoutAssn(HloModule * module,bool is_aot_compile,LLVMTargetMachineFeatures * target_machine_features)426 Status CpuCompiler::RunHloPassesAfterLayoutAssn(
427 HloModule* module, bool is_aot_compile,
428 LLVMTargetMachineFeatures* target_machine_features) {
429 HloPassPipeline pipeline("HLO passes after layout assignment");
430 // After layout assignment, use a layout-sensitive verifier.
431
432 pipeline.AddPass<HloPassPipeline>("after layout assignment")
433 .AddInvariantCheckerDebug<HloVerifier>(
434 /*layout_sensitive=*/true,
435 /*allow_mixed_precision=*/false);
436
437 // The LayoutAssignment pass may leave behind kCopy instructions which are
438 // duplicate or NOPs, so remove them with algebraic simplification and CSE.
439 // Run this to a fixed point.
440 [&pipeline = pipeline.AddPass<HloPassFix<HloPassPipeline>>(
441 "simplification after layout assignment")] {
442 pipeline.AddInvariantCheckerDebug<HloVerifier>(
443 /*layout_sensitive=*/true,
444 /*allow_mixed_precision=*/false,
445 LayoutAssignment::InstructionCanChangeLayout);
446 AlgebraicSimplifierOptions options;
447 options.set_is_layout_sensitive(true);
448 options.set_enable_dot_strength_reduction(false);
449 pipeline.AddPass<AlgebraicSimplifier>(options);
450 pipeline.AddPass<HloDCE>();
451 pipeline.AddPass<HloCSE>(/*is_layout_sensitive=*/true);
452 }();
453
454 // Outline ops in the entry computation into calls to subcomputations.
455 const int max_parallelism =
456 module->config().intra_op_parallelism_threads() > 0
457 ? module->config().intra_op_parallelism_threads()
458 : tensorflow::port::NumSchedulableCPUs();
459 if (!is_aot_compile) {
460 // Run ParallelTaskAssigner to assign parallel tasks to HLOs in module.
461 // Note this is not run for AOT because it would bring in thread pool
462 // and thread synchronization dependencies which would likely increase
463 // binary size (and most AOT applications are single-threaded).
464 // TODO(b/29630486) Support multi-threaded AOT.
465 pipeline.AddPass<ParallelTaskAssigner>(
466 max_parallelism, ShapeSizeBytesFunction(), target_machine_features);
467 }
468 // Copy insertion should be performed immediately before IR emission to
469 // avoid inserting unnecessary copies (later pass adds an instruction which
470 // materializes the value) or missing a necessary copy (later pass removes
471 // an instruction which materializes a value). DCE must be run immediately
472 // before (and sometime after) copy insertion, to avoid dead code from
473 // interfering with the rewrites.
474 pipeline.AddPass<HloDCE>();
475 pipeline.AddPass<CopyInsertion>();
476 pipeline.AddPass<HloDCE>();
477 return pipeline.Run(module).status();
478 }
479
RunHloPasses(HloModule * module,bool is_aot_compile,llvm::TargetMachine * target_machine)480 Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile,
481 llvm::TargetMachine* target_machine) {
482 LLVMTargetMachineFeatures target_machine_features(target_machine);
483 TF_RETURN_IF_ERROR(RunHloPassesThroughLayoutAssn(module, is_aot_compile,
484 &target_machine_features));
485 return RunHloPassesAfterLayoutAssn(module, is_aot_compile,
486 &target_machine_features);
487 }
488
489 namespace {
490
491 // Align buffers to 16-byte boundaries.
memory_alignment(LogicalBuffer::Color)492 int64 memory_alignment(LogicalBuffer::Color) {
493 return cpu_function_runtime::kMinAlign;
494 }
495
CompilerTargetOptions(const HloModuleConfig & module_config)496 llvm::TargetOptions CompilerTargetOptions(
497 const HloModuleConfig& module_config) {
498 llvm::TargetOptions target_options;
499 // Always allow FMA fusion. This increases precision instead of decreasing it.
500 target_options.AllowFPOpFusion = llvm::FPOpFusion::Fast;
501 return target_options;
502 }
503
CodeGenOptLevel(const HloModuleConfig & module_config)504 llvm::CodeGenOpt::Level CodeGenOptLevel(const HloModuleConfig& module_config) {
505 VLOG(2) << "backend_optimization_level: "
506 << module_config.debug_options().xla_backend_optimization_level();
507 switch (module_config.debug_options().xla_backend_optimization_level()) {
508 case 1:
509 return llvm::CodeGenOpt::Less;
510 case 2:
511 return llvm::CodeGenOpt::Default;
512 case 3:
513 return llvm::CodeGenOpt::Aggressive;
514 default:
515 return llvm::CodeGenOpt::None;
516 }
517 }
518
GetIRModuleHooks(const HloModule & hlo_module,const LLVMCompiler::ModuleHook & user_pre_optimization_hook,const LLVMCompiler::ModuleHook & user_post_optimization_hook)519 std::pair<LLVMCompiler::ModuleHook, LLVMCompiler::ModuleHook> GetIRModuleHooks(
520 const HloModule& hlo_module,
521 const LLVMCompiler::ModuleHook& user_pre_optimization_hook,
522 const LLVMCompiler::ModuleHook& user_post_optimization_hook) {
523 // Create the IR hooks. If applicable, each IR hook does the following:
524 //
525 // * Calls the user supplied module hook.
526 // * Writes out the IR to a file in the output directory designated by
527 // --xla_dump_to
528 const HloModule* hlo_module_ptr = &hlo_module;
529 auto hook = [user_pre_optimization_hook, user_post_optimization_hook,
530 hlo_module_ptr](bool optimized,
531 const llvm::Module& llvm_module) {
532 const auto& user_hook =
533 !optimized ? user_pre_optimization_hook : user_post_optimization_hook;
534 if (user_hook) {
535 user_hook(llvm_module);
536 }
537 llvm_ir::DumpIrIfEnabled(*hlo_module_ptr, llvm_module, optimized);
538 };
539 return {[hook](const llvm::Module& llvm_module) {
540 return hook(/*optimized=*/false, llvm_module);
541 },
542 [hook](const llvm::Module& llvm_module) {
543 return hook(/*optimized=*/true, llvm_module);
544 }};
545 }
546
VerifyLlvmModule(const llvm::Module & llvm_module)547 Status VerifyLlvmModule(const llvm::Module& llvm_module) {
548 XLA_SCOPED_LOGGING_TIMER("CpuCompiler - Running LLVM verifier");
549
550 std::string err;
551 llvm::raw_string_ostream err_stream(err);
552
553 // verifyModule() returns true if the module is broken.
554 TF_RET_CHECK(!llvm::verifyModule(llvm_module, &err_stream))
555 << "Invalid LLVM IR before optimizations:\n"
556 << err_stream.str()
557 << "\nThis probably indicates a bug in the HLO -> LLVM IR lowering. "
558 "Rerun with --xla_dump_to to get the IR. ";
559 return Status::OK();
560 }
561
CreateHloProfilingArtifacts(const HloModule & module,std::unordered_map<const HloInstruction *,int64> * instruction_to_profile_idx,std::unordered_map<const HloComputation *,int64> * computation_to_profile_idx,std::unique_ptr<HloProfileIndexMap> * hlo_profile_index_map,std::unique_ptr<HloProfilePrinterData> * hlo_profile_printer_data)562 Status CreateHloProfilingArtifacts(
563 const HloModule& module,
564 std::unordered_map<const HloInstruction*, int64>*
565 instruction_to_profile_idx,
566 std::unordered_map<const HloComputation*, int64>*
567 computation_to_profile_idx,
568 std::unique_ptr<HloProfileIndexMap>* hlo_profile_index_map,
569 std::unique_ptr<HloProfilePrinterData>* hlo_profile_printer_data) {
570 *hlo_profile_index_map = absl::make_unique<HloProfileIndexMap>(module);
571 const HloComputation& entry_computation = *module.entry_computation();
572
573 TF_ASSIGN_OR_RETURN(
574 *instruction_to_profile_idx,
575 CollectProfileCandidates::GetCandidatesForComputation(
576 entry_computation,
577 (*hlo_profile_index_map)->instruction_to_profile_idx()));
578
579 auto shape_size_bytes = [](const Shape& shape) {
580 // On the cpu, opaques are pointers.
581 if (shape.IsOpaque()) {
582 return static_cast<int64>(sizeof(void*));
583 }
584 return ShapeUtil::ByteSizeOf(shape, sizeof(void*));
585 };
586
587 HloCostAnalysis cost_analysis(shape_size_bytes);
588 TF_RETURN_IF_ERROR(entry_computation.Accept(&cost_analysis));
589 *hlo_profile_printer_data = CreateHloProfilePrinterData(
590 **hlo_profile_index_map, cost_analysis, entry_computation.name());
591 *computation_to_profile_idx =
592 (*hlo_profile_index_map)->computation_to_profile_idx();
593
594 return Status::OK();
595 }
596
597 } // namespace
598
RunHloPasses(std::unique_ptr<HloModule> module,se::StreamExecutor *,const CompileOptions &)599 StatusOr<std::unique_ptr<HloModule>> CpuCompiler::RunHloPasses(
600 std::unique_ptr<HloModule> module, se::StreamExecutor* /*stream_exec*/,
601 const CompileOptions& /*options*/) {
602 std::unique_ptr<llvm::TargetMachine> jit_target_machine =
603 SimpleOrcJIT::InferTargetMachineForJIT(
604 CompilerTargetOptions(module->config()),
605 CodeGenOptLevel(module->config()));
606
607 TF_RETURN_IF_ERROR(RunHloPasses(module.get(), /*is_aot_compile=*/false,
608 jit_target_machine.get()));
609 return std::move(module);
610 }
611
612 StatusOr<
613 std::tuple<std::unique_ptr<HloModule>, std::unique_ptr<BufferAssignment>>>
RunHloPassesAndBufferAssignement(std::unique_ptr<HloModule> module,se::StreamExecutor * executor,bool optimize,const CompileOptions & options)614 CpuCompiler::RunHloPassesAndBufferAssignement(std::unique_ptr<HloModule> module,
615 se::StreamExecutor* executor,
616 bool optimize,
617 const CompileOptions& options) {
618 if (optimize) {
619 TF_ASSIGN_OR_RETURN(module,
620 RunHloPasses(std::move(module), executor, options));
621 }
622
623 // Select an order for emitting the HLO instructions for each computation.
624 // Using this sequence enables tighter buffer liveness analysis and reduced
625 // memory usage (as compared to using DependencyHloOrdering).
626 TF_ASSIGN_OR_RETURN(HloSchedule schedule,
627 ScheduleModule(module.get(), BufferSizeBytesFunction(),
628 ComputationSchedulerToModuleScheduler(
629 DFSMemoryScheduler)));
630
631 // Run buffer allocation on the HLO graph.
632 TF_ASSIGN_OR_RETURN(
633 std::unique_ptr<BufferAssignment> assignment,
634 BufferAssigner::Run(module.get(),
635 absl::make_unique<SequentialHloOrdering>(schedule),
636 BufferSizeBytesFunction(), memory_alignment,
637 /*allocate_buffers_for_constants=*/true));
638
639 return std::make_tuple(std::move(module), std::move(assignment));
640 }
641
642 namespace {
643
644 // Post-compilation callback functor for use by SimpleOrcJIT.
645 //
646 // Dumps machine code if dumping is enabled for the module.
647 struct OrcJITPostCompilationHook {
648 // Gets an std::function that implements this hook.
Createxla::cpu::__anond5cc82ac0e11::OrcJITPostCompilationHook649 static std::function<void(const llvm::object::ObjectFile& obj_file)> Create(
650 const HloModule* module) {
651 // This struct is not copyable, but std::functions must be. So to create an
652 // std::function out of this struct, we have to wrap it in a shared_ptr.
653 auto wrapped = std::make_shared<OrcJITPostCompilationHook>(module);
654 return [wrapped](const llvm::object::ObjectFile& obj_file) {
655 (*wrapped)(obj_file);
656 };
657 }
658
659 // Constructor can't be private because we want to call it from
660 // std::make_shared, but users should call Create() instead.
OrcJITPostCompilationHookxla::cpu::__anond5cc82ac0e11::OrcJITPostCompilationHook661 explicit OrcJITPostCompilationHook(const HloModule* module)
662 : module(module) {}
663
664 private:
operator ()xla::cpu::__anond5cc82ac0e11::OrcJITPostCompilationHook665 void operator()(const llvm::object::ObjectFile& obj_file) {
666 if (!DumpingEnabledForHloModule(*module)) {
667 return;
668 }
669 DumpToFileInDir(*module, /*file_prefix=*/"", /*file_suffix=*/"o",
670 absl::string_view(obj_file.getData().data(),
671 obj_file.getData().size()));
672 }
673
674 const HloModule* module;
675 };
676
677 } // namespace
678
RunBackend(std::unique_ptr<HloModule> module,se::StreamExecutor * stream_exec,const CompileOptions & options)679 StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend(
680 std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
681 const CompileOptions& options) {
682 VLOG(1) << "Compiling: " << module->name();
683 XLA_SCOPED_LOGGING_TIMER(
684 absl::StrFormat("Compiling [%s] for CPU using JIT", module->name()));
685 std::string slow_compilation_msg =
686 absl::StrCat("Compiling module ", module->name());
687 auto slow_compile_alarm = SlowCompilationAlarm(slow_compilation_msg);
688
689 absl::call_once(llvm_command_line_options_initialized,
690 &llvm_ir::InitializeLLVMCommandLineOptions, module->config());
691
692 ModuleHook pre_optimization_ir_hook;
693 ModuleHook post_optimization_ir_hook;
694 std::tie(pre_optimization_ir_hook, post_optimization_ir_hook) =
695 GetIRModuleHooks(*module, user_pre_optimization_hook_,
696 user_post_optimization_hook_);
697
698 // Compile must be thread-safe so create a new LLVM context for the module.
699 mlir::MLIRContext mlir_context;
700 LoadMLIRDialects(mlir_context);
701 auto llvm_context = std::make_unique<llvm::LLVMContext>();
702 auto llvm_module =
703 absl::make_unique<llvm::Module>("__compute_module", *llvm_context);
704
705 auto jit = SimpleOrcJIT::Create(
706 CompilerTargetOptions(module->config()),
707 CodeGenOptLevel(module->config()),
708 options::OptimizeForSizeRequested(module->config()),
709 module->config().debug_options().xla_llvm_disable_expensive_passes(),
710 llvm_ir::GetCpuFastMathFlags(module->config()), pre_optimization_ir_hook,
711 post_optimization_ir_hook,
712 OrcJITPostCompilationHook::Create(module.get()));
713 if (!jit) {
714 return InternalError("Creating JIT failed: %s",
715 llvm::toString(jit.takeError()));
716 }
717 llvm_module->setDataLayout((*jit)->data_layout());
718 llvm_module->setTargetTriple((*jit)->target_triple().getTriple());
719
720 HloComputation* entry_computation = module->entry_computation();
721 std::unordered_map<const HloInstruction*, int64> instruction_to_profile_idx;
722 std::unordered_map<const HloComputation*, int64> computation_to_profile_idx;
723 std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map;
724 std::unique_ptr<HloProfilePrinterData> hlo_profile_printer_data;
725 if (module->config().hlo_profiling_enabled()) {
726 TF_RETURN_IF_ERROR(CreateHloProfilingArtifacts(
727 *module, &instruction_to_profile_idx, &computation_to_profile_idx,
728 &hlo_profile_index_map, &hlo_profile_printer_data));
729 }
730
731 // Cache these flags here since we'll want to access them after the module's
732 // ownership is std::moved.
733 const bool embed_ir_in_executable =
734 module->config().debug_options().xla_embed_ir_in_executable();
735
736 // Select an order for emitting the HLO instructions for each
737 // computation. Using this sequence enables tighter buffer liveness analysis
738 // and reduced memory usage (as compared to using DependencyHloOrdering).
739 TF_ASSIGN_OR_RETURN(HloSchedule schedule,
740 ScheduleModule(module.get(), BufferSizeBytesFunction(),
741 ComputationSchedulerToModuleScheduler(
742 DFSMemoryScheduler)));
743
744 // Run buffer allocation on the HLO graph.
745 TF_ASSIGN_OR_RETURN(
746 std::unique_ptr<BufferAssignment> assignment,
747 BufferAssigner::Run(module.get(),
748 absl::make_unique<SequentialHloOrdering>(schedule),
749 BufferSizeBytesFunction(), memory_alignment,
750 /*allocate_buffers_for_constants=*/true));
751 DumpHloModuleIfEnabled(*module, *assignment, "cpu_after_optimizations");
752
753 // Each computation is a single function. Emit all embedded computations
754 // before the entry computation. The order of computations returned from
755 // GetEmbeddedComputations guarantees that a called computation occurs
756 // before a caller computation.
757
758 LLVMTargetMachineFeatures target_machine_features((*jit)->target_machine());
759 IrEmitter ir_emitter(&mlir_context, *module, *assignment, llvm_module.get(),
760 std::move(instruction_to_profile_idx),
761 std::move(computation_to_profile_idx),
762 &target_machine_features,
763 #ifdef MEMORY_SANITIZER
764 /*emit_code_for_msan=*/true
765 #else
766 /*emit_code_for_msan=*/false
767 #endif
768 );
769
770 TF_RETURN_IF_ERROR(ir_emitter.EmitConstantGlobals());
771
772 for (auto embedded_computation :
773 entry_computation->MakeEmbeddedComputationsList()) {
774 if (embedded_computation->IsFusionComputation()) {
775 continue;
776 }
777 TF_RETURN_IF_ERROR(
778 ir_emitter
779 .EmitComputation(
780 embedded_computation, embedded_computation->name(),
781 /*is_top_level_computation=*/false,
782 schedule.sequence(embedded_computation).instructions())
783 .status());
784 }
785 string function_name_prefix = entry_computation->name().empty()
786 ? "__compute"
787 : entry_computation->name();
788 TF_ASSIGN_OR_RETURN(llvm::Function * entry_function,
789 ir_emitter.EmitComputation(
790 entry_computation, function_name_prefix,
791 /*is_top_level_computation=*/true,
792 schedule.sequence(entry_computation).instructions()));
793
794 string function_name = [&]() {
795 llvm::SmallVector<char, 40> function_name_vector;
796 llvm::Mangler::getNameWithPrefix(
797 function_name_vector, entry_function->getName(), (*jit)->data_layout());
798 return string(function_name_vector.begin(), function_name_vector.end());
799 }();
800
801 string ir_module_string;
802 if (embed_ir_in_executable) {
803 ir_module_string = llvm_ir::DumpModuleToString(*llvm_module);
804 }
805
806 TF_RETURN_IF_ERROR(VerifyLlvmModule(*llvm_module));
807
808 // JIT compile the LLVM IR module to in-memory machine code.
809 llvm::orc::ThreadSafeModule thread_safe_module(std::move(llvm_module),
810 std::move(llvm_context));
811 cantFail((*jit)->AddModule(std::move(thread_safe_module)));
812
813 auto cpu_executable = absl::make_unique<CpuExecutable>(
814 std::move(*jit), std::move(assignment), std::move(module), function_name,
815 std::move(hlo_profile_printer_data), std::move(hlo_profile_index_map));
816
817 if (embed_ir_in_executable) {
818 cpu_executable->set_ir_module_string(ir_module_string);
819 }
820
821 // Dump computation proto state and buffer assignment for debug and test, if
822 // dump is enabled.
823 if (DumpingEnabledForHloModule(cpu_executable->module())) {
824 auto hlo_proto = absl::make_unique<HloProto>();
825 *hlo_proto->mutable_hlo_module() = cpu_executable->module().ToProto();
826 *hlo_proto->mutable_buffer_assignment() =
827 cpu_executable->buffer_assignment().ToProto();
828 cpu_executable->set_hlo_proto(std::move(hlo_proto));
829 }
830
831 cpu_executable->set_debug_info(
832 cpu_executable->buffer_assignment().GetStats().ToString());
833 VLOG(1) << "Compilation finished";
834 return std::unique_ptr<Executable>(std::move(cpu_executable));
835 }
836
837 StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group,const AotCompilationOptions & aot_options)838 CpuCompiler::CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group,
839 const AotCompilationOptions& aot_options) {
840 TF_RET_CHECK(!module_group->empty());
841 std::vector<std::unique_ptr<HloModule>> modules =
842 module_group->ConsumeModules();
843
844 absl::call_once(llvm_command_line_options_initialized,
845 &llvm_ir::InitializeLLVMCommandLineOptions,
846 modules[0]->config());
847
848 // We can pass just one llvm::TargetOptions when we compile the LLVM module,
849 // so we bail if the configs have conflicting flags. At the moment, the only
850 // flags that need to be consistent are for fast-math.
851 for (const auto& fn_and_name :
852 {std::make_pair(&DebugOptions::xla_cpu_enable_fast_math,
853 "xla_cpu_enable_fast_math"),
854 std::make_pair(&DebugOptions::xla_cpu_fast_math_honor_infs,
855 "xla_cpu_fast_math_honor_infs"),
856 std::make_pair(&DebugOptions::xla_cpu_fast_math_honor_nans,
857 "xla_cpu_fast_math_honor_nans")}) {
858 // This only works because each of the method pointers above returns a bool.
859 // Otherwise we'd have to do some template magic.
860 const auto& field_method_ptr = fn_and_name.first;
861 const auto& field_name = fn_and_name.second;
862 bool first_module_val =
863 (modules[0]->config().debug_options().*field_method_ptr)();
864 for (int64_t i = 0; i < modules.size(); ++i) {
865 bool cur_module_val =
866 (modules[i]->config().debug_options().*field_method_ptr)();
867 if (first_module_val != cur_module_val) {
868 return InvalidArgument(
869 "All HLO module configs must have the same value for %s, but "
870 "module 0 and %d have different values (%d vs %d).",
871 field_name, i, first_module_val, cur_module_val);
872 }
873 }
874 }
875
876 if (aot_options.PlatformId() != se::host::kHostPlatformId) {
877 return InvalidArgument("Incompatible AOT compilation platform");
878 }
879 const CpuAotCompilationOptions& options =
880 static_cast<const CpuAotCompilationOptions&>(aot_options);
881 llvm::Triple triple(llvm::Triple::normalize(options.triple()));
882 std::string error;
883 const llvm::Target* target =
884 llvm::TargetRegistry::lookupTarget(triple.getTriple(), error);
885 if (target == nullptr) {
886 return InternalError("TargetRegistry::lookupTarget failed: %s", error);
887 }
888
889 llvm::Reloc::Model reloc_model = llvm::Reloc::Static;
890 llvm::PICLevel::Level pic_level = llvm::PICLevel::NotPIC;
891 llvm::PIELevel::Level pie_level = llvm::PIELevel::Default;
892 switch (options.relocation_model()) {
893 case CpuAotCompilationOptions::RelocationModel::Static:
894 reloc_model = llvm::Reloc::Static;
895 pic_level = llvm::PICLevel::NotPIC;
896 pie_level = llvm::PIELevel::Default;
897 break;
898 case CpuAotCompilationOptions::RelocationModel::SmallPic:
899 reloc_model = llvm::Reloc::PIC_;
900 pic_level = llvm::PICLevel::SmallPIC;
901 pie_level = llvm::PIELevel::Default;
902 break;
903 case CpuAotCompilationOptions::RelocationModel::BigPic:
904 reloc_model = llvm::Reloc::PIC_;
905 pic_level = llvm::PICLevel::BigPIC;
906 pie_level = llvm::PIELevel::Default;
907 break;
908 case CpuAotCompilationOptions::RelocationModel::SmallPie:
909 reloc_model = llvm::Reloc::PIC_;
910 pic_level = llvm::PICLevel::SmallPIC;
911 pie_level = llvm::PIELevel::Small;
912 break;
913 case CpuAotCompilationOptions::RelocationModel::BigPie:
914 reloc_model = llvm::Reloc::PIC_;
915 pic_level = llvm::PICLevel::BigPIC;
916 pie_level = llvm::PIELevel::Large;
917 break;
918 }
919 llvm::CodeGenOpt::Level opt_level = CodeGenOptLevel(modules[0]->config());
920 std::unique_ptr<llvm::TargetMachine> target_machine =
921 absl::WrapUnique(target->createTargetMachine(
922 triple.getTriple(), options.cpu_name(), options.features(),
923 CompilerTargetOptions(modules[0]->config()), reloc_model, llvm::None,
924 opt_level));
925
926 // Compile must be thread-safe so create a new LLVM context for the module.
927 mlir::MLIRContext mlir_context;
928 LoadMLIRDialects(mlir_context);
929 llvm::LLVMContext llvm_context;
930 llvm::Module llvm_module("__compute_module", llvm_context);
931 llvm_module.setDataLayout(target_machine->createDataLayout());
932 llvm_module.setTargetTriple(triple.getTriple());
933 if (pic_level != llvm::PICLevel::NotPIC) {
934 llvm_module.setPICLevel(pic_level);
935 }
936 if (pie_level != llvm::PIELevel::Default) {
937 llvm_module.setPIELevel(pie_level);
938 }
939
940 std::vector<std::unique_ptr<AotCompilationResult>> results;
941 for (size_t i = 0; i < modules.size(); ++i) {
942 HloModule* module = modules[i].get();
943 VLOG(1) << "Compiling ahead-of-time: " << module->name();
944
945 TF_RETURN_IF_ERROR(
946 RunHloPasses(module, /*is_aot_compile=*/true, target_machine.get()));
947
948 TF_ASSIGN_OR_RETURN(HloSchedule schedule,
949 ScheduleModule(module, BufferSizeBytesFunction()));
950
951 // Run buffer analysis on the HLO graph. This analysis figures out which
952 // temporary buffers are required to run the computation.
953 TF_ASSIGN_OR_RETURN(
954 std::unique_ptr<BufferAssignment> assignment,
955 BufferAssigner::Run(module,
956 absl::make_unique<SequentialHloOrdering>(schedule),
957 BufferSizeBytesFunction(), memory_alignment,
958 /*allocate_buffers_for_constants=*/true));
959 // BufferAssignment::ToString() includes a header, so no need for us to
960 // print one ourselves.
961 if (DumpingEnabledForHloModule(*module)) {
962 DumpToFileInDirOrStdout(*module, "", "buffer_assignment",
963 assignment->ToString());
964 }
965 DumpHloModuleIfEnabled(*module, *assignment, "cpu_after_optimizations");
966
967 std::unordered_map<const HloInstruction*, int64> instruction_to_profile_idx;
968 std::unordered_map<const HloComputation*, int64> computation_to_profile_idx;
969 std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map;
970 std::unique_ptr<HloProfilePrinterData> hlo_profile_printer_data;
971
972 if (module->config().hlo_profiling_enabled()) {
973 TF_RETURN_IF_ERROR(CreateHloProfilingArtifacts(
974 *module, &instruction_to_profile_idx, &computation_to_profile_idx,
975 &hlo_profile_index_map, &hlo_profile_printer_data));
976 }
977
978 LLVMTargetMachineFeatures target_machine_features(target_machine.get());
979 IrEmitter ir_emitter(&mlir_context, *module, *assignment, &llvm_module,
980 std::move(instruction_to_profile_idx),
981 std::move(computation_to_profile_idx),
982 &target_machine_features,
983 // TODO(b/66051036): Run full msan for AOT.
984 /*emit_code_for_msan=*/false);
985
986 TF_RETURN_IF_ERROR(ir_emitter.EmitConstantGlobals());
987
988 HloComputation* computation = module->entry_computation();
989 for (auto embedded_computation :
990 computation->MakeEmbeddedComputationsList()) {
991 if (embedded_computation->IsFusionComputation()) {
992 continue;
993 }
994 TF_RETURN_IF_ERROR(
995 ir_emitter
996 .EmitComputation(
997 embedded_computation, embedded_computation->name(),
998 /*is_top_level_computation=*/false,
999 schedule.sequence(embedded_computation).instructions())
1000 .status());
1001 }
1002 const string& entry_point_name = options.entry_point_name();
1003 TF_ASSIGN_OR_RETURN(llvm::Function * entry_function,
1004 ir_emitter.EmitComputation(
1005 computation, entry_point_name,
1006 /*is_top_level_computation=*/true,
1007 schedule.sequence(computation).instructions()));
1008
1009 CHECK(entry_function->getName() == entry_point_name);
1010
1011 ModuleHook pre_optimization_ir_hook;
1012 ModuleHook post_optimization_ir_hook;
1013 std::tie(pre_optimization_ir_hook, post_optimization_ir_hook) =
1014 GetIRModuleHooks(*module, user_pre_optimization_hook_,
1015 user_post_optimization_hook_);
1016
1017 // Run the LLVM verifier over the unoptimized LLVM IR. If it fails, run the
1018 // pre-optimization IR dump hook before returning.
1019 {
1020 Status verify_status = VerifyLlvmModule(llvm_module);
1021 if (!verify_status.ok() && pre_optimization_ir_hook) {
1022 pre_optimization_ir_hook(llvm_module);
1023 }
1024 TF_RETURN_IF_ERROR(verify_status);
1025 }
1026
1027 auto post_codegen_hook = [&](const llvm::object::ObjectFile& obj_file) {
1028 if (!DumpingEnabledForHloModule(*module)) {
1029 return;
1030 }
1031 DumpToFileInDir(*module, /*file_prefix=*/"", /*file_suffix=*/"o",
1032 absl::string_view(obj_file.getData().data(),
1033 obj_file.getData().size()));
1034 };
1035
1036 CompilerFunctor compiler_functor(
1037 target_machine.get(), opt_level,
1038 options::OptimizeForSizeRequested(module->config()),
1039 module->config().debug_options().xla_llvm_disable_expensive_passes(),
1040 llvm_ir::GetCpuFastMathFlags(module->config()),
1041 pre_optimization_ir_hook, post_optimization_ir_hook, post_codegen_hook);
1042 std::unique_ptr<llvm::MemoryBuffer> object_file =
1043 cantFail(compiler_functor(llvm_module));
1044 ObjectFileData object_file_data(object_file->getBufferStart(),
1045 object_file->getBufferEnd());
1046
1047 std::vector<BufferInfo> buffer_infos =
1048 CreateBufferInfosFromBufferAssignment(*assignment);
1049
1050 TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice result_slice,
1051 assignment->GetUniqueTopLevelOutputSlice());
1052
1053 results.emplace_back(absl::make_unique<CpuAotCompilationResult>(
1054 std::move(object_file_data), std::move(buffer_infos),
1055 result_slice.index(), std::move(hlo_profile_printer_data)));
1056 }
1057
1058 VLOG(1) << "Compilation finished";
1059 return std::move(results);
1060 }
1061
PlatformId() const1062 se::Platform::Id CpuCompiler::PlatformId() const {
1063 return se::host::kHostPlatformId;
1064 }
1065
ShapeSizeBytesFunction() const1066 HloCostAnalysis::ShapeSizeFunction CpuCompiler::ShapeSizeBytesFunction() const {
1067 return CpuExecutable::ShapeSizeBytes;
1068 }
1069
1070 } // namespace cpu
1071 } // namespace xla
1072
InitModule()1073 static bool InitModule() {
1074 xla::Compiler::RegisterCompilerFactory(
1075 stream_executor::host::kHostPlatformId,
1076 []() { return absl::make_unique<xla::cpu::CpuCompiler>(); });
1077 return true;
1078 }
1079 static bool module_initialized = InitModule();
1080