1 //===- mlir-rocm-runner.cpp - MLIR ROCM Execution Driver-------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This is a command line utility that executes an MLIR file on the GPU by
10 // translating MLIR to ROCDL/LLVM IR before JIT-compiling and executing the
11 // latter.
12 //
13 //===----------------------------------------------------------------------===//
14
15 #include "llvm/ADT/STLExtras.h"
16
17 #include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
18 #include "mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h"
19 #include "mlir/Conversion/SCFToStandard/SCFToStandard.h"
20 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
21 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
22 #include "mlir/Dialect/GPU/GPUDialect.h"
23 #include "mlir/Dialect/GPU/Passes.h"
24 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
25 #include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
26 #include "mlir/ExecutionEngine/JitRunner.h"
27 #include "mlir/ExecutionEngine/OptUtils.h"
28 #include "mlir/IR/BuiltinOps.h"
29 #include "mlir/InitAllDialects.h"
30 #include "mlir/Pass/Pass.h"
31 #include "mlir/Pass/PassManager.h"
32 #include "mlir/Support/FileUtilities.h"
33 #include "mlir/Target/ROCDLIR.h"
34 #include "mlir/Transforms/DialectConversion.h"
35 #include "mlir/Transforms/Passes.h"
36 #include "llvm/Support/ErrorOr.h"
37 #include "llvm/Support/FileUtilities.h"
38 #include "llvm/Support/InitLLVM.h"
39 #include "llvm/Support/LineIterator.h"
40 #include "llvm/Support/Program.h"
41 #include "llvm/Support/SourceMgr.h"
42 #include "llvm/Support/TargetRegistry.h"
43 #include "llvm/Support/TargetSelect.h"
44
45 // MC headers.
46 #include "llvm/MC/MCAsmBackend.h"
47 #include "llvm/MC/MCAsmInfo.h"
48 #include "llvm/MC/MCCodeEmitter.h"
49 #include "llvm/MC/MCContext.h"
50 #include "llvm/MC/MCInstPrinter.h"
51 #include "llvm/MC/MCInstrInfo.h"
52 #include "llvm/MC/MCObjectFileInfo.h"
53 #include "llvm/MC/MCObjectWriter.h"
54 #include "llvm/MC/MCParser/AsmLexer.h"
55 #include "llvm/MC/MCParser/MCTargetAsmParser.h"
56 #include "llvm/MC/MCRegisterInfo.h"
57 #include "llvm/MC/MCStreamer.h"
58 #include "llvm/MC/MCSubtargetInfo.h"
59 #include "llvm/MC/MCTargetOptionsCommandFlags.h"
60
61 // lld headers.
62 #include "lld/Common/Driver.h"
63
64 // HIP headers.
65 #include "hip/hip_version.h"
66
67 #include <mutex>
68
69 using namespace mlir;
70 using namespace llvm;
71
72 using Blob = SmallVector<char, 0>;
73
74 static cl::opt<std::string> tripleName("triple", cl::desc("target triple"),
75 cl::value_desc("triple string"),
76 cl::init("amdgcn-amd-amdhsa"));
77
78 static cl::opt<std::string> targetChip("target", cl::desc("target chip"),
79 cl::value_desc("AMDGPU ISA version"),
80 cl::init(""));
81
82 static cl::opt<std::string> features("feature", cl::desc("target features"),
83 cl::value_desc("AMDGPU target features"),
84 cl::init(""));
85
86 static constexpr const char kRunnerProgram[] = "mlir-rocm-runner";
87 static constexpr const char kRocmAgentEnumerator[] = "rocm_agent_enumerator";
88 static constexpr const char kDefaultTargetChip[] = "gfx900";
89
assembleIsa(const std::string isa,StringRef name,Blob & result)90 static LogicalResult assembleIsa(const std::string isa, StringRef name,
91 Blob &result) {
92 raw_svector_ostream os(result);
93
94 std::string error;
95 Triple theTriple(Triple::normalize(tripleName));
96 const Target *theTarget =
97 TargetRegistry::lookupTarget(theTriple.normalize(), error);
98 if (!theTarget) {
99 WithColor::error(errs(), name) << error;
100 return failure();
101 }
102
103 SourceMgr srcMgr;
104 srcMgr.AddNewSourceBuffer(MemoryBuffer::getMemBuffer(isa), SMLoc());
105
106 const MCTargetOptions mcOptions;
107 std::unique_ptr<MCRegisterInfo> mri(theTarget->createMCRegInfo(tripleName));
108 std::unique_ptr<MCAsmInfo> mai(
109 theTarget->createMCAsmInfo(*mri, tripleName, mcOptions));
110 mai->setRelaxELFRelocations(true);
111
112 MCObjectFileInfo mofi;
113 MCContext ctx(mai.get(), mri.get(), &mofi, &srcMgr, &mcOptions);
114 mofi.InitMCObjectFileInfo(theTriple, false, ctx, false);
115
116 SmallString<128> cwd;
117 if (!sys::fs::current_path(cwd))
118 ctx.setCompilationDir(cwd);
119
120 std::unique_ptr<MCStreamer> mcStreamer;
121 std::unique_ptr<MCInstrInfo> mcii(theTarget->createMCInstrInfo());
122 std::unique_ptr<MCSubtargetInfo> sti(
123 theTarget->createMCSubtargetInfo(tripleName, targetChip, features));
124
125 MCCodeEmitter *ce = theTarget->createMCCodeEmitter(*mcii, *mri, ctx);
126 MCAsmBackend *mab = theTarget->createMCAsmBackend(*sti, *mri, mcOptions);
127 mcStreamer.reset(theTarget->createMCObjectStreamer(
128 theTriple, ctx, std::unique_ptr<MCAsmBackend>(mab),
129 mab->createObjectWriter(os), std::unique_ptr<MCCodeEmitter>(ce), *sti,
130 mcOptions.MCRelaxAll, mcOptions.MCIncrementalLinkerCompatible,
131 /*DWARFMustBeAtTheEnd*/ false));
132 mcStreamer->setUseAssemblerInfoForParsing(true);
133
134 std::unique_ptr<MCAsmParser> parser(
135 createMCAsmParser(srcMgr, ctx, *mcStreamer, *mai));
136 std::unique_ptr<MCTargetAsmParser> tap(
137 theTarget->createMCAsmParser(*sti, *parser, *mcii, mcOptions));
138
139 if (!tap) {
140 WithColor::error(errs(), name) << "assembler initialization error.\n";
141 return failure();
142 }
143
144 parser->setTargetParser(*tap);
145 parser->Run(false);
146
147 return success();
148 }
149
150 static std::mutex mutex;
createHsaco(const Blob & isaBlob,StringRef name,Blob & hsacoBlob)151 static LogicalResult createHsaco(const Blob &isaBlob, StringRef name,
152 Blob &hsacoBlob) {
153 // Save the ISA binary to a temp file.
154 int tempIsaBinaryFd = -1;
155 SmallString<128> tempIsaBinaryFilename;
156 std::error_code ec = sys::fs::createTemporaryFile(
157 "kernel", "o", tempIsaBinaryFd, tempIsaBinaryFilename);
158 if (ec) {
159 WithColor::error(errs(), name)
160 << "temporary file for ISA binary creation error.\n";
161 return failure();
162 }
163 FileRemover cleanupIsaBinary(tempIsaBinaryFilename);
164 raw_fd_ostream tempIsaBinaryOs(tempIsaBinaryFd, true);
165 tempIsaBinaryOs << isaBlob;
166 tempIsaBinaryOs.close();
167
168 // Create a temp file for HSA code object.
169 int tempHsacoFD = -1;
170 SmallString<128> tempHsacoFilename;
171 ec = sys::fs::createTemporaryFile("kernel", "hsaco", tempHsacoFD,
172 tempHsacoFilename);
173 if (ec) {
174 WithColor::error(errs(), name)
175 << "temporary file for HSA code object creation error.\n";
176 return failure();
177 }
178 FileRemover cleanupHsaco(tempHsacoFilename);
179
180 const std::lock_guard<std::mutex> lock(mutex);
181 // Invoke lld. Expect a true return value from lld.
182 bool ret = lld::elf::link({"ld.lld", "-shared", tempIsaBinaryFilename.c_str(),
183 "-o", tempHsacoFilename.c_str()},
184 /*canEarlyExit=*/false, llvm::outs(), llvm::errs());
185 if (!ret) {
186 WithColor::error(errs(), name) << "lld invocation error.\n";
187 return failure();
188 }
189
190 // Load the HSA code object.
191 auto hsacoFile = mlir::openInputFile(tempHsacoFilename);
192 if (!hsacoFile) {
193 WithColor::error(errs(), name)
194 << "read HSA code object from temp file error.\n";
195 return failure();
196 }
197 hsacoBlob.assign(hsacoFile->getBuffer().begin(),
198 hsacoFile->getBuffer().end());
199
200 return success();
201 }
202
203 static std::unique_ptr<llvm::Module>
compileModuleToROCDLIR(Operation * m,llvm::LLVMContext & llvmContext,StringRef name)204 compileModuleToROCDLIR(Operation *m, llvm::LLVMContext &llvmContext,
205 StringRef name) {
206 auto llvmModule = translateModuleToROCDLIR(m, llvmContext, name);
207 // TODO: Link with ROCm-Device-Libs in case needed (ex: the Module
208 // depends on math functions).
209 return llvmModule;
210 }
211
compileISAToHsaco(const std::string isa,Location loc,StringRef name)212 static OwnedBlob compileISAToHsaco(const std::string isa, Location loc,
213 StringRef name) {
214 // ISA -> ISA in binary form via MC.
215 // Use lld to create HSA code object.
216 Blob isaBlob;
217 Blob hsacoBlob;
218
219 if (succeeded(assembleIsa(isa, name, isaBlob)) &&
220 succeeded(createHsaco(isaBlob, name, hsacoBlob)))
221 return std::make_unique<std::vector<char>>(hsacoBlob.begin(),
222 hsacoBlob.end());
223
224 WithColor::error(errs(), name) << "producing HSA code object error.\n";
225 return {};
226 }
227
configTargetChip()228 static void configTargetChip() {
229 // Set targetChip to default value first.
230 targetChip = kDefaultTargetChip;
231
232 // Locate rocm_agent_enumerator.
233 llvm::ErrorOr<std::string> rocmAgentEnumerator = llvm::sys::findProgramByName(
234 kRocmAgentEnumerator, {__ROCM_PATH__ "/bin"});
235 std::error_code ec;
236 if ((ec = rocmAgentEnumerator.getError())) {
237 WithColor::warning(errs(), kRunnerProgram)
238 << kRocmAgentEnumerator << " couldn't be located under "
239 << __ROCM_PATH__ << ", set target as " << kDefaultTargetChip << "\n";
240 return;
241 }
242
243 // Prepare temp file to hold the outputs.
244 int tempFd = -1;
245 SmallString<128> tempFilename;
246 ec = sys::fs::createTemporaryFile("rocm_agent", "txt", tempFd, tempFilename);
247 if (ec) {
248 WithColor::warning(errs(), kRunnerProgram)
249 << "temporary file for " << kRocmAgentEnumerator
250 << " creation error, set target as " << kDefaultTargetChip << "\n";
251 return;
252 }
253 FileRemover cleanup(tempFilename);
254
255 // Invoke rocm_agent_enumerator.
256 std::string errorMessage;
257 SmallVector<StringRef, 2> args{"-t", "GPU"};
258 Optional<StringRef> redirects[3] = {{""}, tempFilename.str(), {""}};
259 int result =
260 llvm::sys::ExecuteAndWait(rocmAgentEnumerator.get(), args, llvm::None,
261 redirects, 0, 0, &errorMessage);
262 if (result) {
263 WithColor::warning(errs(), kRunnerProgram)
264 << kRocmAgentEnumerator << " invocation error: " << errorMessage
265 << ", set target as " << kDefaultTargetChip << "\n";
266 return;
267 }
268
269 // Load and parse the result.
270 auto gfxIsaList = mlir::openInputFile(tempFilename);
271 if (!gfxIsaList) {
272 WithColor::error(errs(), kRunnerProgram)
273 << "read ROCm agent list temp file error, set target as "
274 << kDefaultTargetChip << "\n";
275 return;
276 }
277 for (line_iterator lines(*gfxIsaList); !lines.is_at_end(); ++lines) {
278 // Skip the line with content "gfx000".
279 if (*lines == "gfx000")
280 continue;
281 // Use the first ISA version found.
282 targetChip = lines->str();
283 break;
284 }
285 }
286
configTargetFeatures()287 static void configTargetFeatures() {
288 if (features.size() > 0)
289 features += ",";
290 // After ROCm 3.5, adopt HSA code object V3.
291 if (HIP_VERSION_MAJOR >= 3 && HIP_VERSION_MINOR >= 5)
292 features += "+code-object-v3";
293 else
294 features += "-code-object-v3";
295 }
296
runMLIRPasses(ModuleOp m)297 static LogicalResult runMLIRPasses(ModuleOp m) {
298 PassManager pm(m.getContext());
299 applyPassManagerCLOptions(pm);
300
301 // Configure target chip ISA version if it has not been specified.
302 if (!targetChip.size())
303 configTargetChip();
304
305 // Configure target features per ROCm / HIP version.
306 configTargetFeatures();
307
308 const char gpuBinaryAnnotation[] = "rocdl.hsaco";
309 pm.addPass(createLowerToCFGPass());
310 pm.addPass(createGpuKernelOutliningPass());
311 auto &kernelPm = pm.nest<gpu::GPUModuleOp>();
312 kernelPm.addPass(createStripDebugInfoPass());
313 kernelPm.addPass(createLowerGpuOpsToROCDLOpsPass());
314 kernelPm.addPass(createConvertGPUKernelToBlobPass(
315 compileModuleToROCDLIR, compileISAToHsaco, tripleName, targetChip,
316 features, gpuBinaryAnnotation));
317 pm.addPass(createGpuToLLVMConversionPass(gpuBinaryAnnotation));
318
319 return pm.run(m);
320 }
321
main(int argc,char ** argv)322 int main(int argc, char **argv) {
323 registerPassManagerCLOptions();
324 llvm::InitLLVM y(argc, argv);
325 llvm::InitializeAllTargetInfos();
326 llvm::InitializeAllTargetMCs();
327 llvm::InitializeAllAsmParsers();
328
329 // Initialize LLVM AMDGPU backend.
330 LLVMInitializeAMDGPUTarget();
331 LLVMInitializeAMDGPUTargetInfo();
332 LLVMInitializeAMDGPUTargetMC();
333 LLVMInitializeAMDGPUAsmPrinter();
334
335 mlir::initializeLLVMPasses();
336
337 mlir::JitRunnerConfig jitRunnerConfig;
338 jitRunnerConfig.mlirTransformer = runMLIRPasses;
339
340 return mlir::JitRunnerMain(argc, argv, jitRunnerConfig);
341 }
342