• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //===- OpImplementation.h - Classes for implementing Op types ---*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This classes used by the implementation details of Op types.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef MLIR_IR_OPIMPLEMENTATION_H
14 #define MLIR_IR_OPIMPLEMENTATION_H
15 
16 #include "mlir/IR/BuiltinTypes.h"
17 #include "mlir/IR/DialectInterface.h"
18 #include "mlir/IR/OpDefinition.h"
19 #include "llvm/ADT/Twine.h"
20 #include "llvm/Support/SMLoc.h"
21 #include "llvm/Support/raw_ostream.h"
22 
23 namespace mlir {
24 
25 class Builder;
26 
27 //===----------------------------------------------------------------------===//
28 // OpAsmPrinter
29 //===----------------------------------------------------------------------===//
30 
31 /// This is a pure-virtual base class that exposes the asmprinter hooks
32 /// necessary to implement a custom print() method.
33 class OpAsmPrinter {
34 public:
OpAsmPrinter()35   OpAsmPrinter() {}
36   virtual ~OpAsmPrinter();
37   virtual raw_ostream &getStream() const = 0;
38 
39   /// Print implementations for various things an operation contains.
40   virtual void printOperand(Value value) = 0;
41   virtual void printOperand(Value value, raw_ostream &os) = 0;
42 
43   /// Print a comma separated list of operands.
44   template <typename ContainerType>
printOperands(const ContainerType & container)45   void printOperands(const ContainerType &container) {
46     printOperands(container.begin(), container.end());
47   }
48 
49   /// Print a comma separated list of operands.
50   template <typename IteratorType>
printOperands(IteratorType it,IteratorType end)51   void printOperands(IteratorType it, IteratorType end) {
52     if (it == end)
53       return;
54     printOperand(*it);
55     for (++it; it != end; ++it) {
56       getStream() << ", ";
57       printOperand(*it);
58     }
59   }
60   virtual void printType(Type type) = 0;
61   virtual void printAttribute(Attribute attr) = 0;
62 
63   /// Print the given attribute without its type. The corresponding parser must
64   /// provide a valid type for the attribute.
65   virtual void printAttributeWithoutType(Attribute attr) = 0;
66 
67   /// Print the given successor.
68   virtual void printSuccessor(Block *successor) = 0;
69 
70   /// Print the successor and its operands.
71   virtual void printSuccessorAndUseList(Block *successor,
72                                         ValueRange succOperands) = 0;
73 
74   /// If the specified operation has attributes, print out an attribute
75   /// dictionary with their values.  elidedAttrs allows the client to ignore
76   /// specific well known attributes, commonly used if the attribute value is
77   /// printed some other way (like as a fixed operand).
78   virtual void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
79                                      ArrayRef<StringRef> elidedAttrs = {}) = 0;
80 
81   /// If the specified operation has attributes, print out an attribute
82   /// dictionary prefixed with 'attributes'.
83   virtual void
84   printOptionalAttrDictWithKeyword(ArrayRef<NamedAttribute> attrs,
85                                    ArrayRef<StringRef> elidedAttrs = {}) = 0;
86 
87   /// Print the entire operation with the default generic assembly form.
88   virtual void printGenericOp(Operation *op) = 0;
89 
90   /// Prints a region.
91   virtual void printRegion(Region &blocks, bool printEntryBlockArgs = true,
92                            bool printBlockTerminators = true) = 0;
93 
94   /// Renumber the arguments for the specified region to the same names as the
95   /// SSA values in namesToUse.  This may only be used for IsolatedFromAbove
96   /// operations.  If any entry in namesToUse is null, the corresponding
97   /// argument name is left alone.
98   virtual void shadowRegionArgs(Region &region, ValueRange namesToUse) = 0;
99 
100   /// Prints an affine map of SSA ids, where SSA id names are used in place
101   /// of dims/symbols.
102   /// Operand values must come from single-result sources, and be valid
103   /// dimensions/symbol identifiers according to mlir::isValidDim/Symbol.
104   virtual void printAffineMapOfSSAIds(AffineMapAttr mapAttr,
105                                       ValueRange operands) = 0;
106 
107   /// Print an optional arrow followed by a type list.
108   template <typename TypeRange>
printOptionalArrowTypeList(TypeRange && types)109   void printOptionalArrowTypeList(TypeRange &&types) {
110     if (types.begin() != types.end())
111       printArrowTypeList(types);
112   }
113   template <typename TypeRange>
printArrowTypeList(TypeRange && types)114   void printArrowTypeList(TypeRange &&types) {
115     auto &os = getStream() << " -> ";
116 
117     bool wrapped = !llvm::hasSingleElement(types) ||
118                    (*types.begin()).template isa<FunctionType>();
119     if (wrapped)
120       os << '(';
121     llvm::interleaveComma(types, *this);
122     if (wrapped)
123       os << ')';
124   }
125 
126   /// Print the complete type of an operation in functional form.
printFunctionalType(Operation * op)127   void printFunctionalType(Operation *op) {
128     printFunctionalType(op->getOperandTypes(), op->getResultTypes());
129   }
130   /// Print the two given type ranges in a functional form.
131   template <typename InputRangeT, typename ResultRangeT>
printFunctionalType(InputRangeT && inputs,ResultRangeT && results)132   void printFunctionalType(InputRangeT &&inputs, ResultRangeT &&results) {
133     auto &os = getStream();
134     os << "(";
135     llvm::interleaveComma(inputs, *this);
136     os << ")";
137     printArrowTypeList(results);
138   }
139 
140   /// Print the given string as a symbol reference, i.e. a form representable by
141   /// a SymbolRefAttr. A symbol reference is represented as a string prefixed
142   /// with '@'. The reference is surrounded with ""'s and escaped if it has any
143   /// special or non-printable characters in it.
144   virtual void printSymbolName(StringRef symbolRef) = 0;
145 
146 private:
147   OpAsmPrinter(const OpAsmPrinter &) = delete;
148   void operator=(const OpAsmPrinter &) = delete;
149 };
150 
151 // Make the implementations convenient to use.
152 inline OpAsmPrinter &operator<<(OpAsmPrinter &p, Value value) {
153   p.printOperand(value);
154   return p;
155 }
156 
157 template <typename T,
158           typename std::enable_if<std::is_convertible<T &, ValueRange>::value &&
159                                       !std::is_convertible<T &, Value &>::value,
160                                   T>::type * = nullptr>
161 inline OpAsmPrinter &operator<<(OpAsmPrinter &p, const T &values) {
162   p.printOperands(values);
163   return p;
164 }
165 
166 inline OpAsmPrinter &operator<<(OpAsmPrinter &p, Type type) {
167   p.printType(type);
168   return p;
169 }
170 
171 inline OpAsmPrinter &operator<<(OpAsmPrinter &p, Attribute attr) {
172   p.printAttribute(attr);
173   return p;
174 }
175 
176 // Support printing anything that isn't convertible to one of the above types,
177 // even if it isn't exactly one of them.  For example, we want to print
178 // FunctionType with the Type version above, not have it match this.
179 template <typename T, typename std::enable_if<
180                           !std::is_convertible<T &, Value &>::value &&
181                               !std::is_convertible<T &, Type &>::value &&
182                               !std::is_convertible<T &, Attribute &>::value &&
183                               !std::is_convertible<T &, ValueRange>::value &&
184                               !llvm::is_one_of<T, bool>::value,
185                           T>::type * = nullptr>
186 inline OpAsmPrinter &operator<<(OpAsmPrinter &p, const T &other) {
187   p.getStream() << other;
188   return p;
189 }
190 
191 inline OpAsmPrinter &operator<<(OpAsmPrinter &p, bool value) {
192   return p << (value ? StringRef("true") : "false");
193 }
194 
195 inline OpAsmPrinter &operator<<(OpAsmPrinter &p, Block *value) {
196   p.printSuccessor(value);
197   return p;
198 }
199 
200 template <typename ValueRangeT>
201 inline OpAsmPrinter &operator<<(OpAsmPrinter &p,
202                                 const ValueTypeRange<ValueRangeT> &types) {
203   llvm::interleaveComma(types, p);
204   return p;
205 }
206 inline OpAsmPrinter &operator<<(OpAsmPrinter &p, const TypeRange &types) {
207   llvm::interleaveComma(types, p);
208   return p;
209 }
210 inline OpAsmPrinter &operator<<(OpAsmPrinter &p, ArrayRef<Type> types) {
211   llvm::interleaveComma(types, p);
212   return p;
213 }
214 
215 //===----------------------------------------------------------------------===//
216 // OpAsmParser
217 //===----------------------------------------------------------------------===//
218 
219 /// The OpAsmParser has methods for interacting with the asm parser: parsing
220 /// things from it, emitting errors etc.  It has an intentionally high-level API
221 /// that is designed to reduce/constrain syntax innovation in individual
222 /// operations.
223 ///
224 /// For example, consider an op like this:
225 ///
226 ///    %x = load %p[%1, %2] : memref<...>
227 ///
228 /// The "%x = load" tokens are already parsed and therefore invisible to the
229 /// custom op parser.  This can be supported by calling `parseOperandList` to
230 /// parse the %p, then calling `parseOperandList` with a `SquareDelimiter` to
231 /// parse the indices, then calling `parseColonTypeList` to parse the result
232 /// type.
233 ///
234 class OpAsmParser {
235 public:
236   virtual ~OpAsmParser();
237 
238   /// Emit a diagnostic at the specified location and return failure.
239   virtual InFlightDiagnostic emitError(llvm::SMLoc loc,
240                                        const Twine &message = {}) = 0;
241 
242   /// Return a builder which provides useful access to MLIRContext, global
243   /// objects like types and attributes.
244   virtual Builder &getBuilder() const = 0;
245 
246   /// Get the location of the next token and store it into the argument.  This
247   /// always succeeds.
248   virtual llvm::SMLoc getCurrentLocation() = 0;
getCurrentLocation(llvm::SMLoc * loc)249   ParseResult getCurrentLocation(llvm::SMLoc *loc) {
250     *loc = getCurrentLocation();
251     return success();
252   }
253 
254   /// Return the name of the specified result in the specified syntax, as well
255   /// as the sub-element in the name.  It returns an empty string and ~0U for
256   /// invalid result numbers.  For example, in this operation:
257   ///
258   ///  %x, %y:2, %z = foo.op
259   ///
260   ///    getResultName(0) == {"x", 0 }
261   ///    getResultName(1) == {"y", 0 }
262   ///    getResultName(2) == {"y", 1 }
263   ///    getResultName(3) == {"z", 0 }
264   ///    getResultName(4) == {"", ~0U }
265   virtual std::pair<StringRef, unsigned>
266   getResultName(unsigned resultNo) const = 0;
267 
268   /// Return the number of declared SSA results.  This returns 4 for the foo.op
269   /// example in the comment for `getResultName`.
270   virtual size_t getNumResults() const = 0;
271 
272   /// Return the location of the original name token.
273   virtual llvm::SMLoc getNameLoc() const = 0;
274 
275   // These methods emit an error and return failure or success. This allows
276   // these to be chained together into a linear sequence of || expressions in
277   // many cases.
278 
279   /// Parse an operation in its generic form.
280   /// The parsed operation is parsed in the current context and inserted in the
281   /// provided block and insertion point. The results produced by this operation
282   /// aren't mapped to any named value in the parser. Returns nullptr on
283   /// failure.
284   virtual Operation *parseGenericOperation(Block *insertBlock,
285                                            Block::iterator insertPt) = 0;
286 
287   //===--------------------------------------------------------------------===//
288   // Token Parsing
289   //===--------------------------------------------------------------------===//
290 
291   /// Parse a '->' token.
292   virtual ParseResult parseArrow() = 0;
293 
294   /// Parse a '->' token if present
295   virtual ParseResult parseOptionalArrow() = 0;
296 
297   /// Parse a `{` token.
298   virtual ParseResult parseLBrace() = 0;
299 
300   /// Parse a `{` token if present.
301   virtual ParseResult parseOptionalLBrace() = 0;
302 
303   /// Parse a `}` token.
304   virtual ParseResult parseRBrace() = 0;
305 
306   /// Parse a `}` token if present.
307   virtual ParseResult parseOptionalRBrace() = 0;
308 
309   /// Parse a `:` token.
310   virtual ParseResult parseColon() = 0;
311 
312   /// Parse a `:` token if present.
313   virtual ParseResult parseOptionalColon() = 0;
314 
315   /// Parse a `,` token.
316   virtual ParseResult parseComma() = 0;
317 
318   /// Parse a `,` token if present.
319   virtual ParseResult parseOptionalComma() = 0;
320 
321   /// Parse a `=` token.
322   virtual ParseResult parseEqual() = 0;
323 
324   /// Parse a `=` token if present.
325   virtual ParseResult parseOptionalEqual() = 0;
326 
327   /// Parse a '<' token.
328   virtual ParseResult parseLess() = 0;
329 
330   /// Parse a '<' token if present.
331   virtual ParseResult parseOptionalLess() = 0;
332 
333   /// Parse a '>' token.
334   virtual ParseResult parseGreater() = 0;
335 
336   /// Parse a '>' token if present.
337   virtual ParseResult parseOptionalGreater() = 0;
338 
339   /// Parse a '?' token.
340   virtual ParseResult parseQuestion() = 0;
341 
342   /// Parse a '?' token if present.
343   virtual ParseResult parseOptionalQuestion() = 0;
344 
345   /// Parse a '+' token.
346   virtual ParseResult parsePlus() = 0;
347 
348   /// Parse a '+' token if present.
349   virtual ParseResult parseOptionalPlus() = 0;
350 
351   /// Parse a '*' token.
352   virtual ParseResult parseStar() = 0;
353 
354   /// Parse a '*' token if present.
355   virtual ParseResult parseOptionalStar() = 0;
356 
357   /// Parse a given keyword.
358   ParseResult parseKeyword(StringRef keyword, const Twine &msg = "") {
359     auto loc = getCurrentLocation();
360     if (parseOptionalKeyword(keyword))
361       return emitError(loc, "expected '") << keyword << "'" << msg;
362     return success();
363   }
364 
365   /// Parse a keyword into 'keyword'.
parseKeyword(StringRef * keyword)366   ParseResult parseKeyword(StringRef *keyword) {
367     auto loc = getCurrentLocation();
368     if (parseOptionalKeyword(keyword))
369       return emitError(loc, "expected valid keyword");
370     return success();
371   }
372 
373   /// Parse the given keyword if present.
374   virtual ParseResult parseOptionalKeyword(StringRef keyword) = 0;
375 
376   /// Parse a keyword, if present, into 'keyword'.
377   virtual ParseResult parseOptionalKeyword(StringRef *keyword) = 0;
378 
379   /// Parse a keyword, if present, and if one of the 'allowedValues',
380   /// into 'keyword'
381   virtual ParseResult
382   parseOptionalKeyword(StringRef *keyword,
383                        ArrayRef<StringRef> allowedValues) = 0;
384 
385   /// Parse a `(` token.
386   virtual ParseResult parseLParen() = 0;
387 
388   /// Parse a `(` token if present.
389   virtual ParseResult parseOptionalLParen() = 0;
390 
391   /// Parse a `)` token.
392   virtual ParseResult parseRParen() = 0;
393 
394   /// Parse a `)` token if present.
395   virtual ParseResult parseOptionalRParen() = 0;
396 
397   /// Parse a `[` token.
398   virtual ParseResult parseLSquare() = 0;
399 
400   /// Parse a `[` token if present.
401   virtual ParseResult parseOptionalLSquare() = 0;
402 
403   /// Parse a `]` token.
404   virtual ParseResult parseRSquare() = 0;
405 
406   /// Parse a `]` token if present.
407   virtual ParseResult parseOptionalRSquare() = 0;
408 
409   /// Parse a `...` token if present;
410   virtual ParseResult parseOptionalEllipsis() = 0;
411 
412   //===--------------------------------------------------------------------===//
413   // Attribute Parsing
414   //===--------------------------------------------------------------------===//
415 
416   /// Parse an arbitrary attribute of a given type and return it in result.
417   virtual ParseResult parseAttribute(Attribute &result, Type type = {}) = 0;
418 
419   /// Parse an attribute of a specific kind and type.
420   template <typename AttrType>
421   ParseResult parseAttribute(AttrType &result, Type type = {}) {
422     llvm::SMLoc loc = getCurrentLocation();
423 
424     // Parse any kind of attribute.
425     Attribute attr;
426     if (parseAttribute(attr, type))
427       return failure();
428 
429     // Check for the right kind of attribute.
430     if (!(result = attr.dyn_cast<AttrType>()))
431       return emitError(loc, "invalid kind of attribute specified");
432 
433     return success();
434   }
435 
436   /// Parse an arbitrary attribute and return it in result.  This also adds the
437   /// attribute to the specified attribute list with the specified name.
parseAttribute(Attribute & result,StringRef attrName,NamedAttrList & attrs)438   ParseResult parseAttribute(Attribute &result, StringRef attrName,
439                              NamedAttrList &attrs) {
440     return parseAttribute(result, Type(), attrName, attrs);
441   }
442 
443   /// Parse an attribute of a specific kind and type.
444   template <typename AttrType>
parseAttribute(AttrType & result,StringRef attrName,NamedAttrList & attrs)445   ParseResult parseAttribute(AttrType &result, StringRef attrName,
446                              NamedAttrList &attrs) {
447     return parseAttribute(result, Type(), attrName, attrs);
448   }
449 
450   /// Parse an optional attribute.
451   virtual OptionalParseResult parseOptionalAttribute(Attribute &result,
452                                                      Type type,
453                                                      StringRef attrName,
454                                                      NamedAttrList &attrs) = 0;
455   template <typename AttrT>
parseOptionalAttribute(AttrT & result,StringRef attrName,NamedAttrList & attrs)456   OptionalParseResult parseOptionalAttribute(AttrT &result, StringRef attrName,
457                                              NamedAttrList &attrs) {
458     return parseOptionalAttribute(result, Type(), attrName, attrs);
459   }
460 
461   /// Specialized variants of `parseOptionalAttribute` that remove potential
462   /// ambiguities in syntax.
463   virtual OptionalParseResult parseOptionalAttribute(ArrayAttr &result,
464                                                      Type type,
465                                                      StringRef attrName,
466                                                      NamedAttrList &attrs) = 0;
467   virtual OptionalParseResult parseOptionalAttribute(StringAttr &result,
468                                                      Type type,
469                                                      StringRef attrName,
470                                                      NamedAttrList &attrs) = 0;
471 
472   /// Parse an arbitrary attribute of a given type and return it in result. This
473   /// also adds the attribute to the specified attribute list with the specified
474   /// name.
475   template <typename AttrType>
parseAttribute(AttrType & result,Type type,StringRef attrName,NamedAttrList & attrs)476   ParseResult parseAttribute(AttrType &result, Type type, StringRef attrName,
477                              NamedAttrList &attrs) {
478     llvm::SMLoc loc = getCurrentLocation();
479 
480     // Parse any kind of attribute.
481     Attribute attr;
482     if (parseAttribute(attr, type))
483       return failure();
484 
485     // Check for the right kind of attribute.
486     result = attr.dyn_cast<AttrType>();
487     if (!result)
488       return emitError(loc, "invalid kind of attribute specified");
489 
490     attrs.append(attrName, result);
491     return success();
492   }
493 
494   /// Parse a named dictionary into 'result' if it is present.
495   virtual ParseResult parseOptionalAttrDict(NamedAttrList &result) = 0;
496 
497   /// Parse a named dictionary into 'result' if the `attributes` keyword is
498   /// present.
499   virtual ParseResult
500   parseOptionalAttrDictWithKeyword(NamedAttrList &result) = 0;
501 
502   /// Parse an affine map instance into 'map'.
503   virtual ParseResult parseAffineMap(AffineMap &map) = 0;
504 
505   /// Parse an integer set instance into 'set'.
506   virtual ParseResult printIntegerSet(IntegerSet &set) = 0;
507 
508   //===--------------------------------------------------------------------===//
509   // Identifier Parsing
510   //===--------------------------------------------------------------------===//
511 
512   /// Parse an @-identifier and store it (without the '@' symbol) in a string
513   /// attribute named 'attrName'.
parseSymbolName(StringAttr & result,StringRef attrName,NamedAttrList & attrs)514   ParseResult parseSymbolName(StringAttr &result, StringRef attrName,
515                               NamedAttrList &attrs) {
516     if (failed(parseOptionalSymbolName(result, attrName, attrs)))
517       return emitError(getCurrentLocation())
518              << "expected valid '@'-identifier for symbol name";
519     return success();
520   }
521 
522   /// Parse an optional @-identifier and store it (without the '@' symbol) in a
523   /// string attribute named 'attrName'.
524   virtual ParseResult parseOptionalSymbolName(StringAttr &result,
525                                               StringRef attrName,
526                                               NamedAttrList &attrs) = 0;
527 
528   //===--------------------------------------------------------------------===//
529   // Operand Parsing
530   //===--------------------------------------------------------------------===//
531 
532   /// This is the representation of an operand reference.
533   struct OperandType {
534     llvm::SMLoc location; // Location of the token.
535     StringRef name;       // Value name, e.g. %42 or %abc
536     unsigned number;      // Number, e.g. 12 for an operand like %xyz#12
537   };
538 
539   /// Parse a single operand.
540   virtual ParseResult parseOperand(OperandType &result) = 0;
541 
542   /// Parse a single operand if present.
543   virtual OptionalParseResult parseOptionalOperand(OperandType &result) = 0;
544 
545   /// These are the supported delimiters around operand lists and region
546   /// argument lists, used by parseOperandList and parseRegionArgumentList.
547   enum class Delimiter {
548     /// Zero or more operands with no delimiters.
549     None,
550     /// Parens surrounding zero or more operands.
551     Paren,
552     /// Square brackets surrounding zero or more operands.
553     Square,
554     /// Parens supporting zero or more operands, or nothing.
555     OptionalParen,
556     /// Square brackets supporting zero or more ops, or nothing.
557     OptionalSquare,
558   };
559 
560   /// Parse zero or more SSA comma-separated operand references with a specified
561   /// surrounding delimiter, and an optional required operand count.
562   virtual ParseResult
563   parseOperandList(SmallVectorImpl<OperandType> &result,
564                    int requiredOperandCount = -1,
565                    Delimiter delimiter = Delimiter::None) = 0;
parseOperandList(SmallVectorImpl<OperandType> & result,Delimiter delimiter)566   ParseResult parseOperandList(SmallVectorImpl<OperandType> &result,
567                                Delimiter delimiter) {
568     return parseOperandList(result, /*requiredOperandCount=*/-1, delimiter);
569   }
570 
571   /// Parse zero or more trailing SSA comma-separated trailing operand
572   /// references with a specified surrounding delimiter, and an optional
573   /// required operand count. A leading comma is expected before the operands.
574   virtual ParseResult
575   parseTrailingOperandList(SmallVectorImpl<OperandType> &result,
576                            int requiredOperandCount = -1,
577                            Delimiter delimiter = Delimiter::None) = 0;
parseTrailingOperandList(SmallVectorImpl<OperandType> & result,Delimiter delimiter)578   ParseResult parseTrailingOperandList(SmallVectorImpl<OperandType> &result,
579                                        Delimiter delimiter) {
580     return parseTrailingOperandList(result, /*requiredOperandCount=*/-1,
581                                     delimiter);
582   }
583 
584   /// Resolve an operand to an SSA value, emitting an error on failure.
585   virtual ParseResult resolveOperand(const OperandType &operand, Type type,
586                                      SmallVectorImpl<Value> &result) = 0;
587 
588   /// Resolve a list of operands to SSA values, emitting an error on failure, or
589   /// appending the results to the list on success. This method should be used
590   /// when all operands have the same type.
resolveOperands(ArrayRef<OperandType> operands,Type type,SmallVectorImpl<Value> & result)591   ParseResult resolveOperands(ArrayRef<OperandType> operands, Type type,
592                               SmallVectorImpl<Value> &result) {
593     for (auto elt : operands)
594       if (resolveOperand(elt, type, result))
595         return failure();
596     return success();
597   }
598 
599   /// Resolve a list of operands and a list of operand types to SSA values,
600   /// emitting an error and returning failure, or appending the results
601   /// to the list on success.
resolveOperands(ArrayRef<OperandType> operands,ArrayRef<Type> types,llvm::SMLoc loc,SmallVectorImpl<Value> & result)602   ParseResult resolveOperands(ArrayRef<OperandType> operands,
603                               ArrayRef<Type> types, llvm::SMLoc loc,
604                               SmallVectorImpl<Value> &result) {
605     if (operands.size() != types.size())
606       return emitError(loc)
607              << operands.size() << " operands present, but expected "
608              << types.size();
609 
610     for (unsigned i = 0, e = operands.size(); i != e; ++i)
611       if (resolveOperand(operands[i], types[i], result))
612         return failure();
613     return success();
614   }
615   template <typename Operands>
resolveOperands(Operands && operands,Type type,llvm::SMLoc loc,SmallVectorImpl<Value> & result)616   ParseResult resolveOperands(Operands &&operands, Type type, llvm::SMLoc loc,
617                               SmallVectorImpl<Value> &result) {
618     return resolveOperands(std::forward<Operands>(operands),
619                            ArrayRef<Type>(type), loc, result);
620   }
621   template <typename Operands, typename Types>
622   std::enable_if_t<!std::is_convertible<Types, Type>::value, ParseResult>
resolveOperands(Operands && operands,Types && types,llvm::SMLoc loc,SmallVectorImpl<Value> & result)623   resolveOperands(Operands &&operands, Types &&types, llvm::SMLoc loc,
624                   SmallVectorImpl<Value> &result) {
625     size_t operandSize = std::distance(operands.begin(), operands.end());
626     size_t typeSize = std::distance(types.begin(), types.end());
627     if (operandSize != typeSize)
628       return emitError(loc)
629              << operandSize << " operands present, but expected " << typeSize;
630 
631     for (auto it : llvm::zip(operands, types))
632       if (resolveOperand(std::get<0>(it), std::get<1>(it), result))
633         return failure();
634     return success();
635   }
636 
637   /// Parses an affine map attribute where dims and symbols are SSA operands.
638   /// Operand values must come from single-result sources, and be valid
639   /// dimensions/symbol identifiers according to mlir::isValidDim/Symbol.
640   virtual ParseResult
641   parseAffineMapOfSSAIds(SmallVectorImpl<OperandType> &operands, Attribute &map,
642                          StringRef attrName, NamedAttrList &attrs,
643                          Delimiter delimiter = Delimiter::Square) = 0;
644 
645   //===--------------------------------------------------------------------===//
646   // Region Parsing
647   //===--------------------------------------------------------------------===//
648 
649   /// Parses a region. Any parsed blocks are appended to 'region' and must be
650   /// moved to the op regions after the op is created. The first block of the
651   /// region takes 'arguments' of types 'argTypes'. If 'enableNameShadowing' is
652   /// set to true, the argument names are allowed to shadow the names of other
653   /// existing SSA values defined above the region scope. 'enableNameShadowing'
654   /// can only be set to true for regions attached to operations that are
655   /// 'IsolatedFromAbove.
656   virtual ParseResult parseRegion(Region &region,
657                                   ArrayRef<OperandType> arguments = {},
658                                   ArrayRef<Type> argTypes = {},
659                                   bool enableNameShadowing = false) = 0;
660 
661   /// Parses a region if present.
662   virtual OptionalParseResult
663   parseOptionalRegion(Region &region, ArrayRef<OperandType> arguments = {},
664                       ArrayRef<Type> argTypes = {},
665                       bool enableNameShadowing = false) = 0;
666 
667   /// Parses a region if present. If the region is present, a new region is
668   /// allocated and placed in `region`. If no region is present or on failure,
669   /// `region` remains untouched.
670   virtual OptionalParseResult parseOptionalRegion(
671       std::unique_ptr<Region> &region, ArrayRef<OperandType> arguments = {},
672       ArrayRef<Type> argTypes = {}, bool enableNameShadowing = false) = 0;
673 
674   /// Parse a region argument, this argument is resolved when calling
675   /// 'parseRegion'.
676   virtual ParseResult parseRegionArgument(OperandType &argument) = 0;
677 
678   /// Parse zero or more region arguments with a specified surrounding
679   /// delimiter, and an optional required argument count. Region arguments
680   /// define new values; so this also checks if values with the same names have
681   /// not been defined yet.
682   virtual ParseResult
683   parseRegionArgumentList(SmallVectorImpl<OperandType> &result,
684                           int requiredOperandCount = -1,
685                           Delimiter delimiter = Delimiter::None) = 0;
686   virtual ParseResult
parseRegionArgumentList(SmallVectorImpl<OperandType> & result,Delimiter delimiter)687   parseRegionArgumentList(SmallVectorImpl<OperandType> &result,
688                           Delimiter delimiter) {
689     return parseRegionArgumentList(result, /*requiredOperandCount=*/-1,
690                                    delimiter);
691   }
692 
693   /// Parse a region argument if present.
694   virtual ParseResult parseOptionalRegionArgument(OperandType &argument) = 0;
695 
696   //===--------------------------------------------------------------------===//
697   // Successor Parsing
698   //===--------------------------------------------------------------------===//
699 
700   /// Parse a single operation successor.
701   virtual ParseResult parseSuccessor(Block *&dest) = 0;
702 
703   /// Parse an optional operation successor.
704   virtual OptionalParseResult parseOptionalSuccessor(Block *&dest) = 0;
705 
706   /// Parse a single operation successor and its operand list.
707   virtual ParseResult
708   parseSuccessorAndUseList(Block *&dest, SmallVectorImpl<Value> &operands) = 0;
709 
710   //===--------------------------------------------------------------------===//
711   // Type Parsing
712   //===--------------------------------------------------------------------===//
713 
714   /// Parse a type.
715   virtual ParseResult parseType(Type &result) = 0;
716 
717   /// Parse an optional type.
718   virtual OptionalParseResult parseOptionalType(Type &result) = 0;
719 
720   /// Parse a type of a specific type.
721   template <typename TypeT>
parseType(TypeT & result)722   ParseResult parseType(TypeT &result) {
723     llvm::SMLoc loc = getCurrentLocation();
724 
725     // Parse any kind of type.
726     Type type;
727     if (parseType(type))
728       return failure();
729 
730     // Check for the right kind of attribute.
731     result = type.dyn_cast<TypeT>();
732     if (!result)
733       return emitError(loc, "invalid kind of type specified");
734 
735     return success();
736   }
737 
738   /// Parse a type list.
parseTypeList(SmallVectorImpl<Type> & result)739   ParseResult parseTypeList(SmallVectorImpl<Type> &result) {
740     do {
741       Type type;
742       if (parseType(type))
743         return failure();
744       result.push_back(type);
745     } while (succeeded(parseOptionalComma()));
746     return success();
747   }
748 
749   /// Parse an arrow followed by a type list.
750   virtual ParseResult parseArrowTypeList(SmallVectorImpl<Type> &result) = 0;
751 
752   /// Parse an optional arrow followed by a type list.
753   virtual ParseResult
754   parseOptionalArrowTypeList(SmallVectorImpl<Type> &result) = 0;
755 
756   /// Parse a colon followed by a type.
757   virtual ParseResult parseColonType(Type &result) = 0;
758 
759   /// Parse a colon followed by a type of a specific kind, e.g. a FunctionType.
760   template <typename TypeType>
parseColonType(TypeType & result)761   ParseResult parseColonType(TypeType &result) {
762     llvm::SMLoc loc = getCurrentLocation();
763 
764     // Parse any kind of type.
765     Type type;
766     if (parseColonType(type))
767       return failure();
768 
769     // Check for the right kind of attribute.
770     result = type.dyn_cast<TypeType>();
771     if (!result)
772       return emitError(loc, "invalid kind of type specified");
773 
774     return success();
775   }
776 
777   /// Parse a colon followed by a type list, which must have at least one type.
778   virtual ParseResult parseColonTypeList(SmallVectorImpl<Type> &result) = 0;
779 
780   /// Parse an optional colon followed by a type list, which if present must
781   /// have at least one type.
782   virtual ParseResult
783   parseOptionalColonTypeList(SmallVectorImpl<Type> &result) = 0;
784 
785   /// Parse a list of assignments of the form
786   ///   (%x1 = %y1, %x2 = %y2, ...)
parseAssignmentList(SmallVectorImpl<OperandType> & lhs,SmallVectorImpl<OperandType> & rhs)787   ParseResult parseAssignmentList(SmallVectorImpl<OperandType> &lhs,
788                                   SmallVectorImpl<OperandType> &rhs) {
789     OptionalParseResult result = parseOptionalAssignmentList(lhs, rhs);
790     if (!result.hasValue())
791       return emitError(getCurrentLocation(), "expected '('");
792     return result.getValue();
793   }
794 
795   virtual OptionalParseResult
796   parseOptionalAssignmentList(SmallVectorImpl<OperandType> &lhs,
797                               SmallVectorImpl<OperandType> &rhs) = 0;
798 
799   /// Parse a keyword followed by a type.
parseKeywordType(const char * keyword,Type & result)800   ParseResult parseKeywordType(const char *keyword, Type &result) {
801     return failure(parseKeyword(keyword) || parseType(result));
802   }
803 
804   /// Add the specified type to the end of the specified type list and return
805   /// success.  This is a helper designed to allow parse methods to be simple
806   /// and chain through || operators.
addTypeToList(Type type,SmallVectorImpl<Type> & result)807   ParseResult addTypeToList(Type type, SmallVectorImpl<Type> &result) {
808     result.push_back(type);
809     return success();
810   }
811 
812   /// Add the specified types to the end of the specified type list and return
813   /// success.  This is a helper designed to allow parse methods to be simple
814   /// and chain through || operators.
addTypesToList(ArrayRef<Type> types,SmallVectorImpl<Type> & result)815   ParseResult addTypesToList(ArrayRef<Type> types,
816                              SmallVectorImpl<Type> &result) {
817     result.append(types.begin(), types.end());
818     return success();
819   }
820 
821 private:
822   /// Parse either an operand list or a region argument list depending on
823   /// whether isOperandList is true.
824   ParseResult parseOperandOrRegionArgList(SmallVectorImpl<OperandType> &result,
825                                           bool isOperandList,
826                                           int requiredOperandCount,
827                                           Delimiter delimiter);
828 };
829 
830 //===--------------------------------------------------------------------===//
831 // Dialect OpAsm interface.
832 //===--------------------------------------------------------------------===//
833 
834 /// A functor used to set the name of the start of a result group of an
835 /// operation. See 'getAsmResultNames' below for more details.
836 using OpAsmSetValueNameFn = function_ref<void(Value, StringRef)>;
837 
838 class OpAsmDialectInterface
839     : public DialectInterface::Base<OpAsmDialectInterface> {
840 public:
OpAsmDialectInterface(Dialect * dialect)841   OpAsmDialectInterface(Dialect *dialect) : Base(dialect) {}
842 
843   /// Hooks for getting an alias identifier alias for a given symbol, that is
844   /// not necessarily a part of this dialect. The identifier is used in place of
845   /// the symbol when printing textual IR. These aliases must not contain `.` or
846   /// end with a numeric digit([0-9]+). Returns success if an alias was
847   /// provided, failure otherwise.
getAlias(Attribute attr,raw_ostream & os)848   virtual LogicalResult getAlias(Attribute attr, raw_ostream &os) const {
849     return failure();
850   }
getAlias(Type type,raw_ostream & os)851   virtual LogicalResult getAlias(Type type, raw_ostream &os) const {
852     return failure();
853   }
854 
855   /// Get a special name to use when printing the given operation. See
856   /// OpAsmInterface.td#getAsmResultNames for usage details and documentation.
getAsmResultNames(Operation * op,OpAsmSetValueNameFn setNameFn)857   virtual void getAsmResultNames(Operation *op,
858                                  OpAsmSetValueNameFn setNameFn) const {}
859 
860   /// Get a special name to use when printing the entry block arguments of the
861   /// region contained by an operation in this dialect.
getAsmBlockArgumentNames(Block * block,OpAsmSetValueNameFn setNameFn)862   virtual void getAsmBlockArgumentNames(Block *block,
863                                         OpAsmSetValueNameFn setNameFn) const {}
864 };
865 } // end namespace mlir
866 
867 //===--------------------------------------------------------------------===//
868 // Operation OpAsm interface.
869 //===--------------------------------------------------------------------===//
870 
871 /// The OpAsmOpInterface, see OpAsmInterface.td for more details.
872 #include "mlir/IR/OpAsmInterface.h.inc"
873 
874 #endif
875