1 //===- OpFormatGen.cpp - MLIR operation asm format generator --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "OpFormatGen.h"
10 #include "mlir/Support/LogicalResult.h"
11 #include "mlir/TableGen/Format.h"
12 #include "mlir/TableGen/GenInfo.h"
13 #include "mlir/TableGen/Interfaces.h"
14 #include "mlir/TableGen/OpClass.h"
15 #include "mlir/TableGen/OpTrait.h"
16 #include "mlir/TableGen/Operator.h"
17 #include "llvm/ADT/MapVector.h"
18 #include "llvm/ADT/Sequence.h"
19 #include "llvm/ADT/SetVector.h"
20 #include "llvm/ADT/SmallBitVector.h"
21 #include "llvm/ADT/StringExtras.h"
22 #include "llvm/ADT/TypeSwitch.h"
23 #include "llvm/Support/CommandLine.h"
24 #include "llvm/Support/Signals.h"
25 #include "llvm/TableGen/Error.h"
26 #include "llvm/TableGen/Record.h"
27
28 #define DEBUG_TYPE "mlir-tblgen-opformatgen"
29
30 using namespace mlir;
31 using namespace mlir::tblgen;
32
33 static llvm::cl::opt<bool> formatErrorIsFatal(
34 "asmformat-error-is-fatal",
35 llvm::cl::desc("Emit a fatal error if format parsing fails"),
36 llvm::cl::init(true));
37
38 //===----------------------------------------------------------------------===//
39 // Element
40 //===----------------------------------------------------------------------===//
41
42 namespace {
43 /// This class represents a single format element.
44 class Element {
45 public:
46 enum class Kind {
47 /// This element is a directive.
48 AttrDictDirective,
49 CustomDirective,
50 FunctionalTypeDirective,
51 OperandsDirective,
52 RegionsDirective,
53 ResultsDirective,
54 SuccessorsDirective,
55 TypeDirective,
56 TypeRefDirective,
57
58 /// This element is a literal.
59 Literal,
60
61 /// This element prints or omits a space. It is ignored by the parser.
62 Space,
63
64 /// This element is an variable value.
65 AttributeVariable,
66 OperandVariable,
67 RegionVariable,
68 ResultVariable,
69 SuccessorVariable,
70
71 /// This element is an optional element.
72 Optional,
73 };
Element(Kind kind)74 Element(Kind kind) : kind(kind) {}
75 virtual ~Element() = default;
76
77 /// Return the kind of this element.
getKind() const78 Kind getKind() const { return kind; }
79
80 private:
81 /// The kind of this element.
82 Kind kind;
83 };
84 } // namespace
85
86 //===----------------------------------------------------------------------===//
87 // VariableElement
88
89 namespace {
90 /// This class represents an instance of an variable element. A variable refers
91 /// to something registered on the operation itself, e.g. an argument, result,
92 /// etc.
93 template <typename VarT, Element::Kind kindVal>
94 class VariableElement : public Element {
95 public:
VariableElement(const VarT * var)96 VariableElement(const VarT *var) : Element(kindVal), var(var) {}
classof(const Element * element)97 static bool classof(const Element *element) {
98 return element->getKind() == kindVal;
99 }
getVar()100 const VarT *getVar() { return var; }
101
102 protected:
103 const VarT *var;
104 };
105
106 /// This class represents a variable that refers to an attribute argument.
107 struct AttributeVariable
108 : public VariableElement<NamedAttribute, Element::Kind::AttributeVariable> {
109 using VariableElement<NamedAttribute,
110 Element::Kind::AttributeVariable>::VariableElement;
111
112 /// Return the constant builder call for the type of this attribute, or None
113 /// if it doesn't have one.
getTypeBuilder__anon81548fc80211::AttributeVariable114 Optional<StringRef> getTypeBuilder() const {
115 Optional<Type> attrType = var->attr.getValueType();
116 return attrType ? attrType->getBuilderCall() : llvm::None;
117 }
118
119 /// Return if this attribute refers to a UnitAttr.
isUnitAttr__anon81548fc80211::AttributeVariable120 bool isUnitAttr() const {
121 return var->attr.getBaseAttr().getAttrDefName() == "UnitAttr";
122 }
123 };
124
125 /// This class represents a variable that refers to an operand argument.
126 using OperandVariable =
127 VariableElement<NamedTypeConstraint, Element::Kind::OperandVariable>;
128
129 /// This class represents a variable that refers to a region.
130 using RegionVariable =
131 VariableElement<NamedRegion, Element::Kind::RegionVariable>;
132
133 /// This class represents a variable that refers to a result.
134 using ResultVariable =
135 VariableElement<NamedTypeConstraint, Element::Kind::ResultVariable>;
136
137 /// This class represents a variable that refers to a successor.
138 using SuccessorVariable =
139 VariableElement<NamedSuccessor, Element::Kind::SuccessorVariable>;
140 } // end anonymous namespace
141
142 //===----------------------------------------------------------------------===//
143 // DirectiveElement
144
145 namespace {
146 /// This class implements single kind directives.
147 template <Element::Kind type>
148 class DirectiveElement : public Element {
149 public:
DirectiveElement()150 DirectiveElement() : Element(type){};
classof(const Element * ele)151 static bool classof(const Element *ele) { return ele->getKind() == type; }
152 };
153 /// This class represents the `operands` directive. This directive represents
154 /// all of the operands of an operation.
155 using OperandsDirective = DirectiveElement<Element::Kind::OperandsDirective>;
156
157 /// This class represents the `regions` directive. This directive represents
158 /// all of the regions of an operation.
159 using RegionsDirective = DirectiveElement<Element::Kind::ResultsDirective>;
160
161 /// This class represents the `results` directive. This directive represents
162 /// all of the results of an operation.
163 using ResultsDirective = DirectiveElement<Element::Kind::ResultsDirective>;
164
165 /// This class represents the `successors` directive. This directive represents
166 /// all of the successors of an operation.
167 using SuccessorsDirective =
168 DirectiveElement<Element::Kind::SuccessorsDirective>;
169
170 /// This class represents the `attr-dict` directive. This directive represents
171 /// the attribute dictionary of the operation.
172 class AttrDictDirective
173 : public DirectiveElement<Element::Kind::AttrDictDirective> {
174 public:
AttrDictDirective(bool withKeyword)175 explicit AttrDictDirective(bool withKeyword) : withKeyword(withKeyword) {}
isWithKeyword() const176 bool isWithKeyword() const { return withKeyword; }
177
178 private:
179 /// If the dictionary should be printed with the 'attributes' keyword.
180 bool withKeyword;
181 };
182
183 /// This class represents a custom format directive that is implemented by the
184 /// user in C++.
185 class CustomDirective : public Element {
186 public:
CustomDirective(StringRef name,std::vector<std::unique_ptr<Element>> && arguments)187 CustomDirective(StringRef name,
188 std::vector<std::unique_ptr<Element>> &&arguments)
189 : Element{Kind::CustomDirective}, name(name),
190 arguments(std::move(arguments)) {}
191
classof(const Element * element)192 static bool classof(const Element *element) {
193 return element->getKind() == Kind::CustomDirective;
194 }
195
196 /// Return the name of this optional element.
getName() const197 StringRef getName() const { return name; }
198
199 /// Return the arguments to the custom directive.
getArguments() const200 auto getArguments() const { return llvm::make_pointee_range(arguments); }
201
202 private:
203 /// The user provided name of the directive.
204 StringRef name;
205
206 /// The arguments to the custom directive.
207 std::vector<std::unique_ptr<Element>> arguments;
208 };
209
210 /// This class represents the `functional-type` directive. This directive takes
211 /// two arguments and formats them, respectively, as the inputs and results of a
212 /// FunctionType.
213 class FunctionalTypeDirective
214 : public DirectiveElement<Element::Kind::FunctionalTypeDirective> {
215 public:
FunctionalTypeDirective(std::unique_ptr<Element> inputs,std::unique_ptr<Element> results)216 FunctionalTypeDirective(std::unique_ptr<Element> inputs,
217 std::unique_ptr<Element> results)
218 : inputs(std::move(inputs)), results(std::move(results)) {}
getInputs() const219 Element *getInputs() const { return inputs.get(); }
getResults() const220 Element *getResults() const { return results.get(); }
221
222 private:
223 /// The input and result arguments.
224 std::unique_ptr<Element> inputs, results;
225 };
226
227 /// This class represents the `type` directive.
228 class TypeDirective : public DirectiveElement<Element::Kind::TypeDirective> {
229 public:
TypeDirective(std::unique_ptr<Element> arg)230 TypeDirective(std::unique_ptr<Element> arg) : operand(std::move(arg)) {}
getOperand() const231 Element *getOperand() const { return operand.get(); }
232
233 private:
234 /// The operand that is used to format the directive.
235 std::unique_ptr<Element> operand;
236 };
237
238 /// This class represents the `type_ref` directive.
239 class TypeRefDirective
240 : public DirectiveElement<Element::Kind::TypeRefDirective> {
241 public:
TypeRefDirective(std::unique_ptr<Element> arg)242 TypeRefDirective(std::unique_ptr<Element> arg) : operand(std::move(arg)) {}
getOperand() const243 Element *getOperand() const { return operand.get(); }
244
245 private:
246 /// The operand that is used to format the directive.
247 std::unique_ptr<Element> operand;
248 };
249 } // namespace
250
251 //===----------------------------------------------------------------------===//
252 // LiteralElement
253
254 namespace {
255 /// This class represents an instance of a literal element.
256 class LiteralElement : public Element {
257 public:
LiteralElement(StringRef literal)258 LiteralElement(StringRef literal)
259 : Element{Kind::Literal}, literal(literal) {}
classof(const Element * element)260 static bool classof(const Element *element) {
261 return element->getKind() == Kind::Literal;
262 }
263
264 /// Return the literal for this element.
getLiteral() const265 StringRef getLiteral() const { return literal; }
266
267 /// Returns true if the given string is a valid literal.
268 static bool isValidLiteral(StringRef value);
269
270 private:
271 /// The spelling of the literal for this element.
272 StringRef literal;
273 };
274 } // end anonymous namespace
275
isValidLiteral(StringRef value)276 bool LiteralElement::isValidLiteral(StringRef value) {
277 if (value.empty())
278 return false;
279 char front = value.front();
280
281 // If there is only one character, this must either be punctuation or a
282 // single character bare identifier.
283 if (value.size() == 1)
284 return isalpha(front) || StringRef("_:,=<>()[]{}?+*").contains(front);
285
286 // Check the punctuation that are larger than a single character.
287 if (value == "->")
288 return true;
289
290 // Otherwise, this must be an identifier.
291 if (!isalpha(front) && front != '_')
292 return false;
293 return llvm::all_of(value.drop_front(), [](char c) {
294 return isalnum(c) || c == '_' || c == '$' || c == '.';
295 });
296 }
297
298 //===----------------------------------------------------------------------===//
299 // SpaceElement
300
301 namespace {
302 /// This class represents an instance of a space element. It's a literal that
303 /// prints or omits printing a space. It is ignored by the parser.
304 class SpaceElement : public Element {
305 public:
SpaceElement(bool value)306 SpaceElement(bool value) : Element{Kind::Space}, value(value) {}
classof(const Element * element)307 static bool classof(const Element *element) {
308 return element->getKind() == Kind::Space;
309 }
310
311 /// Returns true if this element should print as a space. Otherwise, the
312 /// element should omit printing a space between the surrounding elements.
getValue() const313 bool getValue() const { return value; }
314
315 private:
316 bool value;
317 };
318 } // end anonymous namespace
319
320 //===----------------------------------------------------------------------===//
321 // OptionalElement
322
323 namespace {
324 /// This class represents a group of elements that are optionally emitted based
325 /// upon an optional variable of the operation.
326 class OptionalElement : public Element {
327 public:
OptionalElement(std::vector<std::unique_ptr<Element>> && elements,unsigned anchor,unsigned parseStart)328 OptionalElement(std::vector<std::unique_ptr<Element>> &&elements,
329 unsigned anchor, unsigned parseStart)
330 : Element{Kind::Optional}, elements(std::move(elements)), anchor(anchor),
331 parseStart(parseStart) {}
classof(const Element * element)332 static bool classof(const Element *element) {
333 return element->getKind() == Kind::Optional;
334 }
335
336 /// Return the nested elements of this grouping.
getElements() const337 auto getElements() const { return llvm::make_pointee_range(elements); }
338
339 /// Return the anchor of this optional group.
getAnchor() const340 Element *getAnchor() const { return elements[anchor].get(); }
341
342 /// Return the index of the first element that needs to be parsed.
getParseStart() const343 unsigned getParseStart() const { return parseStart; }
344
345 private:
346 /// The child elements of this optional.
347 std::vector<std::unique_ptr<Element>> elements;
348 /// The index of the element that acts as the anchor for the optional group.
349 unsigned anchor;
350 /// The index of the first element that is parsed (is not a SpaceElement).
351 unsigned parseStart;
352 };
353 } // end anonymous namespace
354
355 //===----------------------------------------------------------------------===//
356 // OperationFormat
357 //===----------------------------------------------------------------------===//
358
359 namespace {
360
361 using ConstArgument =
362 llvm::PointerUnion<const NamedAttribute *, const NamedTypeConstraint *>;
363
364 struct OperationFormat {
365 /// This class represents a specific resolver for an operand or result type.
366 class TypeResolution {
367 public:
368 TypeResolution() = default;
369
370 /// Get the index into the buildable types for this type, or None.
getBuilderIdx() const371 Optional<int> getBuilderIdx() const { return builderIdx; }
setBuilderIdx(int idx)372 void setBuilderIdx(int idx) { builderIdx = idx; }
373
374 /// Get the variable this type is resolved to, or nullptr.
getVariable() const375 const NamedTypeConstraint *getVariable() const {
376 return resolver.dyn_cast<const NamedTypeConstraint *>();
377 }
378 /// Get the attribute this type is resolved to, or nullptr.
getAttribute() const379 const NamedAttribute *getAttribute() const {
380 return resolver.dyn_cast<const NamedAttribute *>();
381 }
382 /// Get the transformer for the type of the variable, or None.
getVarTransformer() const383 Optional<StringRef> getVarTransformer() const {
384 return variableTransformer;
385 }
setResolver(ConstArgument arg,Optional<StringRef> transformer)386 void setResolver(ConstArgument arg, Optional<StringRef> transformer) {
387 resolver = arg;
388 variableTransformer = transformer;
389 assert(getVariable() || getAttribute());
390 }
391
392 private:
393 /// If the type is resolved with a buildable type, this is the index into
394 /// 'buildableTypes' in the parent format.
395 Optional<int> builderIdx;
396 /// If the type is resolved based upon another operand or result, this is
397 /// the variable or the attribute that this type is resolved to.
398 ConstArgument resolver;
399 /// If the type is resolved based upon another operand or result, this is
400 /// a transformer to apply to the variable when resolving.
401 Optional<StringRef> variableTransformer;
402 };
403
OperationFormat__anon81548fc80811::OperationFormat404 OperationFormat(const Operator &op)
405 : allOperands(false), allOperandTypes(false), allResultTypes(false) {
406 operandTypes.resize(op.getNumOperands(), TypeResolution());
407 resultTypes.resize(op.getNumResults(), TypeResolution());
408
409 hasImplicitTermTrait =
410 llvm::any_of(op.getTraits(), [](const OpTrait &trait) {
411 return trait.getDef().isSubClassOf("SingleBlockImplicitTerminator");
412 });
413 }
414
415 /// Generate the operation parser from this format.
416 void genParser(Operator &op, OpClass &opClass);
417 /// Generate the parser code for a specific format element.
418 void genElementParser(Element *element, OpMethodBody &body,
419 FmtContext &attrTypeCtx);
420 /// Generate the c++ to resolve the types of operands and results during
421 /// parsing.
422 void genParserTypeResolution(Operator &op, OpMethodBody &body);
423 /// Generate the c++ to resolve regions during parsing.
424 void genParserRegionResolution(Operator &op, OpMethodBody &body);
425 /// Generate the c++ to resolve successors during parsing.
426 void genParserSuccessorResolution(Operator &op, OpMethodBody &body);
427 /// Generate the c++ to handling variadic segment size traits.
428 void genParserVariadicSegmentResolution(Operator &op, OpMethodBody &body);
429
430 /// Generate the operation printer from this format.
431 void genPrinter(Operator &op, OpClass &opClass);
432
433 /// Generate the printer code for a specific format element.
434 void genElementPrinter(Element *element, OpMethodBody &body, Operator &op,
435 bool &shouldEmitSpace, bool &lastWasPunctuation);
436
437 /// The various elements in this format.
438 std::vector<std::unique_ptr<Element>> elements;
439
440 /// A flag indicating if all operand/result types were seen. If the format
441 /// contains these, it can not contain individual type resolvers.
442 bool allOperands, allOperandTypes, allResultTypes;
443
444 /// A flag indicating if this operation has the SingleBlockImplicitTerminator
445 /// trait.
446 bool hasImplicitTermTrait;
447
448 /// A map of buildable types to indices.
449 llvm::MapVector<StringRef, int, llvm::StringMap<int>> buildableTypes;
450
451 /// The index of the buildable type, if valid, for every operand and result.
452 std::vector<TypeResolution> operandTypes, resultTypes;
453
454 /// The set of attributes explicitly used within the format.
455 SmallVector<const NamedAttribute *, 8> usedAttributes;
456 };
457 } // end anonymous namespace
458
459 //===----------------------------------------------------------------------===//
460 // Parser Gen
461
462 /// Returns true if we can format the given attribute as an EnumAttr in the
463 /// parser format.
canFormatEnumAttr(const NamedAttribute * attr)464 static bool canFormatEnumAttr(const NamedAttribute *attr) {
465 Attribute baseAttr = attr->attr.getBaseAttr();
466 const EnumAttr *enumAttr = dyn_cast<EnumAttr>(&baseAttr);
467 if (!enumAttr)
468 return false;
469
470 // The attribute must have a valid underlying type and a constant builder.
471 return !enumAttr->getUnderlyingType().empty() &&
472 !enumAttr->getConstBuilderTemplate().empty();
473 }
474
475 /// Returns if we should format the given attribute as an SymbolNameAttr.
shouldFormatSymbolNameAttr(const NamedAttribute * attr)476 static bool shouldFormatSymbolNameAttr(const NamedAttribute *attr) {
477 return attr->attr.getBaseAttr().getAttrDefName() == "SymbolNameAttr";
478 }
479
480 /// The code snippet used to generate a parser call for an attribute.
481 ///
482 /// {0}: The name of the attribute.
483 /// {1}: The type for the attribute.
484 const char *const attrParserCode = R"(
485 if (parser.parseAttribute({0}Attr{1}, "{0}", result.attributes))
486 return ::mlir::failure();
487 )";
488 const char *const optionalAttrParserCode = R"(
489 {
490 ::mlir::OptionalParseResult parseResult =
491 parser.parseOptionalAttribute({0}Attr{1}, "{0}", result.attributes);
492 if (parseResult.hasValue() && failed(*parseResult))
493 return ::mlir::failure();
494 }
495 )";
496
497 /// The code snippet used to generate a parser call for a symbol name attribute.
498 ///
499 /// {0}: The name of the attribute.
500 const char *const symbolNameAttrParserCode = R"(
501 if (parser.parseSymbolName({0}Attr, "{0}", result.attributes))
502 return ::mlir::failure();
503 )";
504 const char *const optionalSymbolNameAttrParserCode = R"(
505 // Parsing an optional symbol name doesn't fail, so no need to check the
506 // result.
507 (void)parser.parseOptionalSymbolName({0}Attr, "{0}", result.attributes);
508 )";
509
510 /// The code snippet used to generate a parser call for an enum attribute.
511 ///
512 /// {0}: The name of the attribute.
513 /// {1}: The c++ namespace for the enum symbolize functions.
514 /// {2}: The function to symbolize a string of the enum.
515 /// {3}: The constant builder call to create an attribute of the enum type.
516 const char *const enumAttrParserCode = R"(
517 {
518 ::mlir::StringAttr attrVal;
519 ::mlir::NamedAttrList attrStorage;
520 auto loc = parser.getCurrentLocation();
521 if (parser.parseAttribute(attrVal, parser.getBuilder().getNoneType(),
522 "{0}", attrStorage))
523 return ::mlir::failure();
524
525 auto attrOptional = {1}::{2}(attrVal.getValue());
526 if (!attrOptional)
527 return parser.emitError(loc, "invalid ")
528 << "{0} attribute specification: " << attrVal;
529
530 {0}Attr = {3};
531 result.addAttribute("{0}", {0}Attr);
532 }
533 )";
534 const char *const optionalEnumAttrParserCode = R"(
535 {
536 ::mlir::StringAttr attrVal;
537 ::mlir::NamedAttrList attrStorage;
538 auto loc = parser.getCurrentLocation();
539
540 ::mlir::OptionalParseResult parseResult =
541 parser.parseOptionalAttribute(attrVal, parser.getBuilder().getNoneType(),
542 "{0}", attrStorage);
543 if (parseResult.hasValue()) {
544 if (failed(*parseResult))
545 return ::mlir::failure();
546
547 auto attrOptional = {1}::{2}(attrVal.getValue());
548 if (!attrOptional)
549 return parser.emitError(loc, "invalid ")
550 << "{0} attribute specification: " << attrVal;
551
552 {0}Attr = {3};
553 result.addAttribute("{0}", {0}Attr);
554 }
555 }
556 )";
557
558 /// The code snippet used to generate a parser call for an operand.
559 ///
560 /// {0}: The name of the operand.
561 const char *const variadicOperandParserCode = R"(
562 {0}OperandsLoc = parser.getCurrentLocation();
563 if (parser.parseOperandList({0}Operands))
564 return ::mlir::failure();
565 )";
566 const char *const optionalOperandParserCode = R"(
567 {
568 {0}OperandsLoc = parser.getCurrentLocation();
569 ::mlir::OpAsmParser::OperandType operand;
570 ::mlir::OptionalParseResult parseResult =
571 parser.parseOptionalOperand(operand);
572 if (parseResult.hasValue()) {
573 if (failed(*parseResult))
574 return ::mlir::failure();
575 {0}Operands.push_back(operand);
576 }
577 }
578 )";
579 const char *const operandParserCode = R"(
580 {0}OperandsLoc = parser.getCurrentLocation();
581 if (parser.parseOperand({0}RawOperands[0]))
582 return ::mlir::failure();
583 )";
584
585 /// The code snippet used to generate a parser call for a type list.
586 ///
587 /// {0}: The name for the type list.
588 const char *const variadicTypeParserCode = R"(
589 if (parser.parseTypeList({0}Types))
590 return ::mlir::failure();
591 )";
592 const char *const optionalTypeParserCode = R"(
593 {
594 ::mlir::Type optionalType;
595 ::mlir::OptionalParseResult parseResult =
596 parser.parseOptionalType(optionalType);
597 if (parseResult.hasValue()) {
598 if (failed(*parseResult))
599 return ::mlir::failure();
600 {0}Types.push_back(optionalType);
601 }
602 }
603 )";
604 const char *const typeParserCode = R"(
605 if (parser.parseType({0}RawTypes[0]))
606 return ::mlir::failure();
607 )";
608
609 /// The code snippet used to generate a parser call for a functional type.
610 ///
611 /// {0}: The name for the input type list.
612 /// {1}: The name for the result type list.
613 const char *const functionalTypeParserCode = R"(
614 ::mlir::FunctionType {0}__{1}_functionType;
615 if (parser.parseType({0}__{1}_functionType))
616 return ::mlir::failure();
617 {0}Types = {0}__{1}_functionType.getInputs();
618 {1}Types = {0}__{1}_functionType.getResults();
619 )";
620
621 /// The code snippet used to generate a parser call for a region list.
622 ///
623 /// {0}: The name for the region list.
624 const char *regionListParserCode = R"(
625 {
626 std::unique_ptr<::mlir::Region> region;
627 auto firstRegionResult = parser.parseOptionalRegion(region);
628 if (firstRegionResult.hasValue()) {
629 if (failed(*firstRegionResult))
630 return ::mlir::failure();
631 {0}Regions.emplace_back(std::move(region));
632
633 // Parse any trailing regions.
634 while (succeeded(parser.parseOptionalComma())) {
635 region = std::make_unique<::mlir::Region>();
636 if (parser.parseRegion(*region))
637 return ::mlir::failure();
638 {0}Regions.emplace_back(std::move(region));
639 }
640 }
641 }
642 )";
643
644 /// The code snippet used to ensure a list of regions have terminators.
645 ///
646 /// {0}: The name of the region list.
647 const char *regionListEnsureTerminatorParserCode = R"(
648 for (auto ®ion : {0}Regions)
649 ensureTerminator(*region, parser.getBuilder(), result.location);
650 )";
651
652 /// The code snippet used to generate a parser call for an optional region.
653 ///
654 /// {0}: The name of the region.
655 const char *optionalRegionParserCode = R"(
656 {
657 auto parseResult = parser.parseOptionalRegion(*{0}Region);
658 if (parseResult.hasValue() && failed(*parseResult))
659 return ::mlir::failure();
660 }
661 )";
662
663 /// The code snippet used to generate a parser call for a region.
664 ///
665 /// {0}: The name of the region.
666 const char *regionParserCode = R"(
667 if (parser.parseRegion(*{0}Region))
668 return ::mlir::failure();
669 )";
670
671 /// The code snippet used to ensure a region has a terminator.
672 ///
673 /// {0}: The name of the region.
674 const char *regionEnsureTerminatorParserCode = R"(
675 ensureTerminator(*{0}Region, parser.getBuilder(), result.location);
676 )";
677
678 /// The code snippet used to generate a parser call for a successor list.
679 ///
680 /// {0}: The name for the successor list.
681 const char *successorListParserCode = R"(
682 {
683 ::mlir::Block *succ;
684 auto firstSucc = parser.parseOptionalSuccessor(succ);
685 if (firstSucc.hasValue()) {
686 if (failed(*firstSucc))
687 return ::mlir::failure();
688 {0}Successors.emplace_back(succ);
689
690 // Parse any trailing successors.
691 while (succeeded(parser.parseOptionalComma())) {
692 if (parser.parseSuccessor(succ))
693 return ::mlir::failure();
694 {0}Successors.emplace_back(succ);
695 }
696 }
697 }
698 )";
699
700 /// The code snippet used to generate a parser call for a successor.
701 ///
702 /// {0}: The name of the successor.
703 const char *successorParserCode = R"(
704 if (parser.parseSuccessor({0}Successor))
705 return ::mlir::failure();
706 )";
707
708 namespace {
709 /// The type of length for a given parse argument.
710 enum class ArgumentLengthKind {
711 /// The argument is variadic, and may contain 0->N elements.
712 Variadic,
713 /// The argument is optional, and may contain 0 or 1 elements.
714 Optional,
715 /// The argument is a single element, i.e. always represents 1 element.
716 Single
717 };
718 } // end anonymous namespace
719
720 /// Get the length kind for the given constraint.
721 static ArgumentLengthKind
getArgumentLengthKind(const NamedTypeConstraint * var)722 getArgumentLengthKind(const NamedTypeConstraint *var) {
723 if (var->isOptional())
724 return ArgumentLengthKind::Optional;
725 if (var->isVariadic())
726 return ArgumentLengthKind::Variadic;
727 return ArgumentLengthKind::Single;
728 }
729
730 /// Get the name used for the type list for the given type directive operand.
731 /// 'lengthKind' to the corresponding kind for the given argument.
getTypeListName(Element * arg,ArgumentLengthKind & lengthKind)732 static StringRef getTypeListName(Element *arg, ArgumentLengthKind &lengthKind) {
733 if (auto *operand = dyn_cast<OperandVariable>(arg)) {
734 lengthKind = getArgumentLengthKind(operand->getVar());
735 return operand->getVar()->name;
736 }
737 if (auto *result = dyn_cast<ResultVariable>(arg)) {
738 lengthKind = getArgumentLengthKind(result->getVar());
739 return result->getVar()->name;
740 }
741 lengthKind = ArgumentLengthKind::Variadic;
742 if (isa<OperandsDirective>(arg))
743 return "allOperand";
744 if (isa<ResultsDirective>(arg))
745 return "allResult";
746 llvm_unreachable("unknown 'type' directive argument");
747 }
748
749 /// Generate the parser for a literal value.
genLiteralParser(StringRef value,OpMethodBody & body)750 static void genLiteralParser(StringRef value, OpMethodBody &body) {
751 // Handle the case of a keyword/identifier.
752 if (value.front() == '_' || isalpha(value.front())) {
753 body << "Keyword(\"" << value << "\")";
754 return;
755 }
756 body << (StringRef)StringSwitch<StringRef>(value)
757 .Case("->", "Arrow()")
758 .Case(":", "Colon()")
759 .Case(",", "Comma()")
760 .Case("=", "Equal()")
761 .Case("<", "Less()")
762 .Case(">", "Greater()")
763 .Case("{", "LBrace()")
764 .Case("}", "RBrace()")
765 .Case("(", "LParen()")
766 .Case(")", "RParen()")
767 .Case("[", "LSquare()")
768 .Case("]", "RSquare()")
769 .Case("?", "Question()")
770 .Case("+", "Plus()")
771 .Case("*", "Star()");
772 }
773
774 /// Generate the storage code required for parsing the given element.
genElementParserStorage(Element * element,OpMethodBody & body)775 static void genElementParserStorage(Element *element, OpMethodBody &body) {
776 if (auto *optional = dyn_cast<OptionalElement>(element)) {
777 auto elements = optional->getElements();
778
779 // If the anchor is a unit attribute, it won't be parsed directly so elide
780 // it.
781 auto *anchor = dyn_cast<AttributeVariable>(optional->getAnchor());
782 Element *elidedAnchorElement = nullptr;
783 if (anchor && anchor != &*elements.begin() && anchor->isUnitAttr())
784 elidedAnchorElement = anchor;
785 for (auto &childElement : elements)
786 if (&childElement != elidedAnchorElement)
787 genElementParserStorage(&childElement, body);
788
789 } else if (auto *custom = dyn_cast<CustomDirective>(element)) {
790 for (auto ¶mElement : custom->getArguments())
791 genElementParserStorage(¶mElement, body);
792
793 } else if (isa<OperandsDirective>(element)) {
794 body << " ::mlir::SmallVector<::mlir::OpAsmParser::OperandType, 4> "
795 "allOperands;\n";
796
797 } else if (isa<RegionsDirective>(element)) {
798 body << " ::llvm::SmallVector<std::unique_ptr<::mlir::Region>, 2> "
799 "fullRegions;\n";
800
801 } else if (isa<SuccessorsDirective>(element)) {
802 body << " ::llvm::SmallVector<::mlir::Block *, 2> fullSuccessors;\n";
803
804 } else if (auto *attr = dyn_cast<AttributeVariable>(element)) {
805 const NamedAttribute *var = attr->getVar();
806 body << llvm::formatv(" {0} {1}Attr;\n", var->attr.getStorageType(),
807 var->name);
808
809 } else if (auto *operand = dyn_cast<OperandVariable>(element)) {
810 StringRef name = operand->getVar()->name;
811 if (operand->getVar()->isVariableLength()) {
812 body << " ::mlir::SmallVector<::mlir::OpAsmParser::OperandType, 4> "
813 << name << "Operands;\n";
814 } else {
815 body << " ::mlir::OpAsmParser::OperandType " << name
816 << "RawOperands[1];\n"
817 << " ::llvm::ArrayRef<::mlir::OpAsmParser::OperandType> " << name
818 << "Operands(" << name << "RawOperands);";
819 }
820 body << llvm::formatv(" ::llvm::SMLoc {0}OperandsLoc;\n"
821 " (void){0}OperandsLoc;\n",
822 name);
823
824 } else if (auto *region = dyn_cast<RegionVariable>(element)) {
825 StringRef name = region->getVar()->name;
826 if (region->getVar()->isVariadic()) {
827 body << llvm::formatv(
828 " ::llvm::SmallVector<std::unique_ptr<::mlir::Region>, 2> "
829 "{0}Regions;\n",
830 name);
831 } else {
832 body << llvm::formatv(" std::unique_ptr<::mlir::Region> {0}Region = "
833 "std::make_unique<::mlir::Region>();\n",
834 name);
835 }
836
837 } else if (auto *successor = dyn_cast<SuccessorVariable>(element)) {
838 StringRef name = successor->getVar()->name;
839 if (successor->getVar()->isVariadic()) {
840 body << llvm::formatv(" ::llvm::SmallVector<::mlir::Block *, 2> "
841 "{0}Successors;\n",
842 name);
843 } else {
844 body << llvm::formatv(" ::mlir::Block *{0}Successor = nullptr;\n", name);
845 }
846
847 } else if (auto *dir = dyn_cast<TypeDirective>(element)) {
848 ArgumentLengthKind lengthKind;
849 StringRef name = getTypeListName(dir->getOperand(), lengthKind);
850 if (lengthKind != ArgumentLengthKind::Single)
851 body << " ::mlir::SmallVector<::mlir::Type, 1> " << name << "Types;\n";
852 else
853 body << llvm::formatv(" ::mlir::Type {0}RawTypes[1];\n", name)
854 << llvm::formatv(
855 " ::llvm::ArrayRef<::mlir::Type> {0}Types({0}RawTypes);\n",
856 name);
857 } else if (auto *dir = dyn_cast<TypeRefDirective>(element)) {
858 ArgumentLengthKind lengthKind;
859 StringRef name = getTypeListName(dir->getOperand(), lengthKind);
860 // Refer to the previously encountered TypeDirective for name.
861 // Take a `const ::mlir::SmallVector<::mlir::Type, 1> &` in the declaration
862 // to properly track the types that will be parsed and pushed later on.
863 if (lengthKind != ArgumentLengthKind::Single)
864 body << " const ::mlir::SmallVector<::mlir::Type, 1> &" << name
865 << "TypesRef(" << name << "Types);\n";
866 else
867 body << llvm::formatv(
868 " ::llvm::ArrayRef<::mlir::Type> {0}RawTypesRef({0}RawTypes);\n",
869 name);
870 } else if (auto *dir = dyn_cast<FunctionalTypeDirective>(element)) {
871 ArgumentLengthKind ignored;
872 body << " ::llvm::ArrayRef<::mlir::Type> "
873 << getTypeListName(dir->getInputs(), ignored) << "Types;\n";
874 body << " ::llvm::ArrayRef<::mlir::Type> "
875 << getTypeListName(dir->getResults(), ignored) << "Types;\n";
876 }
877 }
878
879 /// Generate the parser for a parameter to a custom directive.
genCustomParameterParser(Element & param,OpMethodBody & body)880 static void genCustomParameterParser(Element ¶m, OpMethodBody &body) {
881 body << ", ";
882 if (auto *attr = dyn_cast<AttributeVariable>(¶m)) {
883 body << attr->getVar()->name << "Attr";
884 } else if (isa<AttrDictDirective>(¶m)) {
885 body << "result.attributes";
886 } else if (auto *operand = dyn_cast<OperandVariable>(¶m)) {
887 StringRef name = operand->getVar()->name;
888 ArgumentLengthKind lengthKind = getArgumentLengthKind(operand->getVar());
889 if (lengthKind == ArgumentLengthKind::Variadic)
890 body << llvm::formatv("{0}Operands", name);
891 else if (lengthKind == ArgumentLengthKind::Optional)
892 body << llvm::formatv("{0}Operand", name);
893 else
894 body << formatv("{0}RawOperands[0]", name);
895
896 } else if (auto *region = dyn_cast<RegionVariable>(¶m)) {
897 StringRef name = region->getVar()->name;
898 if (region->getVar()->isVariadic())
899 body << llvm::formatv("{0}Regions", name);
900 else
901 body << llvm::formatv("*{0}Region", name);
902
903 } else if (auto *successor = dyn_cast<SuccessorVariable>(¶m)) {
904 StringRef name = successor->getVar()->name;
905 if (successor->getVar()->isVariadic())
906 body << llvm::formatv("{0}Successors", name);
907 else
908 body << llvm::formatv("{0}Successor", name);
909
910 } else if (auto *dir = dyn_cast<TypeRefDirective>(¶m)) {
911 ArgumentLengthKind lengthKind;
912 StringRef listName = getTypeListName(dir->getOperand(), lengthKind);
913 if (lengthKind == ArgumentLengthKind::Variadic)
914 body << llvm::formatv("{0}TypesRef", listName);
915 else if (lengthKind == ArgumentLengthKind::Optional)
916 body << llvm::formatv("{0}TypeRef", listName);
917 else
918 body << formatv("{0}RawTypesRef[0]", listName);
919 } else if (auto *dir = dyn_cast<TypeDirective>(¶m)) {
920 ArgumentLengthKind lengthKind;
921 StringRef listName = getTypeListName(dir->getOperand(), lengthKind);
922 if (lengthKind == ArgumentLengthKind::Variadic)
923 body << llvm::formatv("{0}Types", listName);
924 else if (lengthKind == ArgumentLengthKind::Optional)
925 body << llvm::formatv("{0}Type", listName);
926 else
927 body << formatv("{0}RawTypes[0]", listName);
928 } else {
929 llvm_unreachable("unknown custom directive parameter");
930 }
931 }
932
933 /// Generate the parser for a custom directive.
genCustomDirectiveParser(CustomDirective * dir,OpMethodBody & body)934 static void genCustomDirectiveParser(CustomDirective *dir, OpMethodBody &body) {
935 body << " {\n";
936
937 // Preprocess the directive variables.
938 // * Add a local variable for optional operands and types. This provides a
939 // better API to the user defined parser methods.
940 // * Set the location of operand variables.
941 for (Element ¶m : dir->getArguments()) {
942 if (auto *operand = dyn_cast<OperandVariable>(¶m)) {
943 body << " " << operand->getVar()->name
944 << "OperandsLoc = parser.getCurrentLocation();\n";
945 if (operand->getVar()->isOptional()) {
946 body << llvm::formatv(
947 " llvm::Optional<::mlir::OpAsmParser::OperandType> "
948 "{0}Operand;\n",
949 operand->getVar()->name);
950 }
951 } else if (auto *dir = dyn_cast<TypeRefDirective>(¶m)) {
952 // Reference to an optional which may or may not have been set.
953 // Retrieve from vector if not empty.
954 ArgumentLengthKind lengthKind;
955 StringRef listName = getTypeListName(dir->getOperand(), lengthKind);
956 if (lengthKind == ArgumentLengthKind::Optional)
957 body << llvm::formatv(
958 " ::mlir::Type {0}TypeRef = {0}TypesRef.empty() "
959 "? Type() : {0}TypesRef[0];\n",
960 listName);
961 } else if (auto *dir = dyn_cast<TypeDirective>(¶m)) {
962 ArgumentLengthKind lengthKind;
963 StringRef listName = getTypeListName(dir->getOperand(), lengthKind);
964 if (lengthKind == ArgumentLengthKind::Optional)
965 body << llvm::formatv(" ::mlir::Type {0}Type;\n", listName);
966 }
967 }
968
969 body << " if (parse" << dir->getName() << "(parser";
970 for (Element ¶m : dir->getArguments())
971 genCustomParameterParser(param, body);
972
973 body << "))\n"
974 << " return ::mlir::failure();\n";
975
976 // After parsing, add handling for any of the optional constructs.
977 for (Element ¶m : dir->getArguments()) {
978 if (auto *attr = dyn_cast<AttributeVariable>(¶m)) {
979 const NamedAttribute *var = attr->getVar();
980 if (var->attr.isOptional())
981 body << llvm::formatv(" if ({0}Attr)\n ", var->name);
982
983 body << llvm::formatv(" result.addAttribute(\"{0}\", {0}Attr);\n",
984 var->name);
985 } else if (auto *operand = dyn_cast<OperandVariable>(¶m)) {
986 const NamedTypeConstraint *var = operand->getVar();
987 if (!var->isOptional())
988 continue;
989 body << llvm::formatv(" if ({0}Operand.hasValue())\n"
990 " {0}Operands.push_back(*{0}Operand);\n",
991 var->name);
992 } else if (isa<TypeRefDirective>(¶m)) {
993 // In the `type_ref` case, do not parse a new Type that needs to be added.
994 // Just do nothing here.
995 } else if (auto *dir = dyn_cast<TypeDirective>(¶m)) {
996 ArgumentLengthKind lengthKind;
997 StringRef listName = getTypeListName(dir->getOperand(), lengthKind);
998 if (lengthKind == ArgumentLengthKind::Optional) {
999 body << llvm::formatv(" if ({0}Type)\n"
1000 " {0}Types.push_back({0}Type);\n",
1001 listName);
1002 }
1003 }
1004 }
1005
1006 body << " }\n";
1007 }
1008
genParser(Operator & op,OpClass & opClass)1009 void OperationFormat::genParser(Operator &op, OpClass &opClass) {
1010 llvm::SmallVector<OpMethodParameter, 4> paramList;
1011 paramList.emplace_back("::mlir::OpAsmParser &", "parser");
1012 paramList.emplace_back("::mlir::OperationState &", "result");
1013
1014 auto *method =
1015 opClass.addMethodAndPrune("::mlir::ParseResult", "parse",
1016 OpMethod::MP_Static, std::move(paramList));
1017 auto &body = method->body();
1018
1019 // Generate variables to store the operands and type within the format. This
1020 // allows for referencing these variables in the presence of optional
1021 // groupings.
1022 for (auto &element : elements)
1023 genElementParserStorage(&*element, body);
1024
1025 // A format context used when parsing attributes with buildable types.
1026 FmtContext attrTypeCtx;
1027 attrTypeCtx.withBuilder("parser.getBuilder()");
1028
1029 // Generate parsers for each of the elements.
1030 for (auto &element : elements)
1031 genElementParser(element.get(), body, attrTypeCtx);
1032
1033 // Generate the code to resolve the operand/result types and successors now
1034 // that they have been parsed.
1035 genParserTypeResolution(op, body);
1036 genParserRegionResolution(op, body);
1037 genParserSuccessorResolution(op, body);
1038 genParserVariadicSegmentResolution(op, body);
1039
1040 body << " return ::mlir::success();\n";
1041 }
1042
genElementParser(Element * element,OpMethodBody & body,FmtContext & attrTypeCtx)1043 void OperationFormat::genElementParser(Element *element, OpMethodBody &body,
1044 FmtContext &attrTypeCtx) {
1045 /// Optional Group.
1046 if (auto *optional = dyn_cast<OptionalElement>(element)) {
1047 auto elements =
1048 llvm::drop_begin(optional->getElements(), optional->getParseStart());
1049
1050 // Generate a special optional parser for the first element to gate the
1051 // parsing of the rest of the elements.
1052 Element *firstElement = &*elements.begin();
1053 if (auto *attrVar = dyn_cast<AttributeVariable>(firstElement)) {
1054 genElementParser(attrVar, body, attrTypeCtx);
1055 body << " if (" << attrVar->getVar()->name << "Attr) {\n";
1056 } else if (auto *literal = dyn_cast<LiteralElement>(firstElement)) {
1057 body << " if (succeeded(parser.parseOptional";
1058 genLiteralParser(literal->getLiteral(), body);
1059 body << ")) {\n";
1060 } else if (auto *opVar = dyn_cast<OperandVariable>(firstElement)) {
1061 genElementParser(opVar, body, attrTypeCtx);
1062 body << " if (!" << opVar->getVar()->name << "Operands.empty()) {\n";
1063 } else if (auto *regionVar = dyn_cast<RegionVariable>(firstElement)) {
1064 const NamedRegion *region = regionVar->getVar();
1065 if (region->isVariadic()) {
1066 genElementParser(regionVar, body, attrTypeCtx);
1067 body << " if (!" << region->name << "Regions.empty()) {\n";
1068 } else {
1069 body << llvm::formatv(optionalRegionParserCode, region->name);
1070 body << " if (!" << region->name << "Region->empty()) {\n ";
1071 if (hasImplicitTermTrait)
1072 body << llvm::formatv(regionEnsureTerminatorParserCode, region->name);
1073 }
1074 }
1075
1076 // If the anchor is a unit attribute, we don't need to print it. When
1077 // parsing, we will add this attribute if this group is present.
1078 Element *elidedAnchorElement = nullptr;
1079 auto *anchorAttr = dyn_cast<AttributeVariable>(optional->getAnchor());
1080 if (anchorAttr && anchorAttr != firstElement && anchorAttr->isUnitAttr()) {
1081 elidedAnchorElement = anchorAttr;
1082
1083 // Add the anchor unit attribute to the operation state.
1084 body << " result.addAttribute(\"" << anchorAttr->getVar()->name
1085 << "\", parser.getBuilder().getUnitAttr());\n";
1086 }
1087
1088 // Generate the rest of the elements normally.
1089 for (Element &childElement : llvm::drop_begin(elements, 1)) {
1090 if (&childElement != elidedAnchorElement)
1091 genElementParser(&childElement, body, attrTypeCtx);
1092 }
1093 body << " }\n";
1094
1095 /// Literals.
1096 } else if (LiteralElement *literal = dyn_cast<LiteralElement>(element)) {
1097 body << " if (parser.parse";
1098 genLiteralParser(literal->getLiteral(), body);
1099 body << ")\n return ::mlir::failure();\n";
1100
1101 /// Spaces.
1102 } else if (isa<SpaceElement>(element)) {
1103 // Nothing to parse.
1104
1105 /// Arguments.
1106 } else if (auto *attr = dyn_cast<AttributeVariable>(element)) {
1107 const NamedAttribute *var = attr->getVar();
1108
1109 // Check to see if we can parse this as an enum attribute.
1110 if (canFormatEnumAttr(var)) {
1111 Attribute baseAttr = var->attr.getBaseAttr();
1112 const EnumAttr &enumAttr = cast<EnumAttr>(baseAttr);
1113
1114 // Generate the code for building an attribute for this enum.
1115 std::string attrBuilderStr;
1116 {
1117 llvm::raw_string_ostream os(attrBuilderStr);
1118 os << tgfmt(enumAttr.getConstBuilderTemplate(), &attrTypeCtx,
1119 "attrOptional.getValue()");
1120 }
1121
1122 body << formatv(var->attr.isOptional() ? optionalEnumAttrParserCode
1123 : enumAttrParserCode,
1124 var->name, enumAttr.getCppNamespace(),
1125 enumAttr.getStringToSymbolFnName(), attrBuilderStr);
1126 return;
1127 }
1128
1129 // Check to see if we should parse this as a symbol name attribute.
1130 if (shouldFormatSymbolNameAttr(var)) {
1131 body << formatv(var->attr.isOptional() ? optionalSymbolNameAttrParserCode
1132 : symbolNameAttrParserCode,
1133 var->name);
1134 return;
1135 }
1136
1137 // If this attribute has a buildable type, use that when parsing the
1138 // attribute.
1139 std::string attrTypeStr;
1140 if (Optional<StringRef> typeBuilder = attr->getTypeBuilder()) {
1141 llvm::raw_string_ostream os(attrTypeStr);
1142 os << ", " << tgfmt(*typeBuilder, &attrTypeCtx);
1143 }
1144
1145 body << formatv(var->attr.isOptional() ? optionalAttrParserCode
1146 : attrParserCode,
1147 var->name, attrTypeStr);
1148 } else if (auto *operand = dyn_cast<OperandVariable>(element)) {
1149 ArgumentLengthKind lengthKind = getArgumentLengthKind(operand->getVar());
1150 StringRef name = operand->getVar()->name;
1151 if (lengthKind == ArgumentLengthKind::Variadic)
1152 body << llvm::formatv(variadicOperandParserCode, name);
1153 else if (lengthKind == ArgumentLengthKind::Optional)
1154 body << llvm::formatv(optionalOperandParserCode, name);
1155 else
1156 body << formatv(operandParserCode, name);
1157
1158 } else if (auto *region = dyn_cast<RegionVariable>(element)) {
1159 bool isVariadic = region->getVar()->isVariadic();
1160 body << llvm::formatv(isVariadic ? regionListParserCode : regionParserCode,
1161 region->getVar()->name);
1162 if (hasImplicitTermTrait) {
1163 body << llvm::formatv(isVariadic ? regionListEnsureTerminatorParserCode
1164 : regionEnsureTerminatorParserCode,
1165 region->getVar()->name);
1166 }
1167
1168 } else if (auto *successor = dyn_cast<SuccessorVariable>(element)) {
1169 bool isVariadic = successor->getVar()->isVariadic();
1170 body << formatv(isVariadic ? successorListParserCode : successorParserCode,
1171 successor->getVar()->name);
1172
1173 /// Directives.
1174 } else if (auto *attrDict = dyn_cast<AttrDictDirective>(element)) {
1175 body << " if (parser.parseOptionalAttrDict"
1176 << (attrDict->isWithKeyword() ? "WithKeyword" : "")
1177 << "(result.attributes))\n"
1178 << " return ::mlir::failure();\n";
1179 } else if (auto *customDir = dyn_cast<CustomDirective>(element)) {
1180 genCustomDirectiveParser(customDir, body);
1181
1182 } else if (isa<OperandsDirective>(element)) {
1183 body << " ::llvm::SMLoc allOperandLoc = parser.getCurrentLocation();\n"
1184 << " if (parser.parseOperandList(allOperands))\n"
1185 << " return ::mlir::failure();\n";
1186
1187 } else if (isa<RegionsDirective>(element)) {
1188 body << llvm::formatv(regionListParserCode, "full");
1189 if (hasImplicitTermTrait)
1190 body << llvm::formatv(regionListEnsureTerminatorParserCode, "full");
1191
1192 } else if (isa<SuccessorsDirective>(element)) {
1193 body << llvm::formatv(successorListParserCode, "full");
1194
1195 } else if (auto *dir = dyn_cast<TypeRefDirective>(element)) {
1196 ArgumentLengthKind lengthKind;
1197 StringRef listName = getTypeListName(dir->getOperand(), lengthKind);
1198 if (lengthKind == ArgumentLengthKind::Variadic)
1199 body << llvm::formatv(variadicTypeParserCode, listName);
1200 else if (lengthKind == ArgumentLengthKind::Optional)
1201 body << llvm::formatv(optionalTypeParserCode, listName);
1202 else
1203 body << formatv(typeParserCode, listName);
1204 } else if (auto *dir = dyn_cast<TypeDirective>(element)) {
1205 ArgumentLengthKind lengthKind;
1206 StringRef listName = getTypeListName(dir->getOperand(), lengthKind);
1207 if (lengthKind == ArgumentLengthKind::Variadic)
1208 body << llvm::formatv(variadicTypeParserCode, listName);
1209 else if (lengthKind == ArgumentLengthKind::Optional)
1210 body << llvm::formatv(optionalTypeParserCode, listName);
1211 else
1212 body << formatv(typeParserCode, listName);
1213 } else if (auto *dir = dyn_cast<FunctionalTypeDirective>(element)) {
1214 ArgumentLengthKind ignored;
1215 body << formatv(functionalTypeParserCode,
1216 getTypeListName(dir->getInputs(), ignored),
1217 getTypeListName(dir->getResults(), ignored));
1218 } else {
1219 llvm_unreachable("unknown format element");
1220 }
1221 }
1222
genParserTypeResolution(Operator & op,OpMethodBody & body)1223 void OperationFormat::genParserTypeResolution(Operator &op,
1224 OpMethodBody &body) {
1225 // If any of type resolutions use transformed variables, make sure that the
1226 // types of those variables are resolved.
1227 SmallPtrSet<const NamedTypeConstraint *, 8> verifiedVariables;
1228 FmtContext verifierFCtx;
1229 for (TypeResolution &resolver :
1230 llvm::concat<TypeResolution>(resultTypes, operandTypes)) {
1231 Optional<StringRef> transformer = resolver.getVarTransformer();
1232 if (!transformer)
1233 continue;
1234 // Ensure that we don't verify the same variables twice.
1235 const NamedTypeConstraint *variable = resolver.getVariable();
1236 if (!variable || !verifiedVariables.insert(variable).second)
1237 continue;
1238
1239 auto constraint = variable->constraint;
1240 body << " for (::mlir::Type type : " << variable->name << "Types) {\n"
1241 << " (void)type;\n"
1242 << " if (!("
1243 << tgfmt(constraint.getConditionTemplate(),
1244 &verifierFCtx.withSelf("type"))
1245 << ")) {\n"
1246 << formatv(" return parser.emitError(parser.getNameLoc()) << "
1247 "\"'{0}' must be {1}, but got \" << type;\n",
1248 variable->name, constraint.getDescription())
1249 << " }\n"
1250 << " }\n";
1251 }
1252
1253 // Initialize the set of buildable types.
1254 if (!buildableTypes.empty()) {
1255 FmtContext typeBuilderCtx;
1256 typeBuilderCtx.withBuilder("parser.getBuilder()");
1257 for (auto &it : buildableTypes)
1258 body << " ::mlir::Type odsBuildableType" << it.second << " = "
1259 << tgfmt(it.first, &typeBuilderCtx) << ";\n";
1260 }
1261
1262 // Emit the code necessary for a type resolver.
1263 auto emitTypeResolver = [&](TypeResolution &resolver, StringRef curVar) {
1264 if (Optional<int> val = resolver.getBuilderIdx()) {
1265 body << "odsBuildableType" << *val;
1266 } else if (const NamedTypeConstraint *var = resolver.getVariable()) {
1267 if (Optional<StringRef> tform = resolver.getVarTransformer())
1268 body << tgfmt(*tform, &FmtContext().withSelf(var->name + "Types[0]"));
1269 else
1270 body << var->name << "Types";
1271 } else if (const NamedAttribute *attr = resolver.getAttribute()) {
1272 if (Optional<StringRef> tform = resolver.getVarTransformer())
1273 body << tgfmt(*tform,
1274 &FmtContext().withSelf(attr->name + "Attr.getType()"));
1275 else
1276 body << attr->name << "Attr.getType()";
1277 } else {
1278 body << curVar << "Types";
1279 }
1280 };
1281
1282 // Resolve each of the result types.
1283 if (allResultTypes) {
1284 body << " result.addTypes(allResultTypes);\n";
1285 } else {
1286 for (unsigned i = 0, e = op.getNumResults(); i != e; ++i) {
1287 body << " result.addTypes(";
1288 emitTypeResolver(resultTypes[i], op.getResultName(i));
1289 body << ");\n";
1290 }
1291 }
1292
1293 // Early exit if there are no operands.
1294 if (op.getNumOperands() == 0)
1295 return;
1296
1297 // Handle the case where all operand types are in one group.
1298 if (allOperandTypes) {
1299 // If we have all operands together, use the full operand list directly.
1300 if (allOperands) {
1301 body << " if (parser.resolveOperands(allOperands, allOperandTypes, "
1302 "allOperandLoc, result.operands))\n"
1303 " return ::mlir::failure();\n";
1304 return;
1305 }
1306
1307 // Otherwise, use llvm::concat to merge the disjoint operand lists together.
1308 // llvm::concat does not allow the case of a single range, so guard it here.
1309 body << " if (parser.resolveOperands(";
1310 if (op.getNumOperands() > 1) {
1311 body << "::llvm::concat<const ::mlir::OpAsmParser::OperandType>(";
1312 llvm::interleaveComma(op.getOperands(), body, [&](auto &operand) {
1313 body << operand.name << "Operands";
1314 });
1315 body << ")";
1316 } else {
1317 body << op.operand_begin()->name << "Operands";
1318 }
1319 body << ", allOperandTypes, parser.getNameLoc(), result.operands))\n"
1320 << " return ::mlir::failure();\n";
1321 return;
1322 }
1323 // Handle the case where all of the operands were grouped together.
1324 if (allOperands) {
1325 body << " if (parser.resolveOperands(allOperands, ";
1326
1327 // Group all of the operand types together to perform the resolution all at
1328 // once. Use llvm::concat to perform the merge. llvm::concat does not allow
1329 // the case of a single range, so guard it here.
1330 if (op.getNumOperands() > 1) {
1331 body << "::llvm::concat<const Type>(";
1332 llvm::interleaveComma(
1333 llvm::seq<int>(0, op.getNumOperands()), body, [&](int i) {
1334 body << "::llvm::ArrayRef<::mlir::Type>(";
1335 emitTypeResolver(operandTypes[i], op.getOperand(i).name);
1336 body << ")";
1337 });
1338 body << ")";
1339 } else {
1340 emitTypeResolver(operandTypes.front(), op.getOperand(0).name);
1341 }
1342
1343 body << ", allOperandLoc, result.operands))\n"
1344 << " return ::mlir::failure();\n";
1345 return;
1346 }
1347
1348 // The final case is the one where each of the operands types are resolved
1349 // separately.
1350 for (unsigned i = 0, e = op.getNumOperands(); i != e; ++i) {
1351 NamedTypeConstraint &operand = op.getOperand(i);
1352 body << " if (parser.resolveOperands(" << operand.name << "Operands, ";
1353
1354 // Resolve the type of this operand.
1355 TypeResolution &operandType = operandTypes[i];
1356 emitTypeResolver(operandType, operand.name);
1357
1358 // If the type is resolved by a non-variadic variable, index into the
1359 // resolved type list. This allows for resolving the types of a variadic
1360 // operand list from a non-variadic variable.
1361 bool verifyOperandAndTypeSize = true;
1362 if (auto *resolverVar = operandType.getVariable()) {
1363 if (!resolverVar->isVariadic() && !operandType.getVarTransformer()) {
1364 body << "[0]";
1365 verifyOperandAndTypeSize = false;
1366 }
1367 } else {
1368 verifyOperandAndTypeSize = !operandType.getBuilderIdx();
1369 }
1370
1371 // Check to see if the sizes between the types and operands must match. If
1372 // they do, provide the operand location to select the proper resolution
1373 // overload.
1374 if (verifyOperandAndTypeSize)
1375 body << ", " << operand.name << "OperandsLoc";
1376 body << ", result.operands))\n return ::mlir::failure();\n";
1377 }
1378 }
1379
genParserRegionResolution(Operator & op,OpMethodBody & body)1380 void OperationFormat::genParserRegionResolution(Operator &op,
1381 OpMethodBody &body) {
1382 // Check for the case where all regions were parsed.
1383 bool hasAllRegions = llvm::any_of(
1384 elements, [](auto &elt) { return isa<RegionsDirective>(elt.get()); });
1385 if (hasAllRegions) {
1386 body << " result.addRegions(fullRegions);\n";
1387 return;
1388 }
1389
1390 // Otherwise, handle each region individually.
1391 for (const NamedRegion ®ion : op.getRegions()) {
1392 if (region.isVariadic())
1393 body << " result.addRegions(" << region.name << "Regions);\n";
1394 else
1395 body << " result.addRegion(std::move(" << region.name << "Region));\n";
1396 }
1397 }
1398
genParserSuccessorResolution(Operator & op,OpMethodBody & body)1399 void OperationFormat::genParserSuccessorResolution(Operator &op,
1400 OpMethodBody &body) {
1401 // Check for the case where all successors were parsed.
1402 bool hasAllSuccessors = llvm::any_of(
1403 elements, [](auto &elt) { return isa<SuccessorsDirective>(elt.get()); });
1404 if (hasAllSuccessors) {
1405 body << " result.addSuccessors(fullSuccessors);\n";
1406 return;
1407 }
1408
1409 // Otherwise, handle each successor individually.
1410 for (const NamedSuccessor &successor : op.getSuccessors()) {
1411 if (successor.isVariadic())
1412 body << " result.addSuccessors(" << successor.name << "Successors);\n";
1413 else
1414 body << " result.addSuccessors(" << successor.name << "Successor);\n";
1415 }
1416 }
1417
genParserVariadicSegmentResolution(Operator & op,OpMethodBody & body)1418 void OperationFormat::genParserVariadicSegmentResolution(Operator &op,
1419 OpMethodBody &body) {
1420 if (!allOperands &&
1421 op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) {
1422 body << " result.addAttribute(\"operand_segment_sizes\", "
1423 << "parser.getBuilder().getI32VectorAttr({";
1424 auto interleaveFn = [&](const NamedTypeConstraint &operand) {
1425 // If the operand is variadic emit the parsed size.
1426 if (operand.isVariableLength())
1427 body << "static_cast<int32_t>(" << operand.name << "Operands.size())";
1428 else
1429 body << "1";
1430 };
1431 llvm::interleaveComma(op.getOperands(), body, interleaveFn);
1432 body << "}));\n";
1433 }
1434
1435 if (!allResultTypes &&
1436 op.getTrait("::mlir::OpTrait::AttrSizedResultSegments")) {
1437 body << " result.addAttribute(\"result_segment_sizes\", "
1438 << "parser.getBuilder().getI32VectorAttr({";
1439 auto interleaveFn = [&](const NamedTypeConstraint &result) {
1440 // If the result is variadic emit the parsed size.
1441 if (result.isVariableLength())
1442 body << "static_cast<int32_t>(" << result.name << "Types.size())";
1443 else
1444 body << "1";
1445 };
1446 llvm::interleaveComma(op.getResults(), body, interleaveFn);
1447 body << "}));\n";
1448 }
1449 }
1450
1451 //===----------------------------------------------------------------------===//
1452 // PrinterGen
1453
1454 /// The code snippet used to generate a printer call for a region of an
1455 // operation that has the SingleBlockImplicitTerminator trait.
1456 ///
1457 /// {0}: The name of the region.
1458 const char *regionSingleBlockImplicitTerminatorPrinterCode = R"(
1459 {
1460 bool printTerminator = true;
1461 if (auto *term = {0}.empty() ? nullptr : {0}.begin()->getTerminator()) {{
1462 printTerminator = !term->getMutableAttrDict().empty() ||
1463 term->getNumOperands() != 0 ||
1464 term->getNumResults() != 0;
1465 }
1466 p.printRegion({0}, /*printEntryBlockArgs=*/true,
1467 /*printBlockTerminators=*/printTerminator);
1468 }
1469 )";
1470
1471 /// Generate the printer for the 'attr-dict' directive.
genAttrDictPrinter(OperationFormat & fmt,Operator & op,OpMethodBody & body,bool withKeyword)1472 static void genAttrDictPrinter(OperationFormat &fmt, Operator &op,
1473 OpMethodBody &body, bool withKeyword) {
1474 body << " p.printOptionalAttrDict" << (withKeyword ? "WithKeyword" : "")
1475 << "(getAttrs(), /*elidedAttrs=*/{";
1476 // Elide the variadic segment size attributes if necessary.
1477 if (!fmt.allOperands &&
1478 op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments"))
1479 body << "\"operand_segment_sizes\", ";
1480 if (!fmt.allResultTypes &&
1481 op.getTrait("::mlir::OpTrait::AttrSizedResultSegments"))
1482 body << "\"result_segment_sizes\", ";
1483 llvm::interleaveComma(
1484 fmt.usedAttributes, body,
1485 [&](const NamedAttribute *attr) { body << "\"" << attr->name << "\""; });
1486 body << "});\n";
1487 }
1488
1489 /// Generate the printer for a literal value. `shouldEmitSpace` is true if a
1490 /// space should be emitted before this element. `lastWasPunctuation` is true if
1491 /// the previous element was a punctuation literal.
genLiteralPrinter(StringRef value,OpMethodBody & body,bool & shouldEmitSpace,bool & lastWasPunctuation)1492 static void genLiteralPrinter(StringRef value, OpMethodBody &body,
1493 bool &shouldEmitSpace, bool &lastWasPunctuation) {
1494 body << " p";
1495
1496 // Don't insert a space for certain punctuation.
1497 auto shouldPrintSpaceBeforeLiteral = [&] {
1498 if (value.size() != 1 && value != "->")
1499 return true;
1500 if (lastWasPunctuation)
1501 return !StringRef(">)}],").contains(value.front());
1502 return !StringRef("<>(){}[],").contains(value.front());
1503 };
1504 if (shouldEmitSpace && shouldPrintSpaceBeforeLiteral())
1505 body << " << ' '";
1506 body << " << \"" << value << "\";\n";
1507
1508 // Insert a space after certain literals.
1509 shouldEmitSpace =
1510 value.size() != 1 || !StringRef("<({[").contains(value.front());
1511 lastWasPunctuation = !(value.front() == '_' || isalpha(value.front()));
1512 }
1513
1514 /// Generate the printer for a space. `shouldEmitSpace` and `lastWasPunctuation`
1515 /// are set to false.
genSpacePrinter(bool value,OpMethodBody & body,bool & shouldEmitSpace,bool & lastWasPunctuation)1516 static void genSpacePrinter(bool value, OpMethodBody &body,
1517 bool &shouldEmitSpace, bool &lastWasPunctuation) {
1518 if (value) {
1519 body << " p << ' ';\n";
1520 lastWasPunctuation = false;
1521 }
1522 shouldEmitSpace = false;
1523 }
1524
1525 /// Generate the printer for a custom directive.
genCustomDirectivePrinter(CustomDirective * customDir,OpMethodBody & body)1526 static void genCustomDirectivePrinter(CustomDirective *customDir,
1527 OpMethodBody &body) {
1528 body << " print" << customDir->getName() << "(p, *this";
1529 for (Element ¶m : customDir->getArguments()) {
1530 body << ", ";
1531 if (auto *attr = dyn_cast<AttributeVariable>(¶m)) {
1532 body << attr->getVar()->name << "Attr()";
1533
1534 } else if (isa<AttrDictDirective>(¶m)) {
1535 // Enforce the const-ness since getMutableAttrDict() returns a reference
1536 // into the Operations `attr` member.
1537 body << "(const "
1538 "MutableDictionaryAttr&)getOperation()->getMutableAttrDict()";
1539
1540 } else if (auto *operand = dyn_cast<OperandVariable>(¶m)) {
1541 body << operand->getVar()->name << "()";
1542
1543 } else if (auto *region = dyn_cast<RegionVariable>(¶m)) {
1544 body << region->getVar()->name << "()";
1545
1546 } else if (auto *successor = dyn_cast<SuccessorVariable>(¶m)) {
1547 body << successor->getVar()->name << "()";
1548
1549 } else if (auto *dir = dyn_cast<TypeRefDirective>(¶m)) {
1550 auto *typeOperand = dir->getOperand();
1551 auto *operand = dyn_cast<OperandVariable>(typeOperand);
1552 auto *var = operand ? operand->getVar()
1553 : cast<ResultVariable>(typeOperand)->getVar();
1554 if (var->isVariadic())
1555 body << var->name << "().getTypes()";
1556 else if (var->isOptional())
1557 body << llvm::formatv("({0}() ? {0}().getType() : Type())", var->name);
1558 else
1559 body << var->name << "().getType()";
1560 } else if (auto *dir = dyn_cast<TypeDirective>(¶m)) {
1561 auto *typeOperand = dir->getOperand();
1562 auto *operand = dyn_cast<OperandVariable>(typeOperand);
1563 auto *var = operand ? operand->getVar()
1564 : cast<ResultVariable>(typeOperand)->getVar();
1565 if (var->isVariadic())
1566 body << var->name << "().getTypes()";
1567 else if (var->isOptional())
1568 body << llvm::formatv("({0}() ? {0}().getType() : Type())", var->name);
1569 else
1570 body << var->name << "().getType()";
1571 } else {
1572 llvm_unreachable("unknown custom directive parameter");
1573 }
1574 }
1575
1576 body << ");\n";
1577 }
1578
1579 /// Generate the printer for a region with the given variable name.
genRegionPrinter(const Twine & regionName,OpMethodBody & body,bool hasImplicitTermTrait)1580 static void genRegionPrinter(const Twine ®ionName, OpMethodBody &body,
1581 bool hasImplicitTermTrait) {
1582 if (hasImplicitTermTrait)
1583 body << llvm::formatv(regionSingleBlockImplicitTerminatorPrinterCode,
1584 regionName);
1585 else
1586 body << " p.printRegion(" << regionName << ");\n";
1587 }
genVariadicRegionPrinter(const Twine & regionListName,OpMethodBody & body,bool hasImplicitTermTrait)1588 static void genVariadicRegionPrinter(const Twine ®ionListName,
1589 OpMethodBody &body,
1590 bool hasImplicitTermTrait) {
1591 body << " llvm::interleaveComma(" << regionListName
1592 << ", p, [&](::mlir::Region ®ion) {\n ";
1593 genRegionPrinter("region", body, hasImplicitTermTrait);
1594 body << " });\n";
1595 }
1596
1597 /// Generate the C++ for an operand to a (*-)type directive.
genTypeOperandPrinter(Element * arg,OpMethodBody & body)1598 static OpMethodBody &genTypeOperandPrinter(Element *arg, OpMethodBody &body) {
1599 if (isa<OperandsDirective>(arg))
1600 return body << "getOperation()->getOperandTypes()";
1601 if (isa<ResultsDirective>(arg))
1602 return body << "getOperation()->getResultTypes()";
1603 auto *operand = dyn_cast<OperandVariable>(arg);
1604 auto *var = operand ? operand->getVar() : cast<ResultVariable>(arg)->getVar();
1605 if (var->isVariadic())
1606 return body << var->name << "().getTypes()";
1607 if (var->isOptional())
1608 return body << llvm::formatv(
1609 "({0}() ? ::llvm::ArrayRef<::mlir::Type>({0}().getType()) : "
1610 "::llvm::ArrayRef<::mlir::Type>())",
1611 var->name);
1612 return body << "::llvm::ArrayRef<::mlir::Type>(" << var->name
1613 << "().getType())";
1614 }
1615
genElementPrinter(Element * element,OpMethodBody & body,Operator & op,bool & shouldEmitSpace,bool & lastWasPunctuation)1616 void OperationFormat::genElementPrinter(Element *element, OpMethodBody &body,
1617 Operator &op, bool &shouldEmitSpace,
1618 bool &lastWasPunctuation) {
1619 if (LiteralElement *literal = dyn_cast<LiteralElement>(element))
1620 return genLiteralPrinter(literal->getLiteral(), body, shouldEmitSpace,
1621 lastWasPunctuation);
1622
1623 if (SpaceElement *space = dyn_cast<SpaceElement>(element))
1624 return genSpacePrinter(space->getValue(), body, shouldEmitSpace,
1625 lastWasPunctuation);
1626
1627 // Emit an optional group.
1628 if (OptionalElement *optional = dyn_cast<OptionalElement>(element)) {
1629 // Emit the check for the presence of the anchor element.
1630 Element *anchor = optional->getAnchor();
1631 if (auto *operand = dyn_cast<OperandVariable>(anchor)) {
1632 const NamedTypeConstraint *var = operand->getVar();
1633 if (var->isOptional())
1634 body << " if (" << var->name << "()) {\n";
1635 else if (var->isVariadic())
1636 body << " if (!" << var->name << "().empty()) {\n";
1637 } else if (auto *region = dyn_cast<RegionVariable>(anchor)) {
1638 const NamedRegion *var = region->getVar();
1639 // TODO: Add a check for optional here when ODS supports it.
1640 body << " if (!" << var->name << "().empty()) {\n";
1641
1642 } else {
1643 body << " if (getAttr(\""
1644 << cast<AttributeVariable>(anchor)->getVar()->name << "\")) {\n";
1645 }
1646
1647 // If the anchor is a unit attribute, we don't need to print it. When
1648 // parsing, we will add this attribute if this group is present.
1649 auto elements = optional->getElements();
1650 Element *elidedAnchorElement = nullptr;
1651 auto *anchorAttr = dyn_cast<AttributeVariable>(anchor);
1652 if (anchorAttr && anchorAttr != &*elements.begin() &&
1653 anchorAttr->isUnitAttr()) {
1654 elidedAnchorElement = anchorAttr;
1655 }
1656
1657 // Emit each of the elements.
1658 for (Element &childElement : elements) {
1659 if (&childElement != elidedAnchorElement) {
1660 genElementPrinter(&childElement, body, op, shouldEmitSpace,
1661 lastWasPunctuation);
1662 }
1663 }
1664 body << " }\n";
1665 return;
1666 }
1667
1668 // Emit the attribute dictionary.
1669 if (auto *attrDict = dyn_cast<AttrDictDirective>(element)) {
1670 genAttrDictPrinter(*this, op, body, attrDict->isWithKeyword());
1671 lastWasPunctuation = false;
1672 return;
1673 }
1674
1675 // Optionally insert a space before the next element. The AttrDict printer
1676 // already adds a space as necessary.
1677 if (shouldEmitSpace || !lastWasPunctuation)
1678 body << " p << ' ';\n";
1679 lastWasPunctuation = false;
1680 shouldEmitSpace = true;
1681
1682 if (auto *attr = dyn_cast<AttributeVariable>(element)) {
1683 const NamedAttribute *var = attr->getVar();
1684
1685 // If we are formatting as an enum, symbolize the attribute as a string.
1686 if (canFormatEnumAttr(var)) {
1687 Attribute baseAttr = var->attr.getBaseAttr();
1688 const EnumAttr &enumAttr = cast<EnumAttr>(baseAttr);
1689 body << " p << '\"' << " << enumAttr.getSymbolToStringFnName() << "("
1690 << (var->attr.isOptional() ? "*" : "") << var->name
1691 << "()) << '\"';\n";
1692 return;
1693 }
1694
1695 // If we are formatting as a symbol name, handle it as a symbol name.
1696 if (shouldFormatSymbolNameAttr(var)) {
1697 body << " p.printSymbolName(" << var->name << "Attr().getValue());\n";
1698 return;
1699 }
1700
1701 // Elide the attribute type if it is buildable.
1702 if (attr->getTypeBuilder())
1703 body << " p.printAttributeWithoutType(" << var->name << "Attr());\n";
1704 else
1705 body << " p.printAttribute(" << var->name << "Attr());\n";
1706 } else if (auto *operand = dyn_cast<OperandVariable>(element)) {
1707 if (operand->getVar()->isOptional()) {
1708 body << " if (::mlir::Value value = " << operand->getVar()->name
1709 << "())\n"
1710 << " p << value;\n";
1711 } else {
1712 body << " p << " << operand->getVar()->name << "();\n";
1713 }
1714 } else if (auto *region = dyn_cast<RegionVariable>(element)) {
1715 const NamedRegion *var = region->getVar();
1716 if (var->isVariadic()) {
1717 genVariadicRegionPrinter(var->name + "()", body, hasImplicitTermTrait);
1718 } else {
1719 genRegionPrinter(var->name + "()", body, hasImplicitTermTrait);
1720 }
1721 } else if (auto *successor = dyn_cast<SuccessorVariable>(element)) {
1722 const NamedSuccessor *var = successor->getVar();
1723 if (var->isVariadic())
1724 body << " ::llvm::interleaveComma(" << var->name << "(), p);\n";
1725 else
1726 body << " p << " << var->name << "();\n";
1727 } else if (auto *dir = dyn_cast<CustomDirective>(element)) {
1728 genCustomDirectivePrinter(dir, body);
1729 } else if (isa<OperandsDirective>(element)) {
1730 body << " p << getOperation()->getOperands();\n";
1731 } else if (isa<RegionsDirective>(element)) {
1732 genVariadicRegionPrinter("getOperation()->getRegions()", body,
1733 hasImplicitTermTrait);
1734 } else if (isa<SuccessorsDirective>(element)) {
1735 body << " ::llvm::interleaveComma(getOperation()->getSuccessors(), p);\n";
1736 } else if (auto *dir = dyn_cast<TypeDirective>(element)) {
1737 body << " p << ";
1738 genTypeOperandPrinter(dir->getOperand(), body) << ";\n";
1739 } else if (auto *dir = dyn_cast<TypeRefDirective>(element)) {
1740 body << " p << ";
1741 genTypeOperandPrinter(dir->getOperand(), body) << ";\n";
1742 } else if (auto *dir = dyn_cast<FunctionalTypeDirective>(element)) {
1743 body << " p.printFunctionalType(";
1744 genTypeOperandPrinter(dir->getInputs(), body) << ", ";
1745 genTypeOperandPrinter(dir->getResults(), body) << ");\n";
1746 } else {
1747 llvm_unreachable("unknown format element");
1748 }
1749 }
1750
genPrinter(Operator & op,OpClass & opClass)1751 void OperationFormat::genPrinter(Operator &op, OpClass &opClass) {
1752 auto *method =
1753 opClass.addMethodAndPrune("void", "print", "::mlir::OpAsmPrinter &p");
1754 auto &body = method->body();
1755
1756 // Emit the operation name, trimming the prefix if this is the standard
1757 // dialect.
1758 body << " p << \"";
1759 std::string opName = op.getOperationName();
1760 if (op.getDialectName() == "std")
1761 body << StringRef(opName).drop_front(4);
1762 else
1763 body << opName;
1764 body << "\";\n";
1765
1766 // Flags for if we should emit a space, and if the last element was
1767 // punctuation.
1768 bool shouldEmitSpace = true, lastWasPunctuation = false;
1769 for (auto &element : elements)
1770 genElementPrinter(element.get(), body, op, shouldEmitSpace,
1771 lastWasPunctuation);
1772 }
1773
1774 //===----------------------------------------------------------------------===//
1775 // FormatLexer
1776 //===----------------------------------------------------------------------===//
1777
1778 namespace {
1779 /// This class represents a specific token in the input format.
1780 class Token {
1781 public:
1782 enum Kind {
1783 // Markers.
1784 eof,
1785 error,
1786
1787 // Tokens with no info.
1788 l_paren,
1789 r_paren,
1790 caret,
1791 comma,
1792 equal,
1793 less,
1794 greater,
1795 question,
1796
1797 // Keywords.
1798 keyword_start,
1799 kw_attr_dict,
1800 kw_attr_dict_w_keyword,
1801 kw_custom,
1802 kw_functional_type,
1803 kw_operands,
1804 kw_regions,
1805 kw_results,
1806 kw_successors,
1807 kw_type,
1808 kw_type_ref,
1809 keyword_end,
1810
1811 // String valued tokens.
1812 identifier,
1813 literal,
1814 variable,
1815 };
Token(Kind kind,StringRef spelling)1816 Token(Kind kind, StringRef spelling) : kind(kind), spelling(spelling) {}
1817
1818 /// Return the bytes that make up this token.
getSpelling() const1819 StringRef getSpelling() const { return spelling; }
1820
1821 /// Return the kind of this token.
getKind() const1822 Kind getKind() const { return kind; }
1823
1824 /// Return a location for this token.
getLoc() const1825 llvm::SMLoc getLoc() const {
1826 return llvm::SMLoc::getFromPointer(spelling.data());
1827 }
1828
1829 /// Return if this token is a keyword.
isKeyword() const1830 bool isKeyword() const { return kind > keyword_start && kind < keyword_end; }
1831
1832 private:
1833 /// Discriminator that indicates the kind of token this is.
1834 Kind kind;
1835
1836 /// A reference to the entire token contents; this is always a pointer into
1837 /// a memory buffer owned by the source manager.
1838 StringRef spelling;
1839 };
1840
1841 /// This class implements a simple lexer for operation assembly format strings.
1842 class FormatLexer {
1843 public:
1844 FormatLexer(llvm::SourceMgr &mgr, Operator &op);
1845
1846 /// Lex the next token and return it.
1847 Token lexToken();
1848
1849 /// Emit an error to the lexer with the given location and message.
1850 Token emitError(llvm::SMLoc loc, const Twine &msg);
1851 Token emitError(const char *loc, const Twine &msg);
1852
1853 Token emitErrorAndNote(llvm::SMLoc loc, const Twine &msg, const Twine ¬e);
1854
1855 private:
formToken(Token::Kind kind,const char * tokStart)1856 Token formToken(Token::Kind kind, const char *tokStart) {
1857 return Token(kind, StringRef(tokStart, curPtr - tokStart));
1858 }
1859
1860 /// Return the next character in the stream.
1861 int getNextChar();
1862
1863 /// Lex an identifier, literal, or variable.
1864 Token lexIdentifier(const char *tokStart);
1865 Token lexLiteral(const char *tokStart);
1866 Token lexVariable(const char *tokStart);
1867
1868 llvm::SourceMgr &srcMgr;
1869 Operator &op;
1870 StringRef curBuffer;
1871 const char *curPtr;
1872 };
1873 } // end anonymous namespace
1874
FormatLexer(llvm::SourceMgr & mgr,Operator & op)1875 FormatLexer::FormatLexer(llvm::SourceMgr &mgr, Operator &op)
1876 : srcMgr(mgr), op(op) {
1877 curBuffer = srcMgr.getMemoryBuffer(mgr.getMainFileID())->getBuffer();
1878 curPtr = curBuffer.begin();
1879 }
1880
emitError(llvm::SMLoc loc,const Twine & msg)1881 Token FormatLexer::emitError(llvm::SMLoc loc, const Twine &msg) {
1882 srcMgr.PrintMessage(loc, llvm::SourceMgr::DK_Error, msg);
1883 llvm::SrcMgr.PrintMessage(op.getLoc()[0], llvm::SourceMgr::DK_Note,
1884 "in custom assembly format for this operation");
1885 return formToken(Token::error, loc.getPointer());
1886 }
emitErrorAndNote(llvm::SMLoc loc,const Twine & msg,const Twine & note)1887 Token FormatLexer::emitErrorAndNote(llvm::SMLoc loc, const Twine &msg,
1888 const Twine ¬e) {
1889 srcMgr.PrintMessage(loc, llvm::SourceMgr::DK_Error, msg);
1890 llvm::SrcMgr.PrintMessage(op.getLoc()[0], llvm::SourceMgr::DK_Note,
1891 "in custom assembly format for this operation");
1892 srcMgr.PrintMessage(loc, llvm::SourceMgr::DK_Note, note);
1893 return formToken(Token::error, loc.getPointer());
1894 }
emitError(const char * loc,const Twine & msg)1895 Token FormatLexer::emitError(const char *loc, const Twine &msg) {
1896 return emitError(llvm::SMLoc::getFromPointer(loc), msg);
1897 }
1898
getNextChar()1899 int FormatLexer::getNextChar() {
1900 char curChar = *curPtr++;
1901 switch (curChar) {
1902 default:
1903 return (unsigned char)curChar;
1904 case 0: {
1905 // A nul character in the stream is either the end of the current buffer or
1906 // a random nul in the file. Disambiguate that here.
1907 if (curPtr - 1 != curBuffer.end())
1908 return 0;
1909
1910 // Otherwise, return end of file.
1911 --curPtr;
1912 return EOF;
1913 }
1914 case '\n':
1915 case '\r':
1916 // Handle the newline character by ignoring it and incrementing the line
1917 // count. However, be careful about 'dos style' files with \n\r in them.
1918 // Only treat a \n\r or \r\n as a single line.
1919 if ((*curPtr == '\n' || (*curPtr == '\r')) && *curPtr != curChar)
1920 ++curPtr;
1921 return '\n';
1922 }
1923 }
1924
lexToken()1925 Token FormatLexer::lexToken() {
1926 const char *tokStart = curPtr;
1927
1928 // This always consumes at least one character.
1929 int curChar = getNextChar();
1930 switch (curChar) {
1931 default:
1932 // Handle identifiers: [a-zA-Z_]
1933 if (isalpha(curChar) || curChar == '_')
1934 return lexIdentifier(tokStart);
1935
1936 // Unknown character, emit an error.
1937 return emitError(tokStart, "unexpected character");
1938 case EOF:
1939 // Return EOF denoting the end of lexing.
1940 return formToken(Token::eof, tokStart);
1941
1942 // Lex punctuation.
1943 case '^':
1944 return formToken(Token::caret, tokStart);
1945 case ',':
1946 return formToken(Token::comma, tokStart);
1947 case '=':
1948 return formToken(Token::equal, tokStart);
1949 case '<':
1950 return formToken(Token::less, tokStart);
1951 case '>':
1952 return formToken(Token::greater, tokStart);
1953 case '?':
1954 return formToken(Token::question, tokStart);
1955 case '(':
1956 return formToken(Token::l_paren, tokStart);
1957 case ')':
1958 return formToken(Token::r_paren, tokStart);
1959
1960 // Ignore whitespace characters.
1961 case 0:
1962 case ' ':
1963 case '\t':
1964 case '\n':
1965 return lexToken();
1966
1967 case '`':
1968 return lexLiteral(tokStart);
1969 case '$':
1970 return lexVariable(tokStart);
1971 }
1972 }
1973
lexLiteral(const char * tokStart)1974 Token FormatLexer::lexLiteral(const char *tokStart) {
1975 assert(curPtr[-1] == '`');
1976
1977 // Lex a literal surrounded by ``.
1978 while (const char curChar = *curPtr++) {
1979 if (curChar == '`')
1980 return formToken(Token::literal, tokStart);
1981 }
1982 return emitError(curPtr - 1, "unexpected end of file in literal");
1983 }
1984
lexVariable(const char * tokStart)1985 Token FormatLexer::lexVariable(const char *tokStart) {
1986 if (!isalpha(curPtr[0]) && curPtr[0] != '_')
1987 return emitError(curPtr - 1, "expected variable name");
1988
1989 // Otherwise, consume the rest of the characters.
1990 while (isalnum(*curPtr) || *curPtr == '_')
1991 ++curPtr;
1992 return formToken(Token::variable, tokStart);
1993 }
1994
lexIdentifier(const char * tokStart)1995 Token FormatLexer::lexIdentifier(const char *tokStart) {
1996 // Match the rest of the identifier regex: [0-9a-zA-Z_\-]*
1997 while (isalnum(*curPtr) || *curPtr == '_' || *curPtr == '-')
1998 ++curPtr;
1999
2000 // Check to see if this identifier is a keyword.
2001 StringRef str(tokStart, curPtr - tokStart);
2002 Token::Kind kind =
2003 StringSwitch<Token::Kind>(str)
2004 .Case("attr-dict", Token::kw_attr_dict)
2005 .Case("attr-dict-with-keyword", Token::kw_attr_dict_w_keyword)
2006 .Case("custom", Token::kw_custom)
2007 .Case("functional-type", Token::kw_functional_type)
2008 .Case("operands", Token::kw_operands)
2009 .Case("regions", Token::kw_regions)
2010 .Case("results", Token::kw_results)
2011 .Case("successors", Token::kw_successors)
2012 .Case("type", Token::kw_type)
2013 .Case("type_ref", Token::kw_type_ref)
2014 .Default(Token::identifier);
2015 return Token(kind, str);
2016 }
2017
2018 //===----------------------------------------------------------------------===//
2019 // FormatParser
2020 //===----------------------------------------------------------------------===//
2021
2022 /// Function to find an element within the given range that has the same name as
2023 /// 'name'.
2024 template <typename RangeT>
findArg(RangeT && range,StringRef name)2025 static auto findArg(RangeT &&range, StringRef name) {
2026 auto it = llvm::find_if(range, [=](auto &arg) { return arg.name == name; });
2027 return it != range.end() ? &*it : nullptr;
2028 }
2029
2030 namespace {
2031 /// This class implements a parser for an instance of an operation assembly
2032 /// format.
2033 class FormatParser {
2034 public:
FormatParser(llvm::SourceMgr & mgr,OperationFormat & format,Operator & op)2035 FormatParser(llvm::SourceMgr &mgr, OperationFormat &format, Operator &op)
2036 : lexer(mgr, op), curToken(lexer.lexToken()), fmt(format), op(op),
2037 seenOperandTypes(op.getNumOperands()),
2038 seenResultTypes(op.getNumResults()) {}
2039
2040 /// Parse the operation assembly format.
2041 LogicalResult parse();
2042
2043 private:
2044 /// This struct represents a type resolution instance. It includes a specific
2045 /// type as well as an optional transformer to apply to that type in order to
2046 /// properly resolve the type of a variable.
2047 struct TypeResolutionInstance {
2048 ConstArgument resolver;
2049 Optional<StringRef> transformer;
2050 };
2051
2052 /// An iterator over the elements of a format group.
2053 using ElementsIterT = llvm::pointee_iterator<
2054 std::vector<std::unique_ptr<Element>>::const_iterator>;
2055
2056 /// Verify the state of operation attributes within the format.
2057 LogicalResult verifyAttributes(llvm::SMLoc loc);
2058 /// Verify the attribute elements at the back of the given stack of iterators.
2059 LogicalResult verifyAttributes(
2060 llvm::SMLoc loc,
2061 SmallVectorImpl<std::pair<ElementsIterT, ElementsIterT>> &iteratorStack);
2062
2063 /// Verify the state of operation operands within the format.
2064 LogicalResult
2065 verifyOperands(llvm::SMLoc loc,
2066 llvm::StringMap<TypeResolutionInstance> &variableTyResolver);
2067
2068 /// Verify the state of operation regions within the format.
2069 LogicalResult verifyRegions(llvm::SMLoc loc);
2070
2071 /// Verify the state of operation results within the format.
2072 LogicalResult
2073 verifyResults(llvm::SMLoc loc,
2074 llvm::StringMap<TypeResolutionInstance> &variableTyResolver);
2075
2076 /// Verify the state of operation successors within the format.
2077 LogicalResult verifySuccessors(llvm::SMLoc loc);
2078
2079 /// Given the values of an `AllTypesMatch` trait, check for inferable type
2080 /// resolution.
2081 void handleAllTypesMatchConstraint(
2082 ArrayRef<StringRef> values,
2083 llvm::StringMap<TypeResolutionInstance> &variableTyResolver);
2084 /// Check for inferable type resolution given all operands, and or results,
2085 /// have the same type. If 'includeResults' is true, the results also have the
2086 /// same type as all of the operands.
2087 void handleSameTypesConstraint(
2088 llvm::StringMap<TypeResolutionInstance> &variableTyResolver,
2089 bool includeResults);
2090 /// Check for inferable type resolution based on another operand, result, or
2091 /// attribute.
2092 void handleTypesMatchConstraint(
2093 llvm::StringMap<TypeResolutionInstance> &variableTyResolver,
2094 llvm::Record def);
2095
2096 /// Returns an argument or attribute with the given name that has been seen
2097 /// within the format.
2098 ConstArgument findSeenArg(StringRef name);
2099
2100 /// Parse a specific element.
2101 LogicalResult parseElement(std::unique_ptr<Element> &element,
2102 bool isTopLevel);
2103 LogicalResult parseVariable(std::unique_ptr<Element> &element,
2104 bool isTopLevel);
2105 LogicalResult parseDirective(std::unique_ptr<Element> &element,
2106 bool isTopLevel);
2107 LogicalResult parseLiteral(std::unique_ptr<Element> &element);
2108 LogicalResult parseOptional(std::unique_ptr<Element> &element,
2109 bool isTopLevel);
2110 LogicalResult parseOptionalChildElement(
2111 std::vector<std::unique_ptr<Element>> &childElements,
2112 SmallPtrSetImpl<const NamedTypeConstraint *> &seenVariables,
2113 Optional<unsigned> &anchorIdx);
2114
2115 /// Parse the various different directives.
2116 LogicalResult parseAttrDictDirective(std::unique_ptr<Element> &element,
2117 llvm::SMLoc loc, bool isTopLevel,
2118 bool withKeyword);
2119 LogicalResult parseCustomDirective(std::unique_ptr<Element> &element,
2120 llvm::SMLoc loc, bool isTopLevel);
2121 LogicalResult parseCustomDirectiveParameter(
2122 std::vector<std::unique_ptr<Element>> ¶meters);
2123 LogicalResult parseFunctionalTypeDirective(std::unique_ptr<Element> &element,
2124 Token tok, bool isTopLevel);
2125 LogicalResult parseOperandsDirective(std::unique_ptr<Element> &element,
2126 llvm::SMLoc loc, bool isTopLevel);
2127 LogicalResult parseRegionsDirective(std::unique_ptr<Element> &element,
2128 llvm::SMLoc loc, bool isTopLevel);
2129 LogicalResult parseResultsDirective(std::unique_ptr<Element> &element,
2130 llvm::SMLoc loc, bool isTopLevel);
2131 LogicalResult parseSuccessorsDirective(std::unique_ptr<Element> &element,
2132 llvm::SMLoc loc, bool isTopLevel);
2133 LogicalResult parseTypeDirective(std::unique_ptr<Element> &element, Token tok,
2134 bool isTopLevel, bool isTypeRef = false);
2135 LogicalResult parseTypeDirectiveOperand(std::unique_ptr<Element> &element,
2136 bool isTypeRef = false);
2137
2138 //===--------------------------------------------------------------------===//
2139 // Lexer Utilities
2140 //===--------------------------------------------------------------------===//
2141
2142 /// Advance the current lexer onto the next token.
consumeToken()2143 void consumeToken() {
2144 assert(curToken.getKind() != Token::eof &&
2145 curToken.getKind() != Token::error &&
2146 "shouldn't advance past EOF or errors");
2147 curToken = lexer.lexToken();
2148 }
parseToken(Token::Kind kind,const Twine & msg)2149 LogicalResult parseToken(Token::Kind kind, const Twine &msg) {
2150 if (curToken.getKind() != kind)
2151 return emitError(curToken.getLoc(), msg);
2152 consumeToken();
2153 return ::mlir::success();
2154 }
emitError(llvm::SMLoc loc,const Twine & msg)2155 LogicalResult emitError(llvm::SMLoc loc, const Twine &msg) {
2156 lexer.emitError(loc, msg);
2157 return ::mlir::failure();
2158 }
emitErrorAndNote(llvm::SMLoc loc,const Twine & msg,const Twine & note)2159 LogicalResult emitErrorAndNote(llvm::SMLoc loc, const Twine &msg,
2160 const Twine ¬e) {
2161 lexer.emitErrorAndNote(loc, msg, note);
2162 return ::mlir::failure();
2163 }
2164
2165 //===--------------------------------------------------------------------===//
2166 // Fields
2167 //===--------------------------------------------------------------------===//
2168
2169 FormatLexer lexer;
2170 Token curToken;
2171 OperationFormat &fmt;
2172 Operator &op;
2173
2174 // The following are various bits of format state used for verification
2175 // during parsing.
2176 bool hasAttrDict = false;
2177 bool hasAllRegions = false, hasAllSuccessors = false;
2178 llvm::SmallBitVector seenOperandTypes, seenResultTypes;
2179 llvm::SmallSetVector<const NamedAttribute *, 8> seenAttrs;
2180 llvm::DenseSet<const NamedTypeConstraint *> seenOperands;
2181 llvm::DenseSet<const NamedRegion *> seenRegions;
2182 llvm::DenseSet<const NamedSuccessor *> seenSuccessors;
2183 llvm::DenseSet<const NamedTypeConstraint *> optionalVariables;
2184 };
2185 } // end anonymous namespace
2186
parse()2187 LogicalResult FormatParser::parse() {
2188 llvm::SMLoc loc = curToken.getLoc();
2189
2190 // Parse each of the format elements into the main format.
2191 while (curToken.getKind() != Token::eof) {
2192 std::unique_ptr<Element> element;
2193 if (failed(parseElement(element, /*isTopLevel=*/true)))
2194 return ::mlir::failure();
2195 fmt.elements.push_back(std::move(element));
2196 }
2197
2198 // Check that the attribute dictionary is in the format.
2199 if (!hasAttrDict)
2200 return emitError(loc, "'attr-dict' directive not found in "
2201 "custom assembly format");
2202
2203 // Check for any type traits that we can use for inferring types.
2204 llvm::StringMap<TypeResolutionInstance> variableTyResolver;
2205 for (const OpTrait &trait : op.getTraits()) {
2206 const llvm::Record &def = trait.getDef();
2207 if (def.isSubClassOf("AllTypesMatch")) {
2208 handleAllTypesMatchConstraint(def.getValueAsListOfStrings("values"),
2209 variableTyResolver);
2210 } else if (def.getName() == "SameTypeOperands") {
2211 handleSameTypesConstraint(variableTyResolver, /*includeResults=*/false);
2212 } else if (def.getName() == "SameOperandsAndResultType") {
2213 handleSameTypesConstraint(variableTyResolver, /*includeResults=*/true);
2214 } else if (def.isSubClassOf("TypesMatchWith")) {
2215 handleTypesMatchConstraint(variableTyResolver, def);
2216 }
2217 }
2218
2219 // Verify the state of the various operation components.
2220 if (failed(verifyAttributes(loc)) ||
2221 failed(verifyResults(loc, variableTyResolver)) ||
2222 failed(verifyOperands(loc, variableTyResolver)) ||
2223 failed(verifyRegions(loc)) || failed(verifySuccessors(loc)))
2224 return ::mlir::failure();
2225
2226 // Collect the set of used attributes in the format.
2227 fmt.usedAttributes = seenAttrs.takeVector();
2228 return ::mlir::success();
2229 }
2230
verifyAttributes(llvm::SMLoc loc)2231 LogicalResult FormatParser::verifyAttributes(llvm::SMLoc loc) {
2232 // Check that there are no `:` literals after an attribute without a constant
2233 // type. The attribute grammar contains an optional trailing colon type, which
2234 // can lead to unexpected and generally unintended behavior. Given that, it is
2235 // better to just error out here instead.
2236 using ElementsIterT = llvm::pointee_iterator<
2237 std::vector<std::unique_ptr<Element>>::const_iterator>;
2238 SmallVector<std::pair<ElementsIterT, ElementsIterT>, 1> iteratorStack;
2239 iteratorStack.emplace_back(fmt.elements.begin(), fmt.elements.end());
2240 while (!iteratorStack.empty())
2241 if (failed(verifyAttributes(loc, iteratorStack)))
2242 return ::mlir::failure();
2243 return ::mlir::success();
2244 }
2245 /// Verify the attribute elements at the back of the given stack of iterators.
verifyAttributes(llvm::SMLoc loc,SmallVectorImpl<std::pair<ElementsIterT,ElementsIterT>> & iteratorStack)2246 LogicalResult FormatParser::verifyAttributes(
2247 llvm::SMLoc loc,
2248 SmallVectorImpl<std::pair<ElementsIterT, ElementsIterT>> &iteratorStack) {
2249 auto &stackIt = iteratorStack.back();
2250 ElementsIterT &it = stackIt.first, e = stackIt.second;
2251 while (it != e) {
2252 Element *element = &*(it++);
2253
2254 // Traverse into optional groups.
2255 if (auto *optional = dyn_cast<OptionalElement>(element)) {
2256 auto elements = optional->getElements();
2257 iteratorStack.emplace_back(elements.begin(), elements.end());
2258 return ::mlir::success();
2259 }
2260
2261 // We are checking for an attribute element followed by a `:`, so there is
2262 // no need to check the end.
2263 if (it == e && iteratorStack.size() == 1)
2264 break;
2265
2266 // Check for an attribute with a constant type builder, followed by a `:`.
2267 auto *prevAttr = dyn_cast<AttributeVariable>(element);
2268 if (!prevAttr || prevAttr->getTypeBuilder())
2269 continue;
2270
2271 // Check the next iterator within the stack for literal elements.
2272 for (auto &nextItPair : iteratorStack) {
2273 ElementsIterT nextIt = nextItPair.first, nextE = nextItPair.second;
2274 for (; nextIt != nextE; ++nextIt) {
2275 // Skip any trailing spaces, attribute dictionaries, or optional groups.
2276 if (isa<SpaceElement>(*nextIt) || isa<AttrDictDirective>(*nextIt) ||
2277 isa<OptionalElement>(*nextIt))
2278 continue;
2279
2280 // We are only interested in `:` literals.
2281 auto *literal = dyn_cast<LiteralElement>(&*nextIt);
2282 if (!literal || literal->getLiteral() != ":")
2283 break;
2284
2285 // TODO: Use the location of the literal element itself.
2286 return emitError(
2287 loc, llvm::formatv("format ambiguity caused by `:` literal found "
2288 "after attribute `{0}` which does not have "
2289 "a buildable type",
2290 prevAttr->getVar()->name));
2291 }
2292 }
2293 }
2294 iteratorStack.pop_back();
2295 return ::mlir::success();
2296 }
2297
verifyOperands(llvm::SMLoc loc,llvm::StringMap<TypeResolutionInstance> & variableTyResolver)2298 LogicalResult FormatParser::verifyOperands(
2299 llvm::SMLoc loc,
2300 llvm::StringMap<TypeResolutionInstance> &variableTyResolver) {
2301 // Check that all of the operands are within the format, and their types can
2302 // be inferred.
2303 auto &buildableTypes = fmt.buildableTypes;
2304 for (unsigned i = 0, e = op.getNumOperands(); i != e; ++i) {
2305 NamedTypeConstraint &operand = op.getOperand(i);
2306
2307 // Check that the operand itself is in the format.
2308 if (!fmt.allOperands && !seenOperands.count(&operand)) {
2309 return emitErrorAndNote(loc,
2310 "operand #" + Twine(i) + ", named '" +
2311 operand.name + "', not found",
2312 "suggest adding a '$" + operand.name +
2313 "' directive to the custom assembly format");
2314 }
2315
2316 // Check that the operand type is in the format, or that it can be inferred.
2317 if (fmt.allOperandTypes || seenOperandTypes.test(i))
2318 continue;
2319
2320 // Check to see if we can infer this type from another variable.
2321 auto varResolverIt = variableTyResolver.find(op.getOperand(i).name);
2322 if (varResolverIt != variableTyResolver.end()) {
2323 TypeResolutionInstance &resolver = varResolverIt->second;
2324 fmt.operandTypes[i].setResolver(resolver.resolver, resolver.transformer);
2325 continue;
2326 }
2327
2328 // Similarly to results, allow a custom builder for resolving the type if
2329 // we aren't using the 'operands' directive.
2330 Optional<StringRef> builder = operand.constraint.getBuilderCall();
2331 if (!builder || (fmt.allOperands && operand.isVariableLength())) {
2332 return emitErrorAndNote(
2333 loc,
2334 "type of operand #" + Twine(i) + ", named '" + operand.name +
2335 "', is not buildable and a buildable type cannot be inferred",
2336 "suggest adding a type constraint to the operation or adding a "
2337 "'type($" +
2338 operand.name + ")' directive to the " + "custom assembly format");
2339 }
2340 auto it = buildableTypes.insert({*builder, buildableTypes.size()});
2341 fmt.operandTypes[i].setBuilderIdx(it.first->second);
2342 }
2343 return ::mlir::success();
2344 }
2345
verifyRegions(llvm::SMLoc loc)2346 LogicalResult FormatParser::verifyRegions(llvm::SMLoc loc) {
2347 // Check that all of the regions are within the format.
2348 if (hasAllRegions)
2349 return ::mlir::success();
2350
2351 for (unsigned i = 0, e = op.getNumRegions(); i != e; ++i) {
2352 const NamedRegion ®ion = op.getRegion(i);
2353 if (!seenRegions.count(®ion)) {
2354 return emitErrorAndNote(loc,
2355 "region #" + Twine(i) + ", named '" +
2356 region.name + "', not found",
2357 "suggest adding a '$" + region.name +
2358 "' directive to the custom assembly format");
2359 }
2360 }
2361 return ::mlir::success();
2362 }
2363
verifyResults(llvm::SMLoc loc,llvm::StringMap<TypeResolutionInstance> & variableTyResolver)2364 LogicalResult FormatParser::verifyResults(
2365 llvm::SMLoc loc,
2366 llvm::StringMap<TypeResolutionInstance> &variableTyResolver) {
2367 // If we format all of the types together, there is nothing to check.
2368 if (fmt.allResultTypes)
2369 return ::mlir::success();
2370
2371 // Check that all of the result types can be inferred.
2372 auto &buildableTypes = fmt.buildableTypes;
2373 for (unsigned i = 0, e = op.getNumResults(); i != e; ++i) {
2374 if (seenResultTypes.test(i))
2375 continue;
2376
2377 // Check to see if we can infer this type from another variable.
2378 auto varResolverIt = variableTyResolver.find(op.getResultName(i));
2379 if (varResolverIt != variableTyResolver.end()) {
2380 TypeResolutionInstance resolver = varResolverIt->second;
2381 fmt.resultTypes[i].setResolver(resolver.resolver, resolver.transformer);
2382 continue;
2383 }
2384
2385 // If the result is not variable length, allow for the case where the type
2386 // has a builder that we can use.
2387 NamedTypeConstraint &result = op.getResult(i);
2388 Optional<StringRef> builder = result.constraint.getBuilderCall();
2389 if (!builder || result.isVariableLength()) {
2390 return emitErrorAndNote(
2391 loc,
2392 "type of result #" + Twine(i) + ", named '" + result.name +
2393 "', is not buildable and a buildable type cannot be inferred",
2394 "suggest adding a type constraint to the operation or adding a "
2395 "'type($" +
2396 result.name + ")' directive to the " + "custom assembly format");
2397 }
2398 // Note in the format that this result uses the custom builder.
2399 auto it = buildableTypes.insert({*builder, buildableTypes.size()});
2400 fmt.resultTypes[i].setBuilderIdx(it.first->second);
2401 }
2402 return ::mlir::success();
2403 }
2404
verifySuccessors(llvm::SMLoc loc)2405 LogicalResult FormatParser::verifySuccessors(llvm::SMLoc loc) {
2406 // Check that all of the successors are within the format.
2407 if (hasAllSuccessors)
2408 return ::mlir::success();
2409
2410 for (unsigned i = 0, e = op.getNumSuccessors(); i != e; ++i) {
2411 const NamedSuccessor &successor = op.getSuccessor(i);
2412 if (!seenSuccessors.count(&successor)) {
2413 return emitErrorAndNote(loc,
2414 "successor #" + Twine(i) + ", named '" +
2415 successor.name + "', not found",
2416 "suggest adding a '$" + successor.name +
2417 "' directive to the custom assembly format");
2418 }
2419 }
2420 return ::mlir::success();
2421 }
2422
handleAllTypesMatchConstraint(ArrayRef<StringRef> values,llvm::StringMap<TypeResolutionInstance> & variableTyResolver)2423 void FormatParser::handleAllTypesMatchConstraint(
2424 ArrayRef<StringRef> values,
2425 llvm::StringMap<TypeResolutionInstance> &variableTyResolver) {
2426 for (unsigned i = 0, e = values.size(); i != e; ++i) {
2427 // Check to see if this value matches a resolved operand or result type.
2428 ConstArgument arg = findSeenArg(values[i]);
2429 if (!arg)
2430 continue;
2431
2432 // Mark this value as the type resolver for the other variables.
2433 for (unsigned j = 0; j != i; ++j)
2434 variableTyResolver[values[j]] = {arg, llvm::None};
2435 for (unsigned j = i + 1; j != e; ++j)
2436 variableTyResolver[values[j]] = {arg, llvm::None};
2437 }
2438 }
2439
handleSameTypesConstraint(llvm::StringMap<TypeResolutionInstance> & variableTyResolver,bool includeResults)2440 void FormatParser::handleSameTypesConstraint(
2441 llvm::StringMap<TypeResolutionInstance> &variableTyResolver,
2442 bool includeResults) {
2443 const NamedTypeConstraint *resolver = nullptr;
2444 int resolvedIt = -1;
2445
2446 // Check to see if there is an operand or result to use for the resolution.
2447 if ((resolvedIt = seenOperandTypes.find_first()) != -1)
2448 resolver = &op.getOperand(resolvedIt);
2449 else if (includeResults && (resolvedIt = seenResultTypes.find_first()) != -1)
2450 resolver = &op.getResult(resolvedIt);
2451 else
2452 return;
2453
2454 // Set the resolvers for each operand and result.
2455 for (unsigned i = 0, e = op.getNumOperands(); i != e; ++i)
2456 if (!seenOperandTypes.test(i) && !op.getOperand(i).name.empty())
2457 variableTyResolver[op.getOperand(i).name] = {resolver, llvm::None};
2458 if (includeResults) {
2459 for (unsigned i = 0, e = op.getNumResults(); i != e; ++i)
2460 if (!seenResultTypes.test(i) && !op.getResultName(i).empty())
2461 variableTyResolver[op.getResultName(i)] = {resolver, llvm::None};
2462 }
2463 }
2464
handleTypesMatchConstraint(llvm::StringMap<TypeResolutionInstance> & variableTyResolver,llvm::Record def)2465 void FormatParser::handleTypesMatchConstraint(
2466 llvm::StringMap<TypeResolutionInstance> &variableTyResolver,
2467 llvm::Record def) {
2468 StringRef lhsName = def.getValueAsString("lhs");
2469 StringRef rhsName = def.getValueAsString("rhs");
2470 StringRef transformer = def.getValueAsString("transformer");
2471 if (ConstArgument arg = findSeenArg(lhsName))
2472 variableTyResolver[rhsName] = {arg, transformer};
2473 }
2474
findSeenArg(StringRef name)2475 ConstArgument FormatParser::findSeenArg(StringRef name) {
2476 if (const NamedTypeConstraint *arg = findArg(op.getOperands(), name))
2477 return seenOperandTypes.test(arg - op.operand_begin()) ? arg : nullptr;
2478 if (const NamedTypeConstraint *arg = findArg(op.getResults(), name))
2479 return seenResultTypes.test(arg - op.result_begin()) ? arg : nullptr;
2480 if (const NamedAttribute *attr = findArg(op.getAttributes(), name))
2481 return seenAttrs.count(attr) ? attr : nullptr;
2482 return nullptr;
2483 }
2484
parseElement(std::unique_ptr<Element> & element,bool isTopLevel)2485 LogicalResult FormatParser::parseElement(std::unique_ptr<Element> &element,
2486 bool isTopLevel) {
2487 // Directives.
2488 if (curToken.isKeyword())
2489 return parseDirective(element, isTopLevel);
2490 // Literals.
2491 if (curToken.getKind() == Token::literal)
2492 return parseLiteral(element);
2493 // Optionals.
2494 if (curToken.getKind() == Token::l_paren)
2495 return parseOptional(element, isTopLevel);
2496 // Variables.
2497 if (curToken.getKind() == Token::variable)
2498 return parseVariable(element, isTopLevel);
2499 return emitError(curToken.getLoc(),
2500 "expected directive, literal, variable, or optional group");
2501 }
2502
parseVariable(std::unique_ptr<Element> & element,bool isTopLevel)2503 LogicalResult FormatParser::parseVariable(std::unique_ptr<Element> &element,
2504 bool isTopLevel) {
2505 Token varTok = curToken;
2506 consumeToken();
2507
2508 StringRef name = varTok.getSpelling().drop_front();
2509 llvm::SMLoc loc = varTok.getLoc();
2510
2511 // Check that the parsed argument is something actually registered on the
2512 // op.
2513 /// Attributes
2514 if (const NamedAttribute *attr = findArg(op.getAttributes(), name)) {
2515 if (isTopLevel && !seenAttrs.insert(attr))
2516 return emitError(loc, "attribute '" + name + "' is already bound");
2517 element = std::make_unique<AttributeVariable>(attr);
2518 return ::mlir::success();
2519 }
2520 /// Operands
2521 if (const NamedTypeConstraint *operand = findArg(op.getOperands(), name)) {
2522 if (isTopLevel) {
2523 if (fmt.allOperands || !seenOperands.insert(operand).second)
2524 return emitError(loc, "operand '" + name + "' is already bound");
2525 }
2526 element = std::make_unique<OperandVariable>(operand);
2527 return ::mlir::success();
2528 }
2529 /// Regions
2530 if (const NamedRegion *region = findArg(op.getRegions(), name)) {
2531 if (!isTopLevel)
2532 return emitError(loc, "regions can only be used at the top level");
2533 if (hasAllRegions || !seenRegions.insert(region).second)
2534 return emitError(loc, "region '" + name + "' is already bound");
2535 element = std::make_unique<RegionVariable>(region);
2536 return ::mlir::success();
2537 }
2538 /// Results.
2539 if (const auto *result = findArg(op.getResults(), name)) {
2540 if (isTopLevel)
2541 return emitError(loc, "results can not be used at the top level");
2542 element = std::make_unique<ResultVariable>(result);
2543 return ::mlir::success();
2544 }
2545 /// Successors.
2546 if (const auto *successor = findArg(op.getSuccessors(), name)) {
2547 if (!isTopLevel)
2548 return emitError(loc, "successors can only be used at the top level");
2549 if (hasAllSuccessors || !seenSuccessors.insert(successor).second)
2550 return emitError(loc, "successor '" + name + "' is already bound");
2551 element = std::make_unique<SuccessorVariable>(successor);
2552 return ::mlir::success();
2553 }
2554 return emitError(loc, "expected variable to refer to an argument, region, "
2555 "result, or successor");
2556 }
2557
parseDirective(std::unique_ptr<Element> & element,bool isTopLevel)2558 LogicalResult FormatParser::parseDirective(std::unique_ptr<Element> &element,
2559 bool isTopLevel) {
2560 Token dirTok = curToken;
2561 consumeToken();
2562
2563 switch (dirTok.getKind()) {
2564 case Token::kw_attr_dict:
2565 return parseAttrDictDirective(element, dirTok.getLoc(), isTopLevel,
2566 /*withKeyword=*/false);
2567 case Token::kw_attr_dict_w_keyword:
2568 return parseAttrDictDirective(element, dirTok.getLoc(), isTopLevel,
2569 /*withKeyword=*/true);
2570 case Token::kw_custom:
2571 return parseCustomDirective(element, dirTok.getLoc(), isTopLevel);
2572 case Token::kw_functional_type:
2573 return parseFunctionalTypeDirective(element, dirTok, isTopLevel);
2574 case Token::kw_operands:
2575 return parseOperandsDirective(element, dirTok.getLoc(), isTopLevel);
2576 case Token::kw_regions:
2577 return parseRegionsDirective(element, dirTok.getLoc(), isTopLevel);
2578 case Token::kw_results:
2579 return parseResultsDirective(element, dirTok.getLoc(), isTopLevel);
2580 case Token::kw_successors:
2581 return parseSuccessorsDirective(element, dirTok.getLoc(), isTopLevel);
2582 case Token::kw_type_ref:
2583 return parseTypeDirective(element, dirTok, isTopLevel, /*isTypeRef=*/true);
2584 case Token::kw_type:
2585 return parseTypeDirective(element, dirTok, isTopLevel);
2586
2587 default:
2588 llvm_unreachable("unknown directive token");
2589 }
2590 }
2591
parseLiteral(std::unique_ptr<Element> & element)2592 LogicalResult FormatParser::parseLiteral(std::unique_ptr<Element> &element) {
2593 Token literalTok = curToken;
2594 consumeToken();
2595
2596 StringRef value = literalTok.getSpelling().drop_front().drop_back();
2597
2598 // The parsed literal is a space element (`` or ` `).
2599 if (value.empty() || (value.size() == 1 && value.front() == ' ')) {
2600 element = std::make_unique<SpaceElement>(!value.empty());
2601 return ::mlir::success();
2602 }
2603
2604 // Check that the parsed literal is valid.
2605 if (!LiteralElement::isValidLiteral(value))
2606 return emitError(literalTok.getLoc(), "expected valid literal");
2607
2608 element = std::make_unique<LiteralElement>(value);
2609 return ::mlir::success();
2610 }
2611
parseOptional(std::unique_ptr<Element> & element,bool isTopLevel)2612 LogicalResult FormatParser::parseOptional(std::unique_ptr<Element> &element,
2613 bool isTopLevel) {
2614 llvm::SMLoc curLoc = curToken.getLoc();
2615 if (!isTopLevel)
2616 return emitError(curLoc, "optional groups can only be used as top-level "
2617 "elements");
2618 consumeToken();
2619
2620 // Parse the child elements for this optional group.
2621 std::vector<std::unique_ptr<Element>> elements;
2622 SmallPtrSet<const NamedTypeConstraint *, 8> seenVariables;
2623 Optional<unsigned> anchorIdx;
2624 do {
2625 if (failed(parseOptionalChildElement(elements, seenVariables, anchorIdx)))
2626 return ::mlir::failure();
2627 } while (curToken.getKind() != Token::r_paren);
2628 consumeToken();
2629 if (failed(parseToken(Token::question, "expected '?' after optional group")))
2630 return ::mlir::failure();
2631
2632 // The optional group is required to have an anchor.
2633 if (!anchorIdx)
2634 return emitError(curLoc, "optional group specified no anchor element");
2635
2636 // The first parsable element of the group must be able to be parsed in an
2637 // optional fashion.
2638 auto parseBegin = llvm::find_if_not(
2639 elements, [](auto &element) { return isa<SpaceElement>(element.get()); });
2640 Element *firstElement = parseBegin->get();
2641 if (!isa<AttributeVariable>(firstElement) &&
2642 !isa<LiteralElement>(firstElement) &&
2643 !isa<OperandVariable>(firstElement) && !isa<RegionVariable>(firstElement))
2644 return emitError(curLoc,
2645 "first parsable element of an operand group must be "
2646 "an attribute, literal, operand, or region");
2647
2648 // After parsing all of the elements, ensure that all type directives refer
2649 // only to elements within the group.
2650 auto checkTypeOperand = [&](Element *typeEle) {
2651 auto *opVar = dyn_cast<OperandVariable>(typeEle);
2652 const NamedTypeConstraint *var = opVar ? opVar->getVar() : nullptr;
2653 if (!seenVariables.count(var))
2654 return emitError(curLoc, "type directive can only refer to variables "
2655 "within the optional group");
2656 return ::mlir::success();
2657 };
2658 for (auto &ele : elements) {
2659 if (auto *typeEle = dyn_cast<TypeRefDirective>(ele.get())) {
2660 if (failed(checkTypeOperand(typeEle->getOperand())))
2661 return failure();
2662 } else if (auto *typeEle = dyn_cast<TypeDirective>(ele.get())) {
2663 if (failed(checkTypeOperand(typeEle->getOperand())))
2664 return ::mlir::failure();
2665 } else if (auto *typeEle = dyn_cast<FunctionalTypeDirective>(ele.get())) {
2666 if (failed(checkTypeOperand(typeEle->getInputs())) ||
2667 failed(checkTypeOperand(typeEle->getResults())))
2668 return ::mlir::failure();
2669 }
2670 }
2671
2672 optionalVariables.insert(seenVariables.begin(), seenVariables.end());
2673 auto parseStart = parseBegin - elements.begin();
2674 element = std::make_unique<OptionalElement>(std::move(elements), *anchorIdx,
2675 parseStart);
2676 return ::mlir::success();
2677 }
2678
parseOptionalChildElement(std::vector<std::unique_ptr<Element>> & childElements,SmallPtrSetImpl<const NamedTypeConstraint * > & seenVariables,Optional<unsigned> & anchorIdx)2679 LogicalResult FormatParser::parseOptionalChildElement(
2680 std::vector<std::unique_ptr<Element>> &childElements,
2681 SmallPtrSetImpl<const NamedTypeConstraint *> &seenVariables,
2682 Optional<unsigned> &anchorIdx) {
2683 llvm::SMLoc childLoc = curToken.getLoc();
2684 childElements.push_back({});
2685 if (failed(parseElement(childElements.back(), /*isTopLevel=*/true)))
2686 return ::mlir::failure();
2687
2688 // Check to see if this element is the anchor of the optional group.
2689 bool isAnchor = curToken.getKind() == Token::caret;
2690 if (isAnchor) {
2691 if (anchorIdx)
2692 return emitError(childLoc, "only one element can be marked as the anchor "
2693 "of an optional group");
2694 anchorIdx = childElements.size() - 1;
2695 consumeToken();
2696 }
2697
2698 return TypeSwitch<Element *, LogicalResult>(childElements.back().get())
2699 // All attributes can be within the optional group, but only optional
2700 // attributes can be the anchor.
2701 .Case([&](AttributeVariable *attrEle) {
2702 if (isAnchor && !attrEle->getVar()->attr.isOptional())
2703 return emitError(childLoc, "only optional attributes can be used to "
2704 "anchor an optional group");
2705 return ::mlir::success();
2706 })
2707 // Only optional-like(i.e. variadic) operands can be within an optional
2708 // group.
2709 .Case<OperandVariable>([&](OperandVariable *ele) {
2710 if (!ele->getVar()->isVariableLength())
2711 return emitError(childLoc, "only variable length operands can be "
2712 "used within an optional group");
2713 seenVariables.insert(ele->getVar());
2714 return ::mlir::success();
2715 })
2716 .Case<RegionVariable>([&](RegionVariable *) {
2717 // TODO: When ODS has proper support for marking "optional" regions, add
2718 // a check here.
2719 return ::mlir::success();
2720 })
2721 // Literals, spaces, custom directives, and type directives may be used,
2722 // but they can't anchor the group.
2723 .Case<LiteralElement, SpaceElement, CustomDirective,
2724 FunctionalTypeDirective, OptionalElement, TypeRefDirective,
2725 TypeDirective>([&](Element *) {
2726 if (isAnchor)
2727 return emitError(childLoc, "only variables can be used to anchor "
2728 "an optional group");
2729 return ::mlir::success();
2730 })
2731 .Default([&](Element *) {
2732 return emitError(childLoc, "only literals, types, and variables can be "
2733 "used within an optional group");
2734 });
2735 }
2736
2737 LogicalResult
parseAttrDictDirective(std::unique_ptr<Element> & element,llvm::SMLoc loc,bool isTopLevel,bool withKeyword)2738 FormatParser::parseAttrDictDirective(std::unique_ptr<Element> &element,
2739 llvm::SMLoc loc, bool isTopLevel,
2740 bool withKeyword) {
2741 if (!isTopLevel)
2742 return emitError(loc, "'attr-dict' directive can only be used as a "
2743 "top-level directive");
2744 if (hasAttrDict)
2745 return emitError(loc, "'attr-dict' directive has already been seen");
2746
2747 hasAttrDict = true;
2748 element = std::make_unique<AttrDictDirective>(withKeyword);
2749 return ::mlir::success();
2750 }
2751
2752 LogicalResult
parseCustomDirective(std::unique_ptr<Element> & element,llvm::SMLoc loc,bool isTopLevel)2753 FormatParser::parseCustomDirective(std::unique_ptr<Element> &element,
2754 llvm::SMLoc loc, bool isTopLevel) {
2755 llvm::SMLoc curLoc = curToken.getLoc();
2756
2757 // Parse the custom directive name.
2758 if (failed(
2759 parseToken(Token::less, "expected '<' before custom directive name")))
2760 return ::mlir::failure();
2761
2762 Token nameTok = curToken;
2763 if (failed(parseToken(Token::identifier,
2764 "expected custom directive name identifier")) ||
2765 failed(parseToken(Token::greater,
2766 "expected '>' after custom directive name")) ||
2767 failed(parseToken(Token::l_paren,
2768 "expected '(' before custom directive parameters")))
2769 return ::mlir::failure();
2770
2771 // Parse the child elements for this optional group.=
2772 std::vector<std::unique_ptr<Element>> elements;
2773 do {
2774 if (failed(parseCustomDirectiveParameter(elements)))
2775 return ::mlir::failure();
2776 if (curToken.getKind() != Token::comma)
2777 break;
2778 consumeToken();
2779 } while (true);
2780
2781 if (failed(parseToken(Token::r_paren,
2782 "expected ')' after custom directive parameters")))
2783 return ::mlir::failure();
2784
2785 // After parsing all of the elements, ensure that all type directives refer
2786 // only to variables.
2787 for (auto &ele : elements) {
2788 if (auto *typeEle = dyn_cast<TypeRefDirective>(ele.get())) {
2789 if (!isa<OperandVariable, ResultVariable>(typeEle->getOperand())) {
2790 return emitError(curLoc,
2791 "type_ref directives within a custom directive "
2792 "may only refer to variables");
2793 }
2794 }
2795 if (auto *typeEle = dyn_cast<TypeDirective>(ele.get())) {
2796 if (!isa<OperandVariable, ResultVariable>(typeEle->getOperand())) {
2797 return emitError(curLoc, "type directives within a custom directive "
2798 "may only refer to variables");
2799 }
2800 }
2801 }
2802
2803 element = std::make_unique<CustomDirective>(nameTok.getSpelling(),
2804 std::move(elements));
2805 return ::mlir::success();
2806 }
2807
parseCustomDirectiveParameter(std::vector<std::unique_ptr<Element>> & parameters)2808 LogicalResult FormatParser::parseCustomDirectiveParameter(
2809 std::vector<std::unique_ptr<Element>> ¶meters) {
2810 llvm::SMLoc childLoc = curToken.getLoc();
2811 parameters.push_back({});
2812 if (failed(parseElement(parameters.back(), /*isTopLevel=*/true)))
2813 return ::mlir::failure();
2814
2815 // Verify that the element can be placed within a custom directive.
2816 if (!isa<TypeRefDirective, TypeDirective, AttrDictDirective,
2817 AttributeVariable, OperandVariable, RegionVariable,
2818 SuccessorVariable>(parameters.back().get())) {
2819 return emitError(childLoc, "only variables and types may be used as "
2820 "parameters to a custom directive");
2821 }
2822 return ::mlir::success();
2823 }
2824
2825 LogicalResult
parseFunctionalTypeDirective(std::unique_ptr<Element> & element,Token tok,bool isTopLevel)2826 FormatParser::parseFunctionalTypeDirective(std::unique_ptr<Element> &element,
2827 Token tok, bool isTopLevel) {
2828 llvm::SMLoc loc = tok.getLoc();
2829 if (!isTopLevel)
2830 return emitError(
2831 loc, "'functional-type' is only valid as a top-level directive");
2832
2833 // Parse the main operand.
2834 std::unique_ptr<Element> inputs, results;
2835 if (failed(parseToken(Token::l_paren, "expected '(' before argument list")) ||
2836 failed(parseTypeDirectiveOperand(inputs)) ||
2837 failed(parseToken(Token::comma, "expected ',' after inputs argument")) ||
2838 failed(parseTypeDirectiveOperand(results)) ||
2839 failed(parseToken(Token::r_paren, "expected ')' after argument list")))
2840 return ::mlir::failure();
2841 element = std::make_unique<FunctionalTypeDirective>(std::move(inputs),
2842 std::move(results));
2843 return ::mlir::success();
2844 }
2845
2846 LogicalResult
parseOperandsDirective(std::unique_ptr<Element> & element,llvm::SMLoc loc,bool isTopLevel)2847 FormatParser::parseOperandsDirective(std::unique_ptr<Element> &element,
2848 llvm::SMLoc loc, bool isTopLevel) {
2849 if (isTopLevel) {
2850 if (fmt.allOperands || !seenOperands.empty())
2851 return emitError(loc, "'operands' directive creates overlap in format");
2852 fmt.allOperands = true;
2853 }
2854 element = std::make_unique<OperandsDirective>();
2855 return ::mlir::success();
2856 }
2857
2858 LogicalResult
parseRegionsDirective(std::unique_ptr<Element> & element,llvm::SMLoc loc,bool isTopLevel)2859 FormatParser::parseRegionsDirective(std::unique_ptr<Element> &element,
2860 llvm::SMLoc loc, bool isTopLevel) {
2861 if (!isTopLevel)
2862 return emitError(loc, "'regions' is only valid as a top-level directive");
2863 if (hasAllRegions || !seenRegions.empty())
2864 return emitError(loc, "'regions' directive creates overlap in format");
2865 hasAllRegions = true;
2866 element = std::make_unique<RegionsDirective>();
2867 return ::mlir::success();
2868 }
2869
2870 LogicalResult
parseResultsDirective(std::unique_ptr<Element> & element,llvm::SMLoc loc,bool isTopLevel)2871 FormatParser::parseResultsDirective(std::unique_ptr<Element> &element,
2872 llvm::SMLoc loc, bool isTopLevel) {
2873 if (isTopLevel)
2874 return emitError(loc, "'results' directive can not be used as a "
2875 "top-level directive");
2876 element = std::make_unique<ResultsDirective>();
2877 return ::mlir::success();
2878 }
2879
2880 LogicalResult
parseSuccessorsDirective(std::unique_ptr<Element> & element,llvm::SMLoc loc,bool isTopLevel)2881 FormatParser::parseSuccessorsDirective(std::unique_ptr<Element> &element,
2882 llvm::SMLoc loc, bool isTopLevel) {
2883 if (!isTopLevel)
2884 return emitError(loc,
2885 "'successors' is only valid as a top-level directive");
2886 if (hasAllSuccessors || !seenSuccessors.empty())
2887 return emitError(loc, "'successors' directive creates overlap in format");
2888 hasAllSuccessors = true;
2889 element = std::make_unique<SuccessorsDirective>();
2890 return ::mlir::success();
2891 }
2892
2893 LogicalResult
parseTypeDirective(std::unique_ptr<Element> & element,Token tok,bool isTopLevel,bool isTypeRef)2894 FormatParser::parseTypeDirective(std::unique_ptr<Element> &element, Token tok,
2895 bool isTopLevel, bool isTypeRef) {
2896 llvm::SMLoc loc = tok.getLoc();
2897 if (!isTopLevel)
2898 return emitError(loc, "'type' is only valid as a top-level directive");
2899
2900 std::unique_ptr<Element> operand;
2901 if (failed(parseToken(Token::l_paren, "expected '(' before argument list")) ||
2902 failed(parseTypeDirectiveOperand(operand, isTypeRef)) ||
2903 failed(parseToken(Token::r_paren, "expected ')' after argument list")))
2904 return ::mlir::failure();
2905 if (isTypeRef)
2906 element = std::make_unique<TypeRefDirective>(std::move(operand));
2907 else
2908 element = std::make_unique<TypeDirective>(std::move(operand));
2909 return ::mlir::success();
2910 }
2911
2912 LogicalResult
parseTypeDirectiveOperand(std::unique_ptr<Element> & element,bool isTypeRef)2913 FormatParser::parseTypeDirectiveOperand(std::unique_ptr<Element> &element,
2914 bool isTypeRef) {
2915 llvm::SMLoc loc = curToken.getLoc();
2916 if (failed(parseElement(element, /*isTopLevel=*/false)))
2917 return ::mlir::failure();
2918 if (isa<LiteralElement>(element.get()))
2919 return emitError(
2920 loc, "'type' directive operand expects variable or directive operand");
2921
2922 if (auto *var = dyn_cast<OperandVariable>(element.get())) {
2923 unsigned opIdx = var->getVar() - op.operand_begin();
2924 if (!isTypeRef && (fmt.allOperandTypes || seenOperandTypes.test(opIdx)))
2925 return emitError(loc, "'type' of '" + var->getVar()->name +
2926 "' is already bound");
2927 if (isTypeRef && !(fmt.allOperandTypes || seenOperandTypes.test(opIdx)))
2928 return emitError(loc, "'type_ref' of '" + var->getVar()->name +
2929 "' is not bound by a prior 'type' directive");
2930 seenOperandTypes.set(opIdx);
2931 } else if (auto *var = dyn_cast<ResultVariable>(element.get())) {
2932 unsigned resIdx = var->getVar() - op.result_begin();
2933 if (!isTypeRef && (fmt.allResultTypes || seenResultTypes.test(resIdx)))
2934 return emitError(loc, "'type' of '" + var->getVar()->name +
2935 "' is already bound");
2936 if (isTypeRef && !(fmt.allResultTypes || seenResultTypes.test(resIdx)))
2937 return emitError(loc, "'type_ref' of '" + var->getVar()->name +
2938 "' is not bound by a prior 'type' directive");
2939 seenResultTypes.set(resIdx);
2940 } else if (isa<OperandsDirective>(&*element)) {
2941 if (!isTypeRef && (fmt.allOperandTypes || seenOperandTypes.any()))
2942 return emitError(loc, "'operands' 'type' is already bound");
2943 if (isTypeRef && !(fmt.allOperandTypes || seenOperandTypes.all()))
2944 return emitError(
2945 loc,
2946 "'operands' 'type_ref' is not bound by a prior 'type' directive");
2947 fmt.allOperandTypes = true;
2948 } else if (isa<ResultsDirective>(&*element)) {
2949 if (!isTypeRef && (fmt.allResultTypes || seenResultTypes.any()))
2950 return emitError(loc, "'results' 'type' is already bound");
2951 if (isTypeRef && !(fmt.allResultTypes || seenResultTypes.all()))
2952 return emitError(
2953 loc, "'results' 'type_ref' is not bound by a prior 'type' directive");
2954 fmt.allResultTypes = true;
2955 } else {
2956 return emitError(loc, "invalid argument to 'type' directive");
2957 }
2958 return ::mlir::success();
2959 }
2960
2961 //===----------------------------------------------------------------------===//
2962 // Interface
2963 //===----------------------------------------------------------------------===//
2964
generateOpFormat(const Operator & constOp,OpClass & opClass)2965 void mlir::tblgen::generateOpFormat(const Operator &constOp, OpClass &opClass) {
2966 // TODO: Operator doesn't expose all necessary functionality via
2967 // the const interface.
2968 Operator &op = const_cast<Operator &>(constOp);
2969 if (!op.hasAssemblyFormat())
2970 return;
2971
2972 // Parse the format description.
2973 llvm::SourceMgr mgr;
2974 mgr.AddNewSourceBuffer(
2975 llvm::MemoryBuffer::getMemBuffer(op.getAssemblyFormat()), llvm::SMLoc());
2976 OperationFormat format(op);
2977 if (failed(FormatParser(mgr, format, op).parse())) {
2978 // Exit the process if format errors are treated as fatal.
2979 if (formatErrorIsFatal) {
2980 // Invoke the interrupt handlers to run the file cleanup handlers.
2981 llvm::sys::RunInterruptHandlers();
2982 std::exit(1);
2983 }
2984 return;
2985 }
2986
2987 // Generate the printer and parser based on the parsed format.
2988 format.genParser(op, opClass);
2989 format.genPrinter(op, opClass);
2990 }
2991