1 //===- OpImplementation.h - Classes for implementing Op types ---*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This classes used by the implementation details of Op types. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #ifndef MLIR_IR_OPIMPLEMENTATION_H 14 #define MLIR_IR_OPIMPLEMENTATION_H 15 16 #include "mlir/IR/BuiltinTypes.h" 17 #include "mlir/IR/DialectInterface.h" 18 #include "mlir/IR/OpDefinition.h" 19 #include "llvm/ADT/Twine.h" 20 #include "llvm/Support/SMLoc.h" 21 #include "llvm/Support/raw_ostream.h" 22 23 namespace mlir { 24 25 class Builder; 26 27 //===----------------------------------------------------------------------===// 28 // OpAsmPrinter 29 //===----------------------------------------------------------------------===// 30 31 /// This is a pure-virtual base class that exposes the asmprinter hooks 32 /// necessary to implement a custom print() method. 33 class OpAsmPrinter { 34 public: OpAsmPrinter()35 OpAsmPrinter() {} 36 virtual ~OpAsmPrinter(); 37 virtual raw_ostream &getStream() const = 0; 38 39 /// Print implementations for various things an operation contains. 40 virtual void printOperand(Value value) = 0; 41 virtual void printOperand(Value value, raw_ostream &os) = 0; 42 43 /// Print a comma separated list of operands. 44 template <typename ContainerType> printOperands(const ContainerType & container)45 void printOperands(const ContainerType &container) { 46 printOperands(container.begin(), container.end()); 47 } 48 49 /// Print a comma separated list of operands. 50 template <typename IteratorType> printOperands(IteratorType it,IteratorType end)51 void printOperands(IteratorType it, IteratorType end) { 52 if (it == end) 53 return; 54 printOperand(*it); 55 for (++it; it != end; ++it) { 56 getStream() << ", "; 57 printOperand(*it); 58 } 59 } 60 virtual void printType(Type type) = 0; 61 virtual void printAttribute(Attribute attr) = 0; 62 63 /// Print the given attribute without its type. The corresponding parser must 64 /// provide a valid type for the attribute. 65 virtual void printAttributeWithoutType(Attribute attr) = 0; 66 67 /// Print the given successor. 68 virtual void printSuccessor(Block *successor) = 0; 69 70 /// Print the successor and its operands. 71 virtual void printSuccessorAndUseList(Block *successor, 72 ValueRange succOperands) = 0; 73 74 /// If the specified operation has attributes, print out an attribute 75 /// dictionary with their values. elidedAttrs allows the client to ignore 76 /// specific well known attributes, commonly used if the attribute value is 77 /// printed some other way (like as a fixed operand). 78 virtual void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs, 79 ArrayRef<StringRef> elidedAttrs = {}) = 0; 80 81 /// If the specified operation has attributes, print out an attribute 82 /// dictionary prefixed with 'attributes'. 83 virtual void 84 printOptionalAttrDictWithKeyword(ArrayRef<NamedAttribute> attrs, 85 ArrayRef<StringRef> elidedAttrs = {}) = 0; 86 87 /// Print the entire operation with the default generic assembly form. 88 virtual void printGenericOp(Operation *op) = 0; 89 90 /// Prints a region. 91 virtual void printRegion(Region &blocks, bool printEntryBlockArgs = true, 92 bool printBlockTerminators = true) = 0; 93 94 /// Renumber the arguments for the specified region to the same names as the 95 /// SSA values in namesToUse. This may only be used for IsolatedFromAbove 96 /// operations. If any entry in namesToUse is null, the corresponding 97 /// argument name is left alone. 98 virtual void shadowRegionArgs(Region ®ion, ValueRange namesToUse) = 0; 99 100 /// Prints an affine map of SSA ids, where SSA id names are used in place 101 /// of dims/symbols. 102 /// Operand values must come from single-result sources, and be valid 103 /// dimensions/symbol identifiers according to mlir::isValidDim/Symbol. 104 virtual void printAffineMapOfSSAIds(AffineMapAttr mapAttr, 105 ValueRange operands) = 0; 106 107 /// Print an optional arrow followed by a type list. 108 template <typename TypeRange> printOptionalArrowTypeList(TypeRange && types)109 void printOptionalArrowTypeList(TypeRange &&types) { 110 if (types.begin() != types.end()) 111 printArrowTypeList(types); 112 } 113 template <typename TypeRange> printArrowTypeList(TypeRange && types)114 void printArrowTypeList(TypeRange &&types) { 115 auto &os = getStream() << " -> "; 116 117 bool wrapped = !llvm::hasSingleElement(types) || 118 (*types.begin()).template isa<FunctionType>(); 119 if (wrapped) 120 os << '('; 121 llvm::interleaveComma(types, *this); 122 if (wrapped) 123 os << ')'; 124 } 125 126 /// Print the complete type of an operation in functional form. printFunctionalType(Operation * op)127 void printFunctionalType(Operation *op) { 128 printFunctionalType(op->getOperandTypes(), op->getResultTypes()); 129 } 130 /// Print the two given type ranges in a functional form. 131 template <typename InputRangeT, typename ResultRangeT> printFunctionalType(InputRangeT && inputs,ResultRangeT && results)132 void printFunctionalType(InputRangeT &&inputs, ResultRangeT &&results) { 133 auto &os = getStream(); 134 os << "("; 135 llvm::interleaveComma(inputs, *this); 136 os << ")"; 137 printArrowTypeList(results); 138 } 139 140 /// Print the given string as a symbol reference, i.e. a form representable by 141 /// a SymbolRefAttr. A symbol reference is represented as a string prefixed 142 /// with '@'. The reference is surrounded with ""'s and escaped if it has any 143 /// special or non-printable characters in it. 144 virtual void printSymbolName(StringRef symbolRef) = 0; 145 146 private: 147 OpAsmPrinter(const OpAsmPrinter &) = delete; 148 void operator=(const OpAsmPrinter &) = delete; 149 }; 150 151 // Make the implementations convenient to use. 152 inline OpAsmPrinter &operator<<(OpAsmPrinter &p, Value value) { 153 p.printOperand(value); 154 return p; 155 } 156 157 template <typename T, 158 typename std::enable_if<std::is_convertible<T &, ValueRange>::value && 159 !std::is_convertible<T &, Value &>::value, 160 T>::type * = nullptr> 161 inline OpAsmPrinter &operator<<(OpAsmPrinter &p, const T &values) { 162 p.printOperands(values); 163 return p; 164 } 165 166 inline OpAsmPrinter &operator<<(OpAsmPrinter &p, Type type) { 167 p.printType(type); 168 return p; 169 } 170 171 inline OpAsmPrinter &operator<<(OpAsmPrinter &p, Attribute attr) { 172 p.printAttribute(attr); 173 return p; 174 } 175 176 // Support printing anything that isn't convertible to one of the above types, 177 // even if it isn't exactly one of them. For example, we want to print 178 // FunctionType with the Type version above, not have it match this. 179 template <typename T, typename std::enable_if< 180 !std::is_convertible<T &, Value &>::value && 181 !std::is_convertible<T &, Type &>::value && 182 !std::is_convertible<T &, Attribute &>::value && 183 !std::is_convertible<T &, ValueRange>::value && 184 !llvm::is_one_of<T, bool>::value, 185 T>::type * = nullptr> 186 inline OpAsmPrinter &operator<<(OpAsmPrinter &p, const T &other) { 187 p.getStream() << other; 188 return p; 189 } 190 191 inline OpAsmPrinter &operator<<(OpAsmPrinter &p, bool value) { 192 return p << (value ? StringRef("true") : "false"); 193 } 194 195 inline OpAsmPrinter &operator<<(OpAsmPrinter &p, Block *value) { 196 p.printSuccessor(value); 197 return p; 198 } 199 200 template <typename ValueRangeT> 201 inline OpAsmPrinter &operator<<(OpAsmPrinter &p, 202 const ValueTypeRange<ValueRangeT> &types) { 203 llvm::interleaveComma(types, p); 204 return p; 205 } 206 inline OpAsmPrinter &operator<<(OpAsmPrinter &p, const TypeRange &types) { 207 llvm::interleaveComma(types, p); 208 return p; 209 } 210 inline OpAsmPrinter &operator<<(OpAsmPrinter &p, ArrayRef<Type> types) { 211 llvm::interleaveComma(types, p); 212 return p; 213 } 214 215 //===----------------------------------------------------------------------===// 216 // OpAsmParser 217 //===----------------------------------------------------------------------===// 218 219 /// The OpAsmParser has methods for interacting with the asm parser: parsing 220 /// things from it, emitting errors etc. It has an intentionally high-level API 221 /// that is designed to reduce/constrain syntax innovation in individual 222 /// operations. 223 /// 224 /// For example, consider an op like this: 225 /// 226 /// %x = load %p[%1, %2] : memref<...> 227 /// 228 /// The "%x = load" tokens are already parsed and therefore invisible to the 229 /// custom op parser. This can be supported by calling `parseOperandList` to 230 /// parse the %p, then calling `parseOperandList` with a `SquareDelimiter` to 231 /// parse the indices, then calling `parseColonTypeList` to parse the result 232 /// type. 233 /// 234 class OpAsmParser { 235 public: 236 virtual ~OpAsmParser(); 237 238 /// Emit a diagnostic at the specified location and return failure. 239 virtual InFlightDiagnostic emitError(llvm::SMLoc loc, 240 const Twine &message = {}) = 0; 241 242 /// Return a builder which provides useful access to MLIRContext, global 243 /// objects like types and attributes. 244 virtual Builder &getBuilder() const = 0; 245 246 /// Get the location of the next token and store it into the argument. This 247 /// always succeeds. 248 virtual llvm::SMLoc getCurrentLocation() = 0; getCurrentLocation(llvm::SMLoc * loc)249 ParseResult getCurrentLocation(llvm::SMLoc *loc) { 250 *loc = getCurrentLocation(); 251 return success(); 252 } 253 254 /// Return the name of the specified result in the specified syntax, as well 255 /// as the sub-element in the name. It returns an empty string and ~0U for 256 /// invalid result numbers. For example, in this operation: 257 /// 258 /// %x, %y:2, %z = foo.op 259 /// 260 /// getResultName(0) == {"x", 0 } 261 /// getResultName(1) == {"y", 0 } 262 /// getResultName(2) == {"y", 1 } 263 /// getResultName(3) == {"z", 0 } 264 /// getResultName(4) == {"", ~0U } 265 virtual std::pair<StringRef, unsigned> 266 getResultName(unsigned resultNo) const = 0; 267 268 /// Return the number of declared SSA results. This returns 4 for the foo.op 269 /// example in the comment for `getResultName`. 270 virtual size_t getNumResults() const = 0; 271 272 /// Return the location of the original name token. 273 virtual llvm::SMLoc getNameLoc() const = 0; 274 275 // These methods emit an error and return failure or success. This allows 276 // these to be chained together into a linear sequence of || expressions in 277 // many cases. 278 279 /// Parse an operation in its generic form. 280 /// The parsed operation is parsed in the current context and inserted in the 281 /// provided block and insertion point. The results produced by this operation 282 /// aren't mapped to any named value in the parser. Returns nullptr on 283 /// failure. 284 virtual Operation *parseGenericOperation(Block *insertBlock, 285 Block::iterator insertPt) = 0; 286 287 //===--------------------------------------------------------------------===// 288 // Token Parsing 289 //===--------------------------------------------------------------------===// 290 291 /// Parse a '->' token. 292 virtual ParseResult parseArrow() = 0; 293 294 /// Parse a '->' token if present 295 virtual ParseResult parseOptionalArrow() = 0; 296 297 /// Parse a `{` token. 298 virtual ParseResult parseLBrace() = 0; 299 300 /// Parse a `{` token if present. 301 virtual ParseResult parseOptionalLBrace() = 0; 302 303 /// Parse a `}` token. 304 virtual ParseResult parseRBrace() = 0; 305 306 /// Parse a `}` token if present. 307 virtual ParseResult parseOptionalRBrace() = 0; 308 309 /// Parse a `:` token. 310 virtual ParseResult parseColon() = 0; 311 312 /// Parse a `:` token if present. 313 virtual ParseResult parseOptionalColon() = 0; 314 315 /// Parse a `,` token. 316 virtual ParseResult parseComma() = 0; 317 318 /// Parse a `,` token if present. 319 virtual ParseResult parseOptionalComma() = 0; 320 321 /// Parse a `=` token. 322 virtual ParseResult parseEqual() = 0; 323 324 /// Parse a `=` token if present. 325 virtual ParseResult parseOptionalEqual() = 0; 326 327 /// Parse a '<' token. 328 virtual ParseResult parseLess() = 0; 329 330 /// Parse a '<' token if present. 331 virtual ParseResult parseOptionalLess() = 0; 332 333 /// Parse a '>' token. 334 virtual ParseResult parseGreater() = 0; 335 336 /// Parse a '>' token if present. 337 virtual ParseResult parseOptionalGreater() = 0; 338 339 /// Parse a '?' token. 340 virtual ParseResult parseQuestion() = 0; 341 342 /// Parse a '?' token if present. 343 virtual ParseResult parseOptionalQuestion() = 0; 344 345 /// Parse a '+' token. 346 virtual ParseResult parsePlus() = 0; 347 348 /// Parse a '+' token if present. 349 virtual ParseResult parseOptionalPlus() = 0; 350 351 /// Parse a '*' token. 352 virtual ParseResult parseStar() = 0; 353 354 /// Parse a '*' token if present. 355 virtual ParseResult parseOptionalStar() = 0; 356 357 /// Parse a given keyword. 358 ParseResult parseKeyword(StringRef keyword, const Twine &msg = "") { 359 auto loc = getCurrentLocation(); 360 if (parseOptionalKeyword(keyword)) 361 return emitError(loc, "expected '") << keyword << "'" << msg; 362 return success(); 363 } 364 365 /// Parse a keyword into 'keyword'. parseKeyword(StringRef * keyword)366 ParseResult parseKeyword(StringRef *keyword) { 367 auto loc = getCurrentLocation(); 368 if (parseOptionalKeyword(keyword)) 369 return emitError(loc, "expected valid keyword"); 370 return success(); 371 } 372 373 /// Parse the given keyword if present. 374 virtual ParseResult parseOptionalKeyword(StringRef keyword) = 0; 375 376 /// Parse a keyword, if present, into 'keyword'. 377 virtual ParseResult parseOptionalKeyword(StringRef *keyword) = 0; 378 379 /// Parse a keyword, if present, and if one of the 'allowedValues', 380 /// into 'keyword' 381 virtual ParseResult 382 parseOptionalKeyword(StringRef *keyword, 383 ArrayRef<StringRef> allowedValues) = 0; 384 385 /// Parse a `(` token. 386 virtual ParseResult parseLParen() = 0; 387 388 /// Parse a `(` token if present. 389 virtual ParseResult parseOptionalLParen() = 0; 390 391 /// Parse a `)` token. 392 virtual ParseResult parseRParen() = 0; 393 394 /// Parse a `)` token if present. 395 virtual ParseResult parseOptionalRParen() = 0; 396 397 /// Parse a `[` token. 398 virtual ParseResult parseLSquare() = 0; 399 400 /// Parse a `[` token if present. 401 virtual ParseResult parseOptionalLSquare() = 0; 402 403 /// Parse a `]` token. 404 virtual ParseResult parseRSquare() = 0; 405 406 /// Parse a `]` token if present. 407 virtual ParseResult parseOptionalRSquare() = 0; 408 409 /// Parse a `...` token if present; 410 virtual ParseResult parseOptionalEllipsis() = 0; 411 412 //===--------------------------------------------------------------------===// 413 // Attribute Parsing 414 //===--------------------------------------------------------------------===// 415 416 /// Parse an arbitrary attribute of a given type and return it in result. 417 virtual ParseResult parseAttribute(Attribute &result, Type type = {}) = 0; 418 419 /// Parse an attribute of a specific kind and type. 420 template <typename AttrType> 421 ParseResult parseAttribute(AttrType &result, Type type = {}) { 422 llvm::SMLoc loc = getCurrentLocation(); 423 424 // Parse any kind of attribute. 425 Attribute attr; 426 if (parseAttribute(attr, type)) 427 return failure(); 428 429 // Check for the right kind of attribute. 430 if (!(result = attr.dyn_cast<AttrType>())) 431 return emitError(loc, "invalid kind of attribute specified"); 432 433 return success(); 434 } 435 436 /// Parse an arbitrary attribute and return it in result. This also adds the 437 /// attribute to the specified attribute list with the specified name. parseAttribute(Attribute & result,StringRef attrName,NamedAttrList & attrs)438 ParseResult parseAttribute(Attribute &result, StringRef attrName, 439 NamedAttrList &attrs) { 440 return parseAttribute(result, Type(), attrName, attrs); 441 } 442 443 /// Parse an attribute of a specific kind and type. 444 template <typename AttrType> parseAttribute(AttrType & result,StringRef attrName,NamedAttrList & attrs)445 ParseResult parseAttribute(AttrType &result, StringRef attrName, 446 NamedAttrList &attrs) { 447 return parseAttribute(result, Type(), attrName, attrs); 448 } 449 450 /// Parse an optional attribute. 451 virtual OptionalParseResult parseOptionalAttribute(Attribute &result, 452 Type type, 453 StringRef attrName, 454 NamedAttrList &attrs) = 0; 455 template <typename AttrT> parseOptionalAttribute(AttrT & result,StringRef attrName,NamedAttrList & attrs)456 OptionalParseResult parseOptionalAttribute(AttrT &result, StringRef attrName, 457 NamedAttrList &attrs) { 458 return parseOptionalAttribute(result, Type(), attrName, attrs); 459 } 460 461 /// Specialized variants of `parseOptionalAttribute` that remove potential 462 /// ambiguities in syntax. 463 virtual OptionalParseResult parseOptionalAttribute(ArrayAttr &result, 464 Type type, 465 StringRef attrName, 466 NamedAttrList &attrs) = 0; 467 virtual OptionalParseResult parseOptionalAttribute(StringAttr &result, 468 Type type, 469 StringRef attrName, 470 NamedAttrList &attrs) = 0; 471 472 /// Parse an arbitrary attribute of a given type and return it in result. This 473 /// also adds the attribute to the specified attribute list with the specified 474 /// name. 475 template <typename AttrType> parseAttribute(AttrType & result,Type type,StringRef attrName,NamedAttrList & attrs)476 ParseResult parseAttribute(AttrType &result, Type type, StringRef attrName, 477 NamedAttrList &attrs) { 478 llvm::SMLoc loc = getCurrentLocation(); 479 480 // Parse any kind of attribute. 481 Attribute attr; 482 if (parseAttribute(attr, type)) 483 return failure(); 484 485 // Check for the right kind of attribute. 486 result = attr.dyn_cast<AttrType>(); 487 if (!result) 488 return emitError(loc, "invalid kind of attribute specified"); 489 490 attrs.append(attrName, result); 491 return success(); 492 } 493 494 /// Parse a named dictionary into 'result' if it is present. 495 virtual ParseResult parseOptionalAttrDict(NamedAttrList &result) = 0; 496 497 /// Parse a named dictionary into 'result' if the `attributes` keyword is 498 /// present. 499 virtual ParseResult 500 parseOptionalAttrDictWithKeyword(NamedAttrList &result) = 0; 501 502 /// Parse an affine map instance into 'map'. 503 virtual ParseResult parseAffineMap(AffineMap &map) = 0; 504 505 /// Parse an integer set instance into 'set'. 506 virtual ParseResult printIntegerSet(IntegerSet &set) = 0; 507 508 //===--------------------------------------------------------------------===// 509 // Identifier Parsing 510 //===--------------------------------------------------------------------===// 511 512 /// Parse an @-identifier and store it (without the '@' symbol) in a string 513 /// attribute named 'attrName'. parseSymbolName(StringAttr & result,StringRef attrName,NamedAttrList & attrs)514 ParseResult parseSymbolName(StringAttr &result, StringRef attrName, 515 NamedAttrList &attrs) { 516 if (failed(parseOptionalSymbolName(result, attrName, attrs))) 517 return emitError(getCurrentLocation()) 518 << "expected valid '@'-identifier for symbol name"; 519 return success(); 520 } 521 522 /// Parse an optional @-identifier and store it (without the '@' symbol) in a 523 /// string attribute named 'attrName'. 524 virtual ParseResult parseOptionalSymbolName(StringAttr &result, 525 StringRef attrName, 526 NamedAttrList &attrs) = 0; 527 528 //===--------------------------------------------------------------------===// 529 // Operand Parsing 530 //===--------------------------------------------------------------------===// 531 532 /// This is the representation of an operand reference. 533 struct OperandType { 534 llvm::SMLoc location; // Location of the token. 535 StringRef name; // Value name, e.g. %42 or %abc 536 unsigned number; // Number, e.g. 12 for an operand like %xyz#12 537 }; 538 539 /// Parse a single operand. 540 virtual ParseResult parseOperand(OperandType &result) = 0; 541 542 /// Parse a single operand if present. 543 virtual OptionalParseResult parseOptionalOperand(OperandType &result) = 0; 544 545 /// These are the supported delimiters around operand lists and region 546 /// argument lists, used by parseOperandList and parseRegionArgumentList. 547 enum class Delimiter { 548 /// Zero or more operands with no delimiters. 549 None, 550 /// Parens surrounding zero or more operands. 551 Paren, 552 /// Square brackets surrounding zero or more operands. 553 Square, 554 /// Parens supporting zero or more operands, or nothing. 555 OptionalParen, 556 /// Square brackets supporting zero or more ops, or nothing. 557 OptionalSquare, 558 }; 559 560 /// Parse zero or more SSA comma-separated operand references with a specified 561 /// surrounding delimiter, and an optional required operand count. 562 virtual ParseResult 563 parseOperandList(SmallVectorImpl<OperandType> &result, 564 int requiredOperandCount = -1, 565 Delimiter delimiter = Delimiter::None) = 0; parseOperandList(SmallVectorImpl<OperandType> & result,Delimiter delimiter)566 ParseResult parseOperandList(SmallVectorImpl<OperandType> &result, 567 Delimiter delimiter) { 568 return parseOperandList(result, /*requiredOperandCount=*/-1, delimiter); 569 } 570 571 /// Parse zero or more trailing SSA comma-separated trailing operand 572 /// references with a specified surrounding delimiter, and an optional 573 /// required operand count. A leading comma is expected before the operands. 574 virtual ParseResult 575 parseTrailingOperandList(SmallVectorImpl<OperandType> &result, 576 int requiredOperandCount = -1, 577 Delimiter delimiter = Delimiter::None) = 0; parseTrailingOperandList(SmallVectorImpl<OperandType> & result,Delimiter delimiter)578 ParseResult parseTrailingOperandList(SmallVectorImpl<OperandType> &result, 579 Delimiter delimiter) { 580 return parseTrailingOperandList(result, /*requiredOperandCount=*/-1, 581 delimiter); 582 } 583 584 /// Resolve an operand to an SSA value, emitting an error on failure. 585 virtual ParseResult resolveOperand(const OperandType &operand, Type type, 586 SmallVectorImpl<Value> &result) = 0; 587 588 /// Resolve a list of operands to SSA values, emitting an error on failure, or 589 /// appending the results to the list on success. This method should be used 590 /// when all operands have the same type. resolveOperands(ArrayRef<OperandType> operands,Type type,SmallVectorImpl<Value> & result)591 ParseResult resolveOperands(ArrayRef<OperandType> operands, Type type, 592 SmallVectorImpl<Value> &result) { 593 for (auto elt : operands) 594 if (resolveOperand(elt, type, result)) 595 return failure(); 596 return success(); 597 } 598 599 /// Resolve a list of operands and a list of operand types to SSA values, 600 /// emitting an error and returning failure, or appending the results 601 /// to the list on success. resolveOperands(ArrayRef<OperandType> operands,ArrayRef<Type> types,llvm::SMLoc loc,SmallVectorImpl<Value> & result)602 ParseResult resolveOperands(ArrayRef<OperandType> operands, 603 ArrayRef<Type> types, llvm::SMLoc loc, 604 SmallVectorImpl<Value> &result) { 605 if (operands.size() != types.size()) 606 return emitError(loc) 607 << operands.size() << " operands present, but expected " 608 << types.size(); 609 610 for (unsigned i = 0, e = operands.size(); i != e; ++i) 611 if (resolveOperand(operands[i], types[i], result)) 612 return failure(); 613 return success(); 614 } 615 template <typename Operands> resolveOperands(Operands && operands,Type type,llvm::SMLoc loc,SmallVectorImpl<Value> & result)616 ParseResult resolveOperands(Operands &&operands, Type type, llvm::SMLoc loc, 617 SmallVectorImpl<Value> &result) { 618 return resolveOperands(std::forward<Operands>(operands), 619 ArrayRef<Type>(type), loc, result); 620 } 621 template <typename Operands, typename Types> 622 std::enable_if_t<!std::is_convertible<Types, Type>::value, ParseResult> resolveOperands(Operands && operands,Types && types,llvm::SMLoc loc,SmallVectorImpl<Value> & result)623 resolveOperands(Operands &&operands, Types &&types, llvm::SMLoc loc, 624 SmallVectorImpl<Value> &result) { 625 size_t operandSize = std::distance(operands.begin(), operands.end()); 626 size_t typeSize = std::distance(types.begin(), types.end()); 627 if (operandSize != typeSize) 628 return emitError(loc) 629 << operandSize << " operands present, but expected " << typeSize; 630 631 for (auto it : llvm::zip(operands, types)) 632 if (resolveOperand(std::get<0>(it), std::get<1>(it), result)) 633 return failure(); 634 return success(); 635 } 636 637 /// Parses an affine map attribute where dims and symbols are SSA operands. 638 /// Operand values must come from single-result sources, and be valid 639 /// dimensions/symbol identifiers according to mlir::isValidDim/Symbol. 640 virtual ParseResult 641 parseAffineMapOfSSAIds(SmallVectorImpl<OperandType> &operands, Attribute &map, 642 StringRef attrName, NamedAttrList &attrs, 643 Delimiter delimiter = Delimiter::Square) = 0; 644 645 //===--------------------------------------------------------------------===// 646 // Region Parsing 647 //===--------------------------------------------------------------------===// 648 649 /// Parses a region. Any parsed blocks are appended to 'region' and must be 650 /// moved to the op regions after the op is created. The first block of the 651 /// region takes 'arguments' of types 'argTypes'. If 'enableNameShadowing' is 652 /// set to true, the argument names are allowed to shadow the names of other 653 /// existing SSA values defined above the region scope. 'enableNameShadowing' 654 /// can only be set to true for regions attached to operations that are 655 /// 'IsolatedFromAbove. 656 virtual ParseResult parseRegion(Region ®ion, 657 ArrayRef<OperandType> arguments = {}, 658 ArrayRef<Type> argTypes = {}, 659 bool enableNameShadowing = false) = 0; 660 661 /// Parses a region if present. 662 virtual OptionalParseResult 663 parseOptionalRegion(Region ®ion, ArrayRef<OperandType> arguments = {}, 664 ArrayRef<Type> argTypes = {}, 665 bool enableNameShadowing = false) = 0; 666 667 /// Parses a region if present. If the region is present, a new region is 668 /// allocated and placed in `region`. If no region is present or on failure, 669 /// `region` remains untouched. 670 virtual OptionalParseResult parseOptionalRegion( 671 std::unique_ptr<Region> ®ion, ArrayRef<OperandType> arguments = {}, 672 ArrayRef<Type> argTypes = {}, bool enableNameShadowing = false) = 0; 673 674 /// Parse a region argument, this argument is resolved when calling 675 /// 'parseRegion'. 676 virtual ParseResult parseRegionArgument(OperandType &argument) = 0; 677 678 /// Parse zero or more region arguments with a specified surrounding 679 /// delimiter, and an optional required argument count. Region arguments 680 /// define new values; so this also checks if values with the same names have 681 /// not been defined yet. 682 virtual ParseResult 683 parseRegionArgumentList(SmallVectorImpl<OperandType> &result, 684 int requiredOperandCount = -1, 685 Delimiter delimiter = Delimiter::None) = 0; 686 virtual ParseResult parseRegionArgumentList(SmallVectorImpl<OperandType> & result,Delimiter delimiter)687 parseRegionArgumentList(SmallVectorImpl<OperandType> &result, 688 Delimiter delimiter) { 689 return parseRegionArgumentList(result, /*requiredOperandCount=*/-1, 690 delimiter); 691 } 692 693 /// Parse a region argument if present. 694 virtual ParseResult parseOptionalRegionArgument(OperandType &argument) = 0; 695 696 //===--------------------------------------------------------------------===// 697 // Successor Parsing 698 //===--------------------------------------------------------------------===// 699 700 /// Parse a single operation successor. 701 virtual ParseResult parseSuccessor(Block *&dest) = 0; 702 703 /// Parse an optional operation successor. 704 virtual OptionalParseResult parseOptionalSuccessor(Block *&dest) = 0; 705 706 /// Parse a single operation successor and its operand list. 707 virtual ParseResult 708 parseSuccessorAndUseList(Block *&dest, SmallVectorImpl<Value> &operands) = 0; 709 710 //===--------------------------------------------------------------------===// 711 // Type Parsing 712 //===--------------------------------------------------------------------===// 713 714 /// Parse a type. 715 virtual ParseResult parseType(Type &result) = 0; 716 717 /// Parse an optional type. 718 virtual OptionalParseResult parseOptionalType(Type &result) = 0; 719 720 /// Parse a type of a specific type. 721 template <typename TypeT> parseType(TypeT & result)722 ParseResult parseType(TypeT &result) { 723 llvm::SMLoc loc = getCurrentLocation(); 724 725 // Parse any kind of type. 726 Type type; 727 if (parseType(type)) 728 return failure(); 729 730 // Check for the right kind of attribute. 731 result = type.dyn_cast<TypeT>(); 732 if (!result) 733 return emitError(loc, "invalid kind of type specified"); 734 735 return success(); 736 } 737 738 /// Parse a type list. parseTypeList(SmallVectorImpl<Type> & result)739 ParseResult parseTypeList(SmallVectorImpl<Type> &result) { 740 do { 741 Type type; 742 if (parseType(type)) 743 return failure(); 744 result.push_back(type); 745 } while (succeeded(parseOptionalComma())); 746 return success(); 747 } 748 749 /// Parse an arrow followed by a type list. 750 virtual ParseResult parseArrowTypeList(SmallVectorImpl<Type> &result) = 0; 751 752 /// Parse an optional arrow followed by a type list. 753 virtual ParseResult 754 parseOptionalArrowTypeList(SmallVectorImpl<Type> &result) = 0; 755 756 /// Parse a colon followed by a type. 757 virtual ParseResult parseColonType(Type &result) = 0; 758 759 /// Parse a colon followed by a type of a specific kind, e.g. a FunctionType. 760 template <typename TypeType> parseColonType(TypeType & result)761 ParseResult parseColonType(TypeType &result) { 762 llvm::SMLoc loc = getCurrentLocation(); 763 764 // Parse any kind of type. 765 Type type; 766 if (parseColonType(type)) 767 return failure(); 768 769 // Check for the right kind of attribute. 770 result = type.dyn_cast<TypeType>(); 771 if (!result) 772 return emitError(loc, "invalid kind of type specified"); 773 774 return success(); 775 } 776 777 /// Parse a colon followed by a type list, which must have at least one type. 778 virtual ParseResult parseColonTypeList(SmallVectorImpl<Type> &result) = 0; 779 780 /// Parse an optional colon followed by a type list, which if present must 781 /// have at least one type. 782 virtual ParseResult 783 parseOptionalColonTypeList(SmallVectorImpl<Type> &result) = 0; 784 785 /// Parse a list of assignments of the form 786 /// (%x1 = %y1, %x2 = %y2, ...) parseAssignmentList(SmallVectorImpl<OperandType> & lhs,SmallVectorImpl<OperandType> & rhs)787 ParseResult parseAssignmentList(SmallVectorImpl<OperandType> &lhs, 788 SmallVectorImpl<OperandType> &rhs) { 789 OptionalParseResult result = parseOptionalAssignmentList(lhs, rhs); 790 if (!result.hasValue()) 791 return emitError(getCurrentLocation(), "expected '('"); 792 return result.getValue(); 793 } 794 795 virtual OptionalParseResult 796 parseOptionalAssignmentList(SmallVectorImpl<OperandType> &lhs, 797 SmallVectorImpl<OperandType> &rhs) = 0; 798 799 /// Parse a keyword followed by a type. parseKeywordType(const char * keyword,Type & result)800 ParseResult parseKeywordType(const char *keyword, Type &result) { 801 return failure(parseKeyword(keyword) || parseType(result)); 802 } 803 804 /// Add the specified type to the end of the specified type list and return 805 /// success. This is a helper designed to allow parse methods to be simple 806 /// and chain through || operators. addTypeToList(Type type,SmallVectorImpl<Type> & result)807 ParseResult addTypeToList(Type type, SmallVectorImpl<Type> &result) { 808 result.push_back(type); 809 return success(); 810 } 811 812 /// Add the specified types to the end of the specified type list and return 813 /// success. This is a helper designed to allow parse methods to be simple 814 /// and chain through || operators. addTypesToList(ArrayRef<Type> types,SmallVectorImpl<Type> & result)815 ParseResult addTypesToList(ArrayRef<Type> types, 816 SmallVectorImpl<Type> &result) { 817 result.append(types.begin(), types.end()); 818 return success(); 819 } 820 821 private: 822 /// Parse either an operand list or a region argument list depending on 823 /// whether isOperandList is true. 824 ParseResult parseOperandOrRegionArgList(SmallVectorImpl<OperandType> &result, 825 bool isOperandList, 826 int requiredOperandCount, 827 Delimiter delimiter); 828 }; 829 830 //===--------------------------------------------------------------------===// 831 // Dialect OpAsm interface. 832 //===--------------------------------------------------------------------===// 833 834 /// A functor used to set the name of the start of a result group of an 835 /// operation. See 'getAsmResultNames' below for more details. 836 using OpAsmSetValueNameFn = function_ref<void(Value, StringRef)>; 837 838 class OpAsmDialectInterface 839 : public DialectInterface::Base<OpAsmDialectInterface> { 840 public: OpAsmDialectInterface(Dialect * dialect)841 OpAsmDialectInterface(Dialect *dialect) : Base(dialect) {} 842 843 /// Hooks for getting an alias identifier alias for a given symbol, that is 844 /// not necessarily a part of this dialect. The identifier is used in place of 845 /// the symbol when printing textual IR. These aliases must not contain `.` or 846 /// end with a numeric digit([0-9]+). Returns success if an alias was 847 /// provided, failure otherwise. getAlias(Attribute attr,raw_ostream & os)848 virtual LogicalResult getAlias(Attribute attr, raw_ostream &os) const { 849 return failure(); 850 } getAlias(Type type,raw_ostream & os)851 virtual LogicalResult getAlias(Type type, raw_ostream &os) const { 852 return failure(); 853 } 854 855 /// Get a special name to use when printing the given operation. See 856 /// OpAsmInterface.td#getAsmResultNames for usage details and documentation. getAsmResultNames(Operation * op,OpAsmSetValueNameFn setNameFn)857 virtual void getAsmResultNames(Operation *op, 858 OpAsmSetValueNameFn setNameFn) const {} 859 860 /// Get a special name to use when printing the entry block arguments of the 861 /// region contained by an operation in this dialect. getAsmBlockArgumentNames(Block * block,OpAsmSetValueNameFn setNameFn)862 virtual void getAsmBlockArgumentNames(Block *block, 863 OpAsmSetValueNameFn setNameFn) const {} 864 }; 865 } // end namespace mlir 866 867 //===--------------------------------------------------------------------===// 868 // Operation OpAsm interface. 869 //===--------------------------------------------------------------------===// 870 871 /// The OpAsmOpInterface, see OpAsmInterface.td for more details. 872 #include "mlir/IR/OpAsmInterface.h.inc" 873 874 #endif 875