1 //===- ConvertToNVVMIR.cpp - MLIR to LLVM IR conversion -------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a translation between the MLIR LLVM + NVVM dialects and
10 // LLVM IR with NVVM intrinsics and metadata.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "mlir/Target/NVVMIR.h"
15
16 #include "mlir/Dialect/GPU/GPUDialect.h"
17 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
18 #include "mlir/Dialect/LLVMIR/NVVMDialect.h"
19 #include "mlir/IR/BuiltinOps.h"
20 #include "mlir/Target/LLVMIR/ModuleTranslation.h"
21 #include "mlir/Translation.h"
22
23 #include "llvm/ADT/StringRef.h"
24 #include "llvm/IR/IntrinsicsNVPTX.h"
25 #include "llvm/IR/Module.h"
26 #include "llvm/Support/ToolOutputFile.h"
27
28 using namespace mlir;
29
createIntrinsicCall(llvm::IRBuilder<> & builder,llvm::Intrinsic::ID intrinsic,ArrayRef<llvm::Value * > args={})30 static llvm::Value *createIntrinsicCall(llvm::IRBuilder<> &builder,
31 llvm::Intrinsic::ID intrinsic,
32 ArrayRef<llvm::Value *> args = {}) {
33 llvm::Module *module = builder.GetInsertBlock()->getModule();
34 llvm::Function *fn = llvm::Intrinsic::getDeclaration(module, intrinsic);
35 return builder.CreateCall(fn, args);
36 }
37
getShflBflyIntrinsicId(llvm::Type * resultType,bool withPredicate)38 static llvm::Intrinsic::ID getShflBflyIntrinsicId(llvm::Type *resultType,
39 bool withPredicate) {
40 if (withPredicate) {
41 resultType = cast<llvm::StructType>(resultType)->getElementType(0);
42 return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_bfly_f32p
43 : llvm::Intrinsic::nvvm_shfl_sync_bfly_i32p;
44 }
45 return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_bfly_f32
46 : llvm::Intrinsic::nvvm_shfl_sync_bfly_i32;
47 }
48
49 namespace {
50 class ModuleTranslation : public LLVM::ModuleTranslation {
51 public:
52 using LLVM::ModuleTranslation::ModuleTranslation;
53
54 protected:
convertOperation(Operation & opInst,llvm::IRBuilder<> & builder)55 LogicalResult convertOperation(Operation &opInst,
56 llvm::IRBuilder<> &builder) override {
57
58 #include "mlir/Dialect/LLVMIR/NVVMConversions.inc"
59
60 return LLVM::ModuleTranslation::convertOperation(opInst, builder);
61 }
62
63 /// Allow access to the constructor.
64 friend LLVM::ModuleTranslation;
65 };
66 } // namespace
67
68 std::unique_ptr<llvm::Module>
translateModuleToNVVMIR(Operation * m,llvm::LLVMContext & llvmContext,StringRef name)69 mlir::translateModuleToNVVMIR(Operation *m, llvm::LLVMContext &llvmContext,
70 StringRef name) {
71 auto llvmModule = LLVM::ModuleTranslation::translateModule<ModuleTranslation>(
72 m, llvmContext, name);
73 if (!llvmModule)
74 return llvmModule;
75
76 // Insert the nvvm.annotations kernel so that the NVVM backend recognizes the
77 // function as a kernel.
78 for (auto func :
79 ModuleTranslation::getModuleBody(m).getOps<LLVM::LLVMFuncOp>()) {
80 if (!gpu::GPUDialect::isKernel(func))
81 continue;
82
83 auto *llvmFunc = llvmModule->getFunction(func.getName());
84
85 llvm::Metadata *llvmMetadata[] = {
86 llvm::ValueAsMetadata::get(llvmFunc),
87 llvm::MDString::get(llvmModule->getContext(), "kernel"),
88 llvm::ValueAsMetadata::get(llvm::ConstantInt::get(
89 llvm::Type::getInt32Ty(llvmModule->getContext()), 1))};
90 llvm::MDNode *llvmMetadataNode =
91 llvm::MDNode::get(llvmModule->getContext(), llvmMetadata);
92 llvmModule->getOrInsertNamedMetadata("nvvm.annotations")
93 ->addOperand(llvmMetadataNode);
94 }
95
96 return llvmModule;
97 }
98
99 namespace mlir {
registerToNVVMIRTranslation()100 void registerToNVVMIRTranslation() {
101 TranslateFromMLIRRegistration registration(
102 "mlir-to-nvvmir",
103 [](ModuleOp module, raw_ostream &output) {
104 llvm::LLVMContext llvmContext;
105 auto llvmModule = mlir::translateModuleToNVVMIR(module, llvmContext);
106 if (!llvmModule)
107 return failure();
108
109 llvmModule->print(output, nullptr);
110 return success();
111 },
112 [](DialectRegistry ®istry) {
113 registry.insert<LLVM::LLVMDialect, NVVM::NVVMDialect>();
114 });
115 }
116 } // namespace mlir
117