1 //===------ OrcTestCommon.h - Utilities for Orc Unit Tests ------*- C++ -*-===// 2 // 3 // The LLVM Compiler Infrastructure 4 // 5 // This file is distributed under the University of Illinois Open Source 6 // License. See LICENSE.TXT for details. 7 // 8 //===----------------------------------------------------------------------===// 9 // 10 // Common utilities for the Orc unit tests. 11 // 12 //===----------------------------------------------------------------------===// 13 14 15 #ifndef LLVM_UNITTESTS_EXECUTIONENGINE_ORC_ORCTESTCOMMON_H 16 #define LLVM_UNITTESTS_EXECUTIONENGINE_ORC_ORCTESTCOMMON_H 17 18 #include "llvm/ExecutionEngine/ExecutionEngine.h" 19 #include "llvm/ExecutionEngine/JITSymbol.h" 20 #include "llvm/ExecutionEngine/Orc/IndirectionUtils.h" 21 #include "llvm/IR/Function.h" 22 #include "llvm/IR/IRBuilder.h" 23 #include "llvm/IR/LLVMContext.h" 24 #include "llvm/IR/Module.h" 25 #include "llvm/IR/TypeBuilder.h" 26 #include "llvm/Object/ObjectFile.h" 27 #include "llvm/Support/TargetRegistry.h" 28 #include "llvm/Support/TargetSelect.h" 29 #include "gtest/gtest.h" 30 31 #include <memory> 32 33 namespace llvm { 34 35 namespace orc { 36 // CoreAPIsStandardTest that saves a bunch of boilerplate by providing the 37 // following: 38 // 39 // (1) ES -- An ExecutionSession 40 // (2) Foo, Bar, Baz, Qux -- SymbolStringPtrs for strings "foo", "bar", "baz", 41 // and "qux" respectively. 42 // (3) FooAddr, BarAddr, BazAddr, QuxAddr -- Dummy addresses. Guaranteed 43 // distinct and non-null. 44 // (4) FooSym, BarSym, BazSym, QuxSym -- JITEvaluatedSymbols with FooAddr, 45 // BarAddr, BazAddr, and QuxAddr respectively. All with default strong, 46 // linkage and non-hidden visibility. 47 // (5) V -- A VSO associated with ES. 48 class CoreAPIsBasedStandardTest : public testing::Test { 49 public: 50 protected: 51 ExecutionSession ES; 52 VSO &V = ES.createVSO("V"); 53 SymbolStringPtr Foo = ES.getSymbolStringPool().intern("foo"); 54 SymbolStringPtr Bar = ES.getSymbolStringPool().intern("bar"); 55 SymbolStringPtr Baz = ES.getSymbolStringPool().intern("baz"); 56 SymbolStringPtr Qux = ES.getSymbolStringPool().intern("qux"); 57 static const JITTargetAddress FooAddr = 1U; 58 static const JITTargetAddress BarAddr = 2U; 59 static const JITTargetAddress BazAddr = 3U; 60 static const JITTargetAddress QuxAddr = 4U; 61 JITEvaluatedSymbol FooSym = 62 JITEvaluatedSymbol(FooAddr, JITSymbolFlags::Exported); 63 JITEvaluatedSymbol BarSym = 64 JITEvaluatedSymbol(BarAddr, JITSymbolFlags::Exported); 65 JITEvaluatedSymbol BazSym = 66 JITEvaluatedSymbol(BazAddr, JITSymbolFlags::Exported); 67 JITEvaluatedSymbol QuxSym = 68 JITEvaluatedSymbol(QuxAddr, JITSymbolFlags::Exported); 69 }; 70 71 } // end namespace orc 72 73 class OrcNativeTarget { 74 public: initialize()75 static void initialize() { 76 if (!NativeTargetInitialized) { 77 InitializeNativeTarget(); 78 InitializeNativeTargetAsmParser(); 79 InitializeNativeTargetAsmPrinter(); 80 NativeTargetInitialized = true; 81 } 82 } 83 84 private: 85 static bool NativeTargetInitialized; 86 }; 87 88 // Base class for Orc tests that will execute code. 89 class OrcExecutionTest { 90 public: 91 OrcExecutionTest()92 OrcExecutionTest() { 93 94 // Initialize the native target if it hasn't been done already. 95 OrcNativeTarget::initialize(); 96 97 // Try to select a TargetMachine for the host. 98 TM.reset(EngineBuilder().selectTarget()); 99 100 if (TM) { 101 // If we found a TargetMachine, check that it's one that Orc supports. 102 const Triple& TT = TM->getTargetTriple(); 103 104 // Bail out for windows platforms. We do not support these yet. 105 if ((TT.getArch() != Triple::x86_64 && TT.getArch() != Triple::x86) || 106 TT.isOSWindows()) 107 return; 108 109 // Target can JIT? 110 SupportsJIT = TM->getTarget().hasJIT(); 111 // Use ability to create callback manager to detect whether Orc 112 // has indirection support on this platform. This way the test 113 // and Orc code do not get out of sync. 114 SupportsIndirection = !!orc::createLocalCompileCallbackManager(TT, ES, 0); 115 } 116 }; 117 118 protected: 119 orc::ExecutionSession ES; 120 LLVMContext Context; 121 std::unique_ptr<TargetMachine> TM; 122 bool SupportsJIT = false; 123 bool SupportsIndirection = false; 124 }; 125 126 class ModuleBuilder { 127 public: 128 ModuleBuilder(LLVMContext &Context, StringRef Triple, 129 StringRef Name); 130 131 template <typename FuncType> createFunctionDecl(StringRef Name)132 Function* createFunctionDecl(StringRef Name) { 133 return Function::Create( 134 TypeBuilder<FuncType, false>::get(M->getContext()), 135 GlobalValue::ExternalLinkage, Name, M.get()); 136 } 137 getModule()138 Module* getModule() { return M.get(); } getModule()139 const Module* getModule() const { return M.get(); } takeModule()140 std::unique_ptr<Module> takeModule() { return std::move(M); } 141 142 private: 143 std::unique_ptr<Module> M; 144 }; 145 146 // Dummy struct type. 147 struct DummyStruct { 148 int X[256]; 149 }; 150 151 // TypeBuilder specialization for DummyStruct. 152 template <bool XCompile> 153 class TypeBuilder<DummyStruct, XCompile> { 154 public: get(LLVMContext & Context)155 static StructType *get(LLVMContext &Context) { 156 return StructType::get( 157 TypeBuilder<types::i<32>[256], XCompile>::get(Context)); 158 } 159 }; 160 161 template <typename HandleT, typename ModuleT> 162 class MockBaseLayer { 163 public: 164 165 using ModuleHandleT = HandleT; 166 167 using AddModuleSignature = 168 Expected<ModuleHandleT>(ModuleT M, 169 std::shared_ptr<JITSymbolResolver> R); 170 171 using RemoveModuleSignature = Error(ModuleHandleT H); 172 using FindSymbolSignature = JITSymbol(const std::string &Name, 173 bool ExportedSymbolsOnly); 174 using FindSymbolInSignature = JITSymbol(ModuleHandleT H, 175 const std::string &Name, 176 bool ExportedSymbolsONly); 177 using EmitAndFinalizeSignature = Error(ModuleHandleT H); 178 179 std::function<AddModuleSignature> addModuleImpl; 180 std::function<RemoveModuleSignature> removeModuleImpl; 181 std::function<FindSymbolSignature> findSymbolImpl; 182 std::function<FindSymbolInSignature> findSymbolInImpl; 183 std::function<EmitAndFinalizeSignature> emitAndFinalizeImpl; 184 addModule(ModuleT M,std::shared_ptr<JITSymbolResolver> R)185 Expected<ModuleHandleT> addModule(ModuleT M, 186 std::shared_ptr<JITSymbolResolver> R) { 187 assert(addModuleImpl && 188 "addModule called, but no mock implementation was provided"); 189 return addModuleImpl(std::move(M), std::move(R)); 190 } 191 removeModule(ModuleHandleT H)192 Error removeModule(ModuleHandleT H) { 193 assert(removeModuleImpl && 194 "removeModule called, but no mock implementation was provided"); 195 return removeModuleImpl(H); 196 } 197 findSymbol(const std::string & Name,bool ExportedSymbolsOnly)198 JITSymbol findSymbol(const std::string &Name, bool ExportedSymbolsOnly) { 199 assert(findSymbolImpl && 200 "findSymbol called, but no mock implementation was provided"); 201 return findSymbolImpl(Name, ExportedSymbolsOnly); 202 } 203 findSymbolIn(ModuleHandleT H,const std::string & Name,bool ExportedSymbolsOnly)204 JITSymbol findSymbolIn(ModuleHandleT H, const std::string &Name, 205 bool ExportedSymbolsOnly) { 206 assert(findSymbolInImpl && 207 "findSymbolIn called, but no mock implementation was provided"); 208 return findSymbolInImpl(H, Name, ExportedSymbolsOnly); 209 } 210 emitAndFinaliez(ModuleHandleT H)211 Error emitAndFinaliez(ModuleHandleT H) { 212 assert(emitAndFinalizeImpl && 213 "emitAndFinalize called, but no mock implementation was provided"); 214 return emitAndFinalizeImpl(H); 215 } 216 }; 217 218 class ReturnNullJITSymbol { 219 public: 220 template <typename... Args> operator()221 JITSymbol operator()(Args...) const { 222 return nullptr; 223 } 224 }; 225 226 template <typename ReturnT> 227 class DoNothingAndReturn { 228 public: DoNothingAndReturn(ReturnT Ret)229 DoNothingAndReturn(ReturnT Ret) : Ret(std::move(Ret)) {} 230 231 template <typename... Args> operator()232 void operator()(Args...) const { return Ret; } 233 private: 234 ReturnT Ret; 235 }; 236 237 template <> 238 class DoNothingAndReturn<void> { 239 public: 240 template <typename... Args> operator()241 void operator()(Args...) const { } 242 }; 243 244 } // namespace llvm 245 246 #endif 247