1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/runtime/execution_engine.h"
17
18 #include <memory>
19 #include <string>
20 #include <utility>
21 #include <vector>
22
23 #include "llvm/ExecutionEngine/JITEventListener.h"
24 #include "llvm/ExecutionEngine/ObjectCache.h"
25 #include "llvm/ExecutionEngine/Orc/CompileUtils.h"
26 #include "llvm/ExecutionEngine/Orc/ExecutionUtils.h"
27 #include "llvm/ExecutionEngine/Orc/IRCompileLayer.h"
28 #include "llvm/ExecutionEngine/Orc/IRTransformLayer.h"
29 #include "llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h"
30 #include "llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h"
31 #include "llvm/IR/IRBuilder.h"
32 #include "llvm/Support/Error.h"
33 #include "llvm/Support/FormatVariadic.h"
34 #include "llvm/Support/MemoryBuffer.h"
35 #include "tensorflow/compiler/xla/runtime/errors.h"
36
37 namespace xla {
38 namespace runtime {
39
40 using llvm::cast;
41
42 using llvm::Expected;
43 using llvm::MemoryBuffer;
44 using llvm::SectionMemoryManager;
45 using llvm::StringRef;
46 using llvm::Triple;
47
48 using llvm::orc::DynamicLibrarySearchGenerator;
49 using llvm::orc::ExecutionSession;
50 using llvm::orc::ExecutorAddr;
51 using llvm::orc::IRCompileLayer;
52 using llvm::orc::JITTargetMachineBuilder;
53 using llvm::orc::RTDyldObjectLinkingLayer;
54 using llvm::orc::SymbolMap;
55 using llvm::orc::ThreadSafeModule;
56 using llvm::orc::TMOwningSimpleCompiler;
57
ExecutionEngine(bool enable_gdb_listener,bool enable_perf_listener)58 ExecutionEngine::ExecutionEngine(bool enable_gdb_listener,
59 bool enable_perf_listener) {
60 if (enable_gdb_listener)
61 gdb_listener_ = llvm::JITEventListener::createGDBRegistrationListener();
62 if (enable_perf_listener)
63 perf_listener_ = llvm::JITEventListener::createPerfJITEventListener();
64 }
65
BindAll(std::vector<SymbolsBinding> bindings)66 /*static*/ ExecutionEngine::SymbolsBinding ExecutionEngine::BindAll(
67 std::vector<SymbolsBinding> bindings) {
68 return [b = std::move(bindings)](llvm::orc::MangleAndInterner mangle) {
69 llvm::orc::SymbolMap symbol_map;
70
71 for (const SymbolsBinding &binding : b) {
72 if (!binding) continue;
73 auto symbols = binding(mangle);
74 symbol_map.insert(symbols.begin(), symbols.end());
75 }
76
77 return symbol_map;
78 };
79 }
80
obj_file() const81 std::unique_ptr<MemoryBuffer> ExecutionEngine::obj_file() const {
82 return obj_file_ ? MemoryBuffer::getMemBuffer(obj_file_->getMemBufferRef())
83 : nullptr;
84 }
85
86 // -------------------------------------------------------------------------- //
87
GetEntrypointName(StringRef name)88 static std::string GetEntrypointName(StringRef name) {
89 return llvm::formatv("__xla__{0}", name);
90 }
91
92 // Converts entrypoint function to an interface function that wraps all the
93 // arguments of the original function into an i8** pointer to provide a function
94 // with trivial ABI.
SetUpEntrypointFunction(llvm::Module & module,StringRef entrypoint)95 static llvm::Error SetUpEntrypointFunction(llvm::Module &module,
96 StringRef entrypoint) {
97 llvm::IRBuilder<> builder(module.getContext());
98
99 // Check that we have an entrypoint function with a valid type.
100 llvm::Function *func = module.getFunction(entrypoint);
101 if (!func)
102 return MakeStringError("entrypoint function not found: ", entrypoint);
103 if (!func->getReturnType()->isVoidTy())
104 return MakeStringError("entrypoint function must return void");
105
106 // Add an XLA interface function for the entrypoint.
107 llvm::FunctionType *xla_runtime_type = llvm::FunctionType::get(
108 builder.getVoidTy(), builder.getInt8PtrTy()->getPointerTo(),
109 /*isVarArg=*/false);
110
111 llvm::FunctionCallee xla_runtime_func = module.getOrInsertFunction(
112 GetEntrypointName(func->getName()), xla_runtime_type);
113
114 llvm::Function *callee = cast<llvm::Function>(xla_runtime_func.getCallee());
115 llvm::Value *packed_args = callee->arg_begin();
116
117 // Load arguments from the type erased pointer array and cast them to the
118 // original type.
119 llvm::BasicBlock *bb = llvm::BasicBlock::Create(builder.getContext());
120 bb->insertInto(callee);
121 builder.SetInsertPoint(bb);
122
123 llvm::SmallVector<llvm::Value *, 8> args;
124 args.reserve(llvm::size(func->args()));
125
126 for (auto &indexed_arg : llvm::enumerate(func->args())) {
127 llvm::Value *arg_idx = llvm::Constant::getIntegerValue(
128 builder.getInt64Ty(), llvm::APInt(64, indexed_arg.index()));
129 llvm::Value *arg_ptr_ptr =
130 builder.CreateGEP(builder.getInt8PtrTy(), packed_args, arg_idx);
131 llvm::Value *arg_ptr =
132 builder.CreateLoad(builder.getInt8PtrTy(), arg_ptr_ptr);
133 llvm::Type *art_ty = indexed_arg.value().getType();
134 arg_ptr = builder.CreateBitCast(arg_ptr, art_ty->getPointerTo());
135 llvm::Value *arg = builder.CreateLoad(art_ty, arg_ptr);
136 args.push_back(arg);
137 }
138
139 // Call the implementation function with the extracted arguments.
140 builder.CreateCall(func, args);
141 builder.CreateRetVoid();
142
143 return llvm::Error::success();
144 }
145
146 // -------------------------------------------------------------------------- //
147
148 namespace {
149 // Intercept object compilation to save the object file corresponding to the
150 // XLA executable in the execution engine.
151 class ExecutionEngineObjectCache : public llvm::ObjectCache {
152 public:
153 void notifyObjectCompiled(const llvm::Module *m,
154 llvm::MemoryBufferRef objBuffer) override;
155
156 std::unique_ptr<llvm::MemoryBuffer> getObject(const llvm::Module *m) override;
157
158 // Transfer memory buffer from the cache to the caller.
159 std::unique_ptr<llvm::MemoryBuffer> stealObject(const llvm::Module *m);
160
161 private:
162 llvm::DenseMap<const llvm::Module *, std::unique_ptr<llvm::MemoryBuffer>>
163 objs_;
164 };
165 } // namespace
166
notifyObjectCompiled(const llvm::Module * m,llvm::MemoryBufferRef objBuffer)167 void ExecutionEngineObjectCache::notifyObjectCompiled(
168 const llvm::Module *m, llvm::MemoryBufferRef objBuffer) {
169 objs_[m] = llvm::MemoryBuffer::getMemBufferCopy(
170 objBuffer.getBuffer(), objBuffer.getBufferIdentifier());
171 }
172
getObject(const llvm::Module * m)173 std::unique_ptr<llvm::MemoryBuffer> ExecutionEngineObjectCache::getObject(
174 const llvm::Module *m) {
175 auto it = objs_.find(m);
176 if (it == objs_.end()) return nullptr;
177 return llvm::MemoryBuffer::getMemBuffer(it->second->getMemBufferRef());
178 }
179
stealObject(const llvm::Module * m)180 std::unique_ptr<llvm::MemoryBuffer> ExecutionEngineObjectCache::stealObject(
181 const llvm::Module *m) {
182 auto it = objs_.find(m);
183 if (it == objs_.end()) return nullptr;
184 return std::move(it->second);
185 }
186
187 // -------------------------------------------------------------------------- //
188
189 /*static*/ Expected<std::unique_ptr<ExecutionEngine>>
CreateFromModule(std::unique_ptr<llvm::LLVMContext> ctx,std::unique_ptr<llvm::Module> module,StringRef entrypoint,JitOptions options)190 ExecutionEngine::CreateFromModule(std::unique_ptr<llvm::LLVMContext> ctx,
191 std::unique_ptr<llvm::Module> module,
192 StringRef entrypoint, JitOptions options) {
193 auto engine = std::unique_ptr<ExecutionEngine>(new ExecutionEngine(
194 options.enable_gdb_listener, options.enable_perf_listener));
195
196 // We'll need module pointer later to lookup object file in the cache.
197 llvm::Module *module_ptr = module.get();
198
199 // Set up the target machine details.
200 if (!options.target_machine)
201 return MakeStringError("target machine was not provided");
202 module->setDataLayout(options.target_machine->createDataLayout());
203 module->setTargetTriple(options.target_machine->getTargetTriple().str());
204
205 // Run an optimization pipeline over the LLVM module.
206 auto transformer = options.make_optimizing_transformer(
207 options.opt_level, /*sizeLevel=*/0, options.target_machine);
208 if (auto err = transformer(module_ptr))
209 return MakeStringError("failed to run optimization pipeline: ", err);
210
211 // Set up the entry point function compatible with XLA ABI.
212 if (auto err = SetUpEntrypointFunction(*module, entrypoint))
213 return MakeStringError("failed to set up entrypoint ABI: ", err);
214
215 // Callback to create the object layer with a user-provided section memory
216 // mapper and JIT event listeners.
217 auto obj_layer_creator = [&](ExecutionSession &session, const Triple &tt) {
218 auto obj_layer = std::make_unique<RTDyldObjectLinkingLayer>(
219 session, [section_memory_mapper = options.section_memory_mapper]() {
220 return std::make_unique<SectionMemoryManager>(section_memory_mapper);
221 });
222
223 // Register JIT event listeners if they are enabled.
224 if (engine->gdb_listener_)
225 obj_layer->registerJITEventListener(*engine->gdb_listener_);
226 if (engine->perf_listener_)
227 obj_layer->registerJITEventListener(*engine->perf_listener_);
228
229 return obj_layer;
230 };
231
232 // Optionally enable cache for compiled object files.
233 std::unique_ptr<ExecutionEngineObjectCache> obj_cache =
234 options.save_compiled_obj_file
235 ? std::make_unique<ExecutionEngineObjectCache>()
236 : nullptr;
237
238 // Callback to compile IR module on demand.
239 auto compile_function_creator = [&](JITTargetMachineBuilder jtmb)
240 -> Expected<std::unique_ptr<IRCompileLayer::IRCompiler>> {
241 jtmb.setCodeGenOptLevel(options.opt_level);
242 auto tm = jtmb.createTargetMachine();
243 if (!tm) return tm.takeError();
244 return std::make_unique<TMOwningSimpleCompiler>(std::move(*tm),
245 obj_cache.get());
246 };
247
248 // Construct the LLJIT with the given compiler and object linking layers.
249 auto jit = llvm::orc::LLJITBuilder()
250 .setCompileFunctionCreator(compile_function_creator)
251 .setObjectLinkingLayerCreator(obj_layer_creator)
252 .create();
253 if (auto err = jit.takeError())
254 return MakeStringError("failed to construct LLJIT: ", err);
255
256 // Register input module with the LLJIT.
257 ThreadSafeModule tsm(std::move(module), std::move(ctx));
258 if (auto err = (*jit)->addIRModule(std::move(tsm)))
259 return MakeStringError("failed to add source module: ", err);
260
261 llvm::orc::JITDylib &main_jd = (*jit)->getMainJITDylib();
262 llvm::DataLayout data_layout = (*jit)->getDataLayout();
263
264 // Register symbols that are statically linked in the current process.
265 auto generator = DynamicLibrarySearchGenerator::GetForCurrentProcess(
266 data_layout.getGlobalPrefix());
267 if (auto err = generator.takeError())
268 return MakeStringError("failed to construct DyLib search generator");
269 main_jd.addGenerator(std::move(*generator));
270
271 // Register user-provided symbols.
272 if (options.symbols_binding) {
273 auto mangle = llvm::orc::MangleAndInterner(main_jd.getExecutionSession(),
274 data_layout);
275 auto symbols = absoluteSymbols(options.symbols_binding(mangle));
276 if (auto err = main_jd.define(symbols))
277 return MakeStringError("failed to add symbols bindings: ", err);
278 }
279
280 // Trigger compilation by looking up the entrypoint function.
281 Expected<ExecutorAddr> addr = (*jit)->lookup(GetEntrypointName(entrypoint));
282 if (auto err = addr.takeError())
283 return MakeStringError("failed to compile the entrypoint: ", err);
284
285 // Check that we found an address of an entrypoint function.
286 auto ptr = addr->toPtr<EntrypointFunctionPtr>();
287 if (!ptr) return MakeStringError("entrypoint function resolved to null");
288
289 // Check that if we enabled object cache we have an object file for the
290 // compiled module.
291 std::unique_ptr<llvm::MemoryBuffer> obj_file =
292 options.save_compiled_obj_file ? obj_cache->stealObject(module_ptr)
293 : nullptr;
294 if (options.save_compiled_obj_file && !obj_file)
295 return MakeStringError("could not find object file for the XLA module");
296
297 // Fill remaining fields and return constructed ExecutionEngine to the caller.
298 engine->jit_ = std::move(*jit);
299 engine->entrypoint_ptr_ = ptr;
300 engine->obj_file_ = std::move(obj_file);
301 return std::move(engine);
302 }
303
304 /*static*/ Expected<std::unique_ptr<ExecutionEngine>>
CreateFromObjFile(std::unique_ptr<llvm::MemoryBuffer> obj_file,llvm::StringRef entrypoint,AotOptions options)305 ExecutionEngine::CreateFromObjFile(std::unique_ptr<llvm::MemoryBuffer> obj_file,
306 llvm::StringRef entrypoint,
307 AotOptions options) {
308 auto engine = std::unique_ptr<ExecutionEngine>(new ExecutionEngine(
309 options.enable_gdb_listener, options.enable_perf_listener));
310
311 // Callback to create the object layer with a user-provided section memory
312 // mapper and JIT event listeners.
313 auto obj_layer_creator = [&](ExecutionSession &session, const Triple &tt) {
314 auto obj_layer = std::make_unique<RTDyldObjectLinkingLayer>(
315 session, [section_memory_mapper = options.section_memory_mapper]() {
316 return std::make_unique<SectionMemoryManager>(section_memory_mapper);
317 });
318
319 // Register JIT event listeners if they are enabled.
320 if (engine->gdb_listener_)
321 obj_layer->registerJITEventListener(*engine->gdb_listener_);
322 if (engine->perf_listener_)
323 obj_layer->registerJITEventListener(*engine->perf_listener_);
324
325 return obj_layer;
326 };
327
328 // Construct the LLJIT with the given compiler and object linking layers.
329 auto jit = llvm::orc::LLJITBuilder()
330 .setObjectLinkingLayerCreator(obj_layer_creator)
331 .create();
332 if (auto err = jit.takeError())
333 return MakeStringError("failed to construct LLJIT: ", err);
334
335 if (auto err = (*jit)->addObjectFile(std::move(obj_file)))
336 return MakeStringError("failed to add object file: ", err);
337
338 llvm::orc::JITDylib &main_jd = (*jit)->getMainJITDylib();
339 llvm::DataLayout data_layout = (*jit)->getDataLayout();
340
341 // Register symbols that are statically linked in the current process.
342 auto generator = DynamicLibrarySearchGenerator::GetForCurrentProcess(
343 data_layout.getGlobalPrefix());
344 if (auto err = generator.takeError())
345 return MakeStringError("failed to construct DyLib search generator");
346 main_jd.addGenerator(std::move(*generator));
347
348 // Register user-provided symbols.
349 if (options.symbols_binding) {
350 auto mangle = llvm::orc::MangleAndInterner(main_jd.getExecutionSession(),
351 data_layout);
352 auto symbols = absoluteSymbols(options.symbols_binding(mangle));
353 if (auto err = main_jd.define(symbols))
354 return MakeStringError("failed to add symbols bindings: ", err);
355 }
356
357 // Lookup entrypoint in the loaded object file.
358 Expected<ExecutorAddr> addr = (*jit)->lookup(GetEntrypointName(entrypoint));
359 if (auto err = addr.takeError())
360 return MakeStringError("failed to lookup the entrypoint: ", err);
361
362 // Check that we found an address of an entrypoint function.
363 auto ptr = addr->toPtr<EntrypointFunctionPtr>();
364 if (!ptr) return MakeStringError("entrypoint function resolved to null");
365
366 // Fill remaining fields and return constructed ExecutionEngine to the caller.
367 engine->jit_ = std::move(*jit);
368 engine->entrypoint_ptr_ = ptr;
369 return std::move(engine);
370 }
371
372 } // namespace runtime
373 } // namespace xla
374