1 //===- mlir-spirv-cpu-runner.cpp - MLIR SPIR-V Execution on CPU -----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Main entry point to a command line utility that executes an MLIR file on the
10 // CPU by translating MLIR GPU module and host part to LLVM IR before
11 // JIT-compiling and executing.
12 //
13 //===----------------------------------------------------------------------===//
14
15 #include "mlir/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.h"
16 #include "mlir/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVMPass.h"
17 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
18 #include "mlir/Dialect/GPU/Passes.h"
19 #include "mlir/Dialect/SPIRV/Passes.h"
20 #include "mlir/Dialect/SPIRV/SPIRVOps.h"
21 #include "mlir/ExecutionEngine/JitRunner.h"
22 #include "mlir/ExecutionEngine/OptUtils.h"
23 #include "mlir/InitAllDialects.h"
24 #include "mlir/Pass/Pass.h"
25 #include "mlir/Pass/PassManager.h"
26 #include "mlir/Target/LLVMIR.h"
27
28 #include "llvm/IR/LLVMContext.h"
29 #include "llvm/IR/Module.h"
30 #include "llvm/Linker/Linker.h"
31 #include "llvm/Support/InitLLVM.h"
32 #include "llvm/Support/TargetSelect.h"
33
34 using namespace mlir;
35
36 /// A utility function that builds llvm::Module from two nested MLIR modules.
37 ///
38 /// module @main {
39 /// module @kernel {
40 /// // Some ops
41 /// }
42 /// // Some other ops
43 /// }
44 ///
45 /// Each of these two modules is translated to LLVM IR module, then they are
46 /// linked together and returned.
47 static std::unique_ptr<llvm::Module>
convertMLIRModule(ModuleOp module,llvm::LLVMContext & context)48 convertMLIRModule(ModuleOp module, llvm::LLVMContext &context) {
49 // Verify that there is only one nested module.
50 auto modules = module.getOps<ModuleOp>();
51 if (!llvm::hasSingleElement(modules)) {
52 module.emitError("The module must contain exactly one nested module");
53 return nullptr;
54 }
55
56 // Translate nested module and erase it.
57 ModuleOp nested = *modules.begin();
58 std::unique_ptr<llvm::Module> kernelModule =
59 translateModuleToLLVMIR(nested, context);
60 nested.erase();
61
62 std::unique_ptr<llvm::Module> mainModule =
63 translateModuleToLLVMIR(module, context);
64 llvm::Linker::linkModules(*mainModule, std::move(kernelModule));
65 return mainModule;
66 }
67
runMLIRPasses(ModuleOp module)68 static LogicalResult runMLIRPasses(ModuleOp module) {
69 PassManager passManager(module.getContext());
70 applyPassManagerCLOptions(passManager);
71 passManager.addPass(createGpuKernelOutliningPass());
72 passManager.addPass(createConvertGPUToSPIRVPass());
73
74 OpPassManager &nestedPM = passManager.nest<spirv::ModuleOp>();
75 nestedPM.addPass(spirv::createLowerABIAttributesPass());
76 nestedPM.addPass(spirv::createUpdateVersionCapabilityExtensionPass());
77 passManager.addPass(createLowerHostCodeToLLVMPass());
78 passManager.addPass(createConvertSPIRVToLLVMPass());
79 return passManager.run(module);
80 }
81
main(int argc,char ** argv)82 int main(int argc, char **argv) {
83 llvm::InitLLVM y(argc, argv);
84
85 llvm::InitializeNativeTarget();
86 llvm::InitializeNativeTargetAsmPrinter();
87 mlir::initializeLLVMPasses();
88
89 mlir::JitRunnerConfig jitRunnerConfig;
90 jitRunnerConfig.mlirTransformer = runMLIRPasses;
91 jitRunnerConfig.llvmModuleBuilder = convertMLIRModule;
92
93 return mlir::JitRunnerMain(argc, argv, jitRunnerConfig);
94 }
95