1 //===- NVVMDialect.cpp - NVVM IR Ops and Dialect registration -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the types and operation details for the NVVM IR dialect in
10 // MLIR, and the LLVM IR dialect. It also registers the dialect.
11 //
12 // The NVVM dialect only contains GPU specific additions on top of the general
13 // LLVM dialect.
14 //
15 //===----------------------------------------------------------------------===//
16
17 #include "mlir/Dialect/LLVMIR/NVVMDialect.h"
18
19 #include "mlir/IR/Builders.h"
20 #include "mlir/IR/BuiltinTypes.h"
21 #include "mlir/IR/MLIRContext.h"
22 #include "mlir/IR/Operation.h"
23 #include "mlir/IR/OperationSupport.h"
24 #include "llvm/AsmParser/Parser.h"
25 #include "llvm/IR/Attributes.h"
26 #include "llvm/IR/Function.h"
27 #include "llvm/IR/Type.h"
28 #include "llvm/Support/SourceMgr.h"
29
30 using namespace mlir;
31 using namespace NVVM;
32
33 //===----------------------------------------------------------------------===//
34 // Printing/parsing for NVVM ops
35 //===----------------------------------------------------------------------===//
36
printNVVMIntrinsicOp(OpAsmPrinter & p,Operation * op)37 static void printNVVMIntrinsicOp(OpAsmPrinter &p, Operation *op) {
38 p << op->getName() << " " << op->getOperands();
39 if (op->getNumResults() > 0)
40 p << " : " << op->getResultTypes();
41 }
42
43 // <operation> ::=
44 // `llvm.nvvm.shfl.sync.bfly %dst, %val, %offset, %clamp_and_mask`
45 // ({return_value_and_is_valid})? : result_type
parseNVVMShflSyncBflyOp(OpAsmParser & parser,OperationState & result)46 static ParseResult parseNVVMShflSyncBflyOp(OpAsmParser &parser,
47 OperationState &result) {
48 SmallVector<OpAsmParser::OperandType, 8> ops;
49 Type resultType;
50 if (parser.parseOperandList(ops) ||
51 parser.parseOptionalAttrDict(result.attributes) ||
52 parser.parseColonType(resultType) ||
53 parser.addTypeToList(resultType, result.types))
54 return failure();
55
56 auto type = resultType.cast<LLVM::LLVMType>();
57 for (auto &attr : result.attributes) {
58 if (attr.first != "return_value_and_is_valid")
59 continue;
60 if (type.isStructTy() && type.getStructNumElements() > 0)
61 type = type.getStructElementType(0);
62 break;
63 }
64
65 auto int32Ty = LLVM::LLVMType::getInt32Ty(parser.getBuilder().getContext());
66 return parser.resolveOperands(ops, {int32Ty, type, int32Ty, int32Ty},
67 parser.getNameLoc(), result.operands);
68 }
69
70 // <operation> ::= `llvm.nvvm.vote.ballot.sync %mask, %pred` : result_type
parseNVVMVoteBallotOp(OpAsmParser & parser,OperationState & result)71 static ParseResult parseNVVMVoteBallotOp(OpAsmParser &parser,
72 OperationState &result) {
73 MLIRContext *context = parser.getBuilder().getContext();
74 auto int32Ty = LLVM::LLVMType::getInt32Ty(context);
75 auto int1Ty = LLVM::LLVMType::getInt1Ty(context);
76
77 SmallVector<OpAsmParser::OperandType, 8> ops;
78 Type type;
79 return failure(parser.parseOperandList(ops) ||
80 parser.parseOptionalAttrDict(result.attributes) ||
81 parser.parseColonType(type) ||
82 parser.addTypeToList(type, result.types) ||
83 parser.resolveOperands(ops, {int32Ty, int1Ty},
84 parser.getNameLoc(), result.operands));
85 }
86
verify(MmaOp op)87 static LogicalResult verify(MmaOp op) {
88 MLIRContext *context = op.getContext();
89 auto f16Ty = LLVM::LLVMType::getHalfTy(context);
90 auto f16x2Ty = LLVM::LLVMType::getVectorTy(f16Ty, 2);
91 auto f32Ty = LLVM::LLVMType::getFloatTy(context);
92 auto f16x2x4StructTy = LLVM::LLVMType::getStructTy(
93 context, {f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty});
94 auto f32x8StructTy = LLVM::LLVMType::getStructTy(
95 context, {f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty});
96
97 SmallVector<Type, 12> operand_types(op.getOperandTypes().begin(),
98 op.getOperandTypes().end());
99 if (operand_types != SmallVector<Type, 8>(8, f16x2Ty) &&
100 operand_types != SmallVector<Type, 12>{f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty,
101 f32Ty, f32Ty, f32Ty, f32Ty, f32Ty,
102 f32Ty, f32Ty, f32Ty}) {
103 return op.emitOpError(
104 "expected operands to be 4 <halfx2>s followed by either "
105 "4 <halfx2>s or 8 floats");
106 }
107 if (op.getType() != f32x8StructTy && op.getType() != f16x2x4StructTy) {
108 return op.emitOpError("expected result type to be a struct of either 4 "
109 "<halfx2>s or 8 floats");
110 }
111
112 auto alayout = op->getAttrOfType<StringAttr>("alayout");
113 auto blayout = op->getAttrOfType<StringAttr>("blayout");
114
115 if (!(alayout && blayout) ||
116 !(alayout.getValue() == "row" || alayout.getValue() == "col") ||
117 !(blayout.getValue() == "row" || blayout.getValue() == "col")) {
118 return op.emitOpError(
119 "alayout and blayout attributes must be set to either "
120 "\"row\" or \"col\"");
121 }
122
123 if (operand_types == SmallVector<Type, 12>{f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty,
124 f32Ty, f32Ty, f32Ty, f32Ty, f32Ty,
125 f32Ty, f32Ty, f32Ty} &&
126 op.getType() == f32x8StructTy && alayout.getValue() == "row" &&
127 blayout.getValue() == "col") {
128 return success();
129 }
130 return op.emitOpError("unimplemented mma.sync variant");
131 }
132
133 //===----------------------------------------------------------------------===//
134 // NVVMDialect initialization, type parsing, and registration.
135 //===----------------------------------------------------------------------===//
136
137 // TODO: This should be the llvm.nvvm dialect once this is supported.
initialize()138 void NVVMDialect::initialize() {
139 addOperations<
140 #define GET_OP_LIST
141 #include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc"
142 >();
143
144 // Support unknown operations because not all NVVM operations are registered.
145 allowUnknownOperations();
146 }
147
148 #define GET_OP_CLASSES
149 #include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc"
150