• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //===- mlir-spirv-cpu-runner.cpp - MLIR SPIR-V Execution on CPU -----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Main entry point to a command line utility that executes an MLIR file on the
10 // CPU by translating MLIR GPU module and host part to LLVM IR before
11 // JIT-compiling and executing.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "mlir/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.h"
16 #include "mlir/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVMPass.h"
17 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
18 #include "mlir/Dialect/GPU/Passes.h"
19 #include "mlir/Dialect/SPIRV/Passes.h"
20 #include "mlir/Dialect/SPIRV/SPIRVOps.h"
21 #include "mlir/ExecutionEngine/JitRunner.h"
22 #include "mlir/ExecutionEngine/OptUtils.h"
23 #include "mlir/InitAllDialects.h"
24 #include "mlir/Pass/Pass.h"
25 #include "mlir/Pass/PassManager.h"
26 #include "mlir/Target/LLVMIR.h"
27 
28 #include "llvm/IR/LLVMContext.h"
29 #include "llvm/IR/Module.h"
30 #include "llvm/Linker/Linker.h"
31 #include "llvm/Support/InitLLVM.h"
32 #include "llvm/Support/TargetSelect.h"
33 
34 using namespace mlir;
35 
36 /// A utility function that builds llvm::Module from two nested MLIR modules.
37 ///
38 /// module @main {
39 ///   module @kernel {
40 ///     // Some ops
41 ///   }
42 ///   // Some other ops
43 /// }
44 ///
45 /// Each of these two modules is translated to LLVM IR module, then they are
46 /// linked together and returned.
47 static std::unique_ptr<llvm::Module>
convertMLIRModule(ModuleOp module,llvm::LLVMContext & context)48 convertMLIRModule(ModuleOp module, llvm::LLVMContext &context) {
49   // Verify that there is only one nested module.
50   auto modules = module.getOps<ModuleOp>();
51   if (!llvm::hasSingleElement(modules)) {
52     module.emitError("The module must contain exactly one nested module");
53     return nullptr;
54   }
55 
56   // Translate nested module and erase it.
57   ModuleOp nested = *modules.begin();
58   std::unique_ptr<llvm::Module> kernelModule =
59       translateModuleToLLVMIR(nested, context);
60   nested.erase();
61 
62   std::unique_ptr<llvm::Module> mainModule =
63       translateModuleToLLVMIR(module, context);
64   llvm::Linker::linkModules(*mainModule, std::move(kernelModule));
65   return mainModule;
66 }
67 
runMLIRPasses(ModuleOp module)68 static LogicalResult runMLIRPasses(ModuleOp module) {
69   PassManager passManager(module.getContext());
70   applyPassManagerCLOptions(passManager);
71   passManager.addPass(createGpuKernelOutliningPass());
72   passManager.addPass(createConvertGPUToSPIRVPass());
73 
74   OpPassManager &nestedPM = passManager.nest<spirv::ModuleOp>();
75   nestedPM.addPass(spirv::createLowerABIAttributesPass());
76   nestedPM.addPass(spirv::createUpdateVersionCapabilityExtensionPass());
77   passManager.addPass(createLowerHostCodeToLLVMPass());
78   passManager.addPass(createConvertSPIRVToLLVMPass());
79   return passManager.run(module);
80 }
81 
main(int argc,char ** argv)82 int main(int argc, char **argv) {
83   llvm::InitLLVM y(argc, argv);
84 
85   llvm::InitializeNativeTarget();
86   llvm::InitializeNativeTargetAsmPrinter();
87   mlir::initializeLLVMPasses();
88 
89   mlir::JitRunnerConfig jitRunnerConfig;
90   jitRunnerConfig.mlirTransformer = runMLIRPasses;
91   jitRunnerConfig.llvmModuleBuilder = convertMLIRModule;
92 
93   return mlir::JitRunnerMain(argc, argv, jitRunnerConfig);
94 }
95