1 //===- jit-runner.cpp - MLIR CPU Execution Driver Library -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This is a library that provides a shared implementation for command line
10 // utilities that execute an MLIR file on the CPU by translating MLIR to LLVM
11 // IR before JIT-compiling and executing the latter.
12 //
13 // The translation can be customized by providing an MLIR to MLIR
14 // transformation.
15 //===----------------------------------------------------------------------===//
16
17 #include "mlir/ExecutionEngine/JitRunner.h"
18
19 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
20 #include "mlir/ExecutionEngine/ExecutionEngine.h"
21 #include "mlir/ExecutionEngine/OptUtils.h"
22 #include "mlir/IR/BuiltinTypes.h"
23 #include "mlir/IR/MLIRContext.h"
24 #include "mlir/InitAllDialects.h"
25 #include "mlir/Parser.h"
26 #include "mlir/Support/FileUtilities.h"
27
28 #include "llvm/ADT/STLExtras.h"
29 #include "llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h"
30 #include "llvm/IR/IRBuilder.h"
31 #include "llvm/IR/LLVMContext.h"
32 #include "llvm/IR/LegacyPassNameParser.h"
33 #include "llvm/Support/CommandLine.h"
34 #include "llvm/Support/FileUtilities.h"
35 #include "llvm/Support/SourceMgr.h"
36 #include "llvm/Support/StringSaver.h"
37 #include "llvm/Support/ToolOutputFile.h"
38 #include <cstdint>
39 #include <numeric>
40
41 using namespace mlir;
42 using llvm::Error;
43
44 namespace {
45 /// This options struct prevents the need for global static initializers, and
46 /// is only initialized if the JITRunner is invoked.
47 struct Options {
48 llvm::cl::opt<std::string> inputFilename{llvm::cl::Positional,
49 llvm::cl::desc("<input file>"),
50 llvm::cl::init("-")};
51 llvm::cl::opt<std::string> mainFuncName{
52 "e", llvm::cl::desc("The function to be called"),
53 llvm::cl::value_desc("<function name>"), llvm::cl::init("main")};
54 llvm::cl::opt<std::string> mainFuncType{
55 "entry-point-result",
56 llvm::cl::desc("Textual description of the function type to be called"),
57 llvm::cl::value_desc("f32 | i32 | i64 | void"), llvm::cl::init("f32")};
58
59 llvm::cl::OptionCategory optFlags{"opt-like flags"};
60
61 // CLI list of pass information
62 llvm::cl::list<const llvm::PassInfo *, bool, llvm::PassNameParser> llvmPasses{
63 llvm::cl::desc("LLVM optimizing passes to run"), llvm::cl::cat(optFlags)};
64
65 // CLI variables for -On options.
66 llvm::cl::opt<bool> optO0{"O0",
67 llvm::cl::desc("Run opt passes and codegen at O0"),
68 llvm::cl::cat(optFlags)};
69 llvm::cl::opt<bool> optO1{"O1",
70 llvm::cl::desc("Run opt passes and codegen at O1"),
71 llvm::cl::cat(optFlags)};
72 llvm::cl::opt<bool> optO2{"O2",
73 llvm::cl::desc("Run opt passes and codegen at O2"),
74 llvm::cl::cat(optFlags)};
75 llvm::cl::opt<bool> optO3{"O3",
76 llvm::cl::desc("Run opt passes and codegen at O3"),
77 llvm::cl::cat(optFlags)};
78
79 llvm::cl::OptionCategory clOptionsCategory{"linking options"};
80 llvm::cl::list<std::string> clSharedLibs{
81 "shared-libs", llvm::cl::desc("Libraries to link dynamically"),
82 llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated,
83 llvm::cl::cat(clOptionsCategory)};
84
85 /// CLI variables for debugging.
86 llvm::cl::opt<bool> dumpObjectFile{
87 "dump-object-file",
88 llvm::cl::desc("Dump JITted-compiled object to file specified with "
89 "-object-filename (<input file>.o by default).")};
90
91 llvm::cl::opt<std::string> objectFilename{
92 "object-filename",
93 llvm::cl::desc("Dump JITted-compiled object to file <input file>.o")};
94 };
95
96 struct CompileAndExecuteConfig {
97 /// LLVM module transformer that is passed to ExecutionEngine.
98 llvm::function_ref<llvm::Error(llvm::Module *)> transformer;
99
100 /// A custom function that is passed to ExecutionEngine. It processes MLIR
101 /// module and creates LLVM IR module.
102 llvm::function_ref<std::unique_ptr<llvm::Module>(ModuleOp,
103 llvm::LLVMContext &)>
104 llvmModuleBuilder;
105
106 /// A custom function that is passed to ExecutinEngine to register symbols at
107 /// runtime.
108 llvm::function_ref<llvm::orc::SymbolMap(llvm::orc::MangleAndInterner)>
109 runtimeSymbolMap;
110 };
111
112 } // end anonymous namespace
113
parseMLIRInput(StringRef inputFilename,MLIRContext * context)114 static OwningModuleRef parseMLIRInput(StringRef inputFilename,
115 MLIRContext *context) {
116 // Set up the input file.
117 std::string errorMessage;
118 auto file = openInputFile(inputFilename, &errorMessage);
119 if (!file) {
120 llvm::errs() << errorMessage << "\n";
121 return nullptr;
122 }
123
124 llvm::SourceMgr sourceMgr;
125 sourceMgr.AddNewSourceBuffer(std::move(file), llvm::SMLoc());
126 return OwningModuleRef(parseSourceFile(sourceMgr, context));
127 }
128
make_string_error(const Twine & message)129 static inline Error make_string_error(const Twine &message) {
130 return llvm::make_error<llvm::StringError>(message.str(),
131 llvm::inconvertibleErrorCode());
132 }
133
getCommandLineOptLevel(Options & options)134 static Optional<unsigned> getCommandLineOptLevel(Options &options) {
135 Optional<unsigned> optLevel;
136 SmallVector<std::reference_wrapper<llvm::cl::opt<bool>>, 4> optFlags{
137 options.optO0, options.optO1, options.optO2, options.optO3};
138
139 // Determine if there is an optimization flag present.
140 for (unsigned j = 0; j < 4; ++j) {
141 auto &flag = optFlags[j].get();
142 if (flag) {
143 optLevel = j;
144 break;
145 }
146 }
147 return optLevel;
148 }
149
150 // JIT-compile the given module and run "entryPoint" with "args" as arguments.
compileAndExecute(Options & options,ModuleOp module,StringRef entryPoint,CompileAndExecuteConfig config,void ** args)151 static Error compileAndExecute(Options &options, ModuleOp module,
152 StringRef entryPoint,
153 CompileAndExecuteConfig config, void **args) {
154 Optional<llvm::CodeGenOpt::Level> jitCodeGenOptLevel;
155 if (auto clOptLevel = getCommandLineOptLevel(options))
156 jitCodeGenOptLevel =
157 static_cast<llvm::CodeGenOpt::Level>(clOptLevel.getValue());
158 SmallVector<StringRef, 4> libs(options.clSharedLibs.begin(),
159 options.clSharedLibs.end());
160 auto expectedEngine = mlir::ExecutionEngine::create(
161 module, config.llvmModuleBuilder, config.transformer, jitCodeGenOptLevel,
162 libs);
163 if (!expectedEngine)
164 return expectedEngine.takeError();
165
166 auto engine = std::move(*expectedEngine);
167 if (config.runtimeSymbolMap)
168 engine->registerSymbols(config.runtimeSymbolMap);
169
170 auto expectedFPtr = engine->lookup(entryPoint);
171 if (!expectedFPtr)
172 return expectedFPtr.takeError();
173
174 if (options.dumpObjectFile)
175 engine->dumpToObjectFile(options.objectFilename.empty()
176 ? options.inputFilename + ".o"
177 : options.objectFilename);
178
179 void (*fptr)(void **) = *expectedFPtr;
180 (*fptr)(args);
181
182 return Error::success();
183 }
184
compileAndExecuteVoidFunction(Options & options,ModuleOp module,StringRef entryPoint,CompileAndExecuteConfig config)185 static Error compileAndExecuteVoidFunction(Options &options, ModuleOp module,
186 StringRef entryPoint,
187 CompileAndExecuteConfig config) {
188 auto mainFunction = module.lookupSymbol<LLVM::LLVMFuncOp>(entryPoint);
189 if (!mainFunction || mainFunction.empty())
190 return make_string_error("entry point not found");
191 void *empty = nullptr;
192 return compileAndExecute(options, module, entryPoint, config, &empty);
193 }
194
195 template <typename Type>
196 Error checkCompatibleReturnType(LLVM::LLVMFuncOp mainFunction);
197 template <>
checkCompatibleReturnType(LLVM::LLVMFuncOp mainFunction)198 Error checkCompatibleReturnType<int32_t>(LLVM::LLVMFuncOp mainFunction) {
199 if (!mainFunction.getType().getFunctionResultType().isIntegerTy(32))
200 return make_string_error("only single llvm.i32 function result supported");
201 return Error::success();
202 }
203 template <>
checkCompatibleReturnType(LLVM::LLVMFuncOp mainFunction)204 Error checkCompatibleReturnType<int64_t>(LLVM::LLVMFuncOp mainFunction) {
205 if (!mainFunction.getType().getFunctionResultType().isIntegerTy(64))
206 return make_string_error("only single llvm.i64 function result supported");
207 return Error::success();
208 }
209 template <>
checkCompatibleReturnType(LLVM::LLVMFuncOp mainFunction)210 Error checkCompatibleReturnType<float>(LLVM::LLVMFuncOp mainFunction) {
211 if (!mainFunction.getType().getFunctionResultType().isFloatTy())
212 return make_string_error("only single llvm.f32 function result supported");
213 return Error::success();
214 }
215 template <typename Type>
compileAndExecuteSingleReturnFunction(Options & options,ModuleOp module,StringRef entryPoint,CompileAndExecuteConfig config)216 Error compileAndExecuteSingleReturnFunction(Options &options, ModuleOp module,
217 StringRef entryPoint,
218 CompileAndExecuteConfig config) {
219 auto mainFunction = module.lookupSymbol<LLVM::LLVMFuncOp>(entryPoint);
220 if (!mainFunction || mainFunction.isExternal())
221 return make_string_error("entry point not found");
222
223 if (mainFunction.getType().getFunctionNumParams() != 0)
224 return make_string_error("function inputs not supported");
225
226 if (Error error = checkCompatibleReturnType<Type>(mainFunction))
227 return error;
228
229 Type res;
230 struct {
231 void *data;
232 } data;
233 data.data = &res;
234 if (auto error = compileAndExecute(options, module, entryPoint, config,
235 (void **)&data))
236 return error;
237
238 // Intentional printing of the output so we can test.
239 llvm::outs() << res << '\n';
240
241 return Error::success();
242 }
243
244 /// Entry point for all CPU runners. Expects the common argc/argv arguments for
245 /// standard C++ main functions.
JitRunnerMain(int argc,char ** argv,JitRunnerConfig config)246 int mlir::JitRunnerMain(int argc, char **argv, JitRunnerConfig config) {
247 // Create the options struct containing the command line options for the
248 // runner. This must come before the command line options are parsed.
249 Options options;
250 llvm::cl::ParseCommandLineOptions(argc, argv, "MLIR CPU execution driver\n");
251
252 Optional<unsigned> optLevel = getCommandLineOptLevel(options);
253 SmallVector<std::reference_wrapper<llvm::cl::opt<bool>>, 4> optFlags{
254 options.optO0, options.optO1, options.optO2, options.optO3};
255 unsigned optCLIPosition = 0;
256 // Determine if there is an optimization flag present, and its CLI position
257 // (optCLIPosition).
258 for (unsigned j = 0; j < 4; ++j) {
259 auto &flag = optFlags[j].get();
260 if (flag) {
261 optCLIPosition = flag.getPosition();
262 break;
263 }
264 }
265 // Generate vector of pass information, plus the index at which we should
266 // insert any optimization passes in that vector (optPosition).
267 SmallVector<const llvm::PassInfo *, 4> passes;
268 unsigned optPosition = 0;
269 for (unsigned i = 0, e = options.llvmPasses.size(); i < e; ++i) {
270 passes.push_back(options.llvmPasses[i]);
271 if (optCLIPosition < options.llvmPasses.getPosition(i)) {
272 optPosition = i;
273 optCLIPosition = UINT_MAX; // To ensure we never insert again
274 }
275 }
276
277 MLIRContext context;
278 registerAllDialects(context.getDialectRegistry());
279
280 auto m = parseMLIRInput(options.inputFilename, &context);
281 if (!m) {
282 llvm::errs() << "could not parse the input IR\n";
283 return 1;
284 }
285
286 if (config.mlirTransformer)
287 if (failed(config.mlirTransformer(m.get())))
288 return EXIT_FAILURE;
289
290 auto tmBuilderOrError = llvm::orc::JITTargetMachineBuilder::detectHost();
291 if (!tmBuilderOrError) {
292 llvm::errs() << "Failed to create a JITTargetMachineBuilder for the host\n";
293 return EXIT_FAILURE;
294 }
295 auto tmOrError = tmBuilderOrError->createTargetMachine();
296 if (!tmOrError) {
297 llvm::errs() << "Failed to create a TargetMachine for the host\n";
298 return EXIT_FAILURE;
299 }
300
301 auto transformer = mlir::makeLLVMPassesTransformer(
302 passes, optLevel, /*targetMachine=*/tmOrError->get(), optPosition);
303
304 CompileAndExecuteConfig compileAndExecuteConfig;
305 compileAndExecuteConfig.transformer = transformer;
306 compileAndExecuteConfig.llvmModuleBuilder = config.llvmModuleBuilder;
307 compileAndExecuteConfig.runtimeSymbolMap = config.runtimesymbolMap;
308
309 // Get the function used to compile and execute the module.
310 using CompileAndExecuteFnT =
311 Error (*)(Options &, ModuleOp, StringRef, CompileAndExecuteConfig);
312 auto compileAndExecuteFn =
313 StringSwitch<CompileAndExecuteFnT>(options.mainFuncType.getValue())
314 .Case("i32", compileAndExecuteSingleReturnFunction<int32_t>)
315 .Case("i64", compileAndExecuteSingleReturnFunction<int64_t>)
316 .Case("f32", compileAndExecuteSingleReturnFunction<float>)
317 .Case("void", compileAndExecuteVoidFunction)
318 .Default(nullptr);
319
320 Error error = compileAndExecuteFn
321 ? compileAndExecuteFn(options, m.get(),
322 options.mainFuncName.getValue(),
323 compileAndExecuteConfig)
324 : make_string_error("unsupported function type");
325
326 int exitCode = EXIT_SUCCESS;
327 llvm::handleAllErrors(std::move(error),
328 [&exitCode](const llvm::ErrorInfoBase &info) {
329 llvm::errs() << "Error: ";
330 info.log(llvm::errs());
331 llvm::errs() << '\n';
332 exitCode = EXIT_FAILURE;
333 });
334
335 return exitCode;
336 }
337