• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //===- KaleidoscopeJIT.h - A simple JIT for Kaleidoscope --------*- C++ -*-===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // Contains a simple JIT definition for use in the kaleidoscope tutorials.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #ifndef LLVM_EXECUTIONENGINE_ORC_KALEIDOSCOPEJIT_H
15 #define LLVM_EXECUTIONENGINE_ORC_KALEIDOSCOPEJIT_H
16 
17 #include "RemoteJITUtils.h"
18 #include "llvm/ADT/STLExtras.h"
19 #include "llvm/ADT/SmallVector.h"
20 #include "llvm/ADT/Triple.h"
21 #include "llvm/ExecutionEngine/ExecutionEngine.h"
22 #include "llvm/ExecutionEngine/JITSymbol.h"
23 #include "llvm/ExecutionEngine/Orc/CompileUtils.h"
24 #include "llvm/ExecutionEngine/Orc/IRCompileLayer.h"
25 #include "llvm/ExecutionEngine/Orc/IRTransformLayer.h"
26 #include "llvm/ExecutionEngine/Orc/IndirectionUtils.h"
27 #include "llvm/ExecutionEngine/Orc/LambdaResolver.h"
28 #include "llvm/ExecutionEngine/Orc/OrcRemoteTargetClient.h"
29 #include "llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h"
30 #include "llvm/IR/DataLayout.h"
31 #include "llvm/IR/LegacyPassManager.h"
32 #include "llvm/IR/Mangler.h"
33 #include "llvm/Support/DynamicLibrary.h"
34 #include "llvm/Support/Error.h"
35 #include "llvm/Support/raw_ostream.h"
36 #include "llvm/Target/TargetMachine.h"
37 #include "llvm/Transforms/InstCombine/InstCombine.h"
38 #include "llvm/Transforms/Scalar.h"
39 #include "llvm/Transforms/Scalar/GVN.h"
40 #include <algorithm>
41 #include <cassert>
42 #include <cstdlib>
43 #include <map>
44 #include <memory>
45 #include <string>
46 #include <vector>
47 
48 class PrototypeAST;
49 class ExprAST;
50 
51 /// FunctionAST - This class represents a function definition itself.
52 class FunctionAST {
53   std::unique_ptr<PrototypeAST> Proto;
54   std::unique_ptr<ExprAST> Body;
55 
56 public:
FunctionAST(std::unique_ptr<PrototypeAST> Proto,std::unique_ptr<ExprAST> Body)57   FunctionAST(std::unique_ptr<PrototypeAST> Proto,
58               std::unique_ptr<ExprAST> Body)
59       : Proto(std::move(Proto)), Body(std::move(Body)) {}
60 
61   const PrototypeAST& getProto() const;
62   const std::string& getName() const;
63   llvm::Function *codegen();
64 };
65 
66 /// This will compile FnAST to IR, rename the function to add the given
67 /// suffix (needed to prevent a name-clash with the function's stub),
68 /// and then take ownership of the module that the function was compiled
69 /// into.
70 std::unique_ptr<llvm::Module>
71 irgenAndTakeOwnership(FunctionAST &FnAST, const std::string &Suffix);
72 
73 namespace llvm {
74 namespace orc {
75 
76 // Typedef the remote-client API.
77 using MyRemote = remote::OrcRemoteTargetClient;
78 
79 class KaleidoscopeJIT {
80 private:
81   ExecutionSession &ES;
82   std::shared_ptr<SymbolResolver> Resolver;
83   std::unique_ptr<TargetMachine> TM;
84   const DataLayout DL;
85   RTDyldObjectLinkingLayer ObjectLayer;
86   IRCompileLayer<decltype(ObjectLayer), SimpleCompiler> CompileLayer;
87 
88   using OptimizeFunction =
89       std::function<std::unique_ptr<Module>(std::unique_ptr<Module>)>;
90 
91   IRTransformLayer<decltype(CompileLayer), OptimizeFunction> OptimizeLayer;
92 
93   JITCompileCallbackManager *CompileCallbackMgr;
94   std::unique_ptr<IndirectStubsManager> IndirectStubsMgr;
95   MyRemote &Remote;
96 
97 public:
KaleidoscopeJIT(ExecutionSession & ES,MyRemote & Remote)98   KaleidoscopeJIT(ExecutionSession &ES, MyRemote &Remote)
99       : ES(ES),
100         Resolver(createLegacyLookupResolver(
101             ES,
102             [this](const std::string &Name) -> JITSymbol {
103               if (auto Sym = IndirectStubsMgr->findStub(Name, false))
104                 return Sym;
105               if (auto Sym = OptimizeLayer.findSymbol(Name, false))
106                 return Sym;
107               else if (auto Err = Sym.takeError())
108                 return std::move(Err);
109               if (auto Addr = cantFail(this->Remote.getSymbolAddress(Name)))
110                 return JITSymbol(Addr, JITSymbolFlags::Exported);
111               return nullptr;
112             },
113             [](Error Err) { cantFail(std::move(Err), "lookupFlags failed"); })),
114         TM(EngineBuilder().selectTarget(Triple(Remote.getTargetTriple()), "",
115                                         "", SmallVector<std::string, 0>())),
116         DL(TM->createDataLayout()),
117         ObjectLayer(ES,
118                     [this](VModuleKey K) {
119                       return RTDyldObjectLinkingLayer::Resources{
120                           cantFail(this->Remote.createRemoteMemoryManager()),
121                           Resolver};
122                     }),
123         CompileLayer(ObjectLayer, SimpleCompiler(*TM)),
124         OptimizeLayer(CompileLayer,
125                       [this](std::unique_ptr<Module> M) {
126                         return optimizeModule(std::move(M));
127                       }),
128         Remote(Remote) {
129     auto CCMgrOrErr = Remote.enableCompileCallbacks(0);
130     if (!CCMgrOrErr) {
131       logAllUnhandledErrors(CCMgrOrErr.takeError(), errs(),
132                             "Error enabling remote compile callbacks:");
133       exit(1);
134     }
135     CompileCallbackMgr = &*CCMgrOrErr;
136     IndirectStubsMgr = cantFail(Remote.createIndirectStubsManager());
137     llvm::sys::DynamicLibrary::LoadLibraryPermanently(nullptr);
138   }
139 
getTargetMachine()140   TargetMachine &getTargetMachine() { return *TM; }
141 
addModule(std::unique_ptr<Module> M)142   VModuleKey addModule(std::unique_ptr<Module> M) {
143     // Add the module with a new VModuleKey.
144     auto K = ES.allocateVModule();
145     cantFail(OptimizeLayer.addModule(K, std::move(M)));
146     return K;
147   }
148 
addFunctionAST(std::unique_ptr<FunctionAST> FnAST)149   Error addFunctionAST(std::unique_ptr<FunctionAST> FnAST) {
150     // Move ownership of FnAST to a shared pointer - C++11 lambdas don't support
151     // capture-by-move, which is be required for unique_ptr.
152     auto SharedFnAST = std::shared_ptr<FunctionAST>(std::move(FnAST));
153 
154     // Set the action to compile our AST. This lambda will be run if/when
155     // execution hits the compile callback (via the stub).
156     //
157     // The steps to compile are:
158     // (1) IRGen the function.
159     // (2) Add the IR module to the JIT to make it executable like any other
160     //     module.
161     // (3) Use findSymbol to get the address of the compiled function.
162     // (4) Update the stub pointer to point at the implementation so that
163     ///    subsequent calls go directly to it and bypass the compiler.
164     // (5) Return the address of the implementation: this lambda will actually
165     //     be run inside an attempted call to the function, and we need to
166     //     continue on to the implementation to complete the attempted call.
167     //     The JIT runtime (the resolver block) will use the return address of
168     //     this function as the address to continue at once it has reset the
169     //     CPU state to what it was immediately before the call.
170     auto CompileAction = [this, SharedFnAST]() {
171       auto M = irgenAndTakeOwnership(*SharedFnAST, "$impl");
172       addModule(std::move(M));
173       auto Sym = findSymbol(SharedFnAST->getName() + "$impl");
174       assert(Sym && "Couldn't find compiled function?");
175       JITTargetAddress SymAddr = cantFail(Sym.getAddress());
176       if (auto Err = IndirectStubsMgr->updatePointer(
177               mangle(SharedFnAST->getName()), SymAddr)) {
178         logAllUnhandledErrors(std::move(Err), errs(),
179                               "Error updating function pointer: ");
180         exit(1);
181       }
182 
183       return SymAddr;
184     };
185 
186     // Create a CompileCallback suing the CompileAction - this is the re-entry
187     // point into the compiler for functions that haven't been compiled yet.
188     auto CCAddr = cantFail(
189         CompileCallbackMgr->getCompileCallback(std::move(CompileAction)));
190 
191     // Create an indirect stub. This serves as the functions "canonical
192     // definition" - an unchanging (constant address) entry point to the
193     // function implementation.
194     // Initially we point the stub's function-pointer at the compile callback
195     // that we just created. In the compile action for the callback we will
196     // update the stub's function pointer to point at the function
197     // implementation that we just implemented.
198     if (auto Err = IndirectStubsMgr->createStub(
199             mangle(SharedFnAST->getName()), CCAddr, JITSymbolFlags::Exported))
200       return Err;
201 
202     return Error::success();
203   }
204 
executeRemoteExpr(JITTargetAddress ExprAddr)205   Error executeRemoteExpr(JITTargetAddress ExprAddr) {
206     return Remote.callVoidVoid(ExprAddr);
207   }
208 
findSymbol(const std::string Name)209   JITSymbol findSymbol(const std::string Name) {
210     return OptimizeLayer.findSymbol(mangle(Name), true);
211   }
212 
removeModule(VModuleKey K)213   void removeModule(VModuleKey K) {
214     cantFail(OptimizeLayer.removeModule(K));
215   }
216 
217 private:
mangle(const std::string & Name)218   std::string mangle(const std::string &Name) {
219     std::string MangledName;
220     raw_string_ostream MangledNameStream(MangledName);
221     Mangler::getNameWithPrefix(MangledNameStream, Name, DL);
222     return MangledNameStream.str();
223   }
224 
optimizeModule(std::unique_ptr<Module> M)225   std::unique_ptr<Module> optimizeModule(std::unique_ptr<Module> M) {
226     // Create a function pass manager.
227     auto FPM = llvm::make_unique<legacy::FunctionPassManager>(M.get());
228 
229     // Add some optimizations.
230     FPM->add(createInstructionCombiningPass());
231     FPM->add(createReassociatePass());
232     FPM->add(createGVNPass());
233     FPM->add(createCFGSimplificationPass());
234     FPM->doInitialization();
235 
236     // Run the optimizations over all functions in the module being added to
237     // the JIT.
238     for (auto &F : *M)
239       FPM->run(F);
240 
241     return M;
242   }
243 };
244 
245 } // end namespace orc
246 } // end namespace llvm
247 
248 #endif // LLVM_EXECUTIONENGINE_ORC_KALEIDOSCOPEJIT_H
249