1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/cpu/compiler_functor.h"
17
18 #include <algorithm>
19 #include <iterator>
20 #include <memory>
21 #include <string>
22 #include <utility>
23 #include <vector>
24
25 #include "absl/memory/memory.h"
26 #include "llvm/ADT/StringRef.h"
27 #include "llvm/Analysis/TargetLibraryInfo.h"
28 #include "llvm/Analysis/TargetTransformInfo.h"
29 #include "llvm/IR/LegacyPassManager.h"
30 #include "llvm/IR/Verifier.h"
31 #include "llvm/MC/MCContext.h"
32 #include "llvm/Object/ObjectFile.h"
33 #include "llvm/Support/SmallVectorMemoryBuffer.h"
34 #include "llvm/Support/raw_ostream.h"
35 #include "llvm/Target/TargetMachine.h"
36 #include "llvm/Transforms/IPO.h"
37 #include "llvm/Transforms/IPO/AlwaysInliner.h"
38 #include "llvm/Transforms/IPO/PassManagerBuilder.h"
39 #include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h"
40 #include "tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.h"
41 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
42 #include "tensorflow/compiler/xla/statusor.h"
43 #include "tensorflow/compiler/xla/types.h"
44 #include "tensorflow/compiler/xla/util.h"
45 #include "tensorflow/core/platform/logging.h"
46
47 namespace xla {
48 namespace cpu {
49
50 /* Create filtered versions of the LLVM Pass Managers to filter out some
51 of the expensive passes.
52 Profiling:
53 learning/brain/google/xla/benchmarks:inception_cpu_benchmark
54 learning/brain/google/xla/benchmarks:cifarnet
55 pointed to LICM and IndVarSimplify as the hottest passes.
56 LICM is known to exhibit O(n^2) time in the number of instructions.
57 IndVarSimplify is slow due to SCEV. If loops are emitted in canonical form,
58 this pass is not necessary.
59 Disabling these as a starting point.
60 */
61 // TODO(b/64227304) Creating a custom pass pipeline will replace this.
62
63 namespace {
64 class FilteredPassManager : public llvm::legacy::PassManager {
65 public:
FilteredPassManager(bool disable_expensive_passes)66 explicit FilteredPassManager(bool disable_expensive_passes)
67 : disable_expensive_passes_(disable_expensive_passes) {}
add(llvm::Pass * p)68 void add(llvm::Pass* p) override {
69 llvm::StringRef PassName = p->getPassName();
70 if (PassName.contains("Warn about non-applied transformations")) {
71 delete p;
72 return;
73 }
74 if (disable_expensive_passes_) {
75 if (PassName.contains("Unroll loops")) {
76 delete p;
77 return;
78 }
79 }
80 llvm::legacy::PassManager::add(p);
81 }
82
83 private:
84 bool disable_expensive_passes_;
85 };
86 } // anonymous namespace
87
operator ()(llvm::Module & module) const88 std::unique_ptr<llvm::MemoryBuffer> CompilerFunctor::operator()(
89 llvm::Module& module) const {
90 FilteredPassManager module_passes(disable_expensive_passes_);
91 llvm::legacy::FunctionPassManager function_passes(&module);
92
93 VLOG(2) << "IR before optimizations";
94 XLA_VLOG_LINES(2, llvm_ir::DumpModuleToString(module));
95
96 if (pre_optimization_hook_) {
97 pre_optimization_hook_(module);
98 }
99
100 // Add the appropriate TargetLibraryInfo and TargetTransformInfo.
101 AddTargetInfoPasses(&module_passes);
102
103 // Build up optimization pipeline.
104 if (optimize_for_size_) {
105 // Optimizing for size turns on -O2 level optimizations.
106 //
107 // TODO(b/64153864): Although the code generator supports size_level = 2 to
108 // turn on more aggressive code size optimizations than size_level = 1, we
109 // pass size_level = 1 because in many cases a size_level of 2 does
110 // worse. Investigate why.
111 AddOptimizationPasses(&module_passes, &function_passes, /*opt_level=*/2,
112 /*size_level=*/1);
113 } else {
114 AddOptimizationPasses(&module_passes, &function_passes,
115 /*opt_level=*/opt_level_, /*size_level=*/0);
116 }
117
118 // Run optimization passes on module.
119 function_passes.doInitialization();
120
121 CHECK(!llvm::verifyModule(module, &llvm::dbgs()));
122
123 for (auto func = module.begin(); func != module.end(); ++func) {
124 function_passes.run(*func);
125 }
126 function_passes.doFinalization();
127 module_passes.run(module);
128
129 CHECK(!llvm::verifyModule(module, &llvm::dbgs()));
130
131 runtime::RewriteIRRuntimeFunctions(&module, enable_fast_math_);
132
133 // Buffer for holding machine code prior to constructing the ObjectFile.
134 llvm::SmallVector<char, 0> stream_buffer;
135 llvm::raw_svector_ostream ostream(stream_buffer);
136
137 VLOG(2) << "IR after optimizations";
138 XLA_VLOG_LINES(2, llvm_ir::DumpModuleToString(module));
139
140 if (post_optimization_hook_) {
141 post_optimization_hook_(module);
142 }
143
144 // Generate code.
145 llvm::MCContext* mc_context;
146 llvm::legacy::PassManager codegen_passes;
147 target_machine_->addPassesToEmitMC(codegen_passes, mc_context, ostream);
148 codegen_passes.run(module);
149
150 std::unique_ptr<llvm::MemoryBuffer> memory_buffer(
151 new llvm::SmallVectorMemoryBuffer(std::move(stream_buffer)));
152
153 if (post_codegen_hook_) {
154 llvm::Expected<std::unique_ptr<llvm::object::ObjectFile>> obj_file =
155 llvm::object::ObjectFile::createObjectFile(*memory_buffer);
156 if (obj_file) {
157 post_codegen_hook_(*obj_file.get());
158 } else {
159 LOG(WARNING) << "Could convert memory buffer to object file!";
160 }
161 }
162
163 return memory_buffer;
164 }
165
VectorFunctionsForTargetLibraryInfoImpl()166 static std::vector<llvm::VecDesc> VectorFunctionsForTargetLibraryInfoImpl() {
167 std::vector<llvm::VecDesc> result = {
168 {"tanhf", runtime::kTanhV4F32SymbolName, 4},
169 {"llvm.tanh.f32", runtime::kTanhV4F32SymbolName, 4},
170
171 {"tanhf", runtime::kTanhV8F32SymbolName, 8},
172 {"llvm.tanh.f32", runtime::kTanhV8F32SymbolName, 8},
173
174 {"expf", runtime::kExpV4F32SymbolName, 4},
175 {"llvm.exp.f32", runtime::kExpV4F32SymbolName, 4},
176
177 {"expf", runtime::kExpV8F32SymbolName, 8},
178 {"llvm.exp.f32", runtime::kExpV8F32SymbolName, 8},
179
180 {"logf", runtime::kLogV4F32SymbolName, 4},
181 {"llvm.log.f32", runtime::kLogV4F32SymbolName, 4},
182
183 {"logf", runtime::kLogV8F32SymbolName, 8},
184 {"llvm.log.f32", runtime::kLogV8F32SymbolName, 8},
185 };
186 return result;
187 }
188
AddTargetInfoPasses(llvm::legacy::PassManagerBase * passes) const189 void CompilerFunctor::AddTargetInfoPasses(
190 llvm::legacy::PassManagerBase* passes) const {
191 llvm::Triple target_triple(target_machine_->getTargetTriple());
192 auto target_library_info_impl =
193 absl::make_unique<llvm::TargetLibraryInfoImpl>(target_triple);
194 target_library_info_impl->addVectorizableFunctions(
195 VectorFunctionsForTargetLibraryInfoImpl());
196 passes->add(
197 new llvm::TargetLibraryInfoWrapperPass(*target_library_info_impl));
198 passes->add(createTargetTransformInfoWrapperPass(
199 target_machine_->getTargetIRAnalysis()));
200 }
201
AddOptimizationPasses(llvm::legacy::PassManagerBase * module_passes,llvm::legacy::FunctionPassManager * function_passes,unsigned opt_level,unsigned size_level) const202 void CompilerFunctor::AddOptimizationPasses(
203 llvm::legacy::PassManagerBase* module_passes,
204 llvm::legacy::FunctionPassManager* function_passes, unsigned opt_level,
205 unsigned size_level) const {
206 llvm::PassManagerBuilder builder;
207 builder.OptLevel = opt_level;
208 builder.SizeLevel = size_level;
209
210 if (opt_level > 1) {
211 builder.Inliner = llvm::createFunctionInliningPass();
212 } else {
213 // Only inline functions marked with "alwaysinline".
214 builder.Inliner = llvm::createAlwaysInlinerLegacyPass();
215 }
216
217 builder.DisableUnitAtATime = false;
218 builder.DisableUnrollLoops = opt_level == 0;
219 builder.LoopVectorize = opt_level > 0 && size_level == 0;
220 builder.SLPVectorize = opt_level > 1 && size_level == 0;
221
222 builder.populateFunctionPassManager(*function_passes);
223 builder.populateModulePassManager(*module_passes);
224 }
225
226 } // namespace cpu
227 } // namespace xla
228