1 //===- LLVMDialect.cpp - LLVM IR Ops and Dialect registration -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the types and operation details for the LLVM IR dialect in
10 // MLIR, and the LLVM IR dialect. It also registers the dialect.
11 //
12 //===----------------------------------------------------------------------===//
13 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
14 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
15 #include "mlir/IR/Builders.h"
16 #include "mlir/IR/BuiltinOps.h"
17 #include "mlir/IR/BuiltinTypes.h"
18 #include "mlir/IR/DialectImplementation.h"
19 #include "mlir/IR/FunctionImplementation.h"
20 #include "mlir/IR/MLIRContext.h"
21
22 #include "llvm/ADT/StringSwitch.h"
23 #include "llvm/AsmParser/Parser.h"
24 #include "llvm/Bitcode/BitcodeReader.h"
25 #include "llvm/Bitcode/BitcodeWriter.h"
26 #include "llvm/IR/Attributes.h"
27 #include "llvm/IR/Function.h"
28 #include "llvm/IR/Type.h"
29 #include "llvm/Support/Mutex.h"
30 #include "llvm/Support/SourceMgr.h"
31
32 using namespace mlir;
33 using namespace mlir::LLVM;
34
35 static constexpr const char kVolatileAttrName[] = "volatile_";
36 static constexpr const char kNonTemporalAttrName[] = "nontemporal";
37
38 #include "mlir/Dialect/LLVMIR/LLVMOpsEnums.cpp.inc"
39
40 //===----------------------------------------------------------------------===//
41 // Printing/parsing for LLVM::CmpOp.
42 //===----------------------------------------------------------------------===//
printICmpOp(OpAsmPrinter & p,ICmpOp & op)43 static void printICmpOp(OpAsmPrinter &p, ICmpOp &op) {
44 p << op.getOperationName() << " \"" << stringifyICmpPredicate(op.predicate())
45 << "\" " << op.getOperand(0) << ", " << op.getOperand(1);
46 p.printOptionalAttrDict(op.getAttrs(), {"predicate"});
47 p << " : " << op.lhs().getType();
48 }
49
printFCmpOp(OpAsmPrinter & p,FCmpOp & op)50 static void printFCmpOp(OpAsmPrinter &p, FCmpOp &op) {
51 p << op.getOperationName() << " \"" << stringifyFCmpPredicate(op.predicate())
52 << "\" " << op.getOperand(0) << ", " << op.getOperand(1);
53 p.printOptionalAttrDict(op.getAttrs(), {"predicate"});
54 p << " : " << op.lhs().getType();
55 }
56
57 // <operation> ::= `llvm.icmp` string-literal ssa-use `,` ssa-use
58 // attribute-dict? `:` type
59 // <operation> ::= `llvm.fcmp` string-literal ssa-use `,` ssa-use
60 // attribute-dict? `:` type
61 template <typename CmpPredicateType>
parseCmpOp(OpAsmParser & parser,OperationState & result)62 static ParseResult parseCmpOp(OpAsmParser &parser, OperationState &result) {
63 Builder &builder = parser.getBuilder();
64
65 StringAttr predicateAttr;
66 OpAsmParser::OperandType lhs, rhs;
67 Type type;
68 llvm::SMLoc predicateLoc, trailingTypeLoc;
69 if (parser.getCurrentLocation(&predicateLoc) ||
70 parser.parseAttribute(predicateAttr, "predicate", result.attributes) ||
71 parser.parseOperand(lhs) || parser.parseComma() ||
72 parser.parseOperand(rhs) ||
73 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
74 parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type) ||
75 parser.resolveOperand(lhs, type, result.operands) ||
76 parser.resolveOperand(rhs, type, result.operands))
77 return failure();
78
79 // Replace the string attribute `predicate` with an integer attribute.
80 int64_t predicateValue = 0;
81 if (std::is_same<CmpPredicateType, ICmpPredicate>()) {
82 Optional<ICmpPredicate> predicate =
83 symbolizeICmpPredicate(predicateAttr.getValue());
84 if (!predicate)
85 return parser.emitError(predicateLoc)
86 << "'" << predicateAttr.getValue()
87 << "' is an incorrect value of the 'predicate' attribute";
88 predicateValue = static_cast<int64_t>(predicate.getValue());
89 } else {
90 Optional<FCmpPredicate> predicate =
91 symbolizeFCmpPredicate(predicateAttr.getValue());
92 if (!predicate)
93 return parser.emitError(predicateLoc)
94 << "'" << predicateAttr.getValue()
95 << "' is an incorrect value of the 'predicate' attribute";
96 predicateValue = static_cast<int64_t>(predicate.getValue());
97 }
98
99 result.attributes.set("predicate",
100 parser.getBuilder().getI64IntegerAttr(predicateValue));
101
102 // The result type is either i1 or a vector type <? x i1> if the inputs are
103 // vectors.
104 auto resultType = LLVMType::getInt1Ty(builder.getContext());
105 auto argType = type.dyn_cast<LLVM::LLVMType>();
106 if (!argType)
107 return parser.emitError(trailingTypeLoc, "expected LLVM IR dialect type");
108 if (argType.isVectorTy())
109 resultType =
110 LLVMType::getVectorTy(resultType, argType.getVectorNumElements());
111
112 result.addTypes({resultType});
113 return success();
114 }
115
116 //===----------------------------------------------------------------------===//
117 // Printing/parsing for LLVM::AllocaOp.
118 //===----------------------------------------------------------------------===//
119
printAllocaOp(OpAsmPrinter & p,AllocaOp & op)120 static void printAllocaOp(OpAsmPrinter &p, AllocaOp &op) {
121 auto elemTy = op.getType().cast<LLVM::LLVMType>().getPointerElementTy();
122
123 auto funcTy = FunctionType::get({op.arraySize().getType()}, {op.getType()},
124 op.getContext());
125
126 p << op.getOperationName() << ' ' << op.arraySize() << " x " << elemTy;
127 if (op.alignment().hasValue() && *op.alignment() != 0)
128 p.printOptionalAttrDict(op.getAttrs());
129 else
130 p.printOptionalAttrDict(op.getAttrs(), {"alignment"});
131 p << " : " << funcTy;
132 }
133
134 // <operation> ::= `llvm.alloca` ssa-use `x` type attribute-dict?
135 // `:` type `,` type
parseAllocaOp(OpAsmParser & parser,OperationState & result)136 static ParseResult parseAllocaOp(OpAsmParser &parser, OperationState &result) {
137 OpAsmParser::OperandType arraySize;
138 Type type, elemType;
139 llvm::SMLoc trailingTypeLoc;
140 if (parser.parseOperand(arraySize) || parser.parseKeyword("x") ||
141 parser.parseType(elemType) ||
142 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
143 parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type))
144 return failure();
145
146 Optional<NamedAttribute> alignmentAttr =
147 result.attributes.getNamed("alignment");
148 if (alignmentAttr.hasValue()) {
149 auto alignmentInt = alignmentAttr.getValue().second.dyn_cast<IntegerAttr>();
150 if (!alignmentInt)
151 return parser.emitError(parser.getNameLoc(),
152 "expected integer alignment");
153 if (alignmentInt.getValue().isNullValue())
154 result.attributes.erase("alignment");
155 }
156
157 // Extract the result type from the trailing function type.
158 auto funcType = type.dyn_cast<FunctionType>();
159 if (!funcType || funcType.getNumInputs() != 1 ||
160 funcType.getNumResults() != 1)
161 return parser.emitError(
162 trailingTypeLoc,
163 "expected trailing function type with one argument and one result");
164
165 if (parser.resolveOperand(arraySize, funcType.getInput(0), result.operands))
166 return failure();
167
168 result.addTypes({funcType.getResult(0)});
169 return success();
170 }
171
172 //===----------------------------------------------------------------------===//
173 // LLVM::BrOp
174 //===----------------------------------------------------------------------===//
175
176 Optional<MutableOperandRange>
getMutableSuccessorOperands(unsigned index)177 BrOp::getMutableSuccessorOperands(unsigned index) {
178 assert(index == 0 && "invalid successor index");
179 return destOperandsMutable();
180 }
181
182 //===----------------------------------------------------------------------===//
183 // LLVM::CondBrOp
184 //===----------------------------------------------------------------------===//
185
186 Optional<MutableOperandRange>
getMutableSuccessorOperands(unsigned index)187 CondBrOp::getMutableSuccessorOperands(unsigned index) {
188 assert(index < getNumSuccessors() && "invalid successor index");
189 return index == 0 ? trueDestOperandsMutable() : falseDestOperandsMutable();
190 }
191
192 //===----------------------------------------------------------------------===//
193 // Builder, printer and parser for for LLVM::LoadOp.
194 //===----------------------------------------------------------------------===//
195
build(OpBuilder & builder,OperationState & result,Type t,Value addr,unsigned alignment,bool isVolatile,bool isNonTemporal)196 void LoadOp::build(OpBuilder &builder, OperationState &result, Type t,
197 Value addr, unsigned alignment, bool isVolatile,
198 bool isNonTemporal) {
199 result.addOperands(addr);
200 result.addTypes(t);
201 if (isVolatile)
202 result.addAttribute(kVolatileAttrName, builder.getUnitAttr());
203 if (isNonTemporal)
204 result.addAttribute(kNonTemporalAttrName, builder.getUnitAttr());
205 if (alignment != 0)
206 result.addAttribute("alignment", builder.getI64IntegerAttr(alignment));
207 }
208
printLoadOp(OpAsmPrinter & p,LoadOp & op)209 static void printLoadOp(OpAsmPrinter &p, LoadOp &op) {
210 p << op.getOperationName() << ' ';
211 if (op.volatile_())
212 p << "volatile ";
213 p << op.addr();
214 p.printOptionalAttrDict(op.getAttrs(), {kVolatileAttrName});
215 p << " : " << op.addr().getType();
216 }
217
218 // Extract the pointee type from the LLVM pointer type wrapped in MLIR. Return
219 // the resulting type wrapped in MLIR, or nullptr on error.
getLoadStoreElementType(OpAsmParser & parser,Type type,llvm::SMLoc trailingTypeLoc)220 static Type getLoadStoreElementType(OpAsmParser &parser, Type type,
221 llvm::SMLoc trailingTypeLoc) {
222 auto llvmTy = type.dyn_cast<LLVM::LLVMType>();
223 if (!llvmTy)
224 return parser.emitError(trailingTypeLoc, "expected LLVM IR dialect type"),
225 nullptr;
226 if (!llvmTy.isPointerTy())
227 return parser.emitError(trailingTypeLoc, "expected LLVM pointer type"),
228 nullptr;
229 return llvmTy.getPointerElementTy();
230 }
231
232 // <operation> ::= `llvm.load` `volatile` ssa-use attribute-dict? `:` type
parseLoadOp(OpAsmParser & parser,OperationState & result)233 static ParseResult parseLoadOp(OpAsmParser &parser, OperationState &result) {
234 OpAsmParser::OperandType addr;
235 Type type;
236 llvm::SMLoc trailingTypeLoc;
237
238 if (succeeded(parser.parseOptionalKeyword("volatile")))
239 result.addAttribute(kVolatileAttrName, parser.getBuilder().getUnitAttr());
240
241 if (parser.parseOperand(addr) ||
242 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
243 parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type) ||
244 parser.resolveOperand(addr, type, result.operands))
245 return failure();
246
247 Type elemTy = getLoadStoreElementType(parser, type, trailingTypeLoc);
248
249 result.addTypes(elemTy);
250 return success();
251 }
252
253 //===----------------------------------------------------------------------===//
254 // Builder, printer and parser for LLVM::StoreOp.
255 //===----------------------------------------------------------------------===//
256
build(OpBuilder & builder,OperationState & result,Value value,Value addr,unsigned alignment,bool isVolatile,bool isNonTemporal)257 void StoreOp::build(OpBuilder &builder, OperationState &result, Value value,
258 Value addr, unsigned alignment, bool isVolatile,
259 bool isNonTemporal) {
260 result.addOperands({value, addr});
261 result.addTypes({});
262 if (isVolatile)
263 result.addAttribute(kVolatileAttrName, builder.getUnitAttr());
264 if (isNonTemporal)
265 result.addAttribute(kNonTemporalAttrName, builder.getUnitAttr());
266 if (alignment != 0)
267 result.addAttribute("alignment", builder.getI64IntegerAttr(alignment));
268 }
269
printStoreOp(OpAsmPrinter & p,StoreOp & op)270 static void printStoreOp(OpAsmPrinter &p, StoreOp &op) {
271 p << op.getOperationName() << ' ';
272 if (op.volatile_())
273 p << "volatile ";
274 p << op.value() << ", " << op.addr();
275 p.printOptionalAttrDict(op.getAttrs(), {kVolatileAttrName});
276 p << " : " << op.addr().getType();
277 }
278
279 // <operation> ::= `llvm.store` `volatile` ssa-use `,` ssa-use
280 // attribute-dict? `:` type
parseStoreOp(OpAsmParser & parser,OperationState & result)281 static ParseResult parseStoreOp(OpAsmParser &parser, OperationState &result) {
282 OpAsmParser::OperandType addr, value;
283 Type type;
284 llvm::SMLoc trailingTypeLoc;
285
286 if (succeeded(parser.parseOptionalKeyword("volatile")))
287 result.addAttribute(kVolatileAttrName, parser.getBuilder().getUnitAttr());
288
289 if (parser.parseOperand(value) || parser.parseComma() ||
290 parser.parseOperand(addr) ||
291 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
292 parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type))
293 return failure();
294
295 Type elemTy = getLoadStoreElementType(parser, type, trailingTypeLoc);
296 if (!elemTy)
297 return failure();
298
299 if (parser.resolveOperand(value, elemTy, result.operands) ||
300 parser.resolveOperand(addr, type, result.operands))
301 return failure();
302
303 return success();
304 }
305
306 ///===---------------------------------------------------------------------===//
307 /// LLVM::InvokeOp
308 ///===---------------------------------------------------------------------===//
309
310 Optional<MutableOperandRange>
getMutableSuccessorOperands(unsigned index)311 InvokeOp::getMutableSuccessorOperands(unsigned index) {
312 assert(index < getNumSuccessors() && "invalid successor index");
313 return index == 0 ? normalDestOperandsMutable() : unwindDestOperandsMutable();
314 }
315
verify(InvokeOp op)316 static LogicalResult verify(InvokeOp op) {
317 if (op.getNumResults() > 1)
318 return op.emitOpError("must have 0 or 1 result");
319
320 Block *unwindDest = op.unwindDest();
321 if (unwindDest->empty())
322 return op.emitError(
323 "must have at least one operation in unwind destination");
324
325 // In unwind destination, first operation must be LandingpadOp
326 if (!isa<LandingpadOp>(unwindDest->front()))
327 return op.emitError("first operation in unwind destination should be a "
328 "llvm.landingpad operation");
329
330 return success();
331 }
332
printInvokeOp(OpAsmPrinter & p,InvokeOp op)333 static void printInvokeOp(OpAsmPrinter &p, InvokeOp op) {
334 auto callee = op.callee();
335 bool isDirect = callee.hasValue();
336
337 p << op.getOperationName() << ' ';
338
339 // Either function name or pointer
340 if (isDirect)
341 p.printSymbolName(callee.getValue());
342 else
343 p << op.getOperand(0);
344
345 p << '(' << op.getOperands().drop_front(isDirect ? 0 : 1) << ')';
346 p << " to ";
347 p.printSuccessorAndUseList(op.normalDest(), op.normalDestOperands());
348 p << " unwind ";
349 p.printSuccessorAndUseList(op.unwindDest(), op.unwindDestOperands());
350
351 p.printOptionalAttrDict(op.getAttrs(),
352 {InvokeOp::getOperandSegmentSizeAttr(), "callee"});
353 p << " : ";
354 p.printFunctionalType(
355 llvm::drop_begin(op.getOperandTypes(), isDirect ? 0 : 1),
356 op.getResultTypes());
357 }
358
359 /// <operation> ::= `llvm.invoke` (function-id | ssa-use) `(` ssa-use-list `)`
360 /// `to` bb-id (`[` ssa-use-and-type-list `]`)?
361 /// `unwind` bb-id (`[` ssa-use-and-type-list `]`)?
362 /// attribute-dict? `:` function-type
parseInvokeOp(OpAsmParser & parser,OperationState & result)363 static ParseResult parseInvokeOp(OpAsmParser &parser, OperationState &result) {
364 SmallVector<OpAsmParser::OperandType, 8> operands;
365 FunctionType funcType;
366 SymbolRefAttr funcAttr;
367 llvm::SMLoc trailingTypeLoc;
368 Block *normalDest, *unwindDest;
369 SmallVector<Value, 4> normalOperands, unwindOperands;
370 Builder &builder = parser.getBuilder();
371
372 // Parse an operand list that will, in practice, contain 0 or 1 operand. In
373 // case of an indirect call, there will be 1 operand before `(`. In case of a
374 // direct call, there will be no operands and the parser will stop at the
375 // function identifier without complaining.
376 if (parser.parseOperandList(operands))
377 return failure();
378 bool isDirect = operands.empty();
379
380 // Optionally parse a function identifier.
381 if (isDirect && parser.parseAttribute(funcAttr, "callee", result.attributes))
382 return failure();
383
384 if (parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren) ||
385 parser.parseKeyword("to") ||
386 parser.parseSuccessorAndUseList(normalDest, normalOperands) ||
387 parser.parseKeyword("unwind") ||
388 parser.parseSuccessorAndUseList(unwindDest, unwindOperands) ||
389 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
390 parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(funcType))
391 return failure();
392
393 if (isDirect) {
394 // Make sure types match.
395 if (parser.resolveOperands(operands, funcType.getInputs(),
396 parser.getNameLoc(), result.operands))
397 return failure();
398 result.addTypes(funcType.getResults());
399 } else {
400 // Construct the LLVM IR Dialect function type that the first operand
401 // should match.
402 if (funcType.getNumResults() > 1)
403 return parser.emitError(trailingTypeLoc,
404 "expected function with 0 or 1 result");
405
406 LLVM::LLVMType llvmResultType;
407 if (funcType.getNumResults() == 0) {
408 llvmResultType = LLVM::LLVMType::getVoidTy(builder.getContext());
409 } else {
410 llvmResultType = funcType.getResult(0).dyn_cast<LLVM::LLVMType>();
411 if (!llvmResultType)
412 return parser.emitError(trailingTypeLoc,
413 "expected result to have LLVM type");
414 }
415
416 SmallVector<LLVM::LLVMType, 8> argTypes;
417 argTypes.reserve(funcType.getNumInputs());
418 for (Type ty : funcType.getInputs()) {
419 if (auto argType = ty.dyn_cast<LLVM::LLVMType>())
420 argTypes.push_back(argType);
421 else
422 return parser.emitError(trailingTypeLoc,
423 "expected LLVM types as inputs");
424 }
425
426 auto llvmFuncType = LLVM::LLVMType::getFunctionTy(llvmResultType, argTypes,
427 /*isVarArg=*/false);
428 auto wrappedFuncType = llvmFuncType.getPointerTo();
429
430 auto funcArguments = llvm::makeArrayRef(operands).drop_front();
431
432 // Make sure that the first operand (indirect callee) matches the wrapped
433 // LLVM IR function type, and that the types of the other call operands
434 // match the types of the function arguments.
435 if (parser.resolveOperand(operands[0], wrappedFuncType, result.operands) ||
436 parser.resolveOperands(funcArguments, funcType.getInputs(),
437 parser.getNameLoc(), result.operands))
438 return failure();
439
440 result.addTypes(llvmResultType);
441 }
442 result.addSuccessors({normalDest, unwindDest});
443 result.addOperands(normalOperands);
444 result.addOperands(unwindOperands);
445
446 result.addAttribute(
447 InvokeOp::getOperandSegmentSizeAttr(),
448 builder.getI32VectorAttr({static_cast<int32_t>(operands.size()),
449 static_cast<int32_t>(normalOperands.size()),
450 static_cast<int32_t>(unwindOperands.size())}));
451 return success();
452 }
453
454 ///===----------------------------------------------------------------------===//
455 /// Verifying/Printing/Parsing for LLVM::LandingpadOp.
456 ///===----------------------------------------------------------------------===//
457
verify(LandingpadOp op)458 static LogicalResult verify(LandingpadOp op) {
459 Value value;
460 if (LLVMFuncOp func = op->getParentOfType<LLVMFuncOp>()) {
461 if (!func.personality().hasValue())
462 return op.emitError(
463 "llvm.landingpad needs to be in a function with a personality");
464 }
465
466 if (!op.cleanup() && op.getOperands().empty())
467 return op.emitError("landingpad instruction expects at least one clause or "
468 "cleanup attribute");
469
470 for (unsigned idx = 0, ie = op.getNumOperands(); idx < ie; idx++) {
471 value = op.getOperand(idx);
472 bool isFilter = value.getType().cast<LLVMType>().isArrayTy();
473 if (isFilter) {
474 // FIXME: Verify filter clauses when arrays are appropriately handled
475 } else {
476 // catch - global addresses only.
477 // Bitcast ops should have global addresses as their args.
478 if (auto bcOp = value.getDefiningOp<BitcastOp>()) {
479 if (auto addrOp = bcOp.arg().getDefiningOp<AddressOfOp>())
480 continue;
481 return op.emitError("constant clauses expected")
482 .attachNote(bcOp.getLoc())
483 << "global addresses expected as operand to "
484 "bitcast used in clauses for landingpad";
485 }
486 // NullOp and AddressOfOp allowed
487 if (value.getDefiningOp<NullOp>())
488 continue;
489 if (value.getDefiningOp<AddressOfOp>())
490 continue;
491 return op.emitError("clause #")
492 << idx << " is not a known constant - null, addressof, bitcast";
493 }
494 }
495 return success();
496 }
497
printLandingpadOp(OpAsmPrinter & p,LandingpadOp & op)498 static void printLandingpadOp(OpAsmPrinter &p, LandingpadOp &op) {
499 p << op.getOperationName() << (op.cleanup() ? " cleanup " : " ");
500
501 // Clauses
502 for (auto value : op.getOperands()) {
503 // Similar to llvm - if clause is an array type then it is filter
504 // clause else catch clause
505 bool isArrayTy = value.getType().cast<LLVMType>().isArrayTy();
506 p << '(' << (isArrayTy ? "filter " : "catch ") << value << " : "
507 << value.getType() << ") ";
508 }
509
510 p.printOptionalAttrDict(op.getAttrs(), {"cleanup"});
511
512 p << ": " << op.getType();
513 }
514
515 /// <operation> ::= `llvm.landingpad` `cleanup`?
516 /// ((`catch` | `filter`) operand-type ssa-use)* attribute-dict?
parseLandingpadOp(OpAsmParser & parser,OperationState & result)517 static ParseResult parseLandingpadOp(OpAsmParser &parser,
518 OperationState &result) {
519 // Check for cleanup
520 if (succeeded(parser.parseOptionalKeyword("cleanup")))
521 result.addAttribute("cleanup", parser.getBuilder().getUnitAttr());
522
523 // Parse clauses with types
524 while (succeeded(parser.parseOptionalLParen()) &&
525 (succeeded(parser.parseOptionalKeyword("filter")) ||
526 succeeded(parser.parseOptionalKeyword("catch")))) {
527 OpAsmParser::OperandType operand;
528 Type ty;
529 if (parser.parseOperand(operand) || parser.parseColon() ||
530 parser.parseType(ty) ||
531 parser.resolveOperand(operand, ty, result.operands) ||
532 parser.parseRParen())
533 return failure();
534 }
535
536 Type type;
537 if (parser.parseColon() || parser.parseType(type))
538 return failure();
539
540 result.addTypes(type);
541 return success();
542 }
543
544 //===----------------------------------------------------------------------===//
545 // Verifying/Printing/parsing for LLVM::CallOp.
546 //===----------------------------------------------------------------------===//
547
verify(CallOp & op)548 static LogicalResult verify(CallOp &op) {
549 if (op.getNumResults() > 1)
550 return op.emitOpError("must have 0 or 1 result");
551
552 // Type for the callee, we'll get it differently depending if it is a direct
553 // or indirect call.
554 LLVMType fnType;
555
556 bool isIndirect = false;
557
558 // If this is an indirect call, the callee attribute is missing.
559 Optional<StringRef> calleeName = op.callee();
560 if (!calleeName) {
561 isIndirect = true;
562 if (!op.getNumOperands())
563 return op.emitOpError(
564 "must have either a `callee` attribute or at least an operand");
565 fnType = op.getOperand(0).getType().dyn_cast<LLVMType>();
566 if (!fnType)
567 return op.emitOpError("indirect call to a non-llvm type: ")
568 << op.getOperand(0).getType();
569 auto ptrType = fnType.dyn_cast<LLVMPointerType>();
570 if (!ptrType)
571 return op.emitOpError("indirect call expects a pointer as callee: ")
572 << fnType;
573 fnType = ptrType.getElementType();
574 } else {
575 Operation *callee = SymbolTable::lookupNearestSymbolFrom(op, *calleeName);
576 if (!callee)
577 return op.emitOpError()
578 << "'" << *calleeName
579 << "' does not reference a symbol in the current scope";
580 auto fn = dyn_cast<LLVMFuncOp>(callee);
581 if (!fn)
582 return op.emitOpError() << "'" << *calleeName
583 << "' does not reference a valid LLVM function";
584
585 fnType = fn.getType();
586 }
587 if (!fnType.isFunctionTy())
588 return op.emitOpError("callee does not have a functional type: ") << fnType;
589
590 // Verify that the operand and result types match the callee.
591
592 if (!fnType.isFunctionVarArg() &&
593 fnType.getFunctionNumParams() != (op.getNumOperands() - isIndirect))
594 return op.emitOpError()
595 << "incorrect number of operands ("
596 << (op.getNumOperands() - isIndirect)
597 << ") for callee (expecting: " << fnType.getFunctionNumParams()
598 << ")";
599
600 if (fnType.getFunctionNumParams() > (op.getNumOperands() - isIndirect))
601 return op.emitOpError() << "incorrect number of operands ("
602 << (op.getNumOperands() - isIndirect)
603 << ") for varargs callee (expecting at least: "
604 << fnType.getFunctionNumParams() << ")";
605
606 for (unsigned i = 0, e = fnType.getFunctionNumParams(); i != e; ++i)
607 if (op.getOperand(i + isIndirect).getType() !=
608 fnType.getFunctionParamType(i))
609 return op.emitOpError() << "operand type mismatch for operand " << i
610 << ": " << op.getOperand(i + isIndirect).getType()
611 << " != " << fnType.getFunctionParamType(i);
612
613 if (op.getNumResults() &&
614 op.getResult(0).getType() != fnType.getFunctionResultType())
615 return op.emitOpError()
616 << "result type mismatch: " << op.getResult(0).getType()
617 << " != " << fnType.getFunctionResultType();
618
619 return success();
620 }
621
printCallOp(OpAsmPrinter & p,CallOp & op)622 static void printCallOp(OpAsmPrinter &p, CallOp &op) {
623 auto callee = op.callee();
624 bool isDirect = callee.hasValue();
625
626 // Print the direct callee if present as a function attribute, or an indirect
627 // callee (first operand) otherwise.
628 p << op.getOperationName() << ' ';
629 if (isDirect)
630 p.printSymbolName(callee.getValue());
631 else
632 p << op.getOperand(0);
633
634 auto args = op.getOperands().drop_front(isDirect ? 0 : 1);
635 p << '(' << args << ')';
636 p.printOptionalAttrDict(op.getAttrs(), {"callee"});
637
638 // Reconstruct the function MLIR function type from operand and result types.
639 p << " : "
640 << FunctionType::get(args.getTypes(), op.getResultTypes(), op.getContext());
641 }
642
643 // <operation> ::= `llvm.call` (function-id | ssa-use) `(` ssa-use-list `)`
644 // attribute-dict? `:` function-type
parseCallOp(OpAsmParser & parser,OperationState & result)645 static ParseResult parseCallOp(OpAsmParser &parser, OperationState &result) {
646 SmallVector<OpAsmParser::OperandType, 8> operands;
647 Type type;
648 SymbolRefAttr funcAttr;
649 llvm::SMLoc trailingTypeLoc;
650
651 // Parse an operand list that will, in practice, contain 0 or 1 operand. In
652 // case of an indirect call, there will be 1 operand before `(`. In case of a
653 // direct call, there will be no operands and the parser will stop at the
654 // function identifier without complaining.
655 if (parser.parseOperandList(operands))
656 return failure();
657 bool isDirect = operands.empty();
658
659 // Optionally parse a function identifier.
660 if (isDirect)
661 if (parser.parseAttribute(funcAttr, "callee", result.attributes))
662 return failure();
663
664 if (parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren) ||
665 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
666 parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type))
667 return failure();
668
669 auto funcType = type.dyn_cast<FunctionType>();
670 if (!funcType)
671 return parser.emitError(trailingTypeLoc, "expected function type");
672 if (isDirect) {
673 // Make sure types match.
674 if (parser.resolveOperands(operands, funcType.getInputs(),
675 parser.getNameLoc(), result.operands))
676 return failure();
677 result.addTypes(funcType.getResults());
678 } else {
679 // Construct the LLVM IR Dialect function type that the first operand
680 // should match.
681 if (funcType.getNumResults() > 1)
682 return parser.emitError(trailingTypeLoc,
683 "expected function with 0 or 1 result");
684
685 Builder &builder = parser.getBuilder();
686 LLVM::LLVMType llvmResultType;
687 if (funcType.getNumResults() == 0) {
688 llvmResultType = LLVM::LLVMType::getVoidTy(builder.getContext());
689 } else {
690 llvmResultType = funcType.getResult(0).dyn_cast<LLVM::LLVMType>();
691 if (!llvmResultType)
692 return parser.emitError(trailingTypeLoc,
693 "expected result to have LLVM type");
694 }
695
696 SmallVector<LLVM::LLVMType, 8> argTypes;
697 argTypes.reserve(funcType.getNumInputs());
698 for (int i = 0, e = funcType.getNumInputs(); i < e; ++i) {
699 auto argType = funcType.getInput(i).dyn_cast<LLVM::LLVMType>();
700 if (!argType)
701 return parser.emitError(trailingTypeLoc,
702 "expected LLVM types as inputs");
703 argTypes.push_back(argType);
704 }
705 auto llvmFuncType = LLVM::LLVMType::getFunctionTy(llvmResultType, argTypes,
706 /*isVarArg=*/false);
707 auto wrappedFuncType = llvmFuncType.getPointerTo();
708
709 auto funcArguments =
710 ArrayRef<OpAsmParser::OperandType>(operands).drop_front();
711
712 // Make sure that the first operand (indirect callee) matches the wrapped
713 // LLVM IR function type, and that the types of the other call operands
714 // match the types of the function arguments.
715 if (parser.resolveOperand(operands[0], wrappedFuncType, result.operands) ||
716 parser.resolveOperands(funcArguments, funcType.getInputs(),
717 parser.getNameLoc(), result.operands))
718 return failure();
719
720 result.addTypes(llvmResultType);
721 }
722
723 return success();
724 }
725
726 //===----------------------------------------------------------------------===//
727 // Printing/parsing for LLVM::ExtractElementOp.
728 //===----------------------------------------------------------------------===//
729 // Expects vector to be of wrapped LLVM vector type and position to be of
730 // wrapped LLVM i32 type.
build(OpBuilder & b,OperationState & result,Value vector,Value position,ArrayRef<NamedAttribute> attrs)731 void LLVM::ExtractElementOp::build(OpBuilder &b, OperationState &result,
732 Value vector, Value position,
733 ArrayRef<NamedAttribute> attrs) {
734 auto wrappedVectorType = vector.getType().cast<LLVM::LLVMType>();
735 auto llvmType = wrappedVectorType.getVectorElementType();
736 build(b, result, llvmType, vector, position);
737 result.addAttributes(attrs);
738 }
739
printExtractElementOp(OpAsmPrinter & p,ExtractElementOp & op)740 static void printExtractElementOp(OpAsmPrinter &p, ExtractElementOp &op) {
741 p << op.getOperationName() << ' ' << op.vector() << "[" << op.position()
742 << " : " << op.position().getType() << "]";
743 p.printOptionalAttrDict(op.getAttrs());
744 p << " : " << op.vector().getType();
745 }
746
747 // <operation> ::= `llvm.extractelement` ssa-use `, ` ssa-use
748 // attribute-dict? `:` type
parseExtractElementOp(OpAsmParser & parser,OperationState & result)749 static ParseResult parseExtractElementOp(OpAsmParser &parser,
750 OperationState &result) {
751 llvm::SMLoc loc;
752 OpAsmParser::OperandType vector, position;
753 Type type, positionType;
754 if (parser.getCurrentLocation(&loc) || parser.parseOperand(vector) ||
755 parser.parseLSquare() || parser.parseOperand(position) ||
756 parser.parseColonType(positionType) || parser.parseRSquare() ||
757 parser.parseOptionalAttrDict(result.attributes) ||
758 parser.parseColonType(type) ||
759 parser.resolveOperand(vector, type, result.operands) ||
760 parser.resolveOperand(position, positionType, result.operands))
761 return failure();
762 auto wrappedVectorType = type.dyn_cast<LLVM::LLVMType>();
763 if (!wrappedVectorType || !wrappedVectorType.isVectorTy())
764 return parser.emitError(
765 loc, "expected LLVM IR dialect vector type for operand #1");
766 result.addTypes(wrappedVectorType.getVectorElementType());
767 return success();
768 }
769
770 //===----------------------------------------------------------------------===//
771 // Printing/parsing for LLVM::ExtractValueOp.
772 //===----------------------------------------------------------------------===//
773
printExtractValueOp(OpAsmPrinter & p,ExtractValueOp & op)774 static void printExtractValueOp(OpAsmPrinter &p, ExtractValueOp &op) {
775 p << op.getOperationName() << ' ' << op.container() << op.position();
776 p.printOptionalAttrDict(op.getAttrs(), {"position"});
777 p << " : " << op.container().getType();
778 }
779
780 // Extract the type at `position` in the wrapped LLVM IR aggregate type
781 // `containerType`. Position is an integer array attribute where each value
782 // is a zero-based position of the element in the aggregate type. Return the
783 // resulting type wrapped in MLIR, or nullptr on error.
getInsertExtractValueElementType(OpAsmParser & parser,Type containerType,ArrayAttr positionAttr,llvm::SMLoc attributeLoc,llvm::SMLoc typeLoc)784 static LLVM::LLVMType getInsertExtractValueElementType(OpAsmParser &parser,
785 Type containerType,
786 ArrayAttr positionAttr,
787 llvm::SMLoc attributeLoc,
788 llvm::SMLoc typeLoc) {
789 auto wrappedContainerType = containerType.dyn_cast<LLVM::LLVMType>();
790 if (!wrappedContainerType)
791 return parser.emitError(typeLoc, "expected LLVM IR Dialect type"), nullptr;
792
793 // Infer the element type from the structure type: iteratively step inside the
794 // type by taking the element type, indexed by the position attribute for
795 // structures. Check the position index before accessing, it is supposed to
796 // be in bounds.
797 for (Attribute subAttr : positionAttr) {
798 auto positionElementAttr = subAttr.dyn_cast<IntegerAttr>();
799 if (!positionElementAttr)
800 return parser.emitError(attributeLoc,
801 "expected an array of integer literals"),
802 nullptr;
803 int position = positionElementAttr.getInt();
804 if (wrappedContainerType.isArrayTy()) {
805 if (position < 0 || static_cast<unsigned>(position) >=
806 wrappedContainerType.getArrayNumElements())
807 return parser.emitError(attributeLoc, "position out of bounds"),
808 nullptr;
809 wrappedContainerType = wrappedContainerType.getArrayElementType();
810 } else if (wrappedContainerType.isStructTy()) {
811 if (position < 0 || static_cast<unsigned>(position) >=
812 wrappedContainerType.getStructNumElements())
813 return parser.emitError(attributeLoc, "position out of bounds"),
814 nullptr;
815 wrappedContainerType =
816 wrappedContainerType.getStructElementType(position);
817 } else {
818 return parser.emitError(typeLoc,
819 "expected wrapped LLVM IR structure/array type"),
820 nullptr;
821 }
822 }
823 return wrappedContainerType;
824 }
825
826 // <operation> ::= `llvm.extractvalue` ssa-use
827 // `[` integer-literal (`,` integer-literal)* `]`
828 // attribute-dict? `:` type
parseExtractValueOp(OpAsmParser & parser,OperationState & result)829 static ParseResult parseExtractValueOp(OpAsmParser &parser,
830 OperationState &result) {
831 OpAsmParser::OperandType container;
832 Type containerType;
833 ArrayAttr positionAttr;
834 llvm::SMLoc attributeLoc, trailingTypeLoc;
835
836 if (parser.parseOperand(container) ||
837 parser.getCurrentLocation(&attributeLoc) ||
838 parser.parseAttribute(positionAttr, "position", result.attributes) ||
839 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
840 parser.getCurrentLocation(&trailingTypeLoc) ||
841 parser.parseType(containerType) ||
842 parser.resolveOperand(container, containerType, result.operands))
843 return failure();
844
845 auto elementType = getInsertExtractValueElementType(
846 parser, containerType, positionAttr, attributeLoc, trailingTypeLoc);
847 if (!elementType)
848 return failure();
849
850 result.addTypes(elementType);
851 return success();
852 }
853
854 //===----------------------------------------------------------------------===//
855 // Printing/parsing for LLVM::InsertElementOp.
856 //===----------------------------------------------------------------------===//
857
printInsertElementOp(OpAsmPrinter & p,InsertElementOp & op)858 static void printInsertElementOp(OpAsmPrinter &p, InsertElementOp &op) {
859 p << op.getOperationName() << ' ' << op.value() << ", " << op.vector() << "["
860 << op.position() << " : " << op.position().getType() << "]";
861 p.printOptionalAttrDict(op.getAttrs());
862 p << " : " << op.vector().getType();
863 }
864
865 // <operation> ::= `llvm.insertelement` ssa-use `,` ssa-use `,` ssa-use
866 // attribute-dict? `:` type
parseInsertElementOp(OpAsmParser & parser,OperationState & result)867 static ParseResult parseInsertElementOp(OpAsmParser &parser,
868 OperationState &result) {
869 llvm::SMLoc loc;
870 OpAsmParser::OperandType vector, value, position;
871 Type vectorType, positionType;
872 if (parser.getCurrentLocation(&loc) || parser.parseOperand(value) ||
873 parser.parseComma() || parser.parseOperand(vector) ||
874 parser.parseLSquare() || parser.parseOperand(position) ||
875 parser.parseColonType(positionType) || parser.parseRSquare() ||
876 parser.parseOptionalAttrDict(result.attributes) ||
877 parser.parseColonType(vectorType))
878 return failure();
879
880 auto wrappedVectorType = vectorType.dyn_cast<LLVM::LLVMType>();
881 if (!wrappedVectorType || !wrappedVectorType.isVectorTy())
882 return parser.emitError(
883 loc, "expected LLVM IR dialect vector type for operand #1");
884 auto valueType = wrappedVectorType.getVectorElementType();
885 if (!valueType)
886 return failure();
887
888 if (parser.resolveOperand(vector, vectorType, result.operands) ||
889 parser.resolveOperand(value, valueType, result.operands) ||
890 parser.resolveOperand(position, positionType, result.operands))
891 return failure();
892
893 result.addTypes(vectorType);
894 return success();
895 }
896
897 //===----------------------------------------------------------------------===//
898 // Printing/parsing for LLVM::InsertValueOp.
899 //===----------------------------------------------------------------------===//
900
printInsertValueOp(OpAsmPrinter & p,InsertValueOp & op)901 static void printInsertValueOp(OpAsmPrinter &p, InsertValueOp &op) {
902 p << op.getOperationName() << ' ' << op.value() << ", " << op.container()
903 << op.position();
904 p.printOptionalAttrDict(op.getAttrs(), {"position"});
905 p << " : " << op.container().getType();
906 }
907
908 // <operation> ::= `llvm.insertvaluevalue` ssa-use `,` ssa-use
909 // `[` integer-literal (`,` integer-literal)* `]`
910 // attribute-dict? `:` type
parseInsertValueOp(OpAsmParser & parser,OperationState & result)911 static ParseResult parseInsertValueOp(OpAsmParser &parser,
912 OperationState &result) {
913 OpAsmParser::OperandType container, value;
914 Type containerType;
915 ArrayAttr positionAttr;
916 llvm::SMLoc attributeLoc, trailingTypeLoc;
917
918 if (parser.parseOperand(value) || parser.parseComma() ||
919 parser.parseOperand(container) ||
920 parser.getCurrentLocation(&attributeLoc) ||
921 parser.parseAttribute(positionAttr, "position", result.attributes) ||
922 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
923 parser.getCurrentLocation(&trailingTypeLoc) ||
924 parser.parseType(containerType))
925 return failure();
926
927 auto valueType = getInsertExtractValueElementType(
928 parser, containerType, positionAttr, attributeLoc, trailingTypeLoc);
929 if (!valueType)
930 return failure();
931
932 if (parser.resolveOperand(container, containerType, result.operands) ||
933 parser.resolveOperand(value, valueType, result.operands))
934 return failure();
935
936 result.addTypes(containerType);
937 return success();
938 }
939
940 //===----------------------------------------------------------------------===//
941 // Printing/parsing for LLVM::ReturnOp.
942 //===----------------------------------------------------------------------===//
943
printReturnOp(OpAsmPrinter & p,ReturnOp & op)944 static void printReturnOp(OpAsmPrinter &p, ReturnOp &op) {
945 p << op.getOperationName();
946 p.printOptionalAttrDict(op.getAttrs());
947 assert(op.getNumOperands() <= 1);
948
949 if (op.getNumOperands() == 0)
950 return;
951
952 p << ' ' << op.getOperand(0) << " : " << op.getOperand(0).getType();
953 }
954
955 // <operation> ::= `llvm.return` ssa-use-list attribute-dict? `:`
956 // type-list-no-parens
parseReturnOp(OpAsmParser & parser,OperationState & result)957 static ParseResult parseReturnOp(OpAsmParser &parser, OperationState &result) {
958 SmallVector<OpAsmParser::OperandType, 1> operands;
959 Type type;
960
961 if (parser.parseOperandList(operands) ||
962 parser.parseOptionalAttrDict(result.attributes))
963 return failure();
964 if (operands.empty())
965 return success();
966
967 if (parser.parseColonType(type) ||
968 parser.resolveOperand(operands[0], type, result.operands))
969 return failure();
970 return success();
971 }
972
973 //===----------------------------------------------------------------------===//
974 // Verifier for LLVM::AddressOfOp.
975 //===----------------------------------------------------------------------===//
976
977 template <typename OpTy>
lookupSymbolInModule(Operation * parent,StringRef name)978 static OpTy lookupSymbolInModule(Operation *parent, StringRef name) {
979 Operation *module = parent;
980 while (module && !satisfiesLLVMModule(module))
981 module = module->getParentOp();
982 assert(module && "unexpected operation outside of a module");
983 return dyn_cast_or_null<OpTy>(
984 mlir::SymbolTable::lookupSymbolIn(module, name));
985 }
986
getGlobal()987 GlobalOp AddressOfOp::getGlobal() {
988 return lookupSymbolInModule<LLVM::GlobalOp>((*this)->getParentOp(),
989 global_name());
990 }
991
getFunction()992 LLVMFuncOp AddressOfOp::getFunction() {
993 return lookupSymbolInModule<LLVM::LLVMFuncOp>((*this)->getParentOp(),
994 global_name());
995 }
996
verify(AddressOfOp op)997 static LogicalResult verify(AddressOfOp op) {
998 auto global = op.getGlobal();
999 auto function = op.getFunction();
1000 if (!global && !function)
1001 return op.emitOpError(
1002 "must reference a global defined by 'llvm.mlir.global' or 'llvm.func'");
1003
1004 if (global && global.getType().getPointerTo(global.addr_space()) !=
1005 op.getResult().getType())
1006 return op.emitOpError(
1007 "the type must be a pointer to the type of the referenced global");
1008
1009 if (function && function.getType().getPointerTo() != op.getResult().getType())
1010 return op.emitOpError(
1011 "the type must be a pointer to the type of the referenced function");
1012
1013 return success();
1014 }
1015
1016 //===----------------------------------------------------------------------===//
1017 // Builder, printer and verifier for LLVM::GlobalOp.
1018 //===----------------------------------------------------------------------===//
1019
1020 /// Returns the name used for the linkage attribute. This *must* correspond to
1021 /// the name of the attribute in ODS.
getLinkageAttrName()1022 static StringRef getLinkageAttrName() { return "linkage"; }
1023
build(OpBuilder & builder,OperationState & result,LLVMType type,bool isConstant,Linkage linkage,StringRef name,Attribute value,unsigned addrSpace,ArrayRef<NamedAttribute> attrs)1024 void GlobalOp::build(OpBuilder &builder, OperationState &result, LLVMType type,
1025 bool isConstant, Linkage linkage, StringRef name,
1026 Attribute value, unsigned addrSpace,
1027 ArrayRef<NamedAttribute> attrs) {
1028 result.addAttribute(SymbolTable::getSymbolAttrName(),
1029 builder.getStringAttr(name));
1030 result.addAttribute("type", TypeAttr::get(type));
1031 if (isConstant)
1032 result.addAttribute("constant", builder.getUnitAttr());
1033 if (value)
1034 result.addAttribute("value", value);
1035 result.addAttribute(getLinkageAttrName(),
1036 builder.getI64IntegerAttr(static_cast<int64_t>(linkage)));
1037 if (addrSpace != 0)
1038 result.addAttribute("addr_space", builder.getI32IntegerAttr(addrSpace));
1039 result.attributes.append(attrs.begin(), attrs.end());
1040 result.addRegion();
1041 }
1042
printGlobalOp(OpAsmPrinter & p,GlobalOp op)1043 static void printGlobalOp(OpAsmPrinter &p, GlobalOp op) {
1044 p << op.getOperationName() << ' ' << stringifyLinkage(op.linkage()) << ' ';
1045 if (op.constant())
1046 p << "constant ";
1047 p.printSymbolName(op.sym_name());
1048 p << '(';
1049 if (auto value = op.getValueOrNull())
1050 p.printAttribute(value);
1051 p << ')';
1052 p.printOptionalAttrDict(op.getAttrs(),
1053 {SymbolTable::getSymbolAttrName(), "type", "constant",
1054 "value", getLinkageAttrName()});
1055
1056 // Print the trailing type unless it's a string global.
1057 if (op.getValueOrNull().dyn_cast_or_null<StringAttr>())
1058 return;
1059 p << " : " << op.type();
1060
1061 Region &initializer = op.getInitializerRegion();
1062 if (!initializer.empty())
1063 p.printRegion(initializer, /*printEntryBlockArgs=*/false);
1064 }
1065
1066 //===----------------------------------------------------------------------===//
1067 // Verifier for LLVM::DialectCastOp.
1068 //===----------------------------------------------------------------------===//
1069
verify(DialectCastOp op)1070 static LogicalResult verify(DialectCastOp op) {
1071 auto verifyMLIRCastType = [&op](Type type) -> LogicalResult {
1072 if (auto llvmType = type.dyn_cast<LLVM::LLVMType>()) {
1073 if (llvmType.isVectorTy())
1074 llvmType = llvmType.getVectorElementType();
1075 if (llvmType.isIntegerTy() || llvmType.isBFloatTy() ||
1076 llvmType.isHalfTy() || llvmType.isFloatTy() ||
1077 llvmType.isDoubleTy()) {
1078 return success();
1079 }
1080 return op.emitOpError("type must be non-index integer types, float "
1081 "types, or vector of mentioned types.");
1082 }
1083 if (auto vectorType = type.dyn_cast<VectorType>()) {
1084 if (vectorType.getShape().size() > 1)
1085 return op.emitOpError("only 1-d vector is allowed");
1086 type = vectorType.getElementType();
1087 }
1088 if (type.isSignlessIntOrFloat())
1089 return success();
1090 // Note that memrefs are not supported. We currently don't have a use case
1091 // for it, but even if we do, there are challenges:
1092 // * if we allow memrefs to cast from/to memref descriptors, then the
1093 // semantics of the cast op depends on the implementation detail of the
1094 // descriptor.
1095 // * if we allow memrefs to cast from/to bare pointers, some users might
1096 // alternatively want metadata that only present in the descriptor.
1097 //
1098 // TODO: re-evaluate the memref cast design when it's needed.
1099 return op.emitOpError("type must be non-index integer types, float types, "
1100 "or vector of mentioned types.");
1101 };
1102 return failure(failed(verifyMLIRCastType(op.in().getType())) ||
1103 failed(verifyMLIRCastType(op.getType())));
1104 }
1105
1106 // Parses one of the keywords provided in the list `keywords` and returns the
1107 // position of the parsed keyword in the list. If none of the keywords from the
1108 // list is parsed, returns -1.
parseOptionalKeywordAlternative(OpAsmParser & parser,ArrayRef<StringRef> keywords)1109 static int parseOptionalKeywordAlternative(OpAsmParser &parser,
1110 ArrayRef<StringRef> keywords) {
1111 for (auto en : llvm::enumerate(keywords)) {
1112 if (succeeded(parser.parseOptionalKeyword(en.value())))
1113 return en.index();
1114 }
1115 return -1;
1116 }
1117
1118 namespace {
1119 template <typename Ty> struct EnumTraits {};
1120
1121 #define REGISTER_ENUM_TYPE(Ty) \
1122 template <> struct EnumTraits<Ty> { \
1123 static StringRef stringify(Ty value) { return stringify##Ty(value); } \
1124 static unsigned getMaxEnumVal() { return getMaxEnumValFor##Ty(); } \
1125 }
1126
1127 REGISTER_ENUM_TYPE(Linkage);
1128 } // end namespace
1129
1130 template <typename EnumTy>
parseOptionalLLVMKeyword(OpAsmParser & parser,OperationState & result,StringRef name)1131 static ParseResult parseOptionalLLVMKeyword(OpAsmParser &parser,
1132 OperationState &result,
1133 StringRef name) {
1134 SmallVector<StringRef, 10> names;
1135 for (unsigned i = 0, e = getMaxEnumValForLinkage(); i <= e; ++i)
1136 names.push_back(EnumTraits<EnumTy>::stringify(static_cast<EnumTy>(i)));
1137
1138 int index = parseOptionalKeywordAlternative(parser, names);
1139 if (index == -1)
1140 return failure();
1141 result.addAttribute(name, parser.getBuilder().getI64IntegerAttr(index));
1142 return success();
1143 }
1144
1145 // operation ::= `llvm.mlir.global` linkage? `constant`? `@` identifier
1146 // `(` attribute? `)` attribute-list? (`:` type)? region?
1147 //
1148 // The type can be omitted for string attributes, in which case it will be
1149 // inferred from the value of the string as [strlen(value) x i8].
parseGlobalOp(OpAsmParser & parser,OperationState & result)1150 static ParseResult parseGlobalOp(OpAsmParser &parser, OperationState &result) {
1151 if (failed(parseOptionalLLVMKeyword<Linkage>(parser, result,
1152 getLinkageAttrName())))
1153 result.addAttribute(getLinkageAttrName(),
1154 parser.getBuilder().getI64IntegerAttr(
1155 static_cast<int64_t>(LLVM::Linkage::External)));
1156
1157 if (succeeded(parser.parseOptionalKeyword("constant")))
1158 result.addAttribute("constant", parser.getBuilder().getUnitAttr());
1159
1160 StringAttr name;
1161 if (parser.parseSymbolName(name, SymbolTable::getSymbolAttrName(),
1162 result.attributes) ||
1163 parser.parseLParen())
1164 return failure();
1165
1166 Attribute value;
1167 if (parser.parseOptionalRParen()) {
1168 if (parser.parseAttribute(value, "value", result.attributes) ||
1169 parser.parseRParen())
1170 return failure();
1171 }
1172
1173 SmallVector<Type, 1> types;
1174 if (parser.parseOptionalAttrDict(result.attributes) ||
1175 parser.parseOptionalColonTypeList(types))
1176 return failure();
1177
1178 if (types.size() > 1)
1179 return parser.emitError(parser.getNameLoc(), "expected zero or one type");
1180
1181 Region &initRegion = *result.addRegion();
1182 if (types.empty()) {
1183 if (auto strAttr = value.dyn_cast_or_null<StringAttr>()) {
1184 MLIRContext *context = parser.getBuilder().getContext();
1185 auto arrayType = LLVM::LLVMType::getArrayTy(
1186 LLVM::LLVMType::getInt8Ty(context), strAttr.getValue().size());
1187 types.push_back(arrayType);
1188 } else {
1189 return parser.emitError(parser.getNameLoc(),
1190 "type can only be omitted for string globals");
1191 }
1192 } else {
1193 OptionalParseResult parseResult =
1194 parser.parseOptionalRegion(initRegion, /*arguments=*/{},
1195 /*argTypes=*/{});
1196 if (parseResult.hasValue() && failed(*parseResult))
1197 return failure();
1198 }
1199
1200 result.addAttribute("type", TypeAttr::get(types[0]));
1201 return success();
1202 }
1203
verify(GlobalOp op)1204 static LogicalResult verify(GlobalOp op) {
1205 if (!LLVMPointerType::isValidElementType(op.getType()))
1206 return op.emitOpError(
1207 "expects type to be a valid element type for an LLVM pointer");
1208 if (op->getParentOp() && !satisfiesLLVMModule(op->getParentOp()))
1209 return op.emitOpError("must appear at the module level");
1210
1211 if (auto strAttr = op.getValueOrNull().dyn_cast_or_null<StringAttr>()) {
1212 auto type = op.getType();
1213 if (!type.isArrayTy() || !type.getArrayElementType().isIntegerTy(8) ||
1214 type.getArrayNumElements() != strAttr.getValue().size())
1215 return op.emitOpError(
1216 "requires an i8 array type of the length equal to that of the string "
1217 "attribute");
1218 }
1219
1220 if (Block *b = op.getInitializerBlock()) {
1221 ReturnOp ret = cast<ReturnOp>(b->getTerminator());
1222 if (ret.operand_type_begin() == ret.operand_type_end())
1223 return op.emitOpError("initializer region cannot return void");
1224 if (*ret.operand_type_begin() != op.getType())
1225 return op.emitOpError("initializer region type ")
1226 << *ret.operand_type_begin() << " does not match global type "
1227 << op.getType();
1228
1229 if (op.getValueOrNull())
1230 return op.emitOpError("cannot have both initializer value and region");
1231 }
1232 return success();
1233 }
1234
1235 //===----------------------------------------------------------------------===//
1236 // Printing/parsing for LLVM::ShuffleVectorOp.
1237 //===----------------------------------------------------------------------===//
1238 // Expects vector to be of wrapped LLVM vector type and position to be of
1239 // wrapped LLVM i32 type.
build(OpBuilder & b,OperationState & result,Value v1,Value v2,ArrayAttr mask,ArrayRef<NamedAttribute> attrs)1240 void LLVM::ShuffleVectorOp::build(OpBuilder &b, OperationState &result,
1241 Value v1, Value v2, ArrayAttr mask,
1242 ArrayRef<NamedAttribute> attrs) {
1243 auto wrappedContainerType1 = v1.getType().cast<LLVM::LLVMType>();
1244 auto vType = LLVMType::getVectorTy(
1245 wrappedContainerType1.getVectorElementType(), mask.size());
1246 build(b, result, vType, v1, v2, mask);
1247 result.addAttributes(attrs);
1248 }
1249
printShuffleVectorOp(OpAsmPrinter & p,ShuffleVectorOp & op)1250 static void printShuffleVectorOp(OpAsmPrinter &p, ShuffleVectorOp &op) {
1251 p << op.getOperationName() << ' ' << op.v1() << ", " << op.v2() << " "
1252 << op.mask();
1253 p.printOptionalAttrDict(op.getAttrs(), {"mask"});
1254 p << " : " << op.v1().getType() << ", " << op.v2().getType();
1255 }
1256
1257 // <operation> ::= `llvm.shufflevector` ssa-use `, ` ssa-use
1258 // `[` integer-literal (`,` integer-literal)* `]`
1259 // attribute-dict? `:` type
parseShuffleVectorOp(OpAsmParser & parser,OperationState & result)1260 static ParseResult parseShuffleVectorOp(OpAsmParser &parser,
1261 OperationState &result) {
1262 llvm::SMLoc loc;
1263 OpAsmParser::OperandType v1, v2;
1264 ArrayAttr maskAttr;
1265 Type typeV1, typeV2;
1266 if (parser.getCurrentLocation(&loc) || parser.parseOperand(v1) ||
1267 parser.parseComma() || parser.parseOperand(v2) ||
1268 parser.parseAttribute(maskAttr, "mask", result.attributes) ||
1269 parser.parseOptionalAttrDict(result.attributes) ||
1270 parser.parseColonType(typeV1) || parser.parseComma() ||
1271 parser.parseType(typeV2) ||
1272 parser.resolveOperand(v1, typeV1, result.operands) ||
1273 parser.resolveOperand(v2, typeV2, result.operands))
1274 return failure();
1275 auto wrappedContainerType1 = typeV1.dyn_cast<LLVM::LLVMType>();
1276 if (!wrappedContainerType1 || !wrappedContainerType1.isVectorTy())
1277 return parser.emitError(
1278 loc, "expected LLVM IR dialect vector type for operand #1");
1279 auto vType = LLVMType::getVectorTy(
1280 wrappedContainerType1.getVectorElementType(), maskAttr.size());
1281 result.addTypes(vType);
1282 return success();
1283 }
1284
1285 //===----------------------------------------------------------------------===//
1286 // Implementations for LLVM::LLVMFuncOp.
1287 //===----------------------------------------------------------------------===//
1288
1289 // Add the entry block to the function.
addEntryBlock()1290 Block *LLVMFuncOp::addEntryBlock() {
1291 assert(empty() && "function already has an entry block");
1292 assert(!isVarArg() && "unimplemented: non-external variadic functions");
1293
1294 auto *entry = new Block;
1295 push_back(entry);
1296
1297 LLVMType type = getType();
1298 for (unsigned i = 0, e = type.getFunctionNumParams(); i < e; ++i)
1299 entry->addArgument(type.getFunctionParamType(i));
1300 return entry;
1301 }
1302
build(OpBuilder & builder,OperationState & result,StringRef name,LLVMType type,LLVM::Linkage linkage,ArrayRef<NamedAttribute> attrs,ArrayRef<MutableDictionaryAttr> argAttrs)1303 void LLVMFuncOp::build(OpBuilder &builder, OperationState &result,
1304 StringRef name, LLVMType type, LLVM::Linkage linkage,
1305 ArrayRef<NamedAttribute> attrs,
1306 ArrayRef<MutableDictionaryAttr> argAttrs) {
1307 result.addRegion();
1308 result.addAttribute(SymbolTable::getSymbolAttrName(),
1309 builder.getStringAttr(name));
1310 result.addAttribute("type", TypeAttr::get(type));
1311 result.addAttribute(getLinkageAttrName(),
1312 builder.getI64IntegerAttr(static_cast<int64_t>(linkage)));
1313 result.attributes.append(attrs.begin(), attrs.end());
1314 if (argAttrs.empty())
1315 return;
1316
1317 unsigned numInputs = type.getFunctionNumParams();
1318 assert(numInputs == argAttrs.size() &&
1319 "expected as many argument attribute lists as arguments");
1320 SmallString<8> argAttrName;
1321 for (unsigned i = 0; i < numInputs; ++i)
1322 if (auto argDict = argAttrs[i].getDictionary(builder.getContext()))
1323 result.addAttribute(getArgAttrName(i, argAttrName), argDict);
1324 }
1325
1326 // Builds an LLVM function type from the given lists of input and output types.
1327 // Returns a null type if any of the types provided are non-LLVM types, or if
1328 // there is more than one output type.
buildLLVMFunctionType(OpAsmParser & parser,llvm::SMLoc loc,ArrayRef<Type> inputs,ArrayRef<Type> outputs,impl::VariadicFlag variadicFlag)1329 static Type buildLLVMFunctionType(OpAsmParser &parser, llvm::SMLoc loc,
1330 ArrayRef<Type> inputs, ArrayRef<Type> outputs,
1331 impl::VariadicFlag variadicFlag) {
1332 Builder &b = parser.getBuilder();
1333 if (outputs.size() > 1) {
1334 parser.emitError(loc, "failed to construct function type: expected zero or "
1335 "one function result");
1336 return {};
1337 }
1338
1339 // Convert inputs to LLVM types, exit early on error.
1340 SmallVector<LLVMType, 4> llvmInputs;
1341 for (auto t : inputs) {
1342 auto llvmTy = t.dyn_cast<LLVMType>();
1343 if (!llvmTy) {
1344 parser.emitError(loc, "failed to construct function type: expected LLVM "
1345 "type for function arguments");
1346 return {};
1347 }
1348 llvmInputs.push_back(llvmTy);
1349 }
1350
1351 // No output is denoted as "void" in LLVM type system.
1352 LLVMType llvmOutput = outputs.empty() ? LLVMType::getVoidTy(b.getContext())
1353 : outputs.front().dyn_cast<LLVMType>();
1354 if (!llvmOutput) {
1355 parser.emitError(loc, "failed to construct function type: expected LLVM "
1356 "type for function results");
1357 return {};
1358 }
1359 return LLVMType::getFunctionTy(llvmOutput, llvmInputs,
1360 variadicFlag.isVariadic());
1361 }
1362
1363 // Parses an LLVM function.
1364 //
1365 // operation ::= `llvm.func` linkage? function-signature function-attributes?
1366 // function-body
1367 //
parseLLVMFuncOp(OpAsmParser & parser,OperationState & result)1368 static ParseResult parseLLVMFuncOp(OpAsmParser &parser,
1369 OperationState &result) {
1370 // Default to external linkage if no keyword is provided.
1371 if (failed(parseOptionalLLVMKeyword<Linkage>(parser, result,
1372 getLinkageAttrName())))
1373 result.addAttribute(getLinkageAttrName(),
1374 parser.getBuilder().getI64IntegerAttr(
1375 static_cast<int64_t>(LLVM::Linkage::External)));
1376
1377 StringAttr nameAttr;
1378 SmallVector<OpAsmParser::OperandType, 8> entryArgs;
1379 SmallVector<NamedAttrList, 1> argAttrs;
1380 SmallVector<NamedAttrList, 1> resultAttrs;
1381 SmallVector<Type, 8> argTypes;
1382 SmallVector<Type, 4> resultTypes;
1383 bool isVariadic;
1384
1385 auto signatureLocation = parser.getCurrentLocation();
1386 if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
1387 result.attributes) ||
1388 impl::parseFunctionSignature(parser, /*allowVariadic=*/true, entryArgs,
1389 argTypes, argAttrs, isVariadic, resultTypes,
1390 resultAttrs))
1391 return failure();
1392
1393 auto type =
1394 buildLLVMFunctionType(parser, signatureLocation, argTypes, resultTypes,
1395 impl::VariadicFlag(isVariadic));
1396 if (!type)
1397 return failure();
1398 result.addAttribute(impl::getTypeAttrName(), TypeAttr::get(type));
1399
1400 if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes)))
1401 return failure();
1402 impl::addArgAndResultAttrs(parser.getBuilder(), result, argAttrs,
1403 resultAttrs);
1404
1405 auto *body = result.addRegion();
1406 OptionalParseResult parseResult = parser.parseOptionalRegion(
1407 *body, entryArgs, entryArgs.empty() ? ArrayRef<Type>() : argTypes);
1408 return failure(parseResult.hasValue() && failed(*parseResult));
1409 }
1410
1411 // Print the LLVMFuncOp. Collects argument and result types and passes them to
1412 // helper functions. Drops "void" result since it cannot be parsed back. Skips
1413 // the external linkage since it is the default value.
printLLVMFuncOp(OpAsmPrinter & p,LLVMFuncOp op)1414 static void printLLVMFuncOp(OpAsmPrinter &p, LLVMFuncOp op) {
1415 p << op.getOperationName() << ' ';
1416 if (op.linkage() != LLVM::Linkage::External)
1417 p << stringifyLinkage(op.linkage()) << ' ';
1418 p.printSymbolName(op.getName());
1419
1420 LLVMType fnType = op.getType();
1421 SmallVector<Type, 8> argTypes;
1422 SmallVector<Type, 1> resTypes;
1423 argTypes.reserve(fnType.getFunctionNumParams());
1424 for (unsigned i = 0, e = fnType.getFunctionNumParams(); i < e; ++i)
1425 argTypes.push_back(fnType.getFunctionParamType(i));
1426
1427 LLVMType returnType = fnType.getFunctionResultType();
1428 if (!returnType.isVoidTy())
1429 resTypes.push_back(returnType);
1430
1431 impl::printFunctionSignature(p, op, argTypes, op.isVarArg(), resTypes);
1432 impl::printFunctionAttributes(p, op, argTypes.size(), resTypes.size(),
1433 {getLinkageAttrName()});
1434
1435 // Print the body if this is not an external function.
1436 Region &body = op.body();
1437 if (!body.empty())
1438 p.printRegion(body, /*printEntryBlockArgs=*/false,
1439 /*printBlockTerminators=*/true);
1440 }
1441
1442 // Hook for OpTrait::FunctionLike, called after verifying that the 'type'
1443 // attribute is present. This can check for preconditions of the
1444 // getNumArguments hook not failing.
verifyType()1445 LogicalResult LLVMFuncOp::verifyType() {
1446 auto llvmType = getTypeAttr().getValue().dyn_cast_or_null<LLVMType>();
1447 if (!llvmType || !llvmType.isFunctionTy())
1448 return emitOpError("requires '" + getTypeAttrName() +
1449 "' attribute of wrapped LLVM function type");
1450
1451 return success();
1452 }
1453
1454 // Hook for OpTrait::FunctionLike, returns the number of function arguments.
1455 // Depends on the type attribute being correct as checked by verifyType
getNumFuncArguments()1456 unsigned LLVMFuncOp::getNumFuncArguments() {
1457 return getType().getFunctionNumParams();
1458 }
1459
1460 // Hook for OpTrait::FunctionLike, returns the number of function results.
1461 // Depends on the type attribute being correct as checked by verifyType
getNumFuncResults()1462 unsigned LLVMFuncOp::getNumFuncResults() {
1463 // We model LLVM functions that return void as having zero results,
1464 // and all others as having one result.
1465 // If we modeled a void return as one result, then it would be possible to
1466 // attach an MLIR result attribute to it, and it isn't clear what semantics we
1467 // would assign to that.
1468 if (getType().getFunctionResultType().isVoidTy())
1469 return 0;
1470 return 1;
1471 }
1472
1473 // Verifies LLVM- and implementation-specific properties of the LLVM func Op:
1474 // - functions don't have 'common' linkage
1475 // - external functions have 'external' or 'extern_weak' linkage;
1476 // - vararg is (currently) only supported for external functions;
1477 // - entry block arguments are of LLVM types and match the function signature.
verify(LLVMFuncOp op)1478 static LogicalResult verify(LLVMFuncOp op) {
1479 if (op.linkage() == LLVM::Linkage::Common)
1480 return op.emitOpError()
1481 << "functions cannot have '"
1482 << stringifyLinkage(LLVM::Linkage::Common) << "' linkage";
1483
1484 if (op.isExternal()) {
1485 if (op.linkage() != LLVM::Linkage::External &&
1486 op.linkage() != LLVM::Linkage::ExternWeak)
1487 return op.emitOpError()
1488 << "external functions must have '"
1489 << stringifyLinkage(LLVM::Linkage::External) << "' or '"
1490 << stringifyLinkage(LLVM::Linkage::ExternWeak) << "' linkage";
1491 return success();
1492 }
1493
1494 if (op.isVarArg())
1495 return op.emitOpError("only external functions can be variadic");
1496
1497 unsigned numArguments = op.getType().getFunctionNumParams();
1498 Block &entryBlock = op.front();
1499 for (unsigned i = 0; i < numArguments; ++i) {
1500 Type argType = entryBlock.getArgument(i).getType();
1501 auto argLLVMType = argType.dyn_cast<LLVMType>();
1502 if (!argLLVMType)
1503 return op.emitOpError("entry block argument #")
1504 << i << " is not of LLVM type";
1505 if (op.getType().getFunctionParamType(i) != argLLVMType)
1506 return op.emitOpError("the type of entry block argument #")
1507 << i << " does not match the function signature";
1508 }
1509
1510 return success();
1511 }
1512
1513 //===----------------------------------------------------------------------===//
1514 // Verification for LLVM::NullOp.
1515 //===----------------------------------------------------------------------===//
1516
1517 // Only LLVM pointer types are supported.
verify(LLVM::NullOp op)1518 static LogicalResult verify(LLVM::NullOp op) {
1519 auto llvmType = op.getType().dyn_cast<LLVM::LLVMType>();
1520 if (!llvmType || !llvmType.isPointerTy())
1521 return op.emitOpError("expected LLVM IR pointer type");
1522 return success();
1523 }
1524
1525 //===----------------------------------------------------------------------===//
1526 // Verification for LLVM::ConstantOp.
1527 //===----------------------------------------------------------------------===//
1528
verify(LLVM::ConstantOp op)1529 static LogicalResult verify(LLVM::ConstantOp op) {
1530 if (!(op.value().isa<IntegerAttr>() || op.value().isa<FloatAttr>() ||
1531 op.value().isa<ElementsAttr>() || op.value().isa<StringAttr>()))
1532 return op.emitOpError()
1533 << "only supports integer, float, string or elements attributes";
1534 return success();
1535 }
1536
1537 //===----------------------------------------------------------------------===//
1538 // Utility functions for parsing atomic ops
1539 //===----------------------------------------------------------------------===//
1540
1541 // Helper function to parse a keyword into the specified attribute named by
1542 // `attrName`. The keyword must match one of the string values defined by the
1543 // AtomicBinOp enum. The resulting I64 attribute is added to the `result`
1544 // state.
parseAtomicBinOp(OpAsmParser & parser,OperationState & result,StringRef attrName)1545 static ParseResult parseAtomicBinOp(OpAsmParser &parser, OperationState &result,
1546 StringRef attrName) {
1547 llvm::SMLoc loc;
1548 StringRef keyword;
1549 if (parser.getCurrentLocation(&loc) || parser.parseKeyword(&keyword))
1550 return failure();
1551
1552 // Replace the keyword `keyword` with an integer attribute.
1553 auto kind = symbolizeAtomicBinOp(keyword);
1554 if (!kind) {
1555 return parser.emitError(loc)
1556 << "'" << keyword << "' is an incorrect value of the '" << attrName
1557 << "' attribute";
1558 }
1559
1560 auto value = static_cast<int64_t>(kind.getValue());
1561 auto attr = parser.getBuilder().getI64IntegerAttr(value);
1562 result.addAttribute(attrName, attr);
1563
1564 return success();
1565 }
1566
1567 // Helper function to parse a keyword into the specified attribute named by
1568 // `attrName`. The keyword must match one of the string values defined by the
1569 // AtomicOrdering enum. The resulting I64 attribute is added to the `result`
1570 // state.
parseAtomicOrdering(OpAsmParser & parser,OperationState & result,StringRef attrName)1571 static ParseResult parseAtomicOrdering(OpAsmParser &parser,
1572 OperationState &result,
1573 StringRef attrName) {
1574 llvm::SMLoc loc;
1575 StringRef ordering;
1576 if (parser.getCurrentLocation(&loc) || parser.parseKeyword(&ordering))
1577 return failure();
1578
1579 // Replace the keyword `ordering` with an integer attribute.
1580 auto kind = symbolizeAtomicOrdering(ordering);
1581 if (!kind) {
1582 return parser.emitError(loc)
1583 << "'" << ordering << "' is an incorrect value of the '" << attrName
1584 << "' attribute";
1585 }
1586
1587 auto value = static_cast<int64_t>(kind.getValue());
1588 auto attr = parser.getBuilder().getI64IntegerAttr(value);
1589 result.addAttribute(attrName, attr);
1590
1591 return success();
1592 }
1593
1594 //===----------------------------------------------------------------------===//
1595 // Printer, parser and verifier for LLVM::AtomicRMWOp.
1596 //===----------------------------------------------------------------------===//
1597
printAtomicRMWOp(OpAsmPrinter & p,AtomicRMWOp & op)1598 static void printAtomicRMWOp(OpAsmPrinter &p, AtomicRMWOp &op) {
1599 p << op.getOperationName() << ' ' << stringifyAtomicBinOp(op.bin_op()) << ' '
1600 << op.ptr() << ", " << op.val() << ' '
1601 << stringifyAtomicOrdering(op.ordering()) << ' ';
1602 p.printOptionalAttrDict(op.getAttrs(), {"bin_op", "ordering"});
1603 p << " : " << op.res().getType();
1604 }
1605
1606 // <operation> ::= `llvm.atomicrmw` keyword ssa-use `,` ssa-use keyword
1607 // attribute-dict? `:` type
parseAtomicRMWOp(OpAsmParser & parser,OperationState & result)1608 static ParseResult parseAtomicRMWOp(OpAsmParser &parser,
1609 OperationState &result) {
1610 LLVMType type;
1611 OpAsmParser::OperandType ptr, val;
1612 if (parseAtomicBinOp(parser, result, "bin_op") || parser.parseOperand(ptr) ||
1613 parser.parseComma() || parser.parseOperand(val) ||
1614 parseAtomicOrdering(parser, result, "ordering") ||
1615 parser.parseOptionalAttrDict(result.attributes) ||
1616 parser.parseColonType(type) ||
1617 parser.resolveOperand(ptr, type.getPointerTo(), result.operands) ||
1618 parser.resolveOperand(val, type, result.operands))
1619 return failure();
1620
1621 result.addTypes(type);
1622 return success();
1623 }
1624
verify(AtomicRMWOp op)1625 static LogicalResult verify(AtomicRMWOp op) {
1626 auto ptrType = op.ptr().getType().cast<LLVM::LLVMType>();
1627 auto valType = op.val().getType().cast<LLVM::LLVMType>();
1628 if (valType != ptrType.getPointerElementTy())
1629 return op.emitOpError("expected LLVM IR element type for operand #0 to "
1630 "match type for operand #1");
1631 auto resType = op.res().getType().cast<LLVM::LLVMType>();
1632 if (resType != valType)
1633 return op.emitOpError(
1634 "expected LLVM IR result type to match type for operand #1");
1635 if (op.bin_op() == AtomicBinOp::fadd || op.bin_op() == AtomicBinOp::fsub) {
1636 if (!valType.isFloatingPointTy())
1637 return op.emitOpError("expected LLVM IR floating point type");
1638 } else if (op.bin_op() == AtomicBinOp::xchg) {
1639 if (!valType.isIntegerTy(8) && !valType.isIntegerTy(16) &&
1640 !valType.isIntegerTy(32) && !valType.isIntegerTy(64) &&
1641 !valType.isBFloatTy() && !valType.isHalfTy() && !valType.isFloatTy() &&
1642 !valType.isDoubleTy())
1643 return op.emitOpError("unexpected LLVM IR type for 'xchg' bin_op");
1644 } else {
1645 if (!valType.isIntegerTy(8) && !valType.isIntegerTy(16) &&
1646 !valType.isIntegerTy(32) && !valType.isIntegerTy(64))
1647 return op.emitOpError("expected LLVM IR integer type");
1648 }
1649 return success();
1650 }
1651
1652 //===----------------------------------------------------------------------===//
1653 // Printer, parser and verifier for LLVM::AtomicCmpXchgOp.
1654 //===----------------------------------------------------------------------===//
1655
printAtomicCmpXchgOp(OpAsmPrinter & p,AtomicCmpXchgOp & op)1656 static void printAtomicCmpXchgOp(OpAsmPrinter &p, AtomicCmpXchgOp &op) {
1657 p << op.getOperationName() << ' ' << op.ptr() << ", " << op.cmp() << ", "
1658 << op.val() << ' ' << stringifyAtomicOrdering(op.success_ordering()) << ' '
1659 << stringifyAtomicOrdering(op.failure_ordering());
1660 p.printOptionalAttrDict(op.getAttrs(),
1661 {"success_ordering", "failure_ordering"});
1662 p << " : " << op.val().getType();
1663 }
1664
1665 // <operation> ::= `llvm.cmpxchg` ssa-use `,` ssa-use `,` ssa-use
1666 // keyword keyword attribute-dict? `:` type
parseAtomicCmpXchgOp(OpAsmParser & parser,OperationState & result)1667 static ParseResult parseAtomicCmpXchgOp(OpAsmParser &parser,
1668 OperationState &result) {
1669 auto &builder = parser.getBuilder();
1670 LLVMType type;
1671 OpAsmParser::OperandType ptr, cmp, val;
1672 if (parser.parseOperand(ptr) || parser.parseComma() ||
1673 parser.parseOperand(cmp) || parser.parseComma() ||
1674 parser.parseOperand(val) ||
1675 parseAtomicOrdering(parser, result, "success_ordering") ||
1676 parseAtomicOrdering(parser, result, "failure_ordering") ||
1677 parser.parseOptionalAttrDict(result.attributes) ||
1678 parser.parseColonType(type) ||
1679 parser.resolveOperand(ptr, type.getPointerTo(), result.operands) ||
1680 parser.resolveOperand(cmp, type, result.operands) ||
1681 parser.resolveOperand(val, type, result.operands))
1682 return failure();
1683
1684 auto boolType = LLVMType::getInt1Ty(builder.getContext());
1685 auto resultType = LLVMType::getStructTy(type, boolType);
1686 result.addTypes(resultType);
1687
1688 return success();
1689 }
1690
verify(AtomicCmpXchgOp op)1691 static LogicalResult verify(AtomicCmpXchgOp op) {
1692 auto ptrType = op.ptr().getType().cast<LLVM::LLVMType>();
1693 if (!ptrType.isPointerTy())
1694 return op.emitOpError("expected LLVM IR pointer type for operand #0");
1695 auto cmpType = op.cmp().getType().cast<LLVM::LLVMType>();
1696 auto valType = op.val().getType().cast<LLVM::LLVMType>();
1697 if (cmpType != ptrType.getPointerElementTy() || cmpType != valType)
1698 return op.emitOpError("expected LLVM IR element type for operand #0 to "
1699 "match type for all other operands");
1700 if (!valType.isPointerTy() && !valType.isIntegerTy(8) &&
1701 !valType.isIntegerTy(16) && !valType.isIntegerTy(32) &&
1702 !valType.isIntegerTy(64) && !valType.isBFloatTy() &&
1703 !valType.isHalfTy() && !valType.isFloatTy() && !valType.isDoubleTy())
1704 return op.emitOpError("unexpected LLVM IR type");
1705 if (op.success_ordering() < AtomicOrdering::monotonic ||
1706 op.failure_ordering() < AtomicOrdering::monotonic)
1707 return op.emitOpError("ordering must be at least 'monotonic'");
1708 if (op.failure_ordering() == AtomicOrdering::release ||
1709 op.failure_ordering() == AtomicOrdering::acq_rel)
1710 return op.emitOpError("failure ordering cannot be 'release' or 'acq_rel'");
1711 return success();
1712 }
1713
1714 //===----------------------------------------------------------------------===//
1715 // Printer, parser and verifier for LLVM::FenceOp.
1716 //===----------------------------------------------------------------------===//
1717
1718 // <operation> ::= `llvm.fence` (`syncscope(`strAttr`)`)? keyword
1719 // attribute-dict?
parseFenceOp(OpAsmParser & parser,OperationState & result)1720 static ParseResult parseFenceOp(OpAsmParser &parser, OperationState &result) {
1721 StringAttr sScope;
1722 StringRef syncscopeKeyword = "syncscope";
1723 if (!failed(parser.parseOptionalKeyword(syncscopeKeyword))) {
1724 if (parser.parseLParen() ||
1725 parser.parseAttribute(sScope, syncscopeKeyword, result.attributes) ||
1726 parser.parseRParen())
1727 return failure();
1728 } else {
1729 result.addAttribute(syncscopeKeyword,
1730 parser.getBuilder().getStringAttr(""));
1731 }
1732 if (parseAtomicOrdering(parser, result, "ordering") ||
1733 parser.parseOptionalAttrDict(result.attributes))
1734 return failure();
1735 return success();
1736 }
1737
printFenceOp(OpAsmPrinter & p,FenceOp & op)1738 static void printFenceOp(OpAsmPrinter &p, FenceOp &op) {
1739 StringRef syncscopeKeyword = "syncscope";
1740 p << op.getOperationName() << ' ';
1741 if (!op.getAttr(syncscopeKeyword).cast<StringAttr>().getValue().empty())
1742 p << "syncscope(" << op.getAttr(syncscopeKeyword) << ") ";
1743 p << stringifyAtomicOrdering(op.ordering());
1744 }
1745
verify(FenceOp & op)1746 static LogicalResult verify(FenceOp &op) {
1747 if (op.ordering() == AtomicOrdering::not_atomic ||
1748 op.ordering() == AtomicOrdering::unordered ||
1749 op.ordering() == AtomicOrdering::monotonic)
1750 return op.emitOpError("can be given only acquire, release, acq_rel, "
1751 "and seq_cst orderings");
1752 return success();
1753 }
1754
1755 //===----------------------------------------------------------------------===//
1756 // LLVMDialect initialization, type parsing, and registration.
1757 //===----------------------------------------------------------------------===//
1758
initialize()1759 void LLVMDialect::initialize() {
1760 // clang-format off
1761 addTypes<LLVMVoidType,
1762 LLVMHalfType,
1763 LLVMBFloatType,
1764 LLVMFloatType,
1765 LLVMDoubleType,
1766 LLVMFP128Type,
1767 LLVMX86FP80Type,
1768 LLVMPPCFP128Type,
1769 LLVMX86MMXType,
1770 LLVMTokenType,
1771 LLVMLabelType,
1772 LLVMMetadataType,
1773 LLVMFunctionType,
1774 LLVMIntegerType,
1775 LLVMPointerType,
1776 LLVMFixedVectorType,
1777 LLVMScalableVectorType,
1778 LLVMArrayType,
1779 LLVMStructType>();
1780 // clang-format on
1781 addOperations<
1782 #define GET_OP_LIST
1783 #include "mlir/Dialect/LLVMIR/LLVMOps.cpp.inc"
1784 >();
1785
1786 // Support unknown operations because not all LLVM operations are registered.
1787 allowUnknownOperations();
1788 }
1789
1790 #define GET_OP_CLASSES
1791 #include "mlir/Dialect/LLVMIR/LLVMOps.cpp.inc"
1792
1793 /// Parse a type registered to this dialect.
parseType(DialectAsmParser & parser) const1794 Type LLVMDialect::parseType(DialectAsmParser &parser) const {
1795 return detail::parseType(parser);
1796 }
1797
1798 /// Print a type registered to this dialect.
printType(Type type,DialectAsmPrinter & os) const1799 void LLVMDialect::printType(Type type, DialectAsmPrinter &os) const {
1800 return detail::printType(type.cast<LLVMType>(), os);
1801 }
1802
verifyDataLayoutString(StringRef descr,llvm::function_ref<void (const Twine &)> reportError)1803 LogicalResult LLVMDialect::verifyDataLayoutString(
1804 StringRef descr, llvm::function_ref<void(const Twine &)> reportError) {
1805 llvm::Expected<llvm::DataLayout> maybeDataLayout =
1806 llvm::DataLayout::parse(descr);
1807 if (maybeDataLayout)
1808 return success();
1809
1810 std::string message;
1811 llvm::raw_string_ostream messageStream(message);
1812 llvm::logAllUnhandledErrors(maybeDataLayout.takeError(), messageStream);
1813 reportError("invalid data layout descriptor: " + messageStream.str());
1814 return failure();
1815 }
1816
1817 /// Verify LLVM dialect attributes.
verifyOperationAttribute(Operation * op,NamedAttribute attr)1818 LogicalResult LLVMDialect::verifyOperationAttribute(Operation *op,
1819 NamedAttribute attr) {
1820 // If the data layout attribute is present, it must use the LLVM data layout
1821 // syntax. Try parsing it and report errors in case of failure. Users of this
1822 // attribute may assume it is well-formed and can pass it to the (asserting)
1823 // llvm::DataLayout constructor.
1824 if (attr.first.strref() != LLVM::LLVMDialect::getDataLayoutAttrName())
1825 return success();
1826 if (auto stringAttr = attr.second.dyn_cast<StringAttr>())
1827 return verifyDataLayoutString(
1828 stringAttr.getValue(),
1829 [op](const Twine &message) { op->emitOpError() << message.str(); });
1830
1831 return op->emitOpError() << "expected '"
1832 << LLVM::LLVMDialect::getDataLayoutAttrName()
1833 << "' to be a string attribute";
1834 }
1835
1836 /// Verify LLVMIR function argument attributes.
verifyRegionArgAttribute(Operation * op,unsigned regionIdx,unsigned argIdx,NamedAttribute argAttr)1837 LogicalResult LLVMDialect::verifyRegionArgAttribute(Operation *op,
1838 unsigned regionIdx,
1839 unsigned argIdx,
1840 NamedAttribute argAttr) {
1841 // Check that llvm.noalias is a boolean attribute.
1842 if (argAttr.first == LLVMDialect::getNoAliasAttrName() &&
1843 !argAttr.second.isa<BoolAttr>())
1844 return op->emitError()
1845 << "llvm.noalias argument attribute of non boolean type";
1846 // Check that llvm.align is an integer attribute.
1847 if (argAttr.first == LLVMDialect::getAlignAttrName() &&
1848 !argAttr.second.isa<IntegerAttr>())
1849 return op->emitError()
1850 << "llvm.align argument attribute of non integer type";
1851 return success();
1852 }
1853
1854 //===----------------------------------------------------------------------===//
1855 // Utility functions.
1856 //===----------------------------------------------------------------------===//
1857
createGlobalString(Location loc,OpBuilder & builder,StringRef name,StringRef value,LLVM::Linkage linkage)1858 Value mlir::LLVM::createGlobalString(Location loc, OpBuilder &builder,
1859 StringRef name, StringRef value,
1860 LLVM::Linkage linkage) {
1861 assert(builder.getInsertionBlock() &&
1862 builder.getInsertionBlock()->getParentOp() &&
1863 "expected builder to point to a block constrained in an op");
1864 auto module =
1865 builder.getInsertionBlock()->getParentOp()->getParentOfType<ModuleOp>();
1866 assert(module && "builder points to an op outside of a module");
1867
1868 // Create the global at the entry of the module.
1869 OpBuilder moduleBuilder(module.getBodyRegion());
1870 MLIRContext *ctx = builder.getContext();
1871 auto type =
1872 LLVM::LLVMType::getArrayTy(LLVM::LLVMType::getInt8Ty(ctx), value.size());
1873 auto global = moduleBuilder.create<LLVM::GlobalOp>(
1874 loc, type, /*isConstant=*/true, linkage, name,
1875 builder.getStringAttr(value));
1876
1877 // Get the pointer to the first character in the global string.
1878 Value globalPtr = builder.create<LLVM::AddressOfOp>(loc, global);
1879 Value cst0 = builder.create<LLVM::ConstantOp>(
1880 loc, LLVM::LLVMType::getInt64Ty(ctx),
1881 builder.getIntegerAttr(builder.getIndexType(), 0));
1882 return builder.create<LLVM::GEPOp>(loc, LLVM::LLVMType::getInt8PtrTy(ctx),
1883 globalPtr, ValueRange{cst0, cst0});
1884 }
1885
satisfiesLLVMModule(Operation * op)1886 bool mlir::LLVM::satisfiesLLVMModule(Operation *op) {
1887 return op->hasTrait<OpTrait::SymbolTable>() &&
1888 op->hasTrait<OpTrait::IsIsolatedFromAbove>();
1889 }
1890