• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //===- ConvertToNVVMIR.cpp - MLIR to LLVM IR conversion -------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a translation between the MLIR LLVM + NVVM dialects and
10 // LLVM IR with NVVM intrinsics and metadata.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Target/NVVMIR.h"
15 
16 #include "mlir/Dialect/GPU/GPUDialect.h"
17 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
18 #include "mlir/Dialect/LLVMIR/NVVMDialect.h"
19 #include "mlir/IR/BuiltinOps.h"
20 #include "mlir/Target/LLVMIR/ModuleTranslation.h"
21 #include "mlir/Translation.h"
22 
23 #include "llvm/ADT/StringRef.h"
24 #include "llvm/IR/IntrinsicsNVPTX.h"
25 #include "llvm/IR/Module.h"
26 #include "llvm/Support/ToolOutputFile.h"
27 
28 using namespace mlir;
29 
createIntrinsicCall(llvm::IRBuilder<> & builder,llvm::Intrinsic::ID intrinsic,ArrayRef<llvm::Value * > args={})30 static llvm::Value *createIntrinsicCall(llvm::IRBuilder<> &builder,
31                                         llvm::Intrinsic::ID intrinsic,
32                                         ArrayRef<llvm::Value *> args = {}) {
33   llvm::Module *module = builder.GetInsertBlock()->getModule();
34   llvm::Function *fn = llvm::Intrinsic::getDeclaration(module, intrinsic);
35   return builder.CreateCall(fn, args);
36 }
37 
getShflBflyIntrinsicId(llvm::Type * resultType,bool withPredicate)38 static llvm::Intrinsic::ID getShflBflyIntrinsicId(llvm::Type *resultType,
39                                                   bool withPredicate) {
40   if (withPredicate) {
41     resultType = cast<llvm::StructType>(resultType)->getElementType(0);
42     return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_bfly_f32p
43                                    : llvm::Intrinsic::nvvm_shfl_sync_bfly_i32p;
44   }
45   return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_bfly_f32
46                                  : llvm::Intrinsic::nvvm_shfl_sync_bfly_i32;
47 }
48 
49 namespace {
50 class ModuleTranslation : public LLVM::ModuleTranslation {
51 public:
52   using LLVM::ModuleTranslation::ModuleTranslation;
53 
54 protected:
convertOperation(Operation & opInst,llvm::IRBuilder<> & builder)55   LogicalResult convertOperation(Operation &opInst,
56                                  llvm::IRBuilder<> &builder) override {
57 
58 #include "mlir/Dialect/LLVMIR/NVVMConversions.inc"
59 
60     return LLVM::ModuleTranslation::convertOperation(opInst, builder);
61   }
62 
63   /// Allow access to the constructor.
64   friend LLVM::ModuleTranslation;
65 };
66 } // namespace
67 
68 std::unique_ptr<llvm::Module>
translateModuleToNVVMIR(Operation * m,llvm::LLVMContext & llvmContext,StringRef name)69 mlir::translateModuleToNVVMIR(Operation *m, llvm::LLVMContext &llvmContext,
70                               StringRef name) {
71   auto llvmModule = LLVM::ModuleTranslation::translateModule<ModuleTranslation>(
72       m, llvmContext, name);
73   if (!llvmModule)
74     return llvmModule;
75 
76   // Insert the nvvm.annotations kernel so that the NVVM backend recognizes the
77   // function as a kernel.
78   for (auto func :
79        ModuleTranslation::getModuleBody(m).getOps<LLVM::LLVMFuncOp>()) {
80     if (!gpu::GPUDialect::isKernel(func))
81       continue;
82 
83     auto *llvmFunc = llvmModule->getFunction(func.getName());
84 
85     llvm::Metadata *llvmMetadata[] = {
86         llvm::ValueAsMetadata::get(llvmFunc),
87         llvm::MDString::get(llvmModule->getContext(), "kernel"),
88         llvm::ValueAsMetadata::get(llvm::ConstantInt::get(
89             llvm::Type::getInt32Ty(llvmModule->getContext()), 1))};
90     llvm::MDNode *llvmMetadataNode =
91         llvm::MDNode::get(llvmModule->getContext(), llvmMetadata);
92     llvmModule->getOrInsertNamedMetadata("nvvm.annotations")
93         ->addOperand(llvmMetadataNode);
94   }
95 
96   return llvmModule;
97 }
98 
99 namespace mlir {
registerToNVVMIRTranslation()100 void registerToNVVMIRTranslation() {
101   TranslateFromMLIRRegistration registration(
102       "mlir-to-nvvmir",
103       [](ModuleOp module, raw_ostream &output) {
104         llvm::LLVMContext llvmContext;
105         auto llvmModule = mlir::translateModuleToNVVMIR(module, llvmContext);
106         if (!llvmModule)
107           return failure();
108 
109         llvmModule->print(output, nullptr);
110         return success();
111       },
112       [](DialectRegistry &registry) {
113         registry.insert<LLVM::LLVMDialect, NVVM::NVVMDialect>();
114       });
115 }
116 } // namespace mlir
117