• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //===- OpFormatGen.cpp - MLIR operation asm format generator --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "OpFormatGen.h"
10 #include "mlir/Support/LogicalResult.h"
11 #include "mlir/TableGen/Format.h"
12 #include "mlir/TableGen/GenInfo.h"
13 #include "mlir/TableGen/Interfaces.h"
14 #include "mlir/TableGen/OpClass.h"
15 #include "mlir/TableGen/OpTrait.h"
16 #include "mlir/TableGen/Operator.h"
17 #include "llvm/ADT/MapVector.h"
18 #include "llvm/ADT/Sequence.h"
19 #include "llvm/ADT/SetVector.h"
20 #include "llvm/ADT/SmallBitVector.h"
21 #include "llvm/ADT/StringExtras.h"
22 #include "llvm/ADT/TypeSwitch.h"
23 #include "llvm/Support/CommandLine.h"
24 #include "llvm/Support/Signals.h"
25 #include "llvm/TableGen/Error.h"
26 #include "llvm/TableGen/Record.h"
27 
28 #define DEBUG_TYPE "mlir-tblgen-opformatgen"
29 
30 using namespace mlir;
31 using namespace mlir::tblgen;
32 
33 static llvm::cl::opt<bool> formatErrorIsFatal(
34     "asmformat-error-is-fatal",
35     llvm::cl::desc("Emit a fatal error if format parsing fails"),
36     llvm::cl::init(true));
37 
38 //===----------------------------------------------------------------------===//
39 // Element
40 //===----------------------------------------------------------------------===//
41 
42 namespace {
43 /// This class represents a single format element.
44 class Element {
45 public:
46   enum class Kind {
47     /// This element is a directive.
48     AttrDictDirective,
49     CustomDirective,
50     FunctionalTypeDirective,
51     OperandsDirective,
52     RegionsDirective,
53     ResultsDirective,
54     SuccessorsDirective,
55     TypeDirective,
56     TypeRefDirective,
57 
58     /// This element is a literal.
59     Literal,
60 
61     /// This element prints or omits a space. It is ignored by the parser.
62     Space,
63 
64     /// This element is an variable value.
65     AttributeVariable,
66     OperandVariable,
67     RegionVariable,
68     ResultVariable,
69     SuccessorVariable,
70 
71     /// This element is an optional element.
72     Optional,
73   };
Element(Kind kind)74   Element(Kind kind) : kind(kind) {}
75   virtual ~Element() = default;
76 
77   /// Return the kind of this element.
getKind() const78   Kind getKind() const { return kind; }
79 
80 private:
81   /// The kind of this element.
82   Kind kind;
83 };
84 } // namespace
85 
86 //===----------------------------------------------------------------------===//
87 // VariableElement
88 
89 namespace {
90 /// This class represents an instance of an variable element. A variable refers
91 /// to something registered on the operation itself, e.g. an argument, result,
92 /// etc.
93 template <typename VarT, Element::Kind kindVal>
94 class VariableElement : public Element {
95 public:
VariableElement(const VarT * var)96   VariableElement(const VarT *var) : Element(kindVal), var(var) {}
classof(const Element * element)97   static bool classof(const Element *element) {
98     return element->getKind() == kindVal;
99   }
getVar()100   const VarT *getVar() { return var; }
101 
102 protected:
103   const VarT *var;
104 };
105 
106 /// This class represents a variable that refers to an attribute argument.
107 struct AttributeVariable
108     : public VariableElement<NamedAttribute, Element::Kind::AttributeVariable> {
109   using VariableElement<NamedAttribute,
110                         Element::Kind::AttributeVariable>::VariableElement;
111 
112   /// Return the constant builder call for the type of this attribute, or None
113   /// if it doesn't have one.
getTypeBuilder__anon81548fc80211::AttributeVariable114   Optional<StringRef> getTypeBuilder() const {
115     Optional<Type> attrType = var->attr.getValueType();
116     return attrType ? attrType->getBuilderCall() : llvm::None;
117   }
118 
119   /// Return if this attribute refers to a UnitAttr.
isUnitAttr__anon81548fc80211::AttributeVariable120   bool isUnitAttr() const {
121     return var->attr.getBaseAttr().getAttrDefName() == "UnitAttr";
122   }
123 };
124 
125 /// This class represents a variable that refers to an operand argument.
126 using OperandVariable =
127     VariableElement<NamedTypeConstraint, Element::Kind::OperandVariable>;
128 
129 /// This class represents a variable that refers to a region.
130 using RegionVariable =
131     VariableElement<NamedRegion, Element::Kind::RegionVariable>;
132 
133 /// This class represents a variable that refers to a result.
134 using ResultVariable =
135     VariableElement<NamedTypeConstraint, Element::Kind::ResultVariable>;
136 
137 /// This class represents a variable that refers to a successor.
138 using SuccessorVariable =
139     VariableElement<NamedSuccessor, Element::Kind::SuccessorVariable>;
140 } // end anonymous namespace
141 
142 //===----------------------------------------------------------------------===//
143 // DirectiveElement
144 
145 namespace {
146 /// This class implements single kind directives.
147 template <Element::Kind type>
148 class DirectiveElement : public Element {
149 public:
DirectiveElement()150   DirectiveElement() : Element(type){};
classof(const Element * ele)151   static bool classof(const Element *ele) { return ele->getKind() == type; }
152 };
153 /// This class represents the `operands` directive. This directive represents
154 /// all of the operands of an operation.
155 using OperandsDirective = DirectiveElement<Element::Kind::OperandsDirective>;
156 
157 /// This class represents the `regions` directive. This directive represents
158 /// all of the regions of an operation.
159 using RegionsDirective = DirectiveElement<Element::Kind::ResultsDirective>;
160 
161 /// This class represents the `results` directive. This directive represents
162 /// all of the results of an operation.
163 using ResultsDirective = DirectiveElement<Element::Kind::ResultsDirective>;
164 
165 /// This class represents the `successors` directive. This directive represents
166 /// all of the successors of an operation.
167 using SuccessorsDirective =
168     DirectiveElement<Element::Kind::SuccessorsDirective>;
169 
170 /// This class represents the `attr-dict` directive. This directive represents
171 /// the attribute dictionary of the operation.
172 class AttrDictDirective
173     : public DirectiveElement<Element::Kind::AttrDictDirective> {
174 public:
AttrDictDirective(bool withKeyword)175   explicit AttrDictDirective(bool withKeyword) : withKeyword(withKeyword) {}
isWithKeyword() const176   bool isWithKeyword() const { return withKeyword; }
177 
178 private:
179   /// If the dictionary should be printed with the 'attributes' keyword.
180   bool withKeyword;
181 };
182 
183 /// This class represents a custom format directive that is implemented by the
184 /// user in C++.
185 class CustomDirective : public Element {
186 public:
CustomDirective(StringRef name,std::vector<std::unique_ptr<Element>> && arguments)187   CustomDirective(StringRef name,
188                   std::vector<std::unique_ptr<Element>> &&arguments)
189       : Element{Kind::CustomDirective}, name(name),
190         arguments(std::move(arguments)) {}
191 
classof(const Element * element)192   static bool classof(const Element *element) {
193     return element->getKind() == Kind::CustomDirective;
194   }
195 
196   /// Return the name of this optional element.
getName() const197   StringRef getName() const { return name; }
198 
199   /// Return the arguments to the custom directive.
getArguments() const200   auto getArguments() const { return llvm::make_pointee_range(arguments); }
201 
202 private:
203   /// The user provided name of the directive.
204   StringRef name;
205 
206   /// The arguments to the custom directive.
207   std::vector<std::unique_ptr<Element>> arguments;
208 };
209 
210 /// This class represents the `functional-type` directive. This directive takes
211 /// two arguments and formats them, respectively, as the inputs and results of a
212 /// FunctionType.
213 class FunctionalTypeDirective
214     : public DirectiveElement<Element::Kind::FunctionalTypeDirective> {
215 public:
FunctionalTypeDirective(std::unique_ptr<Element> inputs,std::unique_ptr<Element> results)216   FunctionalTypeDirective(std::unique_ptr<Element> inputs,
217                           std::unique_ptr<Element> results)
218       : inputs(std::move(inputs)), results(std::move(results)) {}
getInputs() const219   Element *getInputs() const { return inputs.get(); }
getResults() const220   Element *getResults() const { return results.get(); }
221 
222 private:
223   /// The input and result arguments.
224   std::unique_ptr<Element> inputs, results;
225 };
226 
227 /// This class represents the `type` directive.
228 class TypeDirective : public DirectiveElement<Element::Kind::TypeDirective> {
229 public:
TypeDirective(std::unique_ptr<Element> arg)230   TypeDirective(std::unique_ptr<Element> arg) : operand(std::move(arg)) {}
getOperand() const231   Element *getOperand() const { return operand.get(); }
232 
233 private:
234   /// The operand that is used to format the directive.
235   std::unique_ptr<Element> operand;
236 };
237 
238 /// This class represents the `type_ref` directive.
239 class TypeRefDirective
240     : public DirectiveElement<Element::Kind::TypeRefDirective> {
241 public:
TypeRefDirective(std::unique_ptr<Element> arg)242   TypeRefDirective(std::unique_ptr<Element> arg) : operand(std::move(arg)) {}
getOperand() const243   Element *getOperand() const { return operand.get(); }
244 
245 private:
246   /// The operand that is used to format the directive.
247   std::unique_ptr<Element> operand;
248 };
249 } // namespace
250 
251 //===----------------------------------------------------------------------===//
252 // LiteralElement
253 
254 namespace {
255 /// This class represents an instance of a literal element.
256 class LiteralElement : public Element {
257 public:
LiteralElement(StringRef literal)258   LiteralElement(StringRef literal)
259       : Element{Kind::Literal}, literal(literal) {}
classof(const Element * element)260   static bool classof(const Element *element) {
261     return element->getKind() == Kind::Literal;
262   }
263 
264   /// Return the literal for this element.
getLiteral() const265   StringRef getLiteral() const { return literal; }
266 
267   /// Returns true if the given string is a valid literal.
268   static bool isValidLiteral(StringRef value);
269 
270 private:
271   /// The spelling of the literal for this element.
272   StringRef literal;
273 };
274 } // end anonymous namespace
275 
isValidLiteral(StringRef value)276 bool LiteralElement::isValidLiteral(StringRef value) {
277   if (value.empty())
278     return false;
279   char front = value.front();
280 
281   // If there is only one character, this must either be punctuation or a
282   // single character bare identifier.
283   if (value.size() == 1)
284     return isalpha(front) || StringRef("_:,=<>()[]{}?+*").contains(front);
285 
286   // Check the punctuation that are larger than a single character.
287   if (value == "->")
288     return true;
289 
290   // Otherwise, this must be an identifier.
291   if (!isalpha(front) && front != '_')
292     return false;
293   return llvm::all_of(value.drop_front(), [](char c) {
294     return isalnum(c) || c == '_' || c == '$' || c == '.';
295   });
296 }
297 
298 //===----------------------------------------------------------------------===//
299 // SpaceElement
300 
301 namespace {
302 /// This class represents an instance of a space element. It's a literal that
303 /// prints or omits printing a space. It is ignored by the parser.
304 class SpaceElement : public Element {
305 public:
SpaceElement(bool value)306   SpaceElement(bool value) : Element{Kind::Space}, value(value) {}
classof(const Element * element)307   static bool classof(const Element *element) {
308     return element->getKind() == Kind::Space;
309   }
310 
311   /// Returns true if this element should print as a space. Otherwise, the
312   /// element should omit printing a space between the surrounding elements.
getValue() const313   bool getValue() const { return value; }
314 
315 private:
316   bool value;
317 };
318 } // end anonymous namespace
319 
320 //===----------------------------------------------------------------------===//
321 // OptionalElement
322 
323 namespace {
324 /// This class represents a group of elements that are optionally emitted based
325 /// upon an optional variable of the operation.
326 class OptionalElement : public Element {
327 public:
OptionalElement(std::vector<std::unique_ptr<Element>> && elements,unsigned anchor,unsigned parseStart)328   OptionalElement(std::vector<std::unique_ptr<Element>> &&elements,
329                   unsigned anchor, unsigned parseStart)
330       : Element{Kind::Optional}, elements(std::move(elements)), anchor(anchor),
331         parseStart(parseStart) {}
classof(const Element * element)332   static bool classof(const Element *element) {
333     return element->getKind() == Kind::Optional;
334   }
335 
336   /// Return the nested elements of this grouping.
getElements() const337   auto getElements() const { return llvm::make_pointee_range(elements); }
338 
339   /// Return the anchor of this optional group.
getAnchor() const340   Element *getAnchor() const { return elements[anchor].get(); }
341 
342   /// Return the index of the first element that needs to be parsed.
getParseStart() const343   unsigned getParseStart() const { return parseStart; }
344 
345 private:
346   /// The child elements of this optional.
347   std::vector<std::unique_ptr<Element>> elements;
348   /// The index of the element that acts as the anchor for the optional group.
349   unsigned anchor;
350   /// The index of the first element that is parsed (is not a SpaceElement).
351   unsigned parseStart;
352 };
353 } // end anonymous namespace
354 
355 //===----------------------------------------------------------------------===//
356 // OperationFormat
357 //===----------------------------------------------------------------------===//
358 
359 namespace {
360 
361 using ConstArgument =
362     llvm::PointerUnion<const NamedAttribute *, const NamedTypeConstraint *>;
363 
364 struct OperationFormat {
365   /// This class represents a specific resolver for an operand or result type.
366   class TypeResolution {
367   public:
368     TypeResolution() = default;
369 
370     /// Get the index into the buildable types for this type, or None.
getBuilderIdx() const371     Optional<int> getBuilderIdx() const { return builderIdx; }
setBuilderIdx(int idx)372     void setBuilderIdx(int idx) { builderIdx = idx; }
373 
374     /// Get the variable this type is resolved to, or nullptr.
getVariable() const375     const NamedTypeConstraint *getVariable() const {
376       return resolver.dyn_cast<const NamedTypeConstraint *>();
377     }
378     /// Get the attribute this type is resolved to, or nullptr.
getAttribute() const379     const NamedAttribute *getAttribute() const {
380       return resolver.dyn_cast<const NamedAttribute *>();
381     }
382     /// Get the transformer for the type of the variable, or None.
getVarTransformer() const383     Optional<StringRef> getVarTransformer() const {
384       return variableTransformer;
385     }
setResolver(ConstArgument arg,Optional<StringRef> transformer)386     void setResolver(ConstArgument arg, Optional<StringRef> transformer) {
387       resolver = arg;
388       variableTransformer = transformer;
389       assert(getVariable() || getAttribute());
390     }
391 
392   private:
393     /// If the type is resolved with a buildable type, this is the index into
394     /// 'buildableTypes' in the parent format.
395     Optional<int> builderIdx;
396     /// If the type is resolved based upon another operand or result, this is
397     /// the variable or the attribute that this type is resolved to.
398     ConstArgument resolver;
399     /// If the type is resolved based upon another operand or result, this is
400     /// a transformer to apply to the variable when resolving.
401     Optional<StringRef> variableTransformer;
402   };
403 
OperationFormat__anon81548fc80811::OperationFormat404   OperationFormat(const Operator &op)
405       : allOperands(false), allOperandTypes(false), allResultTypes(false) {
406     operandTypes.resize(op.getNumOperands(), TypeResolution());
407     resultTypes.resize(op.getNumResults(), TypeResolution());
408 
409     hasImplicitTermTrait =
410         llvm::any_of(op.getTraits(), [](const OpTrait &trait) {
411           return trait.getDef().isSubClassOf("SingleBlockImplicitTerminator");
412         });
413   }
414 
415   /// Generate the operation parser from this format.
416   void genParser(Operator &op, OpClass &opClass);
417   /// Generate the parser code for a specific format element.
418   void genElementParser(Element *element, OpMethodBody &body,
419                         FmtContext &attrTypeCtx);
420   /// Generate the c++ to resolve the types of operands and results during
421   /// parsing.
422   void genParserTypeResolution(Operator &op, OpMethodBody &body);
423   /// Generate the c++ to resolve regions during parsing.
424   void genParserRegionResolution(Operator &op, OpMethodBody &body);
425   /// Generate the c++ to resolve successors during parsing.
426   void genParserSuccessorResolution(Operator &op, OpMethodBody &body);
427   /// Generate the c++ to handling variadic segment size traits.
428   void genParserVariadicSegmentResolution(Operator &op, OpMethodBody &body);
429 
430   /// Generate the operation printer from this format.
431   void genPrinter(Operator &op, OpClass &opClass);
432 
433   /// Generate the printer code for a specific format element.
434   void genElementPrinter(Element *element, OpMethodBody &body, Operator &op,
435                          bool &shouldEmitSpace, bool &lastWasPunctuation);
436 
437   /// The various elements in this format.
438   std::vector<std::unique_ptr<Element>> elements;
439 
440   /// A flag indicating if all operand/result types were seen. If the format
441   /// contains these, it can not contain individual type resolvers.
442   bool allOperands, allOperandTypes, allResultTypes;
443 
444   /// A flag indicating if this operation has the SingleBlockImplicitTerminator
445   /// trait.
446   bool hasImplicitTermTrait;
447 
448   /// A map of buildable types to indices.
449   llvm::MapVector<StringRef, int, llvm::StringMap<int>> buildableTypes;
450 
451   /// The index of the buildable type, if valid, for every operand and result.
452   std::vector<TypeResolution> operandTypes, resultTypes;
453 
454   /// The set of attributes explicitly used within the format.
455   SmallVector<const NamedAttribute *, 8> usedAttributes;
456 };
457 } // end anonymous namespace
458 
459 //===----------------------------------------------------------------------===//
460 // Parser Gen
461 
462 /// Returns true if we can format the given attribute as an EnumAttr in the
463 /// parser format.
canFormatEnumAttr(const NamedAttribute * attr)464 static bool canFormatEnumAttr(const NamedAttribute *attr) {
465   Attribute baseAttr = attr->attr.getBaseAttr();
466   const EnumAttr *enumAttr = dyn_cast<EnumAttr>(&baseAttr);
467   if (!enumAttr)
468     return false;
469 
470   // The attribute must have a valid underlying type and a constant builder.
471   return !enumAttr->getUnderlyingType().empty() &&
472          !enumAttr->getConstBuilderTemplate().empty();
473 }
474 
475 /// Returns if we should format the given attribute as an SymbolNameAttr.
shouldFormatSymbolNameAttr(const NamedAttribute * attr)476 static bool shouldFormatSymbolNameAttr(const NamedAttribute *attr) {
477   return attr->attr.getBaseAttr().getAttrDefName() == "SymbolNameAttr";
478 }
479 
480 /// The code snippet used to generate a parser call for an attribute.
481 ///
482 /// {0}: The name of the attribute.
483 /// {1}: The type for the attribute.
484 const char *const attrParserCode = R"(
485   if (parser.parseAttribute({0}Attr{1}, "{0}", result.attributes))
486     return ::mlir::failure();
487 )";
488 const char *const optionalAttrParserCode = R"(
489   {
490     ::mlir::OptionalParseResult parseResult =
491       parser.parseOptionalAttribute({0}Attr{1}, "{0}", result.attributes);
492     if (parseResult.hasValue() && failed(*parseResult))
493       return ::mlir::failure();
494   }
495 )";
496 
497 /// The code snippet used to generate a parser call for a symbol name attribute.
498 ///
499 /// {0}: The name of the attribute.
500 const char *const symbolNameAttrParserCode = R"(
501   if (parser.parseSymbolName({0}Attr, "{0}", result.attributes))
502     return ::mlir::failure();
503 )";
504 const char *const optionalSymbolNameAttrParserCode = R"(
505   // Parsing an optional symbol name doesn't fail, so no need to check the
506   // result.
507   (void)parser.parseOptionalSymbolName({0}Attr, "{0}", result.attributes);
508 )";
509 
510 /// The code snippet used to generate a parser call for an enum attribute.
511 ///
512 /// {0}: The name of the attribute.
513 /// {1}: The c++ namespace for the enum symbolize functions.
514 /// {2}: The function to symbolize a string of the enum.
515 /// {3}: The constant builder call to create an attribute of the enum type.
516 const char *const enumAttrParserCode = R"(
517   {
518     ::mlir::StringAttr attrVal;
519     ::mlir::NamedAttrList attrStorage;
520     auto loc = parser.getCurrentLocation();
521     if (parser.parseAttribute(attrVal, parser.getBuilder().getNoneType(),
522                               "{0}", attrStorage))
523       return ::mlir::failure();
524 
525     auto attrOptional = {1}::{2}(attrVal.getValue());
526     if (!attrOptional)
527       return parser.emitError(loc, "invalid ")
528              << "{0} attribute specification: " << attrVal;
529 
530     {0}Attr = {3};
531     result.addAttribute("{0}", {0}Attr);
532   }
533 )";
534 const char *const optionalEnumAttrParserCode = R"(
535   {
536     ::mlir::StringAttr attrVal;
537     ::mlir::NamedAttrList attrStorage;
538     auto loc = parser.getCurrentLocation();
539 
540     ::mlir::OptionalParseResult parseResult =
541       parser.parseOptionalAttribute(attrVal, parser.getBuilder().getNoneType(),
542                                     "{0}", attrStorage);
543     if (parseResult.hasValue()) {
544       if (failed(*parseResult))
545         return ::mlir::failure();
546 
547       auto attrOptional = {1}::{2}(attrVal.getValue());
548       if (!attrOptional)
549         return parser.emitError(loc, "invalid ")
550                << "{0} attribute specification: " << attrVal;
551 
552       {0}Attr = {3};
553       result.addAttribute("{0}", {0}Attr);
554     }
555   }
556 )";
557 
558 /// The code snippet used to generate a parser call for an operand.
559 ///
560 /// {0}: The name of the operand.
561 const char *const variadicOperandParserCode = R"(
562   {0}OperandsLoc = parser.getCurrentLocation();
563   if (parser.parseOperandList({0}Operands))
564     return ::mlir::failure();
565 )";
566 const char *const optionalOperandParserCode = R"(
567   {
568     {0}OperandsLoc = parser.getCurrentLocation();
569     ::mlir::OpAsmParser::OperandType operand;
570     ::mlir::OptionalParseResult parseResult =
571                                     parser.parseOptionalOperand(operand);
572     if (parseResult.hasValue()) {
573       if (failed(*parseResult))
574         return ::mlir::failure();
575       {0}Operands.push_back(operand);
576     }
577   }
578 )";
579 const char *const operandParserCode = R"(
580   {0}OperandsLoc = parser.getCurrentLocation();
581   if (parser.parseOperand({0}RawOperands[0]))
582     return ::mlir::failure();
583 )";
584 
585 /// The code snippet used to generate a parser call for a type list.
586 ///
587 /// {0}: The name for the type list.
588 const char *const variadicTypeParserCode = R"(
589   if (parser.parseTypeList({0}Types))
590     return ::mlir::failure();
591 )";
592 const char *const optionalTypeParserCode = R"(
593   {
594     ::mlir::Type optionalType;
595     ::mlir::OptionalParseResult parseResult =
596                                     parser.parseOptionalType(optionalType);
597     if (parseResult.hasValue()) {
598       if (failed(*parseResult))
599         return ::mlir::failure();
600       {0}Types.push_back(optionalType);
601     }
602   }
603 )";
604 const char *const typeParserCode = R"(
605   if (parser.parseType({0}RawTypes[0]))
606     return ::mlir::failure();
607 )";
608 
609 /// The code snippet used to generate a parser call for a functional type.
610 ///
611 /// {0}: The name for the input type list.
612 /// {1}: The name for the result type list.
613 const char *const functionalTypeParserCode = R"(
614   ::mlir::FunctionType {0}__{1}_functionType;
615   if (parser.parseType({0}__{1}_functionType))
616     return ::mlir::failure();
617   {0}Types = {0}__{1}_functionType.getInputs();
618   {1}Types = {0}__{1}_functionType.getResults();
619 )";
620 
621 /// The code snippet used to generate a parser call for a region list.
622 ///
623 /// {0}: The name for the region list.
624 const char *regionListParserCode = R"(
625   {
626     std::unique_ptr<::mlir::Region> region;
627     auto firstRegionResult = parser.parseOptionalRegion(region);
628     if (firstRegionResult.hasValue()) {
629       if (failed(*firstRegionResult))
630         return ::mlir::failure();
631       {0}Regions.emplace_back(std::move(region));
632 
633       // Parse any trailing regions.
634       while (succeeded(parser.parseOptionalComma())) {
635         region = std::make_unique<::mlir::Region>();
636         if (parser.parseRegion(*region))
637           return ::mlir::failure();
638         {0}Regions.emplace_back(std::move(region));
639       }
640     }
641   }
642 )";
643 
644 /// The code snippet used to ensure a list of regions have terminators.
645 ///
646 /// {0}: The name of the region list.
647 const char *regionListEnsureTerminatorParserCode = R"(
648   for (auto &region : {0}Regions)
649     ensureTerminator(*region, parser.getBuilder(), result.location);
650 )";
651 
652 /// The code snippet used to generate a parser call for an optional region.
653 ///
654 /// {0}: The name of the region.
655 const char *optionalRegionParserCode = R"(
656   {
657      auto parseResult = parser.parseOptionalRegion(*{0}Region);
658      if (parseResult.hasValue() && failed(*parseResult))
659        return ::mlir::failure();
660   }
661 )";
662 
663 /// The code snippet used to generate a parser call for a region.
664 ///
665 /// {0}: The name of the region.
666 const char *regionParserCode = R"(
667   if (parser.parseRegion(*{0}Region))
668     return ::mlir::failure();
669 )";
670 
671 /// The code snippet used to ensure a region has a terminator.
672 ///
673 /// {0}: The name of the region.
674 const char *regionEnsureTerminatorParserCode = R"(
675   ensureTerminator(*{0}Region, parser.getBuilder(), result.location);
676 )";
677 
678 /// The code snippet used to generate a parser call for a successor list.
679 ///
680 /// {0}: The name for the successor list.
681 const char *successorListParserCode = R"(
682   {
683     ::mlir::Block *succ;
684     auto firstSucc = parser.parseOptionalSuccessor(succ);
685     if (firstSucc.hasValue()) {
686       if (failed(*firstSucc))
687         return ::mlir::failure();
688       {0}Successors.emplace_back(succ);
689 
690       // Parse any trailing successors.
691       while (succeeded(parser.parseOptionalComma())) {
692         if (parser.parseSuccessor(succ))
693           return ::mlir::failure();
694         {0}Successors.emplace_back(succ);
695       }
696     }
697   }
698 )";
699 
700 /// The code snippet used to generate a parser call for a successor.
701 ///
702 /// {0}: The name of the successor.
703 const char *successorParserCode = R"(
704   if (parser.parseSuccessor({0}Successor))
705     return ::mlir::failure();
706 )";
707 
708 namespace {
709 /// The type of length for a given parse argument.
710 enum class ArgumentLengthKind {
711   /// The argument is variadic, and may contain 0->N elements.
712   Variadic,
713   /// The argument is optional, and may contain 0 or 1 elements.
714   Optional,
715   /// The argument is a single element, i.e. always represents 1 element.
716   Single
717 };
718 } // end anonymous namespace
719 
720 /// Get the length kind for the given constraint.
721 static ArgumentLengthKind
getArgumentLengthKind(const NamedTypeConstraint * var)722 getArgumentLengthKind(const NamedTypeConstraint *var) {
723   if (var->isOptional())
724     return ArgumentLengthKind::Optional;
725   if (var->isVariadic())
726     return ArgumentLengthKind::Variadic;
727   return ArgumentLengthKind::Single;
728 }
729 
730 /// Get the name used for the type list for the given type directive operand.
731 /// 'lengthKind' to the corresponding kind for the given argument.
getTypeListName(Element * arg,ArgumentLengthKind & lengthKind)732 static StringRef getTypeListName(Element *arg, ArgumentLengthKind &lengthKind) {
733   if (auto *operand = dyn_cast<OperandVariable>(arg)) {
734     lengthKind = getArgumentLengthKind(operand->getVar());
735     return operand->getVar()->name;
736   }
737   if (auto *result = dyn_cast<ResultVariable>(arg)) {
738     lengthKind = getArgumentLengthKind(result->getVar());
739     return result->getVar()->name;
740   }
741   lengthKind = ArgumentLengthKind::Variadic;
742   if (isa<OperandsDirective>(arg))
743     return "allOperand";
744   if (isa<ResultsDirective>(arg))
745     return "allResult";
746   llvm_unreachable("unknown 'type' directive argument");
747 }
748 
749 /// Generate the parser for a literal value.
genLiteralParser(StringRef value,OpMethodBody & body)750 static void genLiteralParser(StringRef value, OpMethodBody &body) {
751   // Handle the case of a keyword/identifier.
752   if (value.front() == '_' || isalpha(value.front())) {
753     body << "Keyword(\"" << value << "\")";
754     return;
755   }
756   body << (StringRef)StringSwitch<StringRef>(value)
757               .Case("->", "Arrow()")
758               .Case(":", "Colon()")
759               .Case(",", "Comma()")
760               .Case("=", "Equal()")
761               .Case("<", "Less()")
762               .Case(">", "Greater()")
763               .Case("{", "LBrace()")
764               .Case("}", "RBrace()")
765               .Case("(", "LParen()")
766               .Case(")", "RParen()")
767               .Case("[", "LSquare()")
768               .Case("]", "RSquare()")
769               .Case("?", "Question()")
770               .Case("+", "Plus()")
771               .Case("*", "Star()");
772 }
773 
774 /// Generate the storage code required for parsing the given element.
genElementParserStorage(Element * element,OpMethodBody & body)775 static void genElementParserStorage(Element *element, OpMethodBody &body) {
776   if (auto *optional = dyn_cast<OptionalElement>(element)) {
777     auto elements = optional->getElements();
778 
779     // If the anchor is a unit attribute, it won't be parsed directly so elide
780     // it.
781     auto *anchor = dyn_cast<AttributeVariable>(optional->getAnchor());
782     Element *elidedAnchorElement = nullptr;
783     if (anchor && anchor != &*elements.begin() && anchor->isUnitAttr())
784       elidedAnchorElement = anchor;
785     for (auto &childElement : elements)
786       if (&childElement != elidedAnchorElement)
787         genElementParserStorage(&childElement, body);
788 
789   } else if (auto *custom = dyn_cast<CustomDirective>(element)) {
790     for (auto &paramElement : custom->getArguments())
791       genElementParserStorage(&paramElement, body);
792 
793   } else if (isa<OperandsDirective>(element)) {
794     body << "  ::mlir::SmallVector<::mlir::OpAsmParser::OperandType, 4> "
795             "allOperands;\n";
796 
797   } else if (isa<RegionsDirective>(element)) {
798     body << "  ::llvm::SmallVector<std::unique_ptr<::mlir::Region>, 2> "
799             "fullRegions;\n";
800 
801   } else if (isa<SuccessorsDirective>(element)) {
802     body << "  ::llvm::SmallVector<::mlir::Block *, 2> fullSuccessors;\n";
803 
804   } else if (auto *attr = dyn_cast<AttributeVariable>(element)) {
805     const NamedAttribute *var = attr->getVar();
806     body << llvm::formatv("  {0} {1}Attr;\n", var->attr.getStorageType(),
807                           var->name);
808 
809   } else if (auto *operand = dyn_cast<OperandVariable>(element)) {
810     StringRef name = operand->getVar()->name;
811     if (operand->getVar()->isVariableLength()) {
812       body << "  ::mlir::SmallVector<::mlir::OpAsmParser::OperandType, 4> "
813            << name << "Operands;\n";
814     } else {
815       body << "  ::mlir::OpAsmParser::OperandType " << name
816            << "RawOperands[1];\n"
817            << "  ::llvm::ArrayRef<::mlir::OpAsmParser::OperandType> " << name
818            << "Operands(" << name << "RawOperands);";
819     }
820     body << llvm::formatv("  ::llvm::SMLoc {0}OperandsLoc;\n"
821                           "  (void){0}OperandsLoc;\n",
822                           name);
823 
824   } else if (auto *region = dyn_cast<RegionVariable>(element)) {
825     StringRef name = region->getVar()->name;
826     if (region->getVar()->isVariadic()) {
827       body << llvm::formatv(
828           "  ::llvm::SmallVector<std::unique_ptr<::mlir::Region>, 2> "
829           "{0}Regions;\n",
830           name);
831     } else {
832       body << llvm::formatv("  std::unique_ptr<::mlir::Region> {0}Region = "
833                             "std::make_unique<::mlir::Region>();\n",
834                             name);
835     }
836 
837   } else if (auto *successor = dyn_cast<SuccessorVariable>(element)) {
838     StringRef name = successor->getVar()->name;
839     if (successor->getVar()->isVariadic()) {
840       body << llvm::formatv("  ::llvm::SmallVector<::mlir::Block *, 2> "
841                             "{0}Successors;\n",
842                             name);
843     } else {
844       body << llvm::formatv("  ::mlir::Block *{0}Successor = nullptr;\n", name);
845     }
846 
847   } else if (auto *dir = dyn_cast<TypeDirective>(element)) {
848     ArgumentLengthKind lengthKind;
849     StringRef name = getTypeListName(dir->getOperand(), lengthKind);
850     if (lengthKind != ArgumentLengthKind::Single)
851       body << "  ::mlir::SmallVector<::mlir::Type, 1> " << name << "Types;\n";
852     else
853       body << llvm::formatv("  ::mlir::Type {0}RawTypes[1];\n", name)
854            << llvm::formatv(
855                   "  ::llvm::ArrayRef<::mlir::Type> {0}Types({0}RawTypes);\n",
856                   name);
857   } else if (auto *dir = dyn_cast<TypeRefDirective>(element)) {
858     ArgumentLengthKind lengthKind;
859     StringRef name = getTypeListName(dir->getOperand(), lengthKind);
860     // Refer to the previously encountered TypeDirective for name.
861     // Take a `const ::mlir::SmallVector<::mlir::Type, 1> &` in the declaration
862     // to properly track the types that will be parsed and pushed later on.
863     if (lengthKind != ArgumentLengthKind::Single)
864       body << "  const ::mlir::SmallVector<::mlir::Type, 1> &" << name
865            << "TypesRef(" << name << "Types);\n";
866     else
867       body << llvm::formatv(
868           "  ::llvm::ArrayRef<::mlir::Type> {0}RawTypesRef({0}RawTypes);\n",
869           name);
870   } else if (auto *dir = dyn_cast<FunctionalTypeDirective>(element)) {
871     ArgumentLengthKind ignored;
872     body << "  ::llvm::ArrayRef<::mlir::Type> "
873          << getTypeListName(dir->getInputs(), ignored) << "Types;\n";
874     body << "  ::llvm::ArrayRef<::mlir::Type> "
875          << getTypeListName(dir->getResults(), ignored) << "Types;\n";
876   }
877 }
878 
879 /// Generate the parser for a parameter to a custom directive.
genCustomParameterParser(Element & param,OpMethodBody & body)880 static void genCustomParameterParser(Element &param, OpMethodBody &body) {
881   body << ", ";
882   if (auto *attr = dyn_cast<AttributeVariable>(&param)) {
883     body << attr->getVar()->name << "Attr";
884   } else if (isa<AttrDictDirective>(&param)) {
885     body << "result.attributes";
886   } else if (auto *operand = dyn_cast<OperandVariable>(&param)) {
887     StringRef name = operand->getVar()->name;
888     ArgumentLengthKind lengthKind = getArgumentLengthKind(operand->getVar());
889     if (lengthKind == ArgumentLengthKind::Variadic)
890       body << llvm::formatv("{0}Operands", name);
891     else if (lengthKind == ArgumentLengthKind::Optional)
892       body << llvm::formatv("{0}Operand", name);
893     else
894       body << formatv("{0}RawOperands[0]", name);
895 
896   } else if (auto *region = dyn_cast<RegionVariable>(&param)) {
897     StringRef name = region->getVar()->name;
898     if (region->getVar()->isVariadic())
899       body << llvm::formatv("{0}Regions", name);
900     else
901       body << llvm::formatv("*{0}Region", name);
902 
903   } else if (auto *successor = dyn_cast<SuccessorVariable>(&param)) {
904     StringRef name = successor->getVar()->name;
905     if (successor->getVar()->isVariadic())
906       body << llvm::formatv("{0}Successors", name);
907     else
908       body << llvm::formatv("{0}Successor", name);
909 
910   } else if (auto *dir = dyn_cast<TypeRefDirective>(&param)) {
911     ArgumentLengthKind lengthKind;
912     StringRef listName = getTypeListName(dir->getOperand(), lengthKind);
913     if (lengthKind == ArgumentLengthKind::Variadic)
914       body << llvm::formatv("{0}TypesRef", listName);
915     else if (lengthKind == ArgumentLengthKind::Optional)
916       body << llvm::formatv("{0}TypeRef", listName);
917     else
918       body << formatv("{0}RawTypesRef[0]", listName);
919   } else if (auto *dir = dyn_cast<TypeDirective>(&param)) {
920     ArgumentLengthKind lengthKind;
921     StringRef listName = getTypeListName(dir->getOperand(), lengthKind);
922     if (lengthKind == ArgumentLengthKind::Variadic)
923       body << llvm::formatv("{0}Types", listName);
924     else if (lengthKind == ArgumentLengthKind::Optional)
925       body << llvm::formatv("{0}Type", listName);
926     else
927       body << formatv("{0}RawTypes[0]", listName);
928   } else {
929     llvm_unreachable("unknown custom directive parameter");
930   }
931 }
932 
933 /// Generate the parser for a custom directive.
genCustomDirectiveParser(CustomDirective * dir,OpMethodBody & body)934 static void genCustomDirectiveParser(CustomDirective *dir, OpMethodBody &body) {
935   body << "  {\n";
936 
937   // Preprocess the directive variables.
938   // * Add a local variable for optional operands and types. This provides a
939   //   better API to the user defined parser methods.
940   // * Set the location of operand variables.
941   for (Element &param : dir->getArguments()) {
942     if (auto *operand = dyn_cast<OperandVariable>(&param)) {
943       body << "    " << operand->getVar()->name
944            << "OperandsLoc = parser.getCurrentLocation();\n";
945       if (operand->getVar()->isOptional()) {
946         body << llvm::formatv(
947             "    llvm::Optional<::mlir::OpAsmParser::OperandType> "
948             "{0}Operand;\n",
949             operand->getVar()->name);
950       }
951     } else if (auto *dir = dyn_cast<TypeRefDirective>(&param)) {
952       // Reference to an optional which may or may not have been set.
953       // Retrieve from vector if not empty.
954       ArgumentLengthKind lengthKind;
955       StringRef listName = getTypeListName(dir->getOperand(), lengthKind);
956       if (lengthKind == ArgumentLengthKind::Optional)
957         body << llvm::formatv(
958             "    ::mlir::Type {0}TypeRef = {0}TypesRef.empty() "
959             "? Type() : {0}TypesRef[0];\n",
960             listName);
961     } else if (auto *dir = dyn_cast<TypeDirective>(&param)) {
962       ArgumentLengthKind lengthKind;
963       StringRef listName = getTypeListName(dir->getOperand(), lengthKind);
964       if (lengthKind == ArgumentLengthKind::Optional)
965         body << llvm::formatv("    ::mlir::Type {0}Type;\n", listName);
966     }
967   }
968 
969   body << "    if (parse" << dir->getName() << "(parser";
970   for (Element &param : dir->getArguments())
971     genCustomParameterParser(param, body);
972 
973   body << "))\n"
974        << "      return ::mlir::failure();\n";
975 
976   // After parsing, add handling for any of the optional constructs.
977   for (Element &param : dir->getArguments()) {
978     if (auto *attr = dyn_cast<AttributeVariable>(&param)) {
979       const NamedAttribute *var = attr->getVar();
980       if (var->attr.isOptional())
981         body << llvm::formatv("    if ({0}Attr)\n  ", var->name);
982 
983       body << llvm::formatv("    result.addAttribute(\"{0}\", {0}Attr);\n",
984                             var->name);
985     } else if (auto *operand = dyn_cast<OperandVariable>(&param)) {
986       const NamedTypeConstraint *var = operand->getVar();
987       if (!var->isOptional())
988         continue;
989       body << llvm::formatv("    if ({0}Operand.hasValue())\n"
990                             "      {0}Operands.push_back(*{0}Operand);\n",
991                             var->name);
992     } else if (isa<TypeRefDirective>(&param)) {
993       // In the `type_ref` case, do not parse a new Type that needs to be added.
994       // Just do nothing here.
995     } else if (auto *dir = dyn_cast<TypeDirective>(&param)) {
996       ArgumentLengthKind lengthKind;
997       StringRef listName = getTypeListName(dir->getOperand(), lengthKind);
998       if (lengthKind == ArgumentLengthKind::Optional) {
999         body << llvm::formatv("    if ({0}Type)\n"
1000                               "      {0}Types.push_back({0}Type);\n",
1001                               listName);
1002       }
1003     }
1004   }
1005 
1006   body << "  }\n";
1007 }
1008 
genParser(Operator & op,OpClass & opClass)1009 void OperationFormat::genParser(Operator &op, OpClass &opClass) {
1010   llvm::SmallVector<OpMethodParameter, 4> paramList;
1011   paramList.emplace_back("::mlir::OpAsmParser &", "parser");
1012   paramList.emplace_back("::mlir::OperationState &", "result");
1013 
1014   auto *method =
1015       opClass.addMethodAndPrune("::mlir::ParseResult", "parse",
1016                                 OpMethod::MP_Static, std::move(paramList));
1017   auto &body = method->body();
1018 
1019   // Generate variables to store the operands and type within the format. This
1020   // allows for referencing these variables in the presence of optional
1021   // groupings.
1022   for (auto &element : elements)
1023     genElementParserStorage(&*element, body);
1024 
1025   // A format context used when parsing attributes with buildable types.
1026   FmtContext attrTypeCtx;
1027   attrTypeCtx.withBuilder("parser.getBuilder()");
1028 
1029   // Generate parsers for each of the elements.
1030   for (auto &element : elements)
1031     genElementParser(element.get(), body, attrTypeCtx);
1032 
1033   // Generate the code to resolve the operand/result types and successors now
1034   // that they have been parsed.
1035   genParserTypeResolution(op, body);
1036   genParserRegionResolution(op, body);
1037   genParserSuccessorResolution(op, body);
1038   genParserVariadicSegmentResolution(op, body);
1039 
1040   body << "  return ::mlir::success();\n";
1041 }
1042 
genElementParser(Element * element,OpMethodBody & body,FmtContext & attrTypeCtx)1043 void OperationFormat::genElementParser(Element *element, OpMethodBody &body,
1044                                        FmtContext &attrTypeCtx) {
1045   /// Optional Group.
1046   if (auto *optional = dyn_cast<OptionalElement>(element)) {
1047     auto elements =
1048         llvm::drop_begin(optional->getElements(), optional->getParseStart());
1049 
1050     // Generate a special optional parser for the first element to gate the
1051     // parsing of the rest of the elements.
1052     Element *firstElement = &*elements.begin();
1053     if (auto *attrVar = dyn_cast<AttributeVariable>(firstElement)) {
1054       genElementParser(attrVar, body, attrTypeCtx);
1055       body << "  if (" << attrVar->getVar()->name << "Attr) {\n";
1056     } else if (auto *literal = dyn_cast<LiteralElement>(firstElement)) {
1057       body << "  if (succeeded(parser.parseOptional";
1058       genLiteralParser(literal->getLiteral(), body);
1059       body << ")) {\n";
1060     } else if (auto *opVar = dyn_cast<OperandVariable>(firstElement)) {
1061       genElementParser(opVar, body, attrTypeCtx);
1062       body << "  if (!" << opVar->getVar()->name << "Operands.empty()) {\n";
1063     } else if (auto *regionVar = dyn_cast<RegionVariable>(firstElement)) {
1064       const NamedRegion *region = regionVar->getVar();
1065       if (region->isVariadic()) {
1066         genElementParser(regionVar, body, attrTypeCtx);
1067         body << "  if (!" << region->name << "Regions.empty()) {\n";
1068       } else {
1069         body << llvm::formatv(optionalRegionParserCode, region->name);
1070         body << "  if (!" << region->name << "Region->empty()) {\n  ";
1071         if (hasImplicitTermTrait)
1072           body << llvm::formatv(regionEnsureTerminatorParserCode, region->name);
1073       }
1074     }
1075 
1076     // If the anchor is a unit attribute, we don't need to print it. When
1077     // parsing, we will add this attribute if this group is present.
1078     Element *elidedAnchorElement = nullptr;
1079     auto *anchorAttr = dyn_cast<AttributeVariable>(optional->getAnchor());
1080     if (anchorAttr && anchorAttr != firstElement && anchorAttr->isUnitAttr()) {
1081       elidedAnchorElement = anchorAttr;
1082 
1083       // Add the anchor unit attribute to the operation state.
1084       body << "    result.addAttribute(\"" << anchorAttr->getVar()->name
1085            << "\", parser.getBuilder().getUnitAttr());\n";
1086     }
1087 
1088     // Generate the rest of the elements normally.
1089     for (Element &childElement : llvm::drop_begin(elements, 1)) {
1090       if (&childElement != elidedAnchorElement)
1091         genElementParser(&childElement, body, attrTypeCtx);
1092     }
1093     body << "  }\n";
1094 
1095     /// Literals.
1096   } else if (LiteralElement *literal = dyn_cast<LiteralElement>(element)) {
1097     body << "  if (parser.parse";
1098     genLiteralParser(literal->getLiteral(), body);
1099     body << ")\n    return ::mlir::failure();\n";
1100 
1101     /// Spaces.
1102   } else if (isa<SpaceElement>(element)) {
1103     // Nothing to parse.
1104 
1105     /// Arguments.
1106   } else if (auto *attr = dyn_cast<AttributeVariable>(element)) {
1107     const NamedAttribute *var = attr->getVar();
1108 
1109     // Check to see if we can parse this as an enum attribute.
1110     if (canFormatEnumAttr(var)) {
1111       Attribute baseAttr = var->attr.getBaseAttr();
1112       const EnumAttr &enumAttr = cast<EnumAttr>(baseAttr);
1113 
1114       // Generate the code for building an attribute for this enum.
1115       std::string attrBuilderStr;
1116       {
1117         llvm::raw_string_ostream os(attrBuilderStr);
1118         os << tgfmt(enumAttr.getConstBuilderTemplate(), &attrTypeCtx,
1119                     "attrOptional.getValue()");
1120       }
1121 
1122       body << formatv(var->attr.isOptional() ? optionalEnumAttrParserCode
1123                                              : enumAttrParserCode,
1124                       var->name, enumAttr.getCppNamespace(),
1125                       enumAttr.getStringToSymbolFnName(), attrBuilderStr);
1126       return;
1127     }
1128 
1129     // Check to see if we should parse this as a symbol name attribute.
1130     if (shouldFormatSymbolNameAttr(var)) {
1131       body << formatv(var->attr.isOptional() ? optionalSymbolNameAttrParserCode
1132                                              : symbolNameAttrParserCode,
1133                       var->name);
1134       return;
1135     }
1136 
1137     // If this attribute has a buildable type, use that when parsing the
1138     // attribute.
1139     std::string attrTypeStr;
1140     if (Optional<StringRef> typeBuilder = attr->getTypeBuilder()) {
1141       llvm::raw_string_ostream os(attrTypeStr);
1142       os << ", " << tgfmt(*typeBuilder, &attrTypeCtx);
1143     }
1144 
1145     body << formatv(var->attr.isOptional() ? optionalAttrParserCode
1146                                            : attrParserCode,
1147                     var->name, attrTypeStr);
1148   } else if (auto *operand = dyn_cast<OperandVariable>(element)) {
1149     ArgumentLengthKind lengthKind = getArgumentLengthKind(operand->getVar());
1150     StringRef name = operand->getVar()->name;
1151     if (lengthKind == ArgumentLengthKind::Variadic)
1152       body << llvm::formatv(variadicOperandParserCode, name);
1153     else if (lengthKind == ArgumentLengthKind::Optional)
1154       body << llvm::formatv(optionalOperandParserCode, name);
1155     else
1156       body << formatv(operandParserCode, name);
1157 
1158   } else if (auto *region = dyn_cast<RegionVariable>(element)) {
1159     bool isVariadic = region->getVar()->isVariadic();
1160     body << llvm::formatv(isVariadic ? regionListParserCode : regionParserCode,
1161                           region->getVar()->name);
1162     if (hasImplicitTermTrait) {
1163       body << llvm::formatv(isVariadic ? regionListEnsureTerminatorParserCode
1164                                        : regionEnsureTerminatorParserCode,
1165                             region->getVar()->name);
1166     }
1167 
1168   } else if (auto *successor = dyn_cast<SuccessorVariable>(element)) {
1169     bool isVariadic = successor->getVar()->isVariadic();
1170     body << formatv(isVariadic ? successorListParserCode : successorParserCode,
1171                     successor->getVar()->name);
1172 
1173     /// Directives.
1174   } else if (auto *attrDict = dyn_cast<AttrDictDirective>(element)) {
1175     body << "  if (parser.parseOptionalAttrDict"
1176          << (attrDict->isWithKeyword() ? "WithKeyword" : "")
1177          << "(result.attributes))\n"
1178          << "    return ::mlir::failure();\n";
1179   } else if (auto *customDir = dyn_cast<CustomDirective>(element)) {
1180     genCustomDirectiveParser(customDir, body);
1181 
1182   } else if (isa<OperandsDirective>(element)) {
1183     body << "  ::llvm::SMLoc allOperandLoc = parser.getCurrentLocation();\n"
1184          << "  if (parser.parseOperandList(allOperands))\n"
1185          << "    return ::mlir::failure();\n";
1186 
1187   } else if (isa<RegionsDirective>(element)) {
1188     body << llvm::formatv(regionListParserCode, "full");
1189     if (hasImplicitTermTrait)
1190       body << llvm::formatv(regionListEnsureTerminatorParserCode, "full");
1191 
1192   } else if (isa<SuccessorsDirective>(element)) {
1193     body << llvm::formatv(successorListParserCode, "full");
1194 
1195   } else if (auto *dir = dyn_cast<TypeRefDirective>(element)) {
1196     ArgumentLengthKind lengthKind;
1197     StringRef listName = getTypeListName(dir->getOperand(), lengthKind);
1198     if (lengthKind == ArgumentLengthKind::Variadic)
1199       body << llvm::formatv(variadicTypeParserCode, listName);
1200     else if (lengthKind == ArgumentLengthKind::Optional)
1201       body << llvm::formatv(optionalTypeParserCode, listName);
1202     else
1203       body << formatv(typeParserCode, listName);
1204   } else if (auto *dir = dyn_cast<TypeDirective>(element)) {
1205     ArgumentLengthKind lengthKind;
1206     StringRef listName = getTypeListName(dir->getOperand(), lengthKind);
1207     if (lengthKind == ArgumentLengthKind::Variadic)
1208       body << llvm::formatv(variadicTypeParserCode, listName);
1209     else if (lengthKind == ArgumentLengthKind::Optional)
1210       body << llvm::formatv(optionalTypeParserCode, listName);
1211     else
1212       body << formatv(typeParserCode, listName);
1213   } else if (auto *dir = dyn_cast<FunctionalTypeDirective>(element)) {
1214     ArgumentLengthKind ignored;
1215     body << formatv(functionalTypeParserCode,
1216                     getTypeListName(dir->getInputs(), ignored),
1217                     getTypeListName(dir->getResults(), ignored));
1218   } else {
1219     llvm_unreachable("unknown format element");
1220   }
1221 }
1222 
genParserTypeResolution(Operator & op,OpMethodBody & body)1223 void OperationFormat::genParserTypeResolution(Operator &op,
1224                                               OpMethodBody &body) {
1225   // If any of type resolutions use transformed variables, make sure that the
1226   // types of those variables are resolved.
1227   SmallPtrSet<const NamedTypeConstraint *, 8> verifiedVariables;
1228   FmtContext verifierFCtx;
1229   for (TypeResolution &resolver :
1230        llvm::concat<TypeResolution>(resultTypes, operandTypes)) {
1231     Optional<StringRef> transformer = resolver.getVarTransformer();
1232     if (!transformer)
1233       continue;
1234     // Ensure that we don't verify the same variables twice.
1235     const NamedTypeConstraint *variable = resolver.getVariable();
1236     if (!variable || !verifiedVariables.insert(variable).second)
1237       continue;
1238 
1239     auto constraint = variable->constraint;
1240     body << "  for (::mlir::Type type : " << variable->name << "Types) {\n"
1241          << "    (void)type;\n"
1242          << "    if (!("
1243          << tgfmt(constraint.getConditionTemplate(),
1244                   &verifierFCtx.withSelf("type"))
1245          << ")) {\n"
1246          << formatv("      return parser.emitError(parser.getNameLoc()) << "
1247                     "\"'{0}' must be {1}, but got \" << type;\n",
1248                     variable->name, constraint.getDescription())
1249          << "    }\n"
1250          << "  }\n";
1251   }
1252 
1253   // Initialize the set of buildable types.
1254   if (!buildableTypes.empty()) {
1255     FmtContext typeBuilderCtx;
1256     typeBuilderCtx.withBuilder("parser.getBuilder()");
1257     for (auto &it : buildableTypes)
1258       body << "  ::mlir::Type odsBuildableType" << it.second << " = "
1259            << tgfmt(it.first, &typeBuilderCtx) << ";\n";
1260   }
1261 
1262   // Emit the code necessary for a type resolver.
1263   auto emitTypeResolver = [&](TypeResolution &resolver, StringRef curVar) {
1264     if (Optional<int> val = resolver.getBuilderIdx()) {
1265       body << "odsBuildableType" << *val;
1266     } else if (const NamedTypeConstraint *var = resolver.getVariable()) {
1267       if (Optional<StringRef> tform = resolver.getVarTransformer())
1268         body << tgfmt(*tform, &FmtContext().withSelf(var->name + "Types[0]"));
1269       else
1270         body << var->name << "Types";
1271     } else if (const NamedAttribute *attr = resolver.getAttribute()) {
1272       if (Optional<StringRef> tform = resolver.getVarTransformer())
1273         body << tgfmt(*tform,
1274                       &FmtContext().withSelf(attr->name + "Attr.getType()"));
1275       else
1276         body << attr->name << "Attr.getType()";
1277     } else {
1278       body << curVar << "Types";
1279     }
1280   };
1281 
1282   // Resolve each of the result types.
1283   if (allResultTypes) {
1284     body << "  result.addTypes(allResultTypes);\n";
1285   } else {
1286     for (unsigned i = 0, e = op.getNumResults(); i != e; ++i) {
1287       body << "  result.addTypes(";
1288       emitTypeResolver(resultTypes[i], op.getResultName(i));
1289       body << ");\n";
1290     }
1291   }
1292 
1293   // Early exit if there are no operands.
1294   if (op.getNumOperands() == 0)
1295     return;
1296 
1297   // Handle the case where all operand types are in one group.
1298   if (allOperandTypes) {
1299     // If we have all operands together, use the full operand list directly.
1300     if (allOperands) {
1301       body << "  if (parser.resolveOperands(allOperands, allOperandTypes, "
1302               "allOperandLoc, result.operands))\n"
1303               "    return ::mlir::failure();\n";
1304       return;
1305     }
1306 
1307     // Otherwise, use llvm::concat to merge the disjoint operand lists together.
1308     // llvm::concat does not allow the case of a single range, so guard it here.
1309     body << "  if (parser.resolveOperands(";
1310     if (op.getNumOperands() > 1) {
1311       body << "::llvm::concat<const ::mlir::OpAsmParser::OperandType>(";
1312       llvm::interleaveComma(op.getOperands(), body, [&](auto &operand) {
1313         body << operand.name << "Operands";
1314       });
1315       body << ")";
1316     } else {
1317       body << op.operand_begin()->name << "Operands";
1318     }
1319     body << ", allOperandTypes, parser.getNameLoc(), result.operands))\n"
1320          << "    return ::mlir::failure();\n";
1321     return;
1322   }
1323   // Handle the case where all of the operands were grouped together.
1324   if (allOperands) {
1325     body << "  if (parser.resolveOperands(allOperands, ";
1326 
1327     // Group all of the operand types together to perform the resolution all at
1328     // once. Use llvm::concat to perform the merge. llvm::concat does not allow
1329     // the case of a single range, so guard it here.
1330     if (op.getNumOperands() > 1) {
1331       body << "::llvm::concat<const Type>(";
1332       llvm::interleaveComma(
1333           llvm::seq<int>(0, op.getNumOperands()), body, [&](int i) {
1334             body << "::llvm::ArrayRef<::mlir::Type>(";
1335             emitTypeResolver(operandTypes[i], op.getOperand(i).name);
1336             body << ")";
1337           });
1338       body << ")";
1339     } else {
1340       emitTypeResolver(operandTypes.front(), op.getOperand(0).name);
1341     }
1342 
1343     body << ", allOperandLoc, result.operands))\n"
1344          << "    return ::mlir::failure();\n";
1345     return;
1346   }
1347 
1348   // The final case is the one where each of the operands types are resolved
1349   // separately.
1350   for (unsigned i = 0, e = op.getNumOperands(); i != e; ++i) {
1351     NamedTypeConstraint &operand = op.getOperand(i);
1352     body << "  if (parser.resolveOperands(" << operand.name << "Operands, ";
1353 
1354     // Resolve the type of this operand.
1355     TypeResolution &operandType = operandTypes[i];
1356     emitTypeResolver(operandType, operand.name);
1357 
1358     // If the type is resolved by a non-variadic variable, index into the
1359     // resolved type list. This allows for resolving the types of a variadic
1360     // operand list from a non-variadic variable.
1361     bool verifyOperandAndTypeSize = true;
1362     if (auto *resolverVar = operandType.getVariable()) {
1363       if (!resolverVar->isVariadic() && !operandType.getVarTransformer()) {
1364         body << "[0]";
1365         verifyOperandAndTypeSize = false;
1366       }
1367     } else {
1368       verifyOperandAndTypeSize = !operandType.getBuilderIdx();
1369     }
1370 
1371     // Check to see if the sizes between the types and operands must match. If
1372     // they do, provide the operand location to select the proper resolution
1373     // overload.
1374     if (verifyOperandAndTypeSize)
1375       body << ", " << operand.name << "OperandsLoc";
1376     body << ", result.operands))\n    return ::mlir::failure();\n";
1377   }
1378 }
1379 
genParserRegionResolution(Operator & op,OpMethodBody & body)1380 void OperationFormat::genParserRegionResolution(Operator &op,
1381                                                 OpMethodBody &body) {
1382   // Check for the case where all regions were parsed.
1383   bool hasAllRegions = llvm::any_of(
1384       elements, [](auto &elt) { return isa<RegionsDirective>(elt.get()); });
1385   if (hasAllRegions) {
1386     body << "  result.addRegions(fullRegions);\n";
1387     return;
1388   }
1389 
1390   // Otherwise, handle each region individually.
1391   for (const NamedRegion &region : op.getRegions()) {
1392     if (region.isVariadic())
1393       body << "  result.addRegions(" << region.name << "Regions);\n";
1394     else
1395       body << "  result.addRegion(std::move(" << region.name << "Region));\n";
1396   }
1397 }
1398 
genParserSuccessorResolution(Operator & op,OpMethodBody & body)1399 void OperationFormat::genParserSuccessorResolution(Operator &op,
1400                                                    OpMethodBody &body) {
1401   // Check for the case where all successors were parsed.
1402   bool hasAllSuccessors = llvm::any_of(
1403       elements, [](auto &elt) { return isa<SuccessorsDirective>(elt.get()); });
1404   if (hasAllSuccessors) {
1405     body << "  result.addSuccessors(fullSuccessors);\n";
1406     return;
1407   }
1408 
1409   // Otherwise, handle each successor individually.
1410   for (const NamedSuccessor &successor : op.getSuccessors()) {
1411     if (successor.isVariadic())
1412       body << "  result.addSuccessors(" << successor.name << "Successors);\n";
1413     else
1414       body << "  result.addSuccessors(" << successor.name << "Successor);\n";
1415   }
1416 }
1417 
genParserVariadicSegmentResolution(Operator & op,OpMethodBody & body)1418 void OperationFormat::genParserVariadicSegmentResolution(Operator &op,
1419                                                          OpMethodBody &body) {
1420   if (!allOperands &&
1421       op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) {
1422     body << "  result.addAttribute(\"operand_segment_sizes\", "
1423          << "parser.getBuilder().getI32VectorAttr({";
1424     auto interleaveFn = [&](const NamedTypeConstraint &operand) {
1425       // If the operand is variadic emit the parsed size.
1426       if (operand.isVariableLength())
1427         body << "static_cast<int32_t>(" << operand.name << "Operands.size())";
1428       else
1429         body << "1";
1430     };
1431     llvm::interleaveComma(op.getOperands(), body, interleaveFn);
1432     body << "}));\n";
1433   }
1434 
1435   if (!allResultTypes &&
1436       op.getTrait("::mlir::OpTrait::AttrSizedResultSegments")) {
1437     body << "  result.addAttribute(\"result_segment_sizes\", "
1438          << "parser.getBuilder().getI32VectorAttr({";
1439     auto interleaveFn = [&](const NamedTypeConstraint &result) {
1440       // If the result is variadic emit the parsed size.
1441       if (result.isVariableLength())
1442         body << "static_cast<int32_t>(" << result.name << "Types.size())";
1443       else
1444         body << "1";
1445     };
1446     llvm::interleaveComma(op.getResults(), body, interleaveFn);
1447     body << "}));\n";
1448   }
1449 }
1450 
1451 //===----------------------------------------------------------------------===//
1452 // PrinterGen
1453 
1454 /// The code snippet used to generate a printer call for a region of an
1455 // operation that has the SingleBlockImplicitTerminator trait.
1456 ///
1457 /// {0}: The name of the region.
1458 const char *regionSingleBlockImplicitTerminatorPrinterCode = R"(
1459   {
1460     bool printTerminator = true;
1461     if (auto *term = {0}.empty() ? nullptr : {0}.begin()->getTerminator()) {{
1462       printTerminator = !term->getMutableAttrDict().empty() ||
1463                         term->getNumOperands() != 0 ||
1464                         term->getNumResults() != 0;
1465     }
1466     p.printRegion({0}, /*printEntryBlockArgs=*/true,
1467                   /*printBlockTerminators=*/printTerminator);
1468   }
1469 )";
1470 
1471 /// Generate the printer for the 'attr-dict' directive.
genAttrDictPrinter(OperationFormat & fmt,Operator & op,OpMethodBody & body,bool withKeyword)1472 static void genAttrDictPrinter(OperationFormat &fmt, Operator &op,
1473                                OpMethodBody &body, bool withKeyword) {
1474   body << "  p.printOptionalAttrDict" << (withKeyword ? "WithKeyword" : "")
1475        << "(getAttrs(), /*elidedAttrs=*/{";
1476   // Elide the variadic segment size attributes if necessary.
1477   if (!fmt.allOperands &&
1478       op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments"))
1479     body << "\"operand_segment_sizes\", ";
1480   if (!fmt.allResultTypes &&
1481       op.getTrait("::mlir::OpTrait::AttrSizedResultSegments"))
1482     body << "\"result_segment_sizes\", ";
1483   llvm::interleaveComma(
1484       fmt.usedAttributes, body,
1485       [&](const NamedAttribute *attr) { body << "\"" << attr->name << "\""; });
1486   body << "});\n";
1487 }
1488 
1489 /// Generate the printer for a literal value. `shouldEmitSpace` is true if a
1490 /// space should be emitted before this element. `lastWasPunctuation` is true if
1491 /// the previous element was a punctuation literal.
genLiteralPrinter(StringRef value,OpMethodBody & body,bool & shouldEmitSpace,bool & lastWasPunctuation)1492 static void genLiteralPrinter(StringRef value, OpMethodBody &body,
1493                               bool &shouldEmitSpace, bool &lastWasPunctuation) {
1494   body << "  p";
1495 
1496   // Don't insert a space for certain punctuation.
1497   auto shouldPrintSpaceBeforeLiteral = [&] {
1498     if (value.size() != 1 && value != "->")
1499       return true;
1500     if (lastWasPunctuation)
1501       return !StringRef(">)}],").contains(value.front());
1502     return !StringRef("<>(){}[],").contains(value.front());
1503   };
1504   if (shouldEmitSpace && shouldPrintSpaceBeforeLiteral())
1505     body << " << ' '";
1506   body << " << \"" << value << "\";\n";
1507 
1508   // Insert a space after certain literals.
1509   shouldEmitSpace =
1510       value.size() != 1 || !StringRef("<({[").contains(value.front());
1511   lastWasPunctuation = !(value.front() == '_' || isalpha(value.front()));
1512 }
1513 
1514 /// Generate the printer for a space. `shouldEmitSpace` and `lastWasPunctuation`
1515 /// are set to false.
genSpacePrinter(bool value,OpMethodBody & body,bool & shouldEmitSpace,bool & lastWasPunctuation)1516 static void genSpacePrinter(bool value, OpMethodBody &body,
1517                             bool &shouldEmitSpace, bool &lastWasPunctuation) {
1518   if (value) {
1519     body << "  p << ' ';\n";
1520     lastWasPunctuation = false;
1521   }
1522   shouldEmitSpace = false;
1523 }
1524 
1525 /// Generate the printer for a custom directive.
genCustomDirectivePrinter(CustomDirective * customDir,OpMethodBody & body)1526 static void genCustomDirectivePrinter(CustomDirective *customDir,
1527                                       OpMethodBody &body) {
1528   body << "  print" << customDir->getName() << "(p, *this";
1529   for (Element &param : customDir->getArguments()) {
1530     body << ", ";
1531     if (auto *attr = dyn_cast<AttributeVariable>(&param)) {
1532       body << attr->getVar()->name << "Attr()";
1533 
1534     } else if (isa<AttrDictDirective>(&param)) {
1535       // Enforce the const-ness since getMutableAttrDict() returns a reference
1536       // into the Operations `attr` member.
1537       body << "(const "
1538               "MutableDictionaryAttr&)getOperation()->getMutableAttrDict()";
1539 
1540     } else if (auto *operand = dyn_cast<OperandVariable>(&param)) {
1541       body << operand->getVar()->name << "()";
1542 
1543     } else if (auto *region = dyn_cast<RegionVariable>(&param)) {
1544       body << region->getVar()->name << "()";
1545 
1546     } else if (auto *successor = dyn_cast<SuccessorVariable>(&param)) {
1547       body << successor->getVar()->name << "()";
1548 
1549     } else if (auto *dir = dyn_cast<TypeRefDirective>(&param)) {
1550       auto *typeOperand = dir->getOperand();
1551       auto *operand = dyn_cast<OperandVariable>(typeOperand);
1552       auto *var = operand ? operand->getVar()
1553                           : cast<ResultVariable>(typeOperand)->getVar();
1554       if (var->isVariadic())
1555         body << var->name << "().getTypes()";
1556       else if (var->isOptional())
1557         body << llvm::formatv("({0}() ? {0}().getType() : Type())", var->name);
1558       else
1559         body << var->name << "().getType()";
1560     } else if (auto *dir = dyn_cast<TypeDirective>(&param)) {
1561       auto *typeOperand = dir->getOperand();
1562       auto *operand = dyn_cast<OperandVariable>(typeOperand);
1563       auto *var = operand ? operand->getVar()
1564                           : cast<ResultVariable>(typeOperand)->getVar();
1565       if (var->isVariadic())
1566         body << var->name << "().getTypes()";
1567       else if (var->isOptional())
1568         body << llvm::formatv("({0}() ? {0}().getType() : Type())", var->name);
1569       else
1570         body << var->name << "().getType()";
1571     } else {
1572       llvm_unreachable("unknown custom directive parameter");
1573     }
1574   }
1575 
1576   body << ");\n";
1577 }
1578 
1579 /// Generate the printer for a region with the given variable name.
genRegionPrinter(const Twine & regionName,OpMethodBody & body,bool hasImplicitTermTrait)1580 static void genRegionPrinter(const Twine &regionName, OpMethodBody &body,
1581                              bool hasImplicitTermTrait) {
1582   if (hasImplicitTermTrait)
1583     body << llvm::formatv(regionSingleBlockImplicitTerminatorPrinterCode,
1584                           regionName);
1585   else
1586     body << "  p.printRegion(" << regionName << ");\n";
1587 }
genVariadicRegionPrinter(const Twine & regionListName,OpMethodBody & body,bool hasImplicitTermTrait)1588 static void genVariadicRegionPrinter(const Twine &regionListName,
1589                                      OpMethodBody &body,
1590                                      bool hasImplicitTermTrait) {
1591   body << "    llvm::interleaveComma(" << regionListName
1592        << ", p, [&](::mlir::Region &region) {\n      ";
1593   genRegionPrinter("region", body, hasImplicitTermTrait);
1594   body << "    });\n";
1595 }
1596 
1597 /// Generate the C++ for an operand to a (*-)type directive.
genTypeOperandPrinter(Element * arg,OpMethodBody & body)1598 static OpMethodBody &genTypeOperandPrinter(Element *arg, OpMethodBody &body) {
1599   if (isa<OperandsDirective>(arg))
1600     return body << "getOperation()->getOperandTypes()";
1601   if (isa<ResultsDirective>(arg))
1602     return body << "getOperation()->getResultTypes()";
1603   auto *operand = dyn_cast<OperandVariable>(arg);
1604   auto *var = operand ? operand->getVar() : cast<ResultVariable>(arg)->getVar();
1605   if (var->isVariadic())
1606     return body << var->name << "().getTypes()";
1607   if (var->isOptional())
1608     return body << llvm::formatv(
1609                "({0}() ? ::llvm::ArrayRef<::mlir::Type>({0}().getType()) : "
1610                "::llvm::ArrayRef<::mlir::Type>())",
1611                var->name);
1612   return body << "::llvm::ArrayRef<::mlir::Type>(" << var->name
1613               << "().getType())";
1614 }
1615 
genElementPrinter(Element * element,OpMethodBody & body,Operator & op,bool & shouldEmitSpace,bool & lastWasPunctuation)1616 void OperationFormat::genElementPrinter(Element *element, OpMethodBody &body,
1617                                         Operator &op, bool &shouldEmitSpace,
1618                                         bool &lastWasPunctuation) {
1619   if (LiteralElement *literal = dyn_cast<LiteralElement>(element))
1620     return genLiteralPrinter(literal->getLiteral(), body, shouldEmitSpace,
1621                              lastWasPunctuation);
1622 
1623   if (SpaceElement *space = dyn_cast<SpaceElement>(element))
1624     return genSpacePrinter(space->getValue(), body, shouldEmitSpace,
1625                            lastWasPunctuation);
1626 
1627   // Emit an optional group.
1628   if (OptionalElement *optional = dyn_cast<OptionalElement>(element)) {
1629     // Emit the check for the presence of the anchor element.
1630     Element *anchor = optional->getAnchor();
1631     if (auto *operand = dyn_cast<OperandVariable>(anchor)) {
1632       const NamedTypeConstraint *var = operand->getVar();
1633       if (var->isOptional())
1634         body << "  if (" << var->name << "()) {\n";
1635       else if (var->isVariadic())
1636         body << "  if (!" << var->name << "().empty()) {\n";
1637     } else if (auto *region = dyn_cast<RegionVariable>(anchor)) {
1638       const NamedRegion *var = region->getVar();
1639       // TODO: Add a check for optional here when ODS supports it.
1640       body << "  if (!" << var->name << "().empty()) {\n";
1641 
1642     } else {
1643       body << "  if (getAttr(\""
1644            << cast<AttributeVariable>(anchor)->getVar()->name << "\")) {\n";
1645     }
1646 
1647     // If the anchor is a unit attribute, we don't need to print it. When
1648     // parsing, we will add this attribute if this group is present.
1649     auto elements = optional->getElements();
1650     Element *elidedAnchorElement = nullptr;
1651     auto *anchorAttr = dyn_cast<AttributeVariable>(anchor);
1652     if (anchorAttr && anchorAttr != &*elements.begin() &&
1653         anchorAttr->isUnitAttr()) {
1654       elidedAnchorElement = anchorAttr;
1655     }
1656 
1657     // Emit each of the elements.
1658     for (Element &childElement : elements) {
1659       if (&childElement != elidedAnchorElement) {
1660         genElementPrinter(&childElement, body, op, shouldEmitSpace,
1661                           lastWasPunctuation);
1662       }
1663     }
1664     body << "  }\n";
1665     return;
1666   }
1667 
1668   // Emit the attribute dictionary.
1669   if (auto *attrDict = dyn_cast<AttrDictDirective>(element)) {
1670     genAttrDictPrinter(*this, op, body, attrDict->isWithKeyword());
1671     lastWasPunctuation = false;
1672     return;
1673   }
1674 
1675   // Optionally insert a space before the next element. The AttrDict printer
1676   // already adds a space as necessary.
1677   if (shouldEmitSpace || !lastWasPunctuation)
1678     body << "  p << ' ';\n";
1679   lastWasPunctuation = false;
1680   shouldEmitSpace = true;
1681 
1682   if (auto *attr = dyn_cast<AttributeVariable>(element)) {
1683     const NamedAttribute *var = attr->getVar();
1684 
1685     // If we are formatting as an enum, symbolize the attribute as a string.
1686     if (canFormatEnumAttr(var)) {
1687       Attribute baseAttr = var->attr.getBaseAttr();
1688       const EnumAttr &enumAttr = cast<EnumAttr>(baseAttr);
1689       body << "  p << '\"' << " << enumAttr.getSymbolToStringFnName() << "("
1690            << (var->attr.isOptional() ? "*" : "") << var->name
1691            << "()) << '\"';\n";
1692       return;
1693     }
1694 
1695     // If we are formatting as a symbol name, handle it as a symbol name.
1696     if (shouldFormatSymbolNameAttr(var)) {
1697       body << "  p.printSymbolName(" << var->name << "Attr().getValue());\n";
1698       return;
1699     }
1700 
1701     // Elide the attribute type if it is buildable.
1702     if (attr->getTypeBuilder())
1703       body << "  p.printAttributeWithoutType(" << var->name << "Attr());\n";
1704     else
1705       body << "  p.printAttribute(" << var->name << "Attr());\n";
1706   } else if (auto *operand = dyn_cast<OperandVariable>(element)) {
1707     if (operand->getVar()->isOptional()) {
1708       body << "  if (::mlir::Value value = " << operand->getVar()->name
1709            << "())\n"
1710            << "    p << value;\n";
1711     } else {
1712       body << "  p << " << operand->getVar()->name << "();\n";
1713     }
1714   } else if (auto *region = dyn_cast<RegionVariable>(element)) {
1715     const NamedRegion *var = region->getVar();
1716     if (var->isVariadic()) {
1717       genVariadicRegionPrinter(var->name + "()", body, hasImplicitTermTrait);
1718     } else {
1719       genRegionPrinter(var->name + "()", body, hasImplicitTermTrait);
1720     }
1721   } else if (auto *successor = dyn_cast<SuccessorVariable>(element)) {
1722     const NamedSuccessor *var = successor->getVar();
1723     if (var->isVariadic())
1724       body << "  ::llvm::interleaveComma(" << var->name << "(), p);\n";
1725     else
1726       body << "  p << " << var->name << "();\n";
1727   } else if (auto *dir = dyn_cast<CustomDirective>(element)) {
1728     genCustomDirectivePrinter(dir, body);
1729   } else if (isa<OperandsDirective>(element)) {
1730     body << "  p << getOperation()->getOperands();\n";
1731   } else if (isa<RegionsDirective>(element)) {
1732     genVariadicRegionPrinter("getOperation()->getRegions()", body,
1733                              hasImplicitTermTrait);
1734   } else if (isa<SuccessorsDirective>(element)) {
1735     body << "  ::llvm::interleaveComma(getOperation()->getSuccessors(), p);\n";
1736   } else if (auto *dir = dyn_cast<TypeDirective>(element)) {
1737     body << "  p << ";
1738     genTypeOperandPrinter(dir->getOperand(), body) << ";\n";
1739   } else if (auto *dir = dyn_cast<TypeRefDirective>(element)) {
1740     body << "  p << ";
1741     genTypeOperandPrinter(dir->getOperand(), body) << ";\n";
1742   } else if (auto *dir = dyn_cast<FunctionalTypeDirective>(element)) {
1743     body << "  p.printFunctionalType(";
1744     genTypeOperandPrinter(dir->getInputs(), body) << ", ";
1745     genTypeOperandPrinter(dir->getResults(), body) << ");\n";
1746   } else {
1747     llvm_unreachable("unknown format element");
1748   }
1749 }
1750 
genPrinter(Operator & op,OpClass & opClass)1751 void OperationFormat::genPrinter(Operator &op, OpClass &opClass) {
1752   auto *method =
1753       opClass.addMethodAndPrune("void", "print", "::mlir::OpAsmPrinter &p");
1754   auto &body = method->body();
1755 
1756   // Emit the operation name, trimming the prefix if this is the standard
1757   // dialect.
1758   body << "  p << \"";
1759   std::string opName = op.getOperationName();
1760   if (op.getDialectName() == "std")
1761     body << StringRef(opName).drop_front(4);
1762   else
1763     body << opName;
1764   body << "\";\n";
1765 
1766   // Flags for if we should emit a space, and if the last element was
1767   // punctuation.
1768   bool shouldEmitSpace = true, lastWasPunctuation = false;
1769   for (auto &element : elements)
1770     genElementPrinter(element.get(), body, op, shouldEmitSpace,
1771                       lastWasPunctuation);
1772 }
1773 
1774 //===----------------------------------------------------------------------===//
1775 // FormatLexer
1776 //===----------------------------------------------------------------------===//
1777 
1778 namespace {
1779 /// This class represents a specific token in the input format.
1780 class Token {
1781 public:
1782   enum Kind {
1783     // Markers.
1784     eof,
1785     error,
1786 
1787     // Tokens with no info.
1788     l_paren,
1789     r_paren,
1790     caret,
1791     comma,
1792     equal,
1793     less,
1794     greater,
1795     question,
1796 
1797     // Keywords.
1798     keyword_start,
1799     kw_attr_dict,
1800     kw_attr_dict_w_keyword,
1801     kw_custom,
1802     kw_functional_type,
1803     kw_operands,
1804     kw_regions,
1805     kw_results,
1806     kw_successors,
1807     kw_type,
1808     kw_type_ref,
1809     keyword_end,
1810 
1811     // String valued tokens.
1812     identifier,
1813     literal,
1814     variable,
1815   };
Token(Kind kind,StringRef spelling)1816   Token(Kind kind, StringRef spelling) : kind(kind), spelling(spelling) {}
1817 
1818   /// Return the bytes that make up this token.
getSpelling() const1819   StringRef getSpelling() const { return spelling; }
1820 
1821   /// Return the kind of this token.
getKind() const1822   Kind getKind() const { return kind; }
1823 
1824   /// Return a location for this token.
getLoc() const1825   llvm::SMLoc getLoc() const {
1826     return llvm::SMLoc::getFromPointer(spelling.data());
1827   }
1828 
1829   /// Return if this token is a keyword.
isKeyword() const1830   bool isKeyword() const { return kind > keyword_start && kind < keyword_end; }
1831 
1832 private:
1833   /// Discriminator that indicates the kind of token this is.
1834   Kind kind;
1835 
1836   /// A reference to the entire token contents; this is always a pointer into
1837   /// a memory buffer owned by the source manager.
1838   StringRef spelling;
1839 };
1840 
1841 /// This class implements a simple lexer for operation assembly format strings.
1842 class FormatLexer {
1843 public:
1844   FormatLexer(llvm::SourceMgr &mgr, Operator &op);
1845 
1846   /// Lex the next token and return it.
1847   Token lexToken();
1848 
1849   /// Emit an error to the lexer with the given location and message.
1850   Token emitError(llvm::SMLoc loc, const Twine &msg);
1851   Token emitError(const char *loc, const Twine &msg);
1852 
1853   Token emitErrorAndNote(llvm::SMLoc loc, const Twine &msg, const Twine &note);
1854 
1855 private:
formToken(Token::Kind kind,const char * tokStart)1856   Token formToken(Token::Kind kind, const char *tokStart) {
1857     return Token(kind, StringRef(tokStart, curPtr - tokStart));
1858   }
1859 
1860   /// Return the next character in the stream.
1861   int getNextChar();
1862 
1863   /// Lex an identifier, literal, or variable.
1864   Token lexIdentifier(const char *tokStart);
1865   Token lexLiteral(const char *tokStart);
1866   Token lexVariable(const char *tokStart);
1867 
1868   llvm::SourceMgr &srcMgr;
1869   Operator &op;
1870   StringRef curBuffer;
1871   const char *curPtr;
1872 };
1873 } // end anonymous namespace
1874 
FormatLexer(llvm::SourceMgr & mgr,Operator & op)1875 FormatLexer::FormatLexer(llvm::SourceMgr &mgr, Operator &op)
1876     : srcMgr(mgr), op(op) {
1877   curBuffer = srcMgr.getMemoryBuffer(mgr.getMainFileID())->getBuffer();
1878   curPtr = curBuffer.begin();
1879 }
1880 
emitError(llvm::SMLoc loc,const Twine & msg)1881 Token FormatLexer::emitError(llvm::SMLoc loc, const Twine &msg) {
1882   srcMgr.PrintMessage(loc, llvm::SourceMgr::DK_Error, msg);
1883   llvm::SrcMgr.PrintMessage(op.getLoc()[0], llvm::SourceMgr::DK_Note,
1884                             "in custom assembly format for this operation");
1885   return formToken(Token::error, loc.getPointer());
1886 }
emitErrorAndNote(llvm::SMLoc loc,const Twine & msg,const Twine & note)1887 Token FormatLexer::emitErrorAndNote(llvm::SMLoc loc, const Twine &msg,
1888                                     const Twine &note) {
1889   srcMgr.PrintMessage(loc, llvm::SourceMgr::DK_Error, msg);
1890   llvm::SrcMgr.PrintMessage(op.getLoc()[0], llvm::SourceMgr::DK_Note,
1891                             "in custom assembly format for this operation");
1892   srcMgr.PrintMessage(loc, llvm::SourceMgr::DK_Note, note);
1893   return formToken(Token::error, loc.getPointer());
1894 }
emitError(const char * loc,const Twine & msg)1895 Token FormatLexer::emitError(const char *loc, const Twine &msg) {
1896   return emitError(llvm::SMLoc::getFromPointer(loc), msg);
1897 }
1898 
getNextChar()1899 int FormatLexer::getNextChar() {
1900   char curChar = *curPtr++;
1901   switch (curChar) {
1902   default:
1903     return (unsigned char)curChar;
1904   case 0: {
1905     // A nul character in the stream is either the end of the current buffer or
1906     // a random nul in the file. Disambiguate that here.
1907     if (curPtr - 1 != curBuffer.end())
1908       return 0;
1909 
1910     // Otherwise, return end of file.
1911     --curPtr;
1912     return EOF;
1913   }
1914   case '\n':
1915   case '\r':
1916     // Handle the newline character by ignoring it and incrementing the line
1917     // count. However, be careful about 'dos style' files with \n\r in them.
1918     // Only treat a \n\r or \r\n as a single line.
1919     if ((*curPtr == '\n' || (*curPtr == '\r')) && *curPtr != curChar)
1920       ++curPtr;
1921     return '\n';
1922   }
1923 }
1924 
lexToken()1925 Token FormatLexer::lexToken() {
1926   const char *tokStart = curPtr;
1927 
1928   // This always consumes at least one character.
1929   int curChar = getNextChar();
1930   switch (curChar) {
1931   default:
1932     // Handle identifiers: [a-zA-Z_]
1933     if (isalpha(curChar) || curChar == '_')
1934       return lexIdentifier(tokStart);
1935 
1936     // Unknown character, emit an error.
1937     return emitError(tokStart, "unexpected character");
1938   case EOF:
1939     // Return EOF denoting the end of lexing.
1940     return formToken(Token::eof, tokStart);
1941 
1942   // Lex punctuation.
1943   case '^':
1944     return formToken(Token::caret, tokStart);
1945   case ',':
1946     return formToken(Token::comma, tokStart);
1947   case '=':
1948     return formToken(Token::equal, tokStart);
1949   case '<':
1950     return formToken(Token::less, tokStart);
1951   case '>':
1952     return formToken(Token::greater, tokStart);
1953   case '?':
1954     return formToken(Token::question, tokStart);
1955   case '(':
1956     return formToken(Token::l_paren, tokStart);
1957   case ')':
1958     return formToken(Token::r_paren, tokStart);
1959 
1960   // Ignore whitespace characters.
1961   case 0:
1962   case ' ':
1963   case '\t':
1964   case '\n':
1965     return lexToken();
1966 
1967   case '`':
1968     return lexLiteral(tokStart);
1969   case '$':
1970     return lexVariable(tokStart);
1971   }
1972 }
1973 
lexLiteral(const char * tokStart)1974 Token FormatLexer::lexLiteral(const char *tokStart) {
1975   assert(curPtr[-1] == '`');
1976 
1977   // Lex a literal surrounded by ``.
1978   while (const char curChar = *curPtr++) {
1979     if (curChar == '`')
1980       return formToken(Token::literal, tokStart);
1981   }
1982   return emitError(curPtr - 1, "unexpected end of file in literal");
1983 }
1984 
lexVariable(const char * tokStart)1985 Token FormatLexer::lexVariable(const char *tokStart) {
1986   if (!isalpha(curPtr[0]) && curPtr[0] != '_')
1987     return emitError(curPtr - 1, "expected variable name");
1988 
1989   // Otherwise, consume the rest of the characters.
1990   while (isalnum(*curPtr) || *curPtr == '_')
1991     ++curPtr;
1992   return formToken(Token::variable, tokStart);
1993 }
1994 
lexIdentifier(const char * tokStart)1995 Token FormatLexer::lexIdentifier(const char *tokStart) {
1996   // Match the rest of the identifier regex: [0-9a-zA-Z_\-]*
1997   while (isalnum(*curPtr) || *curPtr == '_' || *curPtr == '-')
1998     ++curPtr;
1999 
2000   // Check to see if this identifier is a keyword.
2001   StringRef str(tokStart, curPtr - tokStart);
2002   Token::Kind kind =
2003       StringSwitch<Token::Kind>(str)
2004           .Case("attr-dict", Token::kw_attr_dict)
2005           .Case("attr-dict-with-keyword", Token::kw_attr_dict_w_keyword)
2006           .Case("custom", Token::kw_custom)
2007           .Case("functional-type", Token::kw_functional_type)
2008           .Case("operands", Token::kw_operands)
2009           .Case("regions", Token::kw_regions)
2010           .Case("results", Token::kw_results)
2011           .Case("successors", Token::kw_successors)
2012           .Case("type", Token::kw_type)
2013           .Case("type_ref", Token::kw_type_ref)
2014           .Default(Token::identifier);
2015   return Token(kind, str);
2016 }
2017 
2018 //===----------------------------------------------------------------------===//
2019 // FormatParser
2020 //===----------------------------------------------------------------------===//
2021 
2022 /// Function to find an element within the given range that has the same name as
2023 /// 'name'.
2024 template <typename RangeT>
findArg(RangeT && range,StringRef name)2025 static auto findArg(RangeT &&range, StringRef name) {
2026   auto it = llvm::find_if(range, [=](auto &arg) { return arg.name == name; });
2027   return it != range.end() ? &*it : nullptr;
2028 }
2029 
2030 namespace {
2031 /// This class implements a parser for an instance of an operation assembly
2032 /// format.
2033 class FormatParser {
2034 public:
FormatParser(llvm::SourceMgr & mgr,OperationFormat & format,Operator & op)2035   FormatParser(llvm::SourceMgr &mgr, OperationFormat &format, Operator &op)
2036       : lexer(mgr, op), curToken(lexer.lexToken()), fmt(format), op(op),
2037         seenOperandTypes(op.getNumOperands()),
2038         seenResultTypes(op.getNumResults()) {}
2039 
2040   /// Parse the operation assembly format.
2041   LogicalResult parse();
2042 
2043 private:
2044   /// This struct represents a type resolution instance. It includes a specific
2045   /// type as well as an optional transformer to apply to that type in order to
2046   /// properly resolve the type of a variable.
2047   struct TypeResolutionInstance {
2048     ConstArgument resolver;
2049     Optional<StringRef> transformer;
2050   };
2051 
2052   /// An iterator over the elements of a format group.
2053   using ElementsIterT = llvm::pointee_iterator<
2054       std::vector<std::unique_ptr<Element>>::const_iterator>;
2055 
2056   /// Verify the state of operation attributes within the format.
2057   LogicalResult verifyAttributes(llvm::SMLoc loc);
2058   /// Verify the attribute elements at the back of the given stack of iterators.
2059   LogicalResult verifyAttributes(
2060       llvm::SMLoc loc,
2061       SmallVectorImpl<std::pair<ElementsIterT, ElementsIterT>> &iteratorStack);
2062 
2063   /// Verify the state of operation operands within the format.
2064   LogicalResult
2065   verifyOperands(llvm::SMLoc loc,
2066                  llvm::StringMap<TypeResolutionInstance> &variableTyResolver);
2067 
2068   /// Verify the state of operation regions within the format.
2069   LogicalResult verifyRegions(llvm::SMLoc loc);
2070 
2071   /// Verify the state of operation results within the format.
2072   LogicalResult
2073   verifyResults(llvm::SMLoc loc,
2074                 llvm::StringMap<TypeResolutionInstance> &variableTyResolver);
2075 
2076   /// Verify the state of operation successors within the format.
2077   LogicalResult verifySuccessors(llvm::SMLoc loc);
2078 
2079   /// Given the values of an `AllTypesMatch` trait, check for inferable type
2080   /// resolution.
2081   void handleAllTypesMatchConstraint(
2082       ArrayRef<StringRef> values,
2083       llvm::StringMap<TypeResolutionInstance> &variableTyResolver);
2084   /// Check for inferable type resolution given all operands, and or results,
2085   /// have the same type. If 'includeResults' is true, the results also have the
2086   /// same type as all of the operands.
2087   void handleSameTypesConstraint(
2088       llvm::StringMap<TypeResolutionInstance> &variableTyResolver,
2089       bool includeResults);
2090   /// Check for inferable type resolution based on another operand, result, or
2091   /// attribute.
2092   void handleTypesMatchConstraint(
2093       llvm::StringMap<TypeResolutionInstance> &variableTyResolver,
2094       llvm::Record def);
2095 
2096   /// Returns an argument or attribute with the given name that has been seen
2097   /// within the format.
2098   ConstArgument findSeenArg(StringRef name);
2099 
2100   /// Parse a specific element.
2101   LogicalResult parseElement(std::unique_ptr<Element> &element,
2102                              bool isTopLevel);
2103   LogicalResult parseVariable(std::unique_ptr<Element> &element,
2104                               bool isTopLevel);
2105   LogicalResult parseDirective(std::unique_ptr<Element> &element,
2106                                bool isTopLevel);
2107   LogicalResult parseLiteral(std::unique_ptr<Element> &element);
2108   LogicalResult parseOptional(std::unique_ptr<Element> &element,
2109                               bool isTopLevel);
2110   LogicalResult parseOptionalChildElement(
2111       std::vector<std::unique_ptr<Element>> &childElements,
2112       SmallPtrSetImpl<const NamedTypeConstraint *> &seenVariables,
2113       Optional<unsigned> &anchorIdx);
2114 
2115   /// Parse the various different directives.
2116   LogicalResult parseAttrDictDirective(std::unique_ptr<Element> &element,
2117                                        llvm::SMLoc loc, bool isTopLevel,
2118                                        bool withKeyword);
2119   LogicalResult parseCustomDirective(std::unique_ptr<Element> &element,
2120                                      llvm::SMLoc loc, bool isTopLevel);
2121   LogicalResult parseCustomDirectiveParameter(
2122       std::vector<std::unique_ptr<Element>> &parameters);
2123   LogicalResult parseFunctionalTypeDirective(std::unique_ptr<Element> &element,
2124                                              Token tok, bool isTopLevel);
2125   LogicalResult parseOperandsDirective(std::unique_ptr<Element> &element,
2126                                        llvm::SMLoc loc, bool isTopLevel);
2127   LogicalResult parseRegionsDirective(std::unique_ptr<Element> &element,
2128                                       llvm::SMLoc loc, bool isTopLevel);
2129   LogicalResult parseResultsDirective(std::unique_ptr<Element> &element,
2130                                       llvm::SMLoc loc, bool isTopLevel);
2131   LogicalResult parseSuccessorsDirective(std::unique_ptr<Element> &element,
2132                                          llvm::SMLoc loc, bool isTopLevel);
2133   LogicalResult parseTypeDirective(std::unique_ptr<Element> &element, Token tok,
2134                                    bool isTopLevel, bool isTypeRef = false);
2135   LogicalResult parseTypeDirectiveOperand(std::unique_ptr<Element> &element,
2136                                           bool isTypeRef = false);
2137 
2138   //===--------------------------------------------------------------------===//
2139   // Lexer Utilities
2140   //===--------------------------------------------------------------------===//
2141 
2142   /// Advance the current lexer onto the next token.
consumeToken()2143   void consumeToken() {
2144     assert(curToken.getKind() != Token::eof &&
2145            curToken.getKind() != Token::error &&
2146            "shouldn't advance past EOF or errors");
2147     curToken = lexer.lexToken();
2148   }
parseToken(Token::Kind kind,const Twine & msg)2149   LogicalResult parseToken(Token::Kind kind, const Twine &msg) {
2150     if (curToken.getKind() != kind)
2151       return emitError(curToken.getLoc(), msg);
2152     consumeToken();
2153     return ::mlir::success();
2154   }
emitError(llvm::SMLoc loc,const Twine & msg)2155   LogicalResult emitError(llvm::SMLoc loc, const Twine &msg) {
2156     lexer.emitError(loc, msg);
2157     return ::mlir::failure();
2158   }
emitErrorAndNote(llvm::SMLoc loc,const Twine & msg,const Twine & note)2159   LogicalResult emitErrorAndNote(llvm::SMLoc loc, const Twine &msg,
2160                                  const Twine &note) {
2161     lexer.emitErrorAndNote(loc, msg, note);
2162     return ::mlir::failure();
2163   }
2164 
2165   //===--------------------------------------------------------------------===//
2166   // Fields
2167   //===--------------------------------------------------------------------===//
2168 
2169   FormatLexer lexer;
2170   Token curToken;
2171   OperationFormat &fmt;
2172   Operator &op;
2173 
2174   // The following are various bits of format state used for verification
2175   // during parsing.
2176   bool hasAttrDict = false;
2177   bool hasAllRegions = false, hasAllSuccessors = false;
2178   llvm::SmallBitVector seenOperandTypes, seenResultTypes;
2179   llvm::SmallSetVector<const NamedAttribute *, 8> seenAttrs;
2180   llvm::DenseSet<const NamedTypeConstraint *> seenOperands;
2181   llvm::DenseSet<const NamedRegion *> seenRegions;
2182   llvm::DenseSet<const NamedSuccessor *> seenSuccessors;
2183   llvm::DenseSet<const NamedTypeConstraint *> optionalVariables;
2184 };
2185 } // end anonymous namespace
2186 
parse()2187 LogicalResult FormatParser::parse() {
2188   llvm::SMLoc loc = curToken.getLoc();
2189 
2190   // Parse each of the format elements into the main format.
2191   while (curToken.getKind() != Token::eof) {
2192     std::unique_ptr<Element> element;
2193     if (failed(parseElement(element, /*isTopLevel=*/true)))
2194       return ::mlir::failure();
2195     fmt.elements.push_back(std::move(element));
2196   }
2197 
2198   // Check that the attribute dictionary is in the format.
2199   if (!hasAttrDict)
2200     return emitError(loc, "'attr-dict' directive not found in "
2201                           "custom assembly format");
2202 
2203   // Check for any type traits that we can use for inferring types.
2204   llvm::StringMap<TypeResolutionInstance> variableTyResolver;
2205   for (const OpTrait &trait : op.getTraits()) {
2206     const llvm::Record &def = trait.getDef();
2207     if (def.isSubClassOf("AllTypesMatch")) {
2208       handleAllTypesMatchConstraint(def.getValueAsListOfStrings("values"),
2209                                     variableTyResolver);
2210     } else if (def.getName() == "SameTypeOperands") {
2211       handleSameTypesConstraint(variableTyResolver, /*includeResults=*/false);
2212     } else if (def.getName() == "SameOperandsAndResultType") {
2213       handleSameTypesConstraint(variableTyResolver, /*includeResults=*/true);
2214     } else if (def.isSubClassOf("TypesMatchWith")) {
2215       handleTypesMatchConstraint(variableTyResolver, def);
2216     }
2217   }
2218 
2219   // Verify the state of the various operation components.
2220   if (failed(verifyAttributes(loc)) ||
2221       failed(verifyResults(loc, variableTyResolver)) ||
2222       failed(verifyOperands(loc, variableTyResolver)) ||
2223       failed(verifyRegions(loc)) || failed(verifySuccessors(loc)))
2224     return ::mlir::failure();
2225 
2226   // Collect the set of used attributes in the format.
2227   fmt.usedAttributes = seenAttrs.takeVector();
2228   return ::mlir::success();
2229 }
2230 
verifyAttributes(llvm::SMLoc loc)2231 LogicalResult FormatParser::verifyAttributes(llvm::SMLoc loc) {
2232   // Check that there are no `:` literals after an attribute without a constant
2233   // type. The attribute grammar contains an optional trailing colon type, which
2234   // can lead to unexpected and generally unintended behavior. Given that, it is
2235   // better to just error out here instead.
2236   using ElementsIterT = llvm::pointee_iterator<
2237       std::vector<std::unique_ptr<Element>>::const_iterator>;
2238   SmallVector<std::pair<ElementsIterT, ElementsIterT>, 1> iteratorStack;
2239   iteratorStack.emplace_back(fmt.elements.begin(), fmt.elements.end());
2240   while (!iteratorStack.empty())
2241     if (failed(verifyAttributes(loc, iteratorStack)))
2242       return ::mlir::failure();
2243   return ::mlir::success();
2244 }
2245 /// Verify the attribute elements at the back of the given stack of iterators.
verifyAttributes(llvm::SMLoc loc,SmallVectorImpl<std::pair<ElementsIterT,ElementsIterT>> & iteratorStack)2246 LogicalResult FormatParser::verifyAttributes(
2247     llvm::SMLoc loc,
2248     SmallVectorImpl<std::pair<ElementsIterT, ElementsIterT>> &iteratorStack) {
2249   auto &stackIt = iteratorStack.back();
2250   ElementsIterT &it = stackIt.first, e = stackIt.second;
2251   while (it != e) {
2252     Element *element = &*(it++);
2253 
2254     // Traverse into optional groups.
2255     if (auto *optional = dyn_cast<OptionalElement>(element)) {
2256       auto elements = optional->getElements();
2257       iteratorStack.emplace_back(elements.begin(), elements.end());
2258       return ::mlir::success();
2259     }
2260 
2261     // We are checking for an attribute element followed by a `:`, so there is
2262     // no need to check the end.
2263     if (it == e && iteratorStack.size() == 1)
2264       break;
2265 
2266     // Check for an attribute with a constant type builder, followed by a `:`.
2267     auto *prevAttr = dyn_cast<AttributeVariable>(element);
2268     if (!prevAttr || prevAttr->getTypeBuilder())
2269       continue;
2270 
2271     // Check the next iterator within the stack for literal elements.
2272     for (auto &nextItPair : iteratorStack) {
2273       ElementsIterT nextIt = nextItPair.first, nextE = nextItPair.second;
2274       for (; nextIt != nextE; ++nextIt) {
2275         // Skip any trailing spaces, attribute dictionaries, or optional groups.
2276         if (isa<SpaceElement>(*nextIt) || isa<AttrDictDirective>(*nextIt) ||
2277             isa<OptionalElement>(*nextIt))
2278           continue;
2279 
2280         // We are only interested in `:` literals.
2281         auto *literal = dyn_cast<LiteralElement>(&*nextIt);
2282         if (!literal || literal->getLiteral() != ":")
2283           break;
2284 
2285         // TODO: Use the location of the literal element itself.
2286         return emitError(
2287             loc, llvm::formatv("format ambiguity caused by `:` literal found "
2288                                "after attribute `{0}` which does not have "
2289                                "a buildable type",
2290                                prevAttr->getVar()->name));
2291       }
2292     }
2293   }
2294   iteratorStack.pop_back();
2295   return ::mlir::success();
2296 }
2297 
verifyOperands(llvm::SMLoc loc,llvm::StringMap<TypeResolutionInstance> & variableTyResolver)2298 LogicalResult FormatParser::verifyOperands(
2299     llvm::SMLoc loc,
2300     llvm::StringMap<TypeResolutionInstance> &variableTyResolver) {
2301   // Check that all of the operands are within the format, and their types can
2302   // be inferred.
2303   auto &buildableTypes = fmt.buildableTypes;
2304   for (unsigned i = 0, e = op.getNumOperands(); i != e; ++i) {
2305     NamedTypeConstraint &operand = op.getOperand(i);
2306 
2307     // Check that the operand itself is in the format.
2308     if (!fmt.allOperands && !seenOperands.count(&operand)) {
2309       return emitErrorAndNote(loc,
2310                               "operand #" + Twine(i) + ", named '" +
2311                                   operand.name + "', not found",
2312                               "suggest adding a '$" + operand.name +
2313                                   "' directive to the custom assembly format");
2314     }
2315 
2316     // Check that the operand type is in the format, or that it can be inferred.
2317     if (fmt.allOperandTypes || seenOperandTypes.test(i))
2318       continue;
2319 
2320     // Check to see if we can infer this type from another variable.
2321     auto varResolverIt = variableTyResolver.find(op.getOperand(i).name);
2322     if (varResolverIt != variableTyResolver.end()) {
2323       TypeResolutionInstance &resolver = varResolverIt->second;
2324       fmt.operandTypes[i].setResolver(resolver.resolver, resolver.transformer);
2325       continue;
2326     }
2327 
2328     // Similarly to results, allow a custom builder for resolving the type if
2329     // we aren't using the 'operands' directive.
2330     Optional<StringRef> builder = operand.constraint.getBuilderCall();
2331     if (!builder || (fmt.allOperands && operand.isVariableLength())) {
2332       return emitErrorAndNote(
2333           loc,
2334           "type of operand #" + Twine(i) + ", named '" + operand.name +
2335               "', is not buildable and a buildable type cannot be inferred",
2336           "suggest adding a type constraint to the operation or adding a "
2337           "'type($" +
2338               operand.name + ")' directive to the " + "custom assembly format");
2339     }
2340     auto it = buildableTypes.insert({*builder, buildableTypes.size()});
2341     fmt.operandTypes[i].setBuilderIdx(it.first->second);
2342   }
2343   return ::mlir::success();
2344 }
2345 
verifyRegions(llvm::SMLoc loc)2346 LogicalResult FormatParser::verifyRegions(llvm::SMLoc loc) {
2347   // Check that all of the regions are within the format.
2348   if (hasAllRegions)
2349     return ::mlir::success();
2350 
2351   for (unsigned i = 0, e = op.getNumRegions(); i != e; ++i) {
2352     const NamedRegion &region = op.getRegion(i);
2353     if (!seenRegions.count(&region)) {
2354       return emitErrorAndNote(loc,
2355                               "region #" + Twine(i) + ", named '" +
2356                                   region.name + "', not found",
2357                               "suggest adding a '$" + region.name +
2358                                   "' directive to the custom assembly format");
2359     }
2360   }
2361   return ::mlir::success();
2362 }
2363 
verifyResults(llvm::SMLoc loc,llvm::StringMap<TypeResolutionInstance> & variableTyResolver)2364 LogicalResult FormatParser::verifyResults(
2365     llvm::SMLoc loc,
2366     llvm::StringMap<TypeResolutionInstance> &variableTyResolver) {
2367   // If we format all of the types together, there is nothing to check.
2368   if (fmt.allResultTypes)
2369     return ::mlir::success();
2370 
2371   // Check that all of the result types can be inferred.
2372   auto &buildableTypes = fmt.buildableTypes;
2373   for (unsigned i = 0, e = op.getNumResults(); i != e; ++i) {
2374     if (seenResultTypes.test(i))
2375       continue;
2376 
2377     // Check to see if we can infer this type from another variable.
2378     auto varResolverIt = variableTyResolver.find(op.getResultName(i));
2379     if (varResolverIt != variableTyResolver.end()) {
2380       TypeResolutionInstance resolver = varResolverIt->second;
2381       fmt.resultTypes[i].setResolver(resolver.resolver, resolver.transformer);
2382       continue;
2383     }
2384 
2385     // If the result is not variable length, allow for the case where the type
2386     // has a builder that we can use.
2387     NamedTypeConstraint &result = op.getResult(i);
2388     Optional<StringRef> builder = result.constraint.getBuilderCall();
2389     if (!builder || result.isVariableLength()) {
2390       return emitErrorAndNote(
2391           loc,
2392           "type of result #" + Twine(i) + ", named '" + result.name +
2393               "', is not buildable and a buildable type cannot be inferred",
2394           "suggest adding a type constraint to the operation or adding a "
2395           "'type($" +
2396               result.name + ")' directive to the " + "custom assembly format");
2397     }
2398     // Note in the format that this result uses the custom builder.
2399     auto it = buildableTypes.insert({*builder, buildableTypes.size()});
2400     fmt.resultTypes[i].setBuilderIdx(it.first->second);
2401   }
2402   return ::mlir::success();
2403 }
2404 
verifySuccessors(llvm::SMLoc loc)2405 LogicalResult FormatParser::verifySuccessors(llvm::SMLoc loc) {
2406   // Check that all of the successors are within the format.
2407   if (hasAllSuccessors)
2408     return ::mlir::success();
2409 
2410   for (unsigned i = 0, e = op.getNumSuccessors(); i != e; ++i) {
2411     const NamedSuccessor &successor = op.getSuccessor(i);
2412     if (!seenSuccessors.count(&successor)) {
2413       return emitErrorAndNote(loc,
2414                               "successor #" + Twine(i) + ", named '" +
2415                                   successor.name + "', not found",
2416                               "suggest adding a '$" + successor.name +
2417                                   "' directive to the custom assembly format");
2418     }
2419   }
2420   return ::mlir::success();
2421 }
2422 
handleAllTypesMatchConstraint(ArrayRef<StringRef> values,llvm::StringMap<TypeResolutionInstance> & variableTyResolver)2423 void FormatParser::handleAllTypesMatchConstraint(
2424     ArrayRef<StringRef> values,
2425     llvm::StringMap<TypeResolutionInstance> &variableTyResolver) {
2426   for (unsigned i = 0, e = values.size(); i != e; ++i) {
2427     // Check to see if this value matches a resolved operand or result type.
2428     ConstArgument arg = findSeenArg(values[i]);
2429     if (!arg)
2430       continue;
2431 
2432     // Mark this value as the type resolver for the other variables.
2433     for (unsigned j = 0; j != i; ++j)
2434       variableTyResolver[values[j]] = {arg, llvm::None};
2435     for (unsigned j = i + 1; j != e; ++j)
2436       variableTyResolver[values[j]] = {arg, llvm::None};
2437   }
2438 }
2439 
handleSameTypesConstraint(llvm::StringMap<TypeResolutionInstance> & variableTyResolver,bool includeResults)2440 void FormatParser::handleSameTypesConstraint(
2441     llvm::StringMap<TypeResolutionInstance> &variableTyResolver,
2442     bool includeResults) {
2443   const NamedTypeConstraint *resolver = nullptr;
2444   int resolvedIt = -1;
2445 
2446   // Check to see if there is an operand or result to use for the resolution.
2447   if ((resolvedIt = seenOperandTypes.find_first()) != -1)
2448     resolver = &op.getOperand(resolvedIt);
2449   else if (includeResults && (resolvedIt = seenResultTypes.find_first()) != -1)
2450     resolver = &op.getResult(resolvedIt);
2451   else
2452     return;
2453 
2454   // Set the resolvers for each operand and result.
2455   for (unsigned i = 0, e = op.getNumOperands(); i != e; ++i)
2456     if (!seenOperandTypes.test(i) && !op.getOperand(i).name.empty())
2457       variableTyResolver[op.getOperand(i).name] = {resolver, llvm::None};
2458   if (includeResults) {
2459     for (unsigned i = 0, e = op.getNumResults(); i != e; ++i)
2460       if (!seenResultTypes.test(i) && !op.getResultName(i).empty())
2461         variableTyResolver[op.getResultName(i)] = {resolver, llvm::None};
2462   }
2463 }
2464 
handleTypesMatchConstraint(llvm::StringMap<TypeResolutionInstance> & variableTyResolver,llvm::Record def)2465 void FormatParser::handleTypesMatchConstraint(
2466     llvm::StringMap<TypeResolutionInstance> &variableTyResolver,
2467     llvm::Record def) {
2468   StringRef lhsName = def.getValueAsString("lhs");
2469   StringRef rhsName = def.getValueAsString("rhs");
2470   StringRef transformer = def.getValueAsString("transformer");
2471   if (ConstArgument arg = findSeenArg(lhsName))
2472     variableTyResolver[rhsName] = {arg, transformer};
2473 }
2474 
findSeenArg(StringRef name)2475 ConstArgument FormatParser::findSeenArg(StringRef name) {
2476   if (const NamedTypeConstraint *arg = findArg(op.getOperands(), name))
2477     return seenOperandTypes.test(arg - op.operand_begin()) ? arg : nullptr;
2478   if (const NamedTypeConstraint *arg = findArg(op.getResults(), name))
2479     return seenResultTypes.test(arg - op.result_begin()) ? arg : nullptr;
2480   if (const NamedAttribute *attr = findArg(op.getAttributes(), name))
2481     return seenAttrs.count(attr) ? attr : nullptr;
2482   return nullptr;
2483 }
2484 
parseElement(std::unique_ptr<Element> & element,bool isTopLevel)2485 LogicalResult FormatParser::parseElement(std::unique_ptr<Element> &element,
2486                                          bool isTopLevel) {
2487   // Directives.
2488   if (curToken.isKeyword())
2489     return parseDirective(element, isTopLevel);
2490   // Literals.
2491   if (curToken.getKind() == Token::literal)
2492     return parseLiteral(element);
2493   // Optionals.
2494   if (curToken.getKind() == Token::l_paren)
2495     return parseOptional(element, isTopLevel);
2496   // Variables.
2497   if (curToken.getKind() == Token::variable)
2498     return parseVariable(element, isTopLevel);
2499   return emitError(curToken.getLoc(),
2500                    "expected directive, literal, variable, or optional group");
2501 }
2502 
parseVariable(std::unique_ptr<Element> & element,bool isTopLevel)2503 LogicalResult FormatParser::parseVariable(std::unique_ptr<Element> &element,
2504                                           bool isTopLevel) {
2505   Token varTok = curToken;
2506   consumeToken();
2507 
2508   StringRef name = varTok.getSpelling().drop_front();
2509   llvm::SMLoc loc = varTok.getLoc();
2510 
2511   // Check that the parsed argument is something actually registered on the
2512   // op.
2513   /// Attributes
2514   if (const NamedAttribute *attr = findArg(op.getAttributes(), name)) {
2515     if (isTopLevel && !seenAttrs.insert(attr))
2516       return emitError(loc, "attribute '" + name + "' is already bound");
2517     element = std::make_unique<AttributeVariable>(attr);
2518     return ::mlir::success();
2519   }
2520   /// Operands
2521   if (const NamedTypeConstraint *operand = findArg(op.getOperands(), name)) {
2522     if (isTopLevel) {
2523       if (fmt.allOperands || !seenOperands.insert(operand).second)
2524         return emitError(loc, "operand '" + name + "' is already bound");
2525     }
2526     element = std::make_unique<OperandVariable>(operand);
2527     return ::mlir::success();
2528   }
2529   /// Regions
2530   if (const NamedRegion *region = findArg(op.getRegions(), name)) {
2531     if (!isTopLevel)
2532       return emitError(loc, "regions can only be used at the top level");
2533     if (hasAllRegions || !seenRegions.insert(region).second)
2534       return emitError(loc, "region '" + name + "' is already bound");
2535     element = std::make_unique<RegionVariable>(region);
2536     return ::mlir::success();
2537   }
2538   /// Results.
2539   if (const auto *result = findArg(op.getResults(), name)) {
2540     if (isTopLevel)
2541       return emitError(loc, "results can not be used at the top level");
2542     element = std::make_unique<ResultVariable>(result);
2543     return ::mlir::success();
2544   }
2545   /// Successors.
2546   if (const auto *successor = findArg(op.getSuccessors(), name)) {
2547     if (!isTopLevel)
2548       return emitError(loc, "successors can only be used at the top level");
2549     if (hasAllSuccessors || !seenSuccessors.insert(successor).second)
2550       return emitError(loc, "successor '" + name + "' is already bound");
2551     element = std::make_unique<SuccessorVariable>(successor);
2552     return ::mlir::success();
2553   }
2554   return emitError(loc, "expected variable to refer to an argument, region, "
2555                         "result, or successor");
2556 }
2557 
parseDirective(std::unique_ptr<Element> & element,bool isTopLevel)2558 LogicalResult FormatParser::parseDirective(std::unique_ptr<Element> &element,
2559                                            bool isTopLevel) {
2560   Token dirTok = curToken;
2561   consumeToken();
2562 
2563   switch (dirTok.getKind()) {
2564   case Token::kw_attr_dict:
2565     return parseAttrDictDirective(element, dirTok.getLoc(), isTopLevel,
2566                                   /*withKeyword=*/false);
2567   case Token::kw_attr_dict_w_keyword:
2568     return parseAttrDictDirective(element, dirTok.getLoc(), isTopLevel,
2569                                   /*withKeyword=*/true);
2570   case Token::kw_custom:
2571     return parseCustomDirective(element, dirTok.getLoc(), isTopLevel);
2572   case Token::kw_functional_type:
2573     return parseFunctionalTypeDirective(element, dirTok, isTopLevel);
2574   case Token::kw_operands:
2575     return parseOperandsDirective(element, dirTok.getLoc(), isTopLevel);
2576   case Token::kw_regions:
2577     return parseRegionsDirective(element, dirTok.getLoc(), isTopLevel);
2578   case Token::kw_results:
2579     return parseResultsDirective(element, dirTok.getLoc(), isTopLevel);
2580   case Token::kw_successors:
2581     return parseSuccessorsDirective(element, dirTok.getLoc(), isTopLevel);
2582   case Token::kw_type_ref:
2583     return parseTypeDirective(element, dirTok, isTopLevel, /*isTypeRef=*/true);
2584   case Token::kw_type:
2585     return parseTypeDirective(element, dirTok, isTopLevel);
2586 
2587   default:
2588     llvm_unreachable("unknown directive token");
2589   }
2590 }
2591 
parseLiteral(std::unique_ptr<Element> & element)2592 LogicalResult FormatParser::parseLiteral(std::unique_ptr<Element> &element) {
2593   Token literalTok = curToken;
2594   consumeToken();
2595 
2596   StringRef value = literalTok.getSpelling().drop_front().drop_back();
2597 
2598   // The parsed literal is a space element (`` or ` `).
2599   if (value.empty() || (value.size() == 1 && value.front() == ' ')) {
2600     element = std::make_unique<SpaceElement>(!value.empty());
2601     return ::mlir::success();
2602   }
2603 
2604   // Check that the parsed literal is valid.
2605   if (!LiteralElement::isValidLiteral(value))
2606     return emitError(literalTok.getLoc(), "expected valid literal");
2607 
2608   element = std::make_unique<LiteralElement>(value);
2609   return ::mlir::success();
2610 }
2611 
parseOptional(std::unique_ptr<Element> & element,bool isTopLevel)2612 LogicalResult FormatParser::parseOptional(std::unique_ptr<Element> &element,
2613                                           bool isTopLevel) {
2614   llvm::SMLoc curLoc = curToken.getLoc();
2615   if (!isTopLevel)
2616     return emitError(curLoc, "optional groups can only be used as top-level "
2617                              "elements");
2618   consumeToken();
2619 
2620   // Parse the child elements for this optional group.
2621   std::vector<std::unique_ptr<Element>> elements;
2622   SmallPtrSet<const NamedTypeConstraint *, 8> seenVariables;
2623   Optional<unsigned> anchorIdx;
2624   do {
2625     if (failed(parseOptionalChildElement(elements, seenVariables, anchorIdx)))
2626       return ::mlir::failure();
2627   } while (curToken.getKind() != Token::r_paren);
2628   consumeToken();
2629   if (failed(parseToken(Token::question, "expected '?' after optional group")))
2630     return ::mlir::failure();
2631 
2632   // The optional group is required to have an anchor.
2633   if (!anchorIdx)
2634     return emitError(curLoc, "optional group specified no anchor element");
2635 
2636   // The first parsable element of the group must be able to be parsed in an
2637   // optional fashion.
2638   auto parseBegin = llvm::find_if_not(
2639       elements, [](auto &element) { return isa<SpaceElement>(element.get()); });
2640   Element *firstElement = parseBegin->get();
2641   if (!isa<AttributeVariable>(firstElement) &&
2642       !isa<LiteralElement>(firstElement) &&
2643       !isa<OperandVariable>(firstElement) && !isa<RegionVariable>(firstElement))
2644     return emitError(curLoc,
2645                      "first parsable element of an operand group must be "
2646                      "an attribute, literal, operand, or region");
2647 
2648   // After parsing all of the elements, ensure that all type directives refer
2649   // only to elements within the group.
2650   auto checkTypeOperand = [&](Element *typeEle) {
2651     auto *opVar = dyn_cast<OperandVariable>(typeEle);
2652     const NamedTypeConstraint *var = opVar ? opVar->getVar() : nullptr;
2653     if (!seenVariables.count(var))
2654       return emitError(curLoc, "type directive can only refer to variables "
2655                                "within the optional group");
2656     return ::mlir::success();
2657   };
2658   for (auto &ele : elements) {
2659     if (auto *typeEle = dyn_cast<TypeRefDirective>(ele.get())) {
2660       if (failed(checkTypeOperand(typeEle->getOperand())))
2661         return failure();
2662     } else if (auto *typeEle = dyn_cast<TypeDirective>(ele.get())) {
2663       if (failed(checkTypeOperand(typeEle->getOperand())))
2664         return ::mlir::failure();
2665     } else if (auto *typeEle = dyn_cast<FunctionalTypeDirective>(ele.get())) {
2666       if (failed(checkTypeOperand(typeEle->getInputs())) ||
2667           failed(checkTypeOperand(typeEle->getResults())))
2668         return ::mlir::failure();
2669     }
2670   }
2671 
2672   optionalVariables.insert(seenVariables.begin(), seenVariables.end());
2673   auto parseStart = parseBegin - elements.begin();
2674   element = std::make_unique<OptionalElement>(std::move(elements), *anchorIdx,
2675                                               parseStart);
2676   return ::mlir::success();
2677 }
2678 
parseOptionalChildElement(std::vector<std::unique_ptr<Element>> & childElements,SmallPtrSetImpl<const NamedTypeConstraint * > & seenVariables,Optional<unsigned> & anchorIdx)2679 LogicalResult FormatParser::parseOptionalChildElement(
2680     std::vector<std::unique_ptr<Element>> &childElements,
2681     SmallPtrSetImpl<const NamedTypeConstraint *> &seenVariables,
2682     Optional<unsigned> &anchorIdx) {
2683   llvm::SMLoc childLoc = curToken.getLoc();
2684   childElements.push_back({});
2685   if (failed(parseElement(childElements.back(), /*isTopLevel=*/true)))
2686     return ::mlir::failure();
2687 
2688   // Check to see if this element is the anchor of the optional group.
2689   bool isAnchor = curToken.getKind() == Token::caret;
2690   if (isAnchor) {
2691     if (anchorIdx)
2692       return emitError(childLoc, "only one element can be marked as the anchor "
2693                                  "of an optional group");
2694     anchorIdx = childElements.size() - 1;
2695     consumeToken();
2696   }
2697 
2698   return TypeSwitch<Element *, LogicalResult>(childElements.back().get())
2699       // All attributes can be within the optional group, but only optional
2700       // attributes can be the anchor.
2701       .Case([&](AttributeVariable *attrEle) {
2702         if (isAnchor && !attrEle->getVar()->attr.isOptional())
2703           return emitError(childLoc, "only optional attributes can be used to "
2704                                      "anchor an optional group");
2705         return ::mlir::success();
2706       })
2707       // Only optional-like(i.e. variadic) operands can be within an optional
2708       // group.
2709       .Case<OperandVariable>([&](OperandVariable *ele) {
2710         if (!ele->getVar()->isVariableLength())
2711           return emitError(childLoc, "only variable length operands can be "
2712                                      "used within an optional group");
2713         seenVariables.insert(ele->getVar());
2714         return ::mlir::success();
2715       })
2716       .Case<RegionVariable>([&](RegionVariable *) {
2717         // TODO: When ODS has proper support for marking "optional" regions, add
2718         // a check here.
2719         return ::mlir::success();
2720       })
2721       // Literals, spaces, custom directives, and type directives may be used,
2722       // but they can't anchor the group.
2723       .Case<LiteralElement, SpaceElement, CustomDirective,
2724             FunctionalTypeDirective, OptionalElement, TypeRefDirective,
2725             TypeDirective>([&](Element *) {
2726         if (isAnchor)
2727           return emitError(childLoc, "only variables can be used to anchor "
2728                                      "an optional group");
2729         return ::mlir::success();
2730       })
2731       .Default([&](Element *) {
2732         return emitError(childLoc, "only literals, types, and variables can be "
2733                                    "used within an optional group");
2734       });
2735 }
2736 
2737 LogicalResult
parseAttrDictDirective(std::unique_ptr<Element> & element,llvm::SMLoc loc,bool isTopLevel,bool withKeyword)2738 FormatParser::parseAttrDictDirective(std::unique_ptr<Element> &element,
2739                                      llvm::SMLoc loc, bool isTopLevel,
2740                                      bool withKeyword) {
2741   if (!isTopLevel)
2742     return emitError(loc, "'attr-dict' directive can only be used as a "
2743                           "top-level directive");
2744   if (hasAttrDict)
2745     return emitError(loc, "'attr-dict' directive has already been seen");
2746 
2747   hasAttrDict = true;
2748   element = std::make_unique<AttrDictDirective>(withKeyword);
2749   return ::mlir::success();
2750 }
2751 
2752 LogicalResult
parseCustomDirective(std::unique_ptr<Element> & element,llvm::SMLoc loc,bool isTopLevel)2753 FormatParser::parseCustomDirective(std::unique_ptr<Element> &element,
2754                                    llvm::SMLoc loc, bool isTopLevel) {
2755   llvm::SMLoc curLoc = curToken.getLoc();
2756 
2757   // Parse the custom directive name.
2758   if (failed(
2759           parseToken(Token::less, "expected '<' before custom directive name")))
2760     return ::mlir::failure();
2761 
2762   Token nameTok = curToken;
2763   if (failed(parseToken(Token::identifier,
2764                         "expected custom directive name identifier")) ||
2765       failed(parseToken(Token::greater,
2766                         "expected '>' after custom directive name")) ||
2767       failed(parseToken(Token::l_paren,
2768                         "expected '(' before custom directive parameters")))
2769     return ::mlir::failure();
2770 
2771   // Parse the child elements for this optional group.=
2772   std::vector<std::unique_ptr<Element>> elements;
2773   do {
2774     if (failed(parseCustomDirectiveParameter(elements)))
2775       return ::mlir::failure();
2776     if (curToken.getKind() != Token::comma)
2777       break;
2778     consumeToken();
2779   } while (true);
2780 
2781   if (failed(parseToken(Token::r_paren,
2782                         "expected ')' after custom directive parameters")))
2783     return ::mlir::failure();
2784 
2785   // After parsing all of the elements, ensure that all type directives refer
2786   // only to variables.
2787   for (auto &ele : elements) {
2788     if (auto *typeEle = dyn_cast<TypeRefDirective>(ele.get())) {
2789       if (!isa<OperandVariable, ResultVariable>(typeEle->getOperand())) {
2790         return emitError(curLoc,
2791                          "type_ref directives within a custom directive "
2792                          "may only refer to variables");
2793       }
2794     }
2795     if (auto *typeEle = dyn_cast<TypeDirective>(ele.get())) {
2796       if (!isa<OperandVariable, ResultVariable>(typeEle->getOperand())) {
2797         return emitError(curLoc, "type directives within a custom directive "
2798                                  "may only refer to variables");
2799       }
2800     }
2801   }
2802 
2803   element = std::make_unique<CustomDirective>(nameTok.getSpelling(),
2804                                               std::move(elements));
2805   return ::mlir::success();
2806 }
2807 
parseCustomDirectiveParameter(std::vector<std::unique_ptr<Element>> & parameters)2808 LogicalResult FormatParser::parseCustomDirectiveParameter(
2809     std::vector<std::unique_ptr<Element>> &parameters) {
2810   llvm::SMLoc childLoc = curToken.getLoc();
2811   parameters.push_back({});
2812   if (failed(parseElement(parameters.back(), /*isTopLevel=*/true)))
2813     return ::mlir::failure();
2814 
2815   // Verify that the element can be placed within a custom directive.
2816   if (!isa<TypeRefDirective, TypeDirective, AttrDictDirective,
2817            AttributeVariable, OperandVariable, RegionVariable,
2818            SuccessorVariable>(parameters.back().get())) {
2819     return emitError(childLoc, "only variables and types may be used as "
2820                                "parameters to a custom directive");
2821   }
2822   return ::mlir::success();
2823 }
2824 
2825 LogicalResult
parseFunctionalTypeDirective(std::unique_ptr<Element> & element,Token tok,bool isTopLevel)2826 FormatParser::parseFunctionalTypeDirective(std::unique_ptr<Element> &element,
2827                                            Token tok, bool isTopLevel) {
2828   llvm::SMLoc loc = tok.getLoc();
2829   if (!isTopLevel)
2830     return emitError(
2831         loc, "'functional-type' is only valid as a top-level directive");
2832 
2833   // Parse the main operand.
2834   std::unique_ptr<Element> inputs, results;
2835   if (failed(parseToken(Token::l_paren, "expected '(' before argument list")) ||
2836       failed(parseTypeDirectiveOperand(inputs)) ||
2837       failed(parseToken(Token::comma, "expected ',' after inputs argument")) ||
2838       failed(parseTypeDirectiveOperand(results)) ||
2839       failed(parseToken(Token::r_paren, "expected ')' after argument list")))
2840     return ::mlir::failure();
2841   element = std::make_unique<FunctionalTypeDirective>(std::move(inputs),
2842                                                       std::move(results));
2843   return ::mlir::success();
2844 }
2845 
2846 LogicalResult
parseOperandsDirective(std::unique_ptr<Element> & element,llvm::SMLoc loc,bool isTopLevel)2847 FormatParser::parseOperandsDirective(std::unique_ptr<Element> &element,
2848                                      llvm::SMLoc loc, bool isTopLevel) {
2849   if (isTopLevel) {
2850     if (fmt.allOperands || !seenOperands.empty())
2851       return emitError(loc, "'operands' directive creates overlap in format");
2852     fmt.allOperands = true;
2853   }
2854   element = std::make_unique<OperandsDirective>();
2855   return ::mlir::success();
2856 }
2857 
2858 LogicalResult
parseRegionsDirective(std::unique_ptr<Element> & element,llvm::SMLoc loc,bool isTopLevel)2859 FormatParser::parseRegionsDirective(std::unique_ptr<Element> &element,
2860                                     llvm::SMLoc loc, bool isTopLevel) {
2861   if (!isTopLevel)
2862     return emitError(loc, "'regions' is only valid as a top-level directive");
2863   if (hasAllRegions || !seenRegions.empty())
2864     return emitError(loc, "'regions' directive creates overlap in format");
2865   hasAllRegions = true;
2866   element = std::make_unique<RegionsDirective>();
2867   return ::mlir::success();
2868 }
2869 
2870 LogicalResult
parseResultsDirective(std::unique_ptr<Element> & element,llvm::SMLoc loc,bool isTopLevel)2871 FormatParser::parseResultsDirective(std::unique_ptr<Element> &element,
2872                                     llvm::SMLoc loc, bool isTopLevel) {
2873   if (isTopLevel)
2874     return emitError(loc, "'results' directive can not be used as a "
2875                           "top-level directive");
2876   element = std::make_unique<ResultsDirective>();
2877   return ::mlir::success();
2878 }
2879 
2880 LogicalResult
parseSuccessorsDirective(std::unique_ptr<Element> & element,llvm::SMLoc loc,bool isTopLevel)2881 FormatParser::parseSuccessorsDirective(std::unique_ptr<Element> &element,
2882                                        llvm::SMLoc loc, bool isTopLevel) {
2883   if (!isTopLevel)
2884     return emitError(loc,
2885                      "'successors' is only valid as a top-level directive");
2886   if (hasAllSuccessors || !seenSuccessors.empty())
2887     return emitError(loc, "'successors' directive creates overlap in format");
2888   hasAllSuccessors = true;
2889   element = std::make_unique<SuccessorsDirective>();
2890   return ::mlir::success();
2891 }
2892 
2893 LogicalResult
parseTypeDirective(std::unique_ptr<Element> & element,Token tok,bool isTopLevel,bool isTypeRef)2894 FormatParser::parseTypeDirective(std::unique_ptr<Element> &element, Token tok,
2895                                  bool isTopLevel, bool isTypeRef) {
2896   llvm::SMLoc loc = tok.getLoc();
2897   if (!isTopLevel)
2898     return emitError(loc, "'type' is only valid as a top-level directive");
2899 
2900   std::unique_ptr<Element> operand;
2901   if (failed(parseToken(Token::l_paren, "expected '(' before argument list")) ||
2902       failed(parseTypeDirectiveOperand(operand, isTypeRef)) ||
2903       failed(parseToken(Token::r_paren, "expected ')' after argument list")))
2904     return ::mlir::failure();
2905   if (isTypeRef)
2906     element = std::make_unique<TypeRefDirective>(std::move(operand));
2907   else
2908     element = std::make_unique<TypeDirective>(std::move(operand));
2909   return ::mlir::success();
2910 }
2911 
2912 LogicalResult
parseTypeDirectiveOperand(std::unique_ptr<Element> & element,bool isTypeRef)2913 FormatParser::parseTypeDirectiveOperand(std::unique_ptr<Element> &element,
2914                                         bool isTypeRef) {
2915   llvm::SMLoc loc = curToken.getLoc();
2916   if (failed(parseElement(element, /*isTopLevel=*/false)))
2917     return ::mlir::failure();
2918   if (isa<LiteralElement>(element.get()))
2919     return emitError(
2920         loc, "'type' directive operand expects variable or directive operand");
2921 
2922   if (auto *var = dyn_cast<OperandVariable>(element.get())) {
2923     unsigned opIdx = var->getVar() - op.operand_begin();
2924     if (!isTypeRef && (fmt.allOperandTypes || seenOperandTypes.test(opIdx)))
2925       return emitError(loc, "'type' of '" + var->getVar()->name +
2926                                 "' is already bound");
2927     if (isTypeRef && !(fmt.allOperandTypes || seenOperandTypes.test(opIdx)))
2928       return emitError(loc, "'type_ref' of '" + var->getVar()->name +
2929                                 "' is not bound by a prior 'type' directive");
2930     seenOperandTypes.set(opIdx);
2931   } else if (auto *var = dyn_cast<ResultVariable>(element.get())) {
2932     unsigned resIdx = var->getVar() - op.result_begin();
2933     if (!isTypeRef && (fmt.allResultTypes || seenResultTypes.test(resIdx)))
2934       return emitError(loc, "'type' of '" + var->getVar()->name +
2935                                 "' is already bound");
2936     if (isTypeRef && !(fmt.allResultTypes || seenResultTypes.test(resIdx)))
2937       return emitError(loc, "'type_ref' of '" + var->getVar()->name +
2938                                 "' is not bound by a prior 'type' directive");
2939     seenResultTypes.set(resIdx);
2940   } else if (isa<OperandsDirective>(&*element)) {
2941     if (!isTypeRef && (fmt.allOperandTypes || seenOperandTypes.any()))
2942       return emitError(loc, "'operands' 'type' is already bound");
2943     if (isTypeRef && !(fmt.allOperandTypes || seenOperandTypes.all()))
2944       return emitError(
2945           loc,
2946           "'operands' 'type_ref' is not bound by a prior 'type' directive");
2947     fmt.allOperandTypes = true;
2948   } else if (isa<ResultsDirective>(&*element)) {
2949     if (!isTypeRef && (fmt.allResultTypes || seenResultTypes.any()))
2950       return emitError(loc, "'results' 'type' is already bound");
2951     if (isTypeRef && !(fmt.allResultTypes || seenResultTypes.all()))
2952       return emitError(
2953           loc, "'results' 'type_ref' is not bound by a prior 'type' directive");
2954     fmt.allResultTypes = true;
2955   } else {
2956     return emitError(loc, "invalid argument to 'type' directive");
2957   }
2958   return ::mlir::success();
2959 }
2960 
2961 //===----------------------------------------------------------------------===//
2962 // Interface
2963 //===----------------------------------------------------------------------===//
2964 
generateOpFormat(const Operator & constOp,OpClass & opClass)2965 void mlir::tblgen::generateOpFormat(const Operator &constOp, OpClass &opClass) {
2966   // TODO: Operator doesn't expose all necessary functionality via
2967   // the const interface.
2968   Operator &op = const_cast<Operator &>(constOp);
2969   if (!op.hasAssemblyFormat())
2970     return;
2971 
2972   // Parse the format description.
2973   llvm::SourceMgr mgr;
2974   mgr.AddNewSourceBuffer(
2975       llvm::MemoryBuffer::getMemBuffer(op.getAssemblyFormat()), llvm::SMLoc());
2976   OperationFormat format(op);
2977   if (failed(FormatParser(mgr, format, op).parse())) {
2978     // Exit the process if format errors are treated as fatal.
2979     if (formatErrorIsFatal) {
2980       // Invoke the interrupt handlers to run the file cleanup handlers.
2981       llvm::sys::RunInterruptHandlers();
2982       std::exit(1);
2983     }
2984     return;
2985   }
2986 
2987   // Generate the printer and parser based on the parsed format.
2988   format.genParser(op, opClass);
2989   format.genPrinter(op, opClass);
2990 }
2991