• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/cpu/compiler_functor.h"
17 
18 #include <algorithm>
19 #include <iterator>
20 #include <memory>
21 #include <string>
22 #include <utility>
23 #include <vector>
24 
25 #include "absl/memory/memory.h"
26 #include "llvm/ADT/StringRef.h"
27 #include "llvm/Analysis/TargetLibraryInfo.h"
28 #include "llvm/Analysis/TargetTransformInfo.h"
29 #include "llvm/IR/LegacyPassManager.h"
30 #include "llvm/IR/Verifier.h"
31 #include "llvm/MC/MCContext.h"
32 #include "llvm/Object/ObjectFile.h"
33 #include "llvm/Support/SmallVectorMemoryBuffer.h"
34 #include "llvm/Support/raw_ostream.h"
35 #include "llvm/Target/TargetMachine.h"
36 #include "llvm/Transforms/IPO.h"
37 #include "llvm/Transforms/IPO/AlwaysInliner.h"
38 #include "llvm/Transforms/IPO/PassManagerBuilder.h"
39 #include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h"
40 #include "tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.h"
41 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
42 #include "tensorflow/compiler/xla/statusor.h"
43 #include "tensorflow/compiler/xla/types.h"
44 #include "tensorflow/compiler/xla/util.h"
45 #include "tensorflow/core/platform/logging.h"
46 
47 namespace xla {
48 namespace cpu {
49 
50 /* Create filtered versions of the LLVM Pass Managers to filter out some
51 of the expensive passes.
52 Profiling:
53    learning/brain/google/xla/benchmarks:inception_cpu_benchmark
54    learning/brain/google/xla/benchmarks:cifarnet
55 pointed to LICM and IndVarSimplify as the hottest passes.
56 LICM is known to exhibit O(n^2) time in the number of instructions.
57 IndVarSimplify is slow due to SCEV. If loops are emitted in canonical form,
58 this pass is not necessary.
59 Disabling these as a starting point.
60 */
61 // TODO(b/64227304) Creating a custom pass pipeline will replace this.
62 
63 namespace {
64 class FilteredPassManager : public llvm::legacy::PassManager {
65  public:
FilteredPassManager(bool disable_expensive_passes)66   explicit FilteredPassManager(bool disable_expensive_passes)
67       : disable_expensive_passes_(disable_expensive_passes) {}
add(llvm::Pass * p)68   void add(llvm::Pass* p) override {
69     bool pass_disabled =
70         disable_expensive_passes_ && p->getPassName().contains("Unroll loops");
71     if (!pass_disabled) {
72       llvm::legacy::PassManager::add(p);
73     } else {
74       delete p;
75     }
76   }
77 
78  private:
79   bool disable_expensive_passes_;
80 };
81 }  // anonymous namespace
82 
operator ()(llvm::Module & module)83 llvm::Expected<std::unique_ptr<llvm::MemoryBuffer>> CompilerFunctor::operator()(
84     llvm::Module& module) {
85   FilteredPassManager module_passes(disable_expensive_passes_);
86   llvm::legacy::FunctionPassManager function_passes(&module);
87 
88   VLOG(2) << "IR before optimizations";
89   XLA_VLOG_LINES(2, llvm_ir::DumpModuleToString(module));
90 
91   if (pre_optimization_hook_) {
92     pre_optimization_hook_(module);
93   }
94 
95   // Add the appropriate TargetLibraryInfo and TargetTransformInfo.
96   AddTargetInfoPasses(&module_passes);
97 
98   // Build up optimization pipeline.
99   if (optimize_for_size_) {
100     // Optimizing for size turns on -O2 level optimizations.
101     //
102     // TODO(b/64153864): Although the code generator supports size_level = 2 to
103     // turn on more aggressive code size optimizations than size_level = 1, we
104     // pass size_level = 1 because in many cases a size_level of 2 does
105     // worse. Investigate why.
106     AddOptimizationPasses(&module_passes, &function_passes, /*opt_level=*/2,
107                           /*size_level=*/1);
108   } else {
109     AddOptimizationPasses(&module_passes, &function_passes,
110                           /*opt_level=*/opt_level_, /*size_level=*/0);
111   }
112 
113   // Run optimization passes on module.
114   function_passes.doInitialization();
115 
116   CHECK(!llvm::verifyModule(module, &llvm::dbgs()));
117 
118   for (auto func = module.begin(); func != module.end(); ++func) {
119     function_passes.run(*func);
120   }
121   function_passes.doFinalization();
122   module_passes.run(module);
123 
124   CHECK(!llvm::verifyModule(module, &llvm::dbgs()));
125 
126   runtime::RewriteIRRuntimeFunctions(&module, fast_math_flags_);
127 
128   // Buffer for holding machine code prior to constructing the ObjectFile.
129   llvm::SmallVector<char, 0> stream_buffer;
130   llvm::raw_svector_ostream ostream(stream_buffer);
131 
132   VLOG(2) << "IR after optimizations";
133   XLA_VLOG_LINES(2, llvm_ir::DumpModuleToString(module));
134 
135   if (post_optimization_hook_) {
136     post_optimization_hook_(module);
137   }
138 
139   // Generate code.
140   llvm::MCContext* mc_context;
141   llvm::legacy::PassManager codegen_passes;
142   target_machine_->addPassesToEmitMC(codegen_passes, mc_context, ostream);
143   codegen_passes.run(module);
144 
145   std::unique_ptr<llvm::MemoryBuffer> memory_buffer(
146       new llvm::SmallVectorMemoryBuffer(std::move(stream_buffer)));
147 
148   if (post_codegen_hook_) {
149     llvm::Expected<std::unique_ptr<llvm::object::ObjectFile>> obj_file =
150         llvm::object::ObjectFile::createObjectFile(*memory_buffer);
151     if (obj_file) {
152       post_codegen_hook_(*obj_file.get());
153     } else {
154       LOG(WARNING) << "Could convert memory buffer to object file!";
155     }
156   }
157 
158   return std::move(memory_buffer);
159 }
160 
VectorFunctionsForTargetLibraryInfoImpl()161 static std::vector<llvm::VecDesc> VectorFunctionsForTargetLibraryInfoImpl() {
162   std::vector<llvm::VecDesc> result = {
163       {"tanhf", runtime::kTanhV4F32SymbolName, llvm::ElementCount::getFixed(4)},
164       {"llvm.tanh.f32", runtime::kTanhV4F32SymbolName,
165        llvm::ElementCount::getFixed(4)},
166 
167       {"tanhf", runtime::kTanhV8F32SymbolName, llvm::ElementCount::getFixed(8)},
168       {"llvm.tanh.f32", runtime::kTanhV8F32SymbolName,
169        llvm::ElementCount::getFixed(8)},
170 
171       {"tanhf", runtime::kTanhV16F32SymbolName,
172        llvm::ElementCount::getFixed(16)},
173       {"llvm.tanh.f32", runtime::kTanhV16F32SymbolName,
174        llvm::ElementCount::getFixed(16)},
175 
176       {"expf", runtime::kExpV4F32SymbolName, llvm::ElementCount::getFixed(4)},
177       {"llvm.exp.f32", runtime::kExpV4F32SymbolName,
178        llvm::ElementCount::getFixed(4)},
179 
180       {"expf", runtime::kExpV8F32SymbolName, llvm::ElementCount::getFixed(8)},
181       {"llvm.exp.f32", runtime::kExpV8F32SymbolName,
182        llvm::ElementCount::getFixed(8)},
183 
184       {"expf", runtime::kExpV16F32SymbolName, llvm::ElementCount::getFixed(16)},
185       {"llvm.exp.f32", runtime::kExpV16F32SymbolName,
186        llvm::ElementCount::getFixed(16)},
187 
188       {"logf", runtime::kLogV4F32SymbolName, llvm::ElementCount::getFixed(4)},
189       {"llvm.log.f32", runtime::kLogV4F32SymbolName,
190        llvm::ElementCount::getFixed(4)},
191 
192       {"logf", runtime::kLogV8F32SymbolName, llvm::ElementCount::getFixed(8)},
193       {"llvm.log.f32", runtime::kLogV8F32SymbolName,
194        llvm::ElementCount::getFixed(8)},
195 
196       {"logf", runtime::kLogV16F32SymbolName, llvm::ElementCount::getFixed(16)},
197       {"llvm.log.f32", runtime::kLogV16F32SymbolName,
198        llvm::ElementCount::getFixed(16)},
199   };
200   return result;
201 }
202 
AddTargetInfoPasses(llvm::legacy::PassManagerBase * passes) const203 void CompilerFunctor::AddTargetInfoPasses(
204     llvm::legacy::PassManagerBase* passes) const {
205   llvm::Triple target_triple(target_machine_->getTargetTriple());
206   auto target_library_info_impl =
207       absl::make_unique<llvm::TargetLibraryInfoImpl>(target_triple);
208   target_library_info_impl->addVectorizableFunctions(
209       VectorFunctionsForTargetLibraryInfoImpl());
210 
211   passes->add(
212       new llvm::TargetLibraryInfoWrapperPass(*target_library_info_impl));
213   passes->add(createTargetTransformInfoWrapperPass(
214       target_machine_->getTargetIRAnalysis()));
215 }
216 
AddOptimizationPasses(llvm::legacy::PassManagerBase * module_passes,llvm::legacy::FunctionPassManager * function_passes,unsigned opt_level,unsigned size_level) const217 void CompilerFunctor::AddOptimizationPasses(
218     llvm::legacy::PassManagerBase* module_passes,
219     llvm::legacy::FunctionPassManager* function_passes, unsigned opt_level,
220     unsigned size_level) const {
221   llvm::PassManagerBuilder builder;
222   builder.OptLevel = opt_level;
223   builder.SizeLevel = size_level;
224 
225   if (opt_level > 1) {
226     builder.Inliner = llvm::createFunctionInliningPass();
227   } else {
228     // Only inline functions marked with "alwaysinline".
229     builder.Inliner = llvm::createAlwaysInlinerLegacyPass();
230   }
231 
232   builder.DisableUnrollLoops = opt_level == 0;
233   builder.LoopVectorize = opt_level > 0 && size_level == 0;
234   builder.SLPVectorize = opt_level > 1 && size_level == 0;
235 
236   builder.populateFunctionPassManager(*function_passes);
237   builder.populateModulePassManager(*module_passes);
238 }
239 
240 }  // namespace cpu
241 }  // namespace xla
242