1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h"
17
18 #include <stddef.h>
19 #include <string.h>
20 #include <map>
21 #include <mutex> // NOLINT(build/c++11): only using std::call_once, not mutex.
22 #include <string>
23 #include <unordered_map>
24 #include <utility>
25 #include <vector>
26
27 // IWYU pragma: no_include "llvm/Config/Disassemblers.def.inc"
28 // IWYU pragma: no_include "llvm/Config/Targets.def.inc"
29 #include "absl/memory/memory.h"
30 #include "absl/strings/str_cat.h"
31 #include "llvm/ADT/StringRef.h"
32 #include "llvm/ADT/Triple.h"
33 #include "llvm/IR/Function.h"
34 #include "llvm/IR/LLVMContext.h"
35 #include "llvm/IR/Mangler.h"
36 #include "llvm/IR/Module.h"
37 #include "llvm/IR/Verifier.h"
38 #include "llvm/Object/ObjectFile.h"
39 #include "llvm/Support/CommandLine.h"
40 #include "llvm/Support/TargetRegistry.h"
41 #include "llvm/Support/TargetSelect.h"
42 #include "llvm/Target/TargetMachine.h"
43 #include "llvm/Target/TargetOptions.h"
44 #include "tensorflow/compiler/xla/literal.h"
45 #include "tensorflow/compiler/xla/map_util.h"
46 #include "tensorflow/compiler/xla/protobuf_util.h"
47 #include "tensorflow/compiler/xla/service/algebraic_simplifier.h"
48 #include "tensorflow/compiler/xla/service/batch_dot_simplification.h"
49 #include "tensorflow/compiler/xla/service/batchnorm_expander.h"
50 #include "tensorflow/compiler/xla/service/buffer_assignment.h"
51 #include "tensorflow/compiler/xla/service/buffer_liveness.h"
52 #include "tensorflow/compiler/xla/service/call_inliner.h"
53 #include "tensorflow/compiler/xla/service/cholesky_expander.h"
54 #include "tensorflow/compiler/xla/service/conditional_simplifier.h"
55 #include "tensorflow/compiler/xla/service/convolution_group_converter.h"
56 #include "tensorflow/compiler/xla/service/cpu/buffer_info_util.h"
57 #include "tensorflow/compiler/xla/service/cpu/compiler_functor.h"
58 #include "tensorflow/compiler/xla/service/cpu/conv_canonicalization.h"
59 #include "tensorflow/compiler/xla/service/cpu/cpu_copy_insertion.h"
60 #include "tensorflow/compiler/xla/service/cpu/cpu_executable.h"
61 #include "tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h"
62 #include "tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.h"
63 #include "tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.h"
64 #include "tensorflow/compiler/xla/service/cpu/cpu_options.h"
65 #include "tensorflow/compiler/xla/service/cpu/disassembler.h"
66 #include "tensorflow/compiler/xla/service/cpu/dot_op_emitter.h"
67 #include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h"
68 #include "tensorflow/compiler/xla/service/cpu/ir_emitter.h"
69 #include "tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h"
70 #include "tensorflow/compiler/xla/service/cpu/simple_orc_jit.h"
71 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
72 #include "tensorflow/compiler/xla/service/dot_decomposer.h"
73 #include "tensorflow/compiler/xla/service/dump.h"
74 #include "tensorflow/compiler/xla/service/dynamic_index_splitter.h"
75 #include "tensorflow/compiler/xla/service/flatten_call_graph.h"
76 #include "tensorflow/compiler/xla/service/hlo.pb.h"
77 #include "tensorflow/compiler/xla/service/hlo_computation.h"
78 #include "tensorflow/compiler/xla/service/hlo_constant_folding.h"
79 #include "tensorflow/compiler/xla/service/hlo_cse.h"
80 #include "tensorflow/compiler/xla/service/hlo_dce.h"
81 #include "tensorflow/compiler/xla/service/hlo_element_type_converter.h"
82 #include "tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter.h"
83 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
84 #include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h"
85 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
86 #include "tensorflow/compiler/xla/service/hlo_ordering.h"
87 #include "tensorflow/compiler/xla/service/hlo_pass_fix.h"
88 #include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h"
89 #include "tensorflow/compiler/xla/service/hlo_proto_util.h"
90 #include "tensorflow/compiler/xla/service/hlo_subcomputation_unification.h"
91 #include "tensorflow/compiler/xla/service/hlo_verifier.h"
92 #include "tensorflow/compiler/xla/service/indexed_array_analysis.h"
93 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
94 #include "tensorflow/compiler/xla/service/map_inliner.h"
95 #include "tensorflow/compiler/xla/service/reduce_precision_insertion.h"
96 #include "tensorflow/compiler/xla/service/reshape_mover.h"
97 #include "tensorflow/compiler/xla/service/scatter_expander.h"
98 #include "tensorflow/compiler/xla/service/sort_simplifier.h"
99 #include "tensorflow/compiler/xla/service/transpose_folding.h"
100 #include "tensorflow/compiler/xla/service/triangular_solve_expander.h"
101 #include "tensorflow/compiler/xla/service/tuple_simplifier.h"
102 #include "tensorflow/compiler/xla/service/while_loop_constant_sinking.h"
103 #include "tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h"
104 #include "tensorflow/compiler/xla/service/while_loop_simplifier.h"
105 #include "tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h"
106 #include "tensorflow/compiler/xla/status_macros.h"
107 #include "tensorflow/compiler/xla/statusor.h"
108 #include "tensorflow/compiler/xla/types.h"
109 #include "tensorflow/compiler/xla/util.h"
110 #include "tensorflow/compiler/xla/xla_data.pb.h"
111 #include "tensorflow/core/platform/dynamic_annotations.h"
112
113 namespace xla {
114 namespace cpu {
115 using BufferInfo = ::tensorflow::cpu_function_runtime::BufferInfo;
116
CpuAotCompilationOptions(string triple,string cpu_name,string features,string entry_point_name,RelocationModel relocation_model)117 CpuAotCompilationOptions::CpuAotCompilationOptions(
118 string triple, string cpu_name, string features, string entry_point_name,
119 RelocationModel relocation_model)
120 : triple_(std::move(triple)),
121 cpu_name_(std::move(cpu_name)),
122 features_(std::move(features)),
123 entry_point_name_(std::move(entry_point_name)),
124 relocation_model_(relocation_model) {}
125
126 CpuAotCompilationOptions::~CpuAotCompilationOptions() = default;
127
PlatformId() const128 se::Platform::Id CpuAotCompilationOptions::PlatformId() const {
129 return se::host::kHostPlatformId;
130 }
131
CpuAotCompilationResult(ObjectFileData object_file_data,std::vector<BufferInfo> buffer_infos,int64 result_buffer_index,std::unique_ptr<HloProfilePrinterData> hlo_profile_printer_data)132 CpuAotCompilationResult::CpuAotCompilationResult(
133 ObjectFileData object_file_data, std::vector<BufferInfo> buffer_infos,
134 int64 result_buffer_index,
135 std::unique_ptr<HloProfilePrinterData> hlo_profile_printer_data)
136 : object_file_data_(std::move(object_file_data)),
137 buffer_infos_(std::move(buffer_infos)),
138 result_buffer_index_(result_buffer_index),
139 hlo_profile_printer_data_(std::move(hlo_profile_printer_data)) {}
140
141 CpuAotCompilationResult::~CpuAotCompilationResult() = default;
142
CpuCompiler()143 CpuCompiler::CpuCompiler() {
144 // Initialize LLVM the first time the CpuCompiler is initialized.
145 static bool llvm_initialized = []() {
146 InitializeLLVMTarget();
147 return true;
148 }();
149 (void)llvm_initialized;
150 }
151
InitializeLLVMTarget()152 /* static */ void CpuCompiler::InitializeLLVMTarget() {
153 // Initialize LLVM's MC layer for the native target.
154 llvm::InitializeNativeTarget();
155 llvm::InitializeNativeTargetAsmPrinter();
156 LLVMInitializeX86Target();
157 LLVMInitializeX86TargetInfo();
158 LLVMInitializeX86TargetMC();
159 LLVMInitializeX86AsmPrinter();
160 LLVMInitializeX86Disassembler();
161 LLVMInitializeARMTarget();
162 LLVMInitializeARMTargetInfo();
163 LLVMInitializeARMTargetMC();
164 LLVMInitializeARMAsmPrinter();
165 LLVMInitializeARMDisassembler();
166 LLVMInitializeAArch64Target();
167 LLVMInitializeAArch64TargetInfo();
168 LLVMInitializeAArch64TargetMC();
169 LLVMInitializeAArch64AsmPrinter();
170 LLVMInitializeAArch64Disassembler();
171 }
172
173 namespace {
174
175 // LLVM makes certain options configurable only through its command-line
176 // options; it provide the ParseCommandLineOptions function that lets us set
177 // flags at runtime. However, since these flags are global we want to avoid
178 // multiple invocations of the LLVM compilation pipeline with a different set of
179 // flags. Therefore, we only pass command-line flags to LLVM once, before the
180 // first module is compiled.
181 std::once_flag llvm_command_line_options_initialized;
182
183 // This visitor records which HLO instructions should have profiling information
184 // recorded.
185 class CollectProfileCandidates : public DfsHloVisitorWithDefault {
186 public:
187 static StatusOr<std::unordered_map<const HloInstruction*, int64>>
GetCandidatesForComputation(const HloComputation & computation,const std::unordered_map<const HloInstruction *,int64> & assigned_indices)188 GetCandidatesForComputation(
189 const HloComputation& computation,
190 const std::unordered_map<const HloInstruction*, int64>&
191 assigned_indices) {
192 std::unordered_map<const HloInstruction*, int64> hlo_to_profile_idx;
193 CollectProfileCandidates profile_candidates_for_computation(
194 &hlo_to_profile_idx, assigned_indices);
195 TF_RETURN_IF_ERROR(computation.Accept(&profile_candidates_for_computation));
196 return hlo_to_profile_idx;
197 }
198
199 private:
CollectProfileCandidates(std::unordered_map<const HloInstruction *,int64> * hlo_to_profile_idx,const std::unordered_map<const HloInstruction *,int64> & assigned_indices)200 CollectProfileCandidates(
201 std::unordered_map<const HloInstruction*, int64>* hlo_to_profile_idx,
202 const std::unordered_map<const HloInstruction*, int64>& assigned_indices)
203 : hlo_to_profile_idx_(hlo_to_profile_idx),
204 assigned_indices_(assigned_indices) {}
205
DefaultAction(HloInstruction * hlo_instruction)206 Status DefaultAction(HloInstruction* hlo_instruction) override {
207 hlo_to_profile_idx_->insert(
208 {hlo_instruction, FindOrDie(assigned_indices_, hlo_instruction)});
209 return Status::OK();
210 }
211
HandleCall(HloInstruction * call)212 Status HandleCall(HloInstruction* call) override {
213 TF_RETURN_IF_ERROR(DefaultAction(call));
214 CollectProfileCandidates candidates_for_call(hlo_to_profile_idx_,
215 assigned_indices_);
216 TF_RETURN_IF_ERROR(call->to_apply()->Accept(&candidates_for_call));
217 return Status::OK();
218 }
219
220 // Skip constants, there is nothing to profile.
HandleConstant(HloInstruction *)221 Status HandleConstant(HloInstruction*) override { return Status::OK(); }
222 // Skip parameters, they are a simple load.
HandleParameter(HloInstruction *)223 Status HandleParameter(HloInstruction*) override { return Status::OK(); }
224 // It is important to recurse for "while" or else we risk overly coarse
225 // profiling information.
HandleWhile(HloInstruction * xla_while)226 Status HandleWhile(HloInstruction* xla_while) override {
227 TF_RETURN_IF_ERROR(DefaultAction(xla_while));
228
229 CollectProfileCandidates candidates_for_condition(hlo_to_profile_idx_,
230 assigned_indices_);
231 TF_RETURN_IF_ERROR(
232 xla_while->while_condition()->Accept(&candidates_for_condition));
233
234 CollectProfileCandidates candidates_for_body(hlo_to_profile_idx_,
235 assigned_indices_);
236 TF_RETURN_IF_ERROR(xla_while->while_body()->Accept(&candidates_for_body));
237
238 return Status::OK();
239 }
240
241 std::unordered_map<const HloInstruction*, int64>* hlo_to_profile_idx_;
242 const std::unordered_map<const HloInstruction*, int64>& assigned_indices_;
243 };
244
245 } // namespace
246
RunHloPassesThroughLayoutAssn(HloModule * module,bool,LLVMTargetMachineFeatures * target_machine_features)247 Status CpuCompiler::RunHloPassesThroughLayoutAssn(
248 HloModule* module, bool /*is_aot_compile*/,
249 LLVMTargetMachineFeatures* target_machine_features) {
250 HloPassPipeline pipeline("HLO passes through layout assignment");
251 pipeline.AddInvariantChecker<HloVerifier>(/*layout_sensitive=*/false,
252 /*allow_mixed_precision=*/false);
253 pipeline.AddPass<DynamicIndexSplitter>();
254 pipeline.AddPass<CpuHloSupportChecker>();
255
256 ReducePrecisionInsertion::AddPasses(
257 &pipeline, module->config().debug_options(),
258 ReducePrecisionInsertion::PassTiming::BEFORE_OPTIMIZATION);
259
260 pipeline.AddPass<MapInliner>();
261
262 pipeline.AddPass<CholeskyExpander>();
263 pipeline.AddPass<TriangularSolveExpander>();
264
265 // TODO(b/65775800): Fix wrong output bug in Call and remove the CallInliner
266 // pass.
267 pipeline.AddPass<CallInliner>();
268 pipeline.AddPass<BatchDotSimplification>();
269 pipeline.AddPass<DotDecomposer>(/*decompose_batch_dot=*/false);
270 auto cost_model = [](HloInstruction* conv) {
271 // We need a cost model for CPUs. Currently, do nothing.
272 return false;
273 };
274 pipeline.AddPass<ConvolutionGroupConverter>(
275 cost_model,
276 /*convert_batch_groups_only=*/true);
277 pipeline.AddPass<ConvolutionGroupConverter>(
278 cost_model,
279 /*convert_batch_groups_only=*/false);
280 pipeline.AddPass<ConvCanonicalization>(target_machine_features);
281 {
282 auto& pass =
283 pipeline.AddPass<HloPassFix<HloPassPipeline>>("simplification");
284 pass.AddInvariantChecker<HloVerifier>(/*layout_sensitive=*/false,
285 /*allow_mixed_precision=*/false);
286
287 pass.AddPass<BatchNormExpander>(
288 /*rewrite_training_op=*/true,
289 /*rewrite_inference_op=*/true,
290 /*rewrite_grad_op=*/true);
291 pipeline.AddPass<HloGetDimensionSizeRewriter>();
292 AlgebraicSimplifierOptions options;
293 options.set_enable_dot_strength_reduction(false);
294 pass.AddPass<AlgebraicSimplifier>(options);
295 pass.AddPass<SortSimplifier>();
296 pass.AddPass<HloDCE>();
297
298 // BatchNormExpander can create zero-sized ops, so zero-sized HLO
299 // elimination has to come after that pass.
300 pass.AddPass<ZeroSizedHloElimination>();
301
302 pass.AddPass<WhileLoopInvariantCodeMotion>();
303 pass.AddPass<TupleSimplifier>();
304 pass.AddPass<WhileLoopConstantSinking>();
305 pass.AddPass<WhileLoopSimplifier>();
306 pass.AddPass<HloDCE>();
307 pass.AddPass<ReshapeMover>();
308 pass.AddPass<HloConstantFolding>();
309 pass.AddPass<ConditionalSimplifier>();
310 }
311 pipeline.AddPass<IndexedArrayAnalysisPrinterPass>();
312 pipeline.AddPass<TransposeFolding>(
313 [&](const HloInstruction& dot,
314 const TransposeFolding::OperandIndices& candidate_operands) {
315 return DotImplementationCanHandleTranspose(dot,
316 *target_machine_features)
317 ? candidate_operands
318 : TransposeFolding::OperandIndices{};
319 },
320 TransposeFolding::NeverFoldTranspose);
321 pipeline.AddPass<HloCSE>(/*is_layout_sensitive=*/false);
322
323 pipeline.AddPass<CpuLayoutAssignment>(
324 module->mutable_entry_computation_layout(),
325 LayoutAssignment::InstructionCanChangeLayout, target_machine_features);
326
327 pipeline.AddPass<CpuInstructionFusion>();
328
329 pipeline.AddPass<ScatterExpander>();
330
331 ReducePrecisionInsertion::AddPasses(
332 &pipeline, module->config().debug_options(),
333 ReducePrecisionInsertion::PassTiming::AFTER_FUSION);
334 return pipeline.Run(module).status();
335 }
336
RunHloPassesAfterLayoutAssn(HloModule * module,bool is_aot_compile,LLVMTargetMachineFeatures * target_machine_features)337 Status CpuCompiler::RunHloPassesAfterLayoutAssn(
338 HloModule* module, bool is_aot_compile,
339 LLVMTargetMachineFeatures* target_machine_features) {
340 HloPassPipeline pipeline("HLO passes after layout assignment");
341 // After layout assignment, use a layout-sensitive verifier.
342 auto& after_layout_assn =
343 pipeline.AddPass<HloPassPipeline>("after layout assignment");
344 after_layout_assn.AddInvariantChecker<HloVerifier>(
345 /*layout_sensitive=*/true,
346 /*allow_mixed_precision=*/false);
347
348 // The LayoutAssignment pass may leave behind kCopy instructions which are
349 // duplicate or NOPs, so remove them with algebraic simplification and CSE.
350 {
351 auto& pass = pipeline.AddPass<HloPassFix<HloPassPipeline>>(
352 "simplification after layout assignement");
353 pass.AddInvariantChecker<HloVerifier>(
354 /*layout_sensitive=*/true,
355 /*allow_mixed_precision=*/false,
356 LayoutAssignment::InstructionCanChangeLayout);
357 AlgebraicSimplifierOptions options;
358 options.set_is_layout_sensitive(true);
359 options.set_enable_dot_strength_reduction(false);
360 pass.AddPass<HloPassFix<AlgebraicSimplifier>>(options);
361 pass.AddPass<HloDCE>();
362 pass.AddPass<HloCSE>(/*is_layout_sensitive=*/true);
363 }
364
365 pipeline.AddPass<HloElementTypeConverter>(BF16, F32);
366
367 // Outline ops in the entry computation into calls to subcomputations.
368 const int max_parallelism =
369 module->config().intra_op_parallelism_threads() > 0
370 ? module->config().intra_op_parallelism_threads()
371 : tensorflow::port::NumSchedulableCPUs();
372 if (!is_aot_compile) {
373 // Run ParallelTaskAssigner to assign parallel tasks to HLOs in module.
374 // Note this is not run for AOT because it would bring in thread pool
375 // and thread synchronization dependencies which would likely increase
376 // binary size (and most AOT applications are single-threaded).
377 // TODO(b/29630486) Support multi-threaded AOT.
378 pipeline.AddPass<ParallelTaskAssigner>(
379 max_parallelism, ShapeSizeBytesFunction(), target_machine_features);
380 }
381 // Copy insertion should be performed immediately before IR emission to
382 // avoid inserting unnecessary copies (later pass adds an instruction which
383 // materializes the value) or missing a necessary copy (later pass removes
384 // an instruction which materializes a value). DCE must be run immediately
385 // before (and sometime after) copy insertion, to avoid dead code from
386 // interfering with the rewrites.
387 pipeline.AddPass<HloDCE>();
388 pipeline.AddPass<FlattenCallGraph>();
389 pipeline.AddPass<CpuCopyInsertion>();
390 pipeline.AddPass<HloDCE>();
391 return pipeline.Run(module).status();
392 }
393
RunHloPasses(HloModule * module,bool is_aot_compile,llvm::TargetMachine * target_machine)394 Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile,
395 llvm::TargetMachine* target_machine) {
396 LLVMTargetMachineFeatures target_machine_features(target_machine);
397 TF_RETURN_IF_ERROR(RunHloPassesThroughLayoutAssn(module, is_aot_compile,
398 &target_machine_features));
399 return RunHloPassesAfterLayoutAssn(module, is_aot_compile,
400 &target_machine_features);
401 }
402
403 namespace {
404
405 // Align buffers to 16-byte boundaries.
406 constexpr int64 kMemoryAlignment = 16;
__anonaf93c4290602(LogicalBuffer::Color) 407 auto memory_alignment = [](LogicalBuffer::Color) { return kMemoryAlignment; };
408
CompilerTargetOptions(const HloModuleConfig & module_config)409 llvm::TargetOptions CompilerTargetOptions(
410 const HloModuleConfig& module_config) {
411 llvm::TargetOptions target_options;
412 // In LLVM backend flags, UnsafeFPMath does not explicitly imply NoInfs, etc.
413 if (module_config.debug_options().xla_cpu_enable_fast_math()) {
414 target_options.UnsafeFPMath = true;
415 target_options.NoInfsFPMath =
416 module_config.debug_options().xla_cpu_fast_math_honor_infs();
417 target_options.NoNaNsFPMath =
418 module_config.debug_options().xla_cpu_fast_math_honor_nans();
419 target_options.NoSignedZerosFPMath = true;
420 } else {
421 target_options.UnsafeFPMath = false;
422 target_options.NoInfsFPMath = false;
423 target_options.NoNaNsFPMath = false;
424 target_options.NoSignedZerosFPMath = false;
425 }
426 return target_options;
427 }
428
CodeGenOptLevel(const HloModuleConfig & module_config)429 llvm::CodeGenOpt::Level CodeGenOptLevel(const HloModuleConfig& module_config) {
430 VLOG(2) << "backend_optimization_level: "
431 << module_config.debug_options().xla_backend_optimization_level();
432 switch (module_config.debug_options().xla_backend_optimization_level()) {
433 case 1:
434 return llvm::CodeGenOpt::Less;
435 case 2:
436 return llvm::CodeGenOpt::Default;
437 case 3:
438 return llvm::CodeGenOpt::Aggressive;
439 default:
440 return llvm::CodeGenOpt::None;
441 }
442 }
443
GetIRModuleHooks(const HloModule & hlo_module,const LLVMCompiler::ModuleHook & user_pre_optimization_hook,const LLVMCompiler::ModuleHook & user_post_optimization_hook)444 std::pair<LLVMCompiler::ModuleHook, LLVMCompiler::ModuleHook> GetIRModuleHooks(
445 const HloModule& hlo_module,
446 const LLVMCompiler::ModuleHook& user_pre_optimization_hook,
447 const LLVMCompiler::ModuleHook& user_post_optimization_hook) {
448 // Create the IR hooks. If applicable, each IR hook does the following:
449 //
450 // * Calls the user supplied module hook.
451 // * Writes out the IR to a file in the output directory designated by
452 // --xla_dump_to
453 const HloModule* hlo_module_ptr = &hlo_module;
454 auto hook = [user_pre_optimization_hook, user_post_optimization_hook,
455 hlo_module_ptr](bool optimized,
456 const llvm::Module& llvm_module) {
457 const auto& user_hook =
458 !optimized ? user_pre_optimization_hook : user_post_optimization_hook;
459 if (user_hook) {
460 user_hook(llvm_module);
461 }
462 llvm_ir::DumpIrIfEnabled(*hlo_module_ptr, llvm_module, optimized);
463 };
464 return {[hook](const llvm::Module& llvm_module) {
465 return hook(/*optimized=*/false, llvm_module);
466 },
467 [hook](const llvm::Module& llvm_module) {
468 return hook(/*optimized=*/true, llvm_module);
469 }};
470 }
471
VerifyLlvmModule(const llvm::Module & llvm_module)472 Status VerifyLlvmModule(const llvm::Module& llvm_module) {
473 XLA_SCOPED_LOGGING_TIMER("CpuCompiler - Running LLVM verifier");
474
475 std::string err;
476 llvm::raw_string_ostream err_stream(err);
477
478 // verifyModule() returns true if the module is broken.
479 TF_RET_CHECK(!llvm::verifyModule(llvm_module, &err_stream))
480 << "Invalid LLVM IR before optimizations:\n"
481 << err_stream.str()
482 << "\nThis probably indicates a bug in the HLO -> LLVM IR lowering. "
483 "Rerun with --xla_dump_to to get the IR. ";
484 return Status::OK();
485 }
486
CreateHloProfilingArtifacts(const HloModule & module,std::unordered_map<const HloInstruction *,int64> * instruction_to_profile_idx,std::unordered_map<const HloComputation *,int64> * computation_to_profile_idx,std::unique_ptr<HloProfileIndexMap> * hlo_profile_index_map,std::unique_ptr<HloProfilePrinterData> * hlo_profile_printer_data)487 Status CreateHloProfilingArtifacts(
488 const HloModule& module,
489 std::unordered_map<const HloInstruction*, int64>*
490 instruction_to_profile_idx,
491 std::unordered_map<const HloComputation*, int64>*
492 computation_to_profile_idx,
493 std::unique_ptr<HloProfileIndexMap>* hlo_profile_index_map,
494 std::unique_ptr<HloProfilePrinterData>* hlo_profile_printer_data) {
495 *hlo_profile_index_map = absl::make_unique<HloProfileIndexMap>(module);
496 const HloComputation& entry_computation = *module.entry_computation();
497
498 TF_ASSIGN_OR_RETURN(
499 *instruction_to_profile_idx,
500 CollectProfileCandidates::GetCandidatesForComputation(
501 entry_computation,
502 (*hlo_profile_index_map)->instruction_to_profile_idx()));
503
504 auto shape_size_bytes = [](const Shape& shape) {
505 // On the cpu, opaques are pointers.
506 if (shape.IsOpaque()) {
507 return static_cast<int64>(sizeof(void*));
508 }
509 return ShapeUtil::ByteSizeOf(shape, sizeof(void*));
510 };
511
512 HloCostAnalysis cost_analysis(shape_size_bytes);
513 TF_RETURN_IF_ERROR(entry_computation.Accept(&cost_analysis));
514 *hlo_profile_printer_data = CreateHloProfilePrinterData(
515 **hlo_profile_index_map, cost_analysis, entry_computation.name());
516 *computation_to_profile_idx =
517 (*hlo_profile_index_map)->computation_to_profile_idx();
518
519 return Status::OK();
520 }
521
522 } // namespace
523
RunHloPasses(std::unique_ptr<HloModule> module,se::StreamExecutor *,DeviceMemoryAllocator *)524 StatusOr<std::unique_ptr<HloModule>> CpuCompiler::RunHloPasses(
525 std::unique_ptr<HloModule> module, se::StreamExecutor* /*stream_exec*/,
526 DeviceMemoryAllocator* /*device_allocator*/) {
527 std::unique_ptr<llvm::TargetMachine> jit_target_machine =
528 SimpleOrcJIT::InferTargetMachineForJIT(
529 CompilerTargetOptions(module->config()),
530 CodeGenOptLevel(module->config()));
531
532 TF_RETURN_IF_ERROR(RunHloPasses(module.get(), /*is_aot_compile=*/false,
533 jit_target_machine.get()));
534 return std::move(module);
535 }
536
537 namespace {
538
539 // Post-compilation callback functor for use by SimpleOrcJIT.
540 //
541 // Dumps disassembled machine code if dumping is enabled for the module.
542 struct OrcJITPostCompilationHook {
543 // Gets an std::function that implements this hook.
Createxla::cpu::__anonaf93c4290b11::OrcJITPostCompilationHook544 static std::function<void(const llvm::object::ObjectFile& obj_file)> Create(
545 const HloModule* module) {
546 // This struct is not copyable, but std::functions must be. So to create an
547 // std::function out of this struct, we have to wrap it in a shared_ptr.
548 auto wrapped = std::make_shared<OrcJITPostCompilationHook>(module);
549 return [wrapped](const llvm::object::ObjectFile& obj_file) {
550 (*wrapped)(obj_file);
551 };
552 }
553
554 // Constructor can't be private because we want to call it from
555 // std::make_shared, but users should call Create() instead.
OrcJITPostCompilationHookxla::cpu::__anonaf93c4290b11::OrcJITPostCompilationHook556 explicit OrcJITPostCompilationHook(const HloModule* module)
557 : module(module),
558 target_machine(SimpleOrcJIT::InferTargetMachineForJIT(
559 CompilerTargetOptions(module->config()),
560 CodeGenOptLevel(module->config()))),
561 disassembler(*target_machine) {}
562
563 private:
operator ()xla::cpu::__anonaf93c4290b11::OrcJITPostCompilationHook564 void operator()(const llvm::object::ObjectFile& obj_file) {
565 if (!DumpingEnabledForHloModule(*module)) {
566 return;
567 }
568 StatusOr<DisassemblerResult> disasm_or =
569 disassembler.DisassembleObjectFile(obj_file);
570 string text = disasm_or.ok() ? std::move(disasm_or).ValueOrDie().text
571 : absl::StrCat("Error disassembling: ",
572 disasm_or.status().ToString());
573 DumpToFileInDirOrStdout(*module, /*file_suffix=*/"s", text);
574 }
575
576 const HloModule* module;
577 // disassembler keeps references to data inside of target_machine.
578 std::unique_ptr<llvm::TargetMachine> target_machine;
579 Disassembler disassembler;
580 };
581
582 } // namespace
583
RunBackend(std::unique_ptr<HloModule> module,se::StreamExecutor * stream_exec,DeviceMemoryAllocator *)584 StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend(
585 std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
586 DeviceMemoryAllocator* /*device_allocator*/) {
587 VLOG(1) << "Compiling: " << module->name();
588 XLA_SCOPED_LOGGING_TIMER(
589 absl::StrFormat("Compiling [%s] for CPU using JIT", module->name()));
590
591 TF_RET_CHECK(stream_exec != nullptr);
592 std::call_once(llvm_command_line_options_initialized,
593 &llvm_ir::InitializeLLVMCommandLineOptions, module->config());
594
595 ModuleHook pre_optimization_ir_hook;
596 ModuleHook post_optimization_ir_hook;
597 std::tie(pre_optimization_ir_hook, post_optimization_ir_hook) =
598 GetIRModuleHooks(*module, user_pre_optimization_hook_,
599 user_post_optimization_hook_);
600
601 // Compile must be thread-safe so create a new LLVM context for the module.
602 auto llvm_context = absl::make_unique<llvm::LLVMContext>();
603 auto llvm_module =
604 absl::make_unique<llvm::Module>("__compute_module", *llvm_context);
605
606 auto jit = absl::make_unique<SimpleOrcJIT>(
607 CompilerTargetOptions(module->config()),
608 CodeGenOptLevel(module->config()),
609 options::OptimizeForSizeRequested(module->config()),
610 module->config().debug_options().xla_cpu_enable_fast_math(),
611 module->config().debug_options().xla_llvm_disable_expensive_passes(),
612 pre_optimization_ir_hook, post_optimization_ir_hook,
613 OrcJITPostCompilationHook::Create(module.get()));
614 llvm_module->setDataLayout(jit->data_layout());
615 llvm_module->setTargetTriple(jit->target_triple().getTriple());
616
617 HloComputation* entry_computation = module->entry_computation();
618 std::unordered_map<const HloInstruction*, int64> instruction_to_profile_idx;
619 std::unordered_map<const HloComputation*, int64> computation_to_profile_idx;
620 std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map;
621 std::unique_ptr<HloProfilePrinterData> hlo_profile_printer_data;
622 if (module->config().hlo_profiling_enabled()) {
623 TF_RETURN_IF_ERROR(CreateHloProfilingArtifacts(
624 *module, &instruction_to_profile_idx, &computation_to_profile_idx,
625 &hlo_profile_index_map, &hlo_profile_printer_data));
626 }
627
628 std::unique_ptr<Executable> cpu_executable;
629
630 // Cache these flags here since we'll want to access them after the module's
631 // ownership is std::moved.
632 const bool embed_ir_in_executable =
633 module->config().debug_options().xla_embed_ir_in_executable();
634
635 // Select an order for emitting the HLO instructions for each
636 // computation. Using this sequence enables tighter buffer liveness analysis
637 // and reduced memory usage (as compared to using DependencyHloOrdering).
638 TF_ASSIGN_OR_RETURN(HloSchedule schedule,
639 ScheduleModule(module.get(), BufferSizeBytesFunction(),
640 DFSMemoryScheduler));
641
642 // Run buffer allocation on the HLO graph.
643 TF_ASSIGN_OR_RETURN(
644 std::unique_ptr<BufferAssignment> assignment,
645 BufferAssigner::Run(module.get(),
646 absl::make_unique<SequentialHloOrdering>(schedule),
647 BufferSizeBytesFunction(), memory_alignment,
648 /*allow_input_output_aliasing=*/false,
649 /*allocate_buffers_for_constants=*/true));
650 // BufferAssignment::ToString() includes a header, so no need for us to
651 // print one ourselves.
652 if (DumpingEnabledForHloModule(*module)) {
653 DumpToFileInDirOrStdout(*module, "buffer_assignment",
654 assignment->ToString());
655 }
656 DumpHloModuleIfEnabled(*module, *assignment, "after_optimizations");
657
658 // Each computation is a single function. Emit all embedded computations
659 // before the entry computation. The order of computations returned from
660 // GetEmbeddedComputations guarantees that a called computation occurs
661 // before a caller computation.
662
663 LLVMTargetMachineFeatures target_machine_features(jit->target_machine());
664 IrEmitter ir_emitter(*module, *assignment, llvm_module.get(),
665 std::move(instruction_to_profile_idx),
666 std::move(computation_to_profile_idx),
667 &target_machine_features,
668 #ifdef MEMORY_SANITIZER
669 /*emit_code_for_msan=*/true
670 #else
671 /*emit_code_for_msan=*/false
672 #endif
673 );
674
675 TF_RETURN_IF_ERROR(ir_emitter.EmitConstantGlobals());
676
677 for (auto embedded_computation :
678 entry_computation->MakeEmbeddedComputationsList()) {
679 if (embedded_computation->IsFusionComputation()) {
680 continue;
681 }
682 TF_RETURN_IF_ERROR(
683 ir_emitter
684 .EmitComputation(
685 embedded_computation, embedded_computation->name(),
686 /*is_top_level_computation=*/false,
687 schedule.sequence(embedded_computation).instructions())
688 .status());
689 }
690 string function_name_prefix = entry_computation->name().empty()
691 ? "__compute"
692 : entry_computation->name();
693 TF_ASSIGN_OR_RETURN(llvm::Function * entry_function,
694 ir_emitter.EmitComputation(
695 entry_computation, function_name_prefix,
696 /*is_top_level_computation=*/true,
697 schedule.sequence(entry_computation).instructions()));
698
699 string function_name = [&]() {
700 llvm::SmallVector<char, 40> function_name_vector;
701 llvm::Mangler::getNameWithPrefix(
702 function_name_vector, entry_function->getName(), jit->data_layout());
703 return string(function_name_vector.begin(), function_name_vector.end());
704 }();
705
706 string ir_module_string;
707 if (embed_ir_in_executable) {
708 ir_module_string = llvm_ir::DumpModuleToString(*llvm_module);
709 }
710
711 TF_RETURN_IF_ERROR(VerifyLlvmModule(*llvm_module));
712
713 // JIT compile the LLVM IR module to in-memory machine code.
714 jit->AddModule(std::move(llvm_module));
715 cpu_executable.reset(new CpuExecutable(
716 std::move(jit), std::move(assignment), std::move(module), function_name,
717 std::move(hlo_profile_printer_data), std::move(hlo_profile_index_map)));
718
719 if (embed_ir_in_executable) {
720 static_cast<CpuExecutable&>(*cpu_executable)
721 .set_ir_module_string(ir_module_string);
722 }
723
724 VLOG(1) << "Compilation finished";
725 return std::move(cpu_executable);
726 }
727
728 StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group,const AotCompilationOptions & aot_options)729 CpuCompiler::CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group,
730 const AotCompilationOptions& aot_options) {
731 TF_RET_CHECK(!module_group->empty());
732 std::vector<std::unique_ptr<HloModule>> modules =
733 module_group->ConsumeModules();
734
735 std::call_once(llvm_command_line_options_initialized,
736 &llvm_ir::InitializeLLVMCommandLineOptions,
737 modules[0]->config());
738
739 // We can pass just one llvm::TargetOptions when we compile the LLVM module,
740 // so we bail if the configs have conflicting flags. At the moment, the only
741 // flag that needs to be consistent is fast-math.
742 const bool fast_math_enabled =
743 modules[0]->config().debug_options().xla_cpu_enable_fast_math();
744 for (const auto& module : modules) {
745 if (module->config().debug_options().xla_cpu_enable_fast_math() !=
746 fast_math_enabled) {
747 return InvalidArgument(
748 "All HLO module configs must have the same value for "
749 "xla_enable_fast_math.");
750 }
751 }
752
753 if (aot_options.PlatformId() != se::host::kHostPlatformId) {
754 return InvalidArgument("Incompatible AOT compilation platform");
755 }
756 const CpuAotCompilationOptions& options =
757 static_cast<const CpuAotCompilationOptions&>(aot_options);
758 llvm::Triple triple(llvm::Triple::normalize(options.triple()));
759 std::string error;
760 const llvm::Target* target =
761 llvm::TargetRegistry::lookupTarget(triple.getTriple(), error);
762 if (target == nullptr) {
763 return InternalError("TargetRegistry::lookupTarget failed: %s", error);
764 }
765
766 llvm::Reloc::Model reloc_model = llvm::Reloc::Static;
767 llvm::PICLevel::Level pic_level = llvm::PICLevel::NotPIC;
768 llvm::PIELevel::Level pie_level = llvm::PIELevel::Default;
769 switch (options.relocation_model()) {
770 case CpuAotCompilationOptions::RelocationModel::Static:
771 reloc_model = llvm::Reloc::Static;
772 pic_level = llvm::PICLevel::NotPIC;
773 pie_level = llvm::PIELevel::Default;
774 break;
775 case CpuAotCompilationOptions::RelocationModel::SmallPic:
776 reloc_model = llvm::Reloc::PIC_;
777 pic_level = llvm::PICLevel::SmallPIC;
778 pie_level = llvm::PIELevel::Default;
779 break;
780 case CpuAotCompilationOptions::RelocationModel::BigPic:
781 reloc_model = llvm::Reloc::PIC_;
782 pic_level = llvm::PICLevel::BigPIC;
783 pie_level = llvm::PIELevel::Default;
784 break;
785 case CpuAotCompilationOptions::RelocationModel::SmallPie:
786 reloc_model = llvm::Reloc::PIC_;
787 pic_level = llvm::PICLevel::SmallPIC;
788 pie_level = llvm::PIELevel::Small;
789 break;
790 case CpuAotCompilationOptions::RelocationModel::BigPie:
791 reloc_model = llvm::Reloc::PIC_;
792 pic_level = llvm::PICLevel::BigPIC;
793 pie_level = llvm::PIELevel::Large;
794 break;
795 }
796 llvm::CodeGenOpt::Level opt_level = CodeGenOptLevel(modules[0]->config());
797 std::unique_ptr<llvm::TargetMachine> target_machine =
798 absl::WrapUnique(target->createTargetMachine(
799 triple.getTriple(), options.cpu_name(), options.features(),
800 CompilerTargetOptions(modules[0]->config()), reloc_model, llvm::None,
801 opt_level));
802
803 // Compile must be thread-safe so create a new LLVM context for the module.
804 llvm::LLVMContext llvm_context;
805 llvm::Module llvm_module("__compute_module", llvm_context);
806 llvm_module.setDataLayout(target_machine->createDataLayout());
807 llvm_module.setTargetTriple(triple.getTriple());
808 if (pic_level != llvm::PICLevel::NotPIC) {
809 llvm_module.setPICLevel(pic_level);
810 }
811 if (pie_level != llvm::PIELevel::Default) {
812 llvm_module.setPIELevel(pie_level);
813 }
814
815 std::vector<std::unique_ptr<AotCompilationResult>> results;
816 for (size_t i = 0; i < modules.size(); ++i) {
817 HloModule* module = modules[i].get();
818 VLOG(1) << "Compiling ahead-of-time: " << module->name();
819
820 TF_RETURN_IF_ERROR(
821 RunHloPasses(module, /*is_aot_compile=*/true, target_machine.get()));
822
823 TF_ASSIGN_OR_RETURN(HloSchedule schedule,
824 ScheduleModule(module, BufferSizeBytesFunction()));
825
826 // Run buffer analysis on the HLO graph. This analysis figures out which
827 // temporary buffers are required to run the computation.
828 TF_ASSIGN_OR_RETURN(
829 std::unique_ptr<BufferAssignment> assignment,
830 BufferAssigner::Run(module,
831 absl::make_unique<SequentialHloOrdering>(schedule),
832 BufferSizeBytesFunction(), memory_alignment,
833 /*allow_input_output_aliasing=*/false,
834 /*allocate_buffers_for_constants=*/true));
835 // BufferAssignment::ToString() includes a header, so no need for us to
836 // print one ourselves.
837 if (DumpingEnabledForHloModule(*module)) {
838 DumpToFileInDirOrStdout(*module, "buffer_assignment",
839 assignment->ToString());
840 }
841 DumpHloModuleIfEnabled(*module, *assignment, "after_optimizations");
842
843 std::unordered_map<const HloInstruction*, int64> instruction_to_profile_idx;
844 std::unordered_map<const HloComputation*, int64> computation_to_profile_idx;
845 std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map;
846 std::unique_ptr<HloProfilePrinterData> hlo_profile_printer_data;
847
848 if (module->config().hlo_profiling_enabled()) {
849 TF_RETURN_IF_ERROR(CreateHloProfilingArtifacts(
850 *module, &instruction_to_profile_idx, &computation_to_profile_idx,
851 &hlo_profile_index_map, &hlo_profile_printer_data));
852 }
853
854 LLVMTargetMachineFeatures target_machine_features(target_machine.get());
855 IrEmitter ir_emitter(*module, *assignment, &llvm_module,
856 std::move(instruction_to_profile_idx),
857 std::move(computation_to_profile_idx),
858 &target_machine_features,
859 // TODO(b/66051036): Run full msan for AOT.
860 /*emit_code_for_msan=*/false);
861
862 TF_RETURN_IF_ERROR(ir_emitter.EmitConstantGlobals());
863
864 HloComputation* computation = module->entry_computation();
865 for (auto embedded_computation :
866 computation->MakeEmbeddedComputationsList()) {
867 if (embedded_computation->IsFusionComputation()) {
868 continue;
869 }
870 TF_RETURN_IF_ERROR(
871 ir_emitter
872 .EmitComputation(
873 embedded_computation, embedded_computation->name(),
874 /*is_top_level_computation=*/false,
875 schedule.sequence(embedded_computation).instructions())
876 .status());
877 }
878 const string& entry_point_name = options.entry_point_name();
879 TF_ASSIGN_OR_RETURN(llvm::Function * entry_function,
880 ir_emitter.EmitComputation(
881 computation, entry_point_name,
882 /*is_top_level_computation=*/true,
883 schedule.sequence(computation).instructions()));
884
885 CHECK(entry_function->getName() == entry_point_name);
886
887 ModuleHook pre_optimization_ir_hook;
888 ModuleHook post_optimization_ir_hook;
889 std::tie(pre_optimization_ir_hook, post_optimization_ir_hook) =
890 GetIRModuleHooks(*module, user_pre_optimization_hook_,
891 user_post_optimization_hook_);
892
893 // Run the LLVM verifier over the unoptimized LLVM IR. If it fails, run the
894 // pre-optimization IR dump hook before returning.
895 {
896 Status verify_status = VerifyLlvmModule(llvm_module);
897 if (!verify_status.ok() && pre_optimization_ir_hook) {
898 pre_optimization_ir_hook(llvm_module);
899 }
900 TF_RETURN_IF_ERROR(verify_status);
901 }
902
903 auto post_codegen_hook = [&](const llvm::object::ObjectFile& obj_file) {
904 if (!DumpingEnabledForHloModule(*module)) {
905 return;
906 }
907 StatusOr<DisassemblerResult> disasm_or =
908 Disassembler(*target_machine).DisassembleObjectFile(obj_file);
909 string text = disasm_or.ok()
910 ? std::move(disasm_or).ValueOrDie().text
911 : absl::StrCat("Error disassembling: ",
912 disasm_or.status().ToString());
913 DumpToFileInDirOrStdout(*module, /*file_suffix=*/"s", text);
914 };
915
916 CompilerFunctor compiler_functor(
917 target_machine.get(), opt_level,
918 options::OptimizeForSizeRequested(module->config()),
919 module->config().debug_options().xla_cpu_enable_fast_math(),
920 module->config().debug_options().xla_llvm_disable_expensive_passes(),
921 pre_optimization_ir_hook, post_optimization_ir_hook, post_codegen_hook);
922 std::unique_ptr<llvm::MemoryBuffer> object_file =
923 compiler_functor(llvm_module);
924 ObjectFileData object_file_data(object_file->getBufferStart(),
925 object_file->getBufferEnd());
926
927 std::vector<BufferInfo> buffer_infos =
928 CreateBufferInfosFromBufferAssignment(*assignment);
929
930 TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice result_slice,
931 assignment->GetUniqueTopLevelOutputSlice());
932
933 results.emplace_back(absl::make_unique<CpuAotCompilationResult>(
934 std::move(object_file_data), std::move(buffer_infos),
935 result_slice.index(), std::move(hlo_profile_printer_data)));
936 }
937
938 VLOG(1) << "Compilation finished";
939 return std::move(results);
940 }
941
PlatformId() const942 se::Platform::Id CpuCompiler::PlatformId() const {
943 return se::host::kHostPlatformId;
944 }
945
ShapeSizeBytesFunction() const946 HloCostAnalysis::ShapeSizeFunction CpuCompiler::ShapeSizeBytesFunction() const {
947 return CpuExecutable::ShapeSizeBytes;
948 }
949
950 } // namespace cpu
951 } // namespace xla
952
InitModule()953 static bool InitModule() {
954 xla::Compiler::RegisterCompilerFactory(
955 stream_executor::host::kHostPlatformId,
956 []() { return absl::make_unique<xla::cpu::CpuCompiler>(); });
957 return true;
958 }
959 static bool module_initialized = InitModule();
960