• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/runtime/execution_engine.h"
17 
18 #include <memory>
19 #include <string>
20 #include <utility>
21 #include <vector>
22 
23 #include "llvm/ExecutionEngine/JITEventListener.h"
24 #include "llvm/ExecutionEngine/ObjectCache.h"
25 #include "llvm/ExecutionEngine/Orc/CompileUtils.h"
26 #include "llvm/ExecutionEngine/Orc/ExecutionUtils.h"
27 #include "llvm/ExecutionEngine/Orc/IRCompileLayer.h"
28 #include "llvm/ExecutionEngine/Orc/IRTransformLayer.h"
29 #include "llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h"
30 #include "llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h"
31 #include "llvm/IR/IRBuilder.h"
32 #include "llvm/Support/Error.h"
33 #include "llvm/Support/FormatVariadic.h"
34 #include "llvm/Support/MemoryBuffer.h"
35 #include "tensorflow/compiler/xla/runtime/errors.h"
36 
37 namespace xla {
38 namespace runtime {
39 
40 using llvm::cast;
41 
42 using llvm::Expected;
43 using llvm::MemoryBuffer;
44 using llvm::SectionMemoryManager;
45 using llvm::StringRef;
46 using llvm::Triple;
47 
48 using llvm::orc::DynamicLibrarySearchGenerator;
49 using llvm::orc::ExecutionSession;
50 using llvm::orc::ExecutorAddr;
51 using llvm::orc::IRCompileLayer;
52 using llvm::orc::JITTargetMachineBuilder;
53 using llvm::orc::RTDyldObjectLinkingLayer;
54 using llvm::orc::SymbolMap;
55 using llvm::orc::ThreadSafeModule;
56 using llvm::orc::TMOwningSimpleCompiler;
57 
ExecutionEngine(bool enable_gdb_listener,bool enable_perf_listener)58 ExecutionEngine::ExecutionEngine(bool enable_gdb_listener,
59                                  bool enable_perf_listener) {
60   if (enable_gdb_listener)
61     gdb_listener_ = llvm::JITEventListener::createGDBRegistrationListener();
62   if (enable_perf_listener)
63     perf_listener_ = llvm::JITEventListener::createPerfJITEventListener();
64 }
65 
BindAll(std::vector<SymbolsBinding> bindings)66 /*static*/ ExecutionEngine::SymbolsBinding ExecutionEngine::BindAll(
67     std::vector<SymbolsBinding> bindings) {
68   return [b = std::move(bindings)](llvm::orc::MangleAndInterner mangle) {
69     llvm::orc::SymbolMap symbol_map;
70 
71     for (const SymbolsBinding &binding : b) {
72       if (!binding) continue;
73       auto symbols = binding(mangle);
74       symbol_map.insert(symbols.begin(), symbols.end());
75     }
76 
77     return symbol_map;
78   };
79 }
80 
obj_file() const81 std::unique_ptr<MemoryBuffer> ExecutionEngine::obj_file() const {
82   return obj_file_ ? MemoryBuffer::getMemBuffer(obj_file_->getMemBufferRef())
83                    : nullptr;
84 }
85 
86 // -------------------------------------------------------------------------- //
87 
GetEntrypointName(StringRef name)88 static std::string GetEntrypointName(StringRef name) {
89   return llvm::formatv("__xla__{0}", name);
90 }
91 
92 // Converts entrypoint function to an interface function that wraps all the
93 // arguments of the original function into an i8** pointer to provide a function
94 // with trivial ABI.
SetUpEntrypointFunction(llvm::Module & module,StringRef entrypoint)95 static llvm::Error SetUpEntrypointFunction(llvm::Module &module,
96                                            StringRef entrypoint) {
97   llvm::IRBuilder<> builder(module.getContext());
98 
99   // Check that we have an entrypoint function with a valid type.
100   llvm::Function *func = module.getFunction(entrypoint);
101   if (!func)
102     return MakeStringError("entrypoint function not found: ", entrypoint);
103   if (!func->getReturnType()->isVoidTy())
104     return MakeStringError("entrypoint function must return void");
105 
106   // Add an XLA interface function for the entrypoint.
107   llvm::FunctionType *xla_runtime_type = llvm::FunctionType::get(
108       builder.getVoidTy(), builder.getInt8PtrTy()->getPointerTo(),
109       /*isVarArg=*/false);
110 
111   llvm::FunctionCallee xla_runtime_func = module.getOrInsertFunction(
112       GetEntrypointName(func->getName()), xla_runtime_type);
113 
114   llvm::Function *callee = cast<llvm::Function>(xla_runtime_func.getCallee());
115   llvm::Value *packed_args = callee->arg_begin();
116 
117   // Load arguments from the type erased pointer array and cast them to the
118   // original type.
119   llvm::BasicBlock *bb = llvm::BasicBlock::Create(builder.getContext());
120   bb->insertInto(callee);
121   builder.SetInsertPoint(bb);
122 
123   llvm::SmallVector<llvm::Value *, 8> args;
124   args.reserve(llvm::size(func->args()));
125 
126   for (auto &indexed_arg : llvm::enumerate(func->args())) {
127     llvm::Value *arg_idx = llvm::Constant::getIntegerValue(
128         builder.getInt64Ty(), llvm::APInt(64, indexed_arg.index()));
129     llvm::Value *arg_ptr_ptr =
130         builder.CreateGEP(builder.getInt8PtrTy(), packed_args, arg_idx);
131     llvm::Value *arg_ptr =
132         builder.CreateLoad(builder.getInt8PtrTy(), arg_ptr_ptr);
133     llvm::Type *art_ty = indexed_arg.value().getType();
134     arg_ptr = builder.CreateBitCast(arg_ptr, art_ty->getPointerTo());
135     llvm::Value *arg = builder.CreateLoad(art_ty, arg_ptr);
136     args.push_back(arg);
137   }
138 
139   // Call the implementation function with the extracted arguments.
140   builder.CreateCall(func, args);
141   builder.CreateRetVoid();
142 
143   return llvm::Error::success();
144 }
145 
146 // -------------------------------------------------------------------------- //
147 
148 namespace {
149 // Intercept object compilation to save the object file corresponding to the
150 // XLA executable in the execution engine.
151 class ExecutionEngineObjectCache : public llvm::ObjectCache {
152  public:
153   void notifyObjectCompiled(const llvm::Module *m,
154                             llvm::MemoryBufferRef objBuffer) override;
155 
156   std::unique_ptr<llvm::MemoryBuffer> getObject(const llvm::Module *m) override;
157 
158   // Transfer memory buffer from the cache to the caller.
159   std::unique_ptr<llvm::MemoryBuffer> stealObject(const llvm::Module *m);
160 
161  private:
162   llvm::DenseMap<const llvm::Module *, std::unique_ptr<llvm::MemoryBuffer>>
163       objs_;
164 };
165 }  // namespace
166 
notifyObjectCompiled(const llvm::Module * m,llvm::MemoryBufferRef objBuffer)167 void ExecutionEngineObjectCache::notifyObjectCompiled(
168     const llvm::Module *m, llvm::MemoryBufferRef objBuffer) {
169   objs_[m] = llvm::MemoryBuffer::getMemBufferCopy(
170       objBuffer.getBuffer(), objBuffer.getBufferIdentifier());
171 }
172 
getObject(const llvm::Module * m)173 std::unique_ptr<llvm::MemoryBuffer> ExecutionEngineObjectCache::getObject(
174     const llvm::Module *m) {
175   auto it = objs_.find(m);
176   if (it == objs_.end()) return nullptr;
177   return llvm::MemoryBuffer::getMemBuffer(it->second->getMemBufferRef());
178 }
179 
stealObject(const llvm::Module * m)180 std::unique_ptr<llvm::MemoryBuffer> ExecutionEngineObjectCache::stealObject(
181     const llvm::Module *m) {
182   auto it = objs_.find(m);
183   if (it == objs_.end()) return nullptr;
184   return std::move(it->second);
185 }
186 
187 // -------------------------------------------------------------------------- //
188 
189 /*static*/ Expected<std::unique_ptr<ExecutionEngine>>
CreateFromModule(std::unique_ptr<llvm::LLVMContext> ctx,std::unique_ptr<llvm::Module> module,StringRef entrypoint,JitOptions options)190 ExecutionEngine::CreateFromModule(std::unique_ptr<llvm::LLVMContext> ctx,
191                                   std::unique_ptr<llvm::Module> module,
192                                   StringRef entrypoint, JitOptions options) {
193   auto engine = std::unique_ptr<ExecutionEngine>(new ExecutionEngine(
194       options.enable_gdb_listener, options.enable_perf_listener));
195 
196   // We'll need module pointer later to lookup object file in the cache.
197   llvm::Module *module_ptr = module.get();
198 
199   // Set up the target machine details.
200   if (!options.target_machine)
201     return MakeStringError("target machine was not provided");
202   module->setDataLayout(options.target_machine->createDataLayout());
203   module->setTargetTriple(options.target_machine->getTargetTriple().str());
204 
205   // Run an optimization pipeline over the LLVM module.
206   auto transformer = options.make_optimizing_transformer(
207       options.opt_level, /*sizeLevel=*/0, options.target_machine);
208   if (auto err = transformer(module_ptr))
209     return MakeStringError("failed to run optimization pipeline: ", err);
210 
211   // Set up the entry point function compatible with XLA ABI.
212   if (auto err = SetUpEntrypointFunction(*module, entrypoint))
213     return MakeStringError("failed to set up entrypoint ABI: ", err);
214 
215   // Callback to create the object layer with a user-provided section memory
216   // mapper and JIT event listeners.
217   auto obj_layer_creator = [&](ExecutionSession &session, const Triple &tt) {
218     auto obj_layer = std::make_unique<RTDyldObjectLinkingLayer>(
219         session, [section_memory_mapper = options.section_memory_mapper]() {
220           return std::make_unique<SectionMemoryManager>(section_memory_mapper);
221         });
222 
223     // Register JIT event listeners if they are enabled.
224     if (engine->gdb_listener_)
225       obj_layer->registerJITEventListener(*engine->gdb_listener_);
226     if (engine->perf_listener_)
227       obj_layer->registerJITEventListener(*engine->perf_listener_);
228 
229     return obj_layer;
230   };
231 
232   // Optionally enable cache for compiled object files.
233   std::unique_ptr<ExecutionEngineObjectCache> obj_cache =
234       options.save_compiled_obj_file
235           ? std::make_unique<ExecutionEngineObjectCache>()
236           : nullptr;
237 
238   // Callback to compile IR module on demand.
239   auto compile_function_creator = [&](JITTargetMachineBuilder jtmb)
240       -> Expected<std::unique_ptr<IRCompileLayer::IRCompiler>> {
241     jtmb.setCodeGenOptLevel(options.opt_level);
242     auto tm = jtmb.createTargetMachine();
243     if (!tm) return tm.takeError();
244     return std::make_unique<TMOwningSimpleCompiler>(std::move(*tm),
245                                                     obj_cache.get());
246   };
247 
248   // Construct the LLJIT with the given compiler and object linking layers.
249   auto jit = llvm::orc::LLJITBuilder()
250                  .setCompileFunctionCreator(compile_function_creator)
251                  .setObjectLinkingLayerCreator(obj_layer_creator)
252                  .create();
253   if (auto err = jit.takeError())
254     return MakeStringError("failed to construct LLJIT: ", err);
255 
256   // Register input module with the LLJIT.
257   ThreadSafeModule tsm(std::move(module), std::move(ctx));
258   if (auto err = (*jit)->addIRModule(std::move(tsm)))
259     return MakeStringError("failed to add source module: ", err);
260 
261   llvm::orc::JITDylib &main_jd = (*jit)->getMainJITDylib();
262   llvm::DataLayout data_layout = (*jit)->getDataLayout();
263 
264   // Register symbols that are statically linked in the current process.
265   auto generator = DynamicLibrarySearchGenerator::GetForCurrentProcess(
266       data_layout.getGlobalPrefix());
267   if (auto err = generator.takeError())
268     return MakeStringError("failed to construct DyLib search generator");
269   main_jd.addGenerator(std::move(*generator));
270 
271   // Register user-provided symbols.
272   if (options.symbols_binding) {
273     auto mangle = llvm::orc::MangleAndInterner(main_jd.getExecutionSession(),
274                                                data_layout);
275     auto symbols = absoluteSymbols(options.symbols_binding(mangle));
276     if (auto err = main_jd.define(symbols))
277       return MakeStringError("failed to add symbols bindings: ", err);
278   }
279 
280   // Trigger compilation by looking up the entrypoint function.
281   Expected<ExecutorAddr> addr = (*jit)->lookup(GetEntrypointName(entrypoint));
282   if (auto err = addr.takeError())
283     return MakeStringError("failed to compile the entrypoint: ", err);
284 
285   // Check that we found an address of an entrypoint function.
286   auto ptr = addr->toPtr<EntrypointFunctionPtr>();
287   if (!ptr) return MakeStringError("entrypoint function resolved to null");
288 
289   // Check that if we enabled object cache we have an object file for the
290   // compiled module.
291   std::unique_ptr<llvm::MemoryBuffer> obj_file =
292       options.save_compiled_obj_file ? obj_cache->stealObject(module_ptr)
293                                      : nullptr;
294   if (options.save_compiled_obj_file && !obj_file)
295     return MakeStringError("could not find object file for the XLA module");
296 
297   // Fill remaining fields and return constructed ExecutionEngine to the caller.
298   engine->jit_ = std::move(*jit);
299   engine->entrypoint_ptr_ = ptr;
300   engine->obj_file_ = std::move(obj_file);
301   return std::move(engine);
302 }
303 
304 /*static*/ Expected<std::unique_ptr<ExecutionEngine>>
CreateFromObjFile(std::unique_ptr<llvm::MemoryBuffer> obj_file,llvm::StringRef entrypoint,AotOptions options)305 ExecutionEngine::CreateFromObjFile(std::unique_ptr<llvm::MemoryBuffer> obj_file,
306                                    llvm::StringRef entrypoint,
307                                    AotOptions options) {
308   auto engine = std::unique_ptr<ExecutionEngine>(new ExecutionEngine(
309       options.enable_gdb_listener, options.enable_perf_listener));
310 
311   // Callback to create the object layer with a user-provided section memory
312   // mapper and JIT event listeners.
313   auto obj_layer_creator = [&](ExecutionSession &session, const Triple &tt) {
314     auto obj_layer = std::make_unique<RTDyldObjectLinkingLayer>(
315         session, [section_memory_mapper = options.section_memory_mapper]() {
316           return std::make_unique<SectionMemoryManager>(section_memory_mapper);
317         });
318 
319     // Register JIT event listeners if they are enabled.
320     if (engine->gdb_listener_)
321       obj_layer->registerJITEventListener(*engine->gdb_listener_);
322     if (engine->perf_listener_)
323       obj_layer->registerJITEventListener(*engine->perf_listener_);
324 
325     return obj_layer;
326   };
327 
328   // Construct the LLJIT with the given compiler and object linking layers.
329   auto jit = llvm::orc::LLJITBuilder()
330                  .setObjectLinkingLayerCreator(obj_layer_creator)
331                  .create();
332   if (auto err = jit.takeError())
333     return MakeStringError("failed to construct LLJIT: ", err);
334 
335   if (auto err = (*jit)->addObjectFile(std::move(obj_file)))
336     return MakeStringError("failed to add object file: ", err);
337 
338   llvm::orc::JITDylib &main_jd = (*jit)->getMainJITDylib();
339   llvm::DataLayout data_layout = (*jit)->getDataLayout();
340 
341   // Register symbols that are statically linked in the current process.
342   auto generator = DynamicLibrarySearchGenerator::GetForCurrentProcess(
343       data_layout.getGlobalPrefix());
344   if (auto err = generator.takeError())
345     return MakeStringError("failed to construct DyLib search generator");
346   main_jd.addGenerator(std::move(*generator));
347 
348   // Register user-provided symbols.
349   if (options.symbols_binding) {
350     auto mangle = llvm::orc::MangleAndInterner(main_jd.getExecutionSession(),
351                                                data_layout);
352     auto symbols = absoluteSymbols(options.symbols_binding(mangle));
353     if (auto err = main_jd.define(symbols))
354       return MakeStringError("failed to add symbols bindings: ", err);
355   }
356 
357   // Lookup entrypoint in the loaded object file.
358   Expected<ExecutorAddr> addr = (*jit)->lookup(GetEntrypointName(entrypoint));
359   if (auto err = addr.takeError())
360     return MakeStringError("failed to lookup the entrypoint: ", err);
361 
362   // Check that we found an address of an entrypoint function.
363   auto ptr = addr->toPtr<EntrypointFunctionPtr>();
364   if (!ptr) return MakeStringError("entrypoint function resolved to null");
365 
366   // Fill remaining fields and return constructed ExecutionEngine to the caller.
367   engine->jit_ = std::move(*jit);
368   engine->entrypoint_ptr_ = ptr;
369   return std::move(engine);
370 }
371 
372 }  // namespace runtime
373 }  // namespace xla
374