• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //===- mlir-linalg-ods-gen.cpp - Linalg ODS generation from math form -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains the implementation for the Tensor Comprehension-inspired
10 // parser and ODS pretty-printer for specifying Linalg "named ops" from a
11 // mathematical form.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "mlir/IR/AffineExpr.h"
16 #include "mlir/IR/AffineMap.h"
17 #include "mlir/IR/MLIRContext.h"
18 #include "mlir/IR/OpImplementation.h"
19 #include "mlir/Support/FileUtilities.h"
20 #include "mlir/Support/LLVM.h"
21 #include "mlir/Support/LogicalResult.h"
22 #include "llvm/ADT/SetVector.h"
23 #include "llvm/Support/Casting.h"
24 #include "llvm/Support/CommandLine.h"
25 #include "llvm/Support/FormatVariadic.h"
26 #include "llvm/Support/ToolOutputFile.h"
27 
28 #define DEBUG_TYPE "linalg-ods-gen"
29 
30 static llvm::cl::OptionCategory ODSGenCat("Linalg ODS Gen");
31 
32 // Commandline options
33 static llvm::cl::opt<std::string>
34     inputFilename(llvm::cl::Positional, llvm::cl::desc("<input file>"),
35                   llvm::cl::init("-"), llvm::cl::value_desc("filename"));
36 
37 static llvm::cl::opt<std::string>
38     outputFilename("o", llvm::cl::desc("Output filename"),
39                    llvm::cl::value_desc("filename"), llvm::cl::init("-"));
40 
41 static llvm::cl::opt<bool>
42     genODSDecl("gen-ods-decl", llvm::cl::desc("Emit the ODS ops declarations."),
43                llvm::cl::cat(ODSGenCat));
44 
45 static llvm::cl::opt<bool>
46     genODSImpl("gen-impl", llvm::cl::desc("Emit the ops implementations"),
47                llvm::cl::init(false), llvm::cl::cat(ODSGenCat));
48 
49 static llvm::cl::opt<bool> testEmitIncludeTdHeader(
50     "test-emit-include-td-header",
51     llvm::cl::desc("Include LinalgStructuredOps.td for end-to-end "
52                    "tblgen testing."),
53     llvm::cl::init(false), llvm::cl::cat(ODSGenCat));
54 
55 using llvm::SetVector;
56 using llvm::SMLoc;
57 using llvm::StringRef;
58 using llvm::Twine;
59 
60 using namespace mlir;
61 
62 //===----------------------------------------------------------------------===//
63 // Lexer
64 //===----------------------------------------------------------------------===//
65 
66 namespace {
67 /// This class represents a specific token in the input format.
68 class Token {
69 public:
70   enum class Kind {
71     // Markers.
72     eof,
73     error,
74 
75     // Tokens with no info.
76     colon,
77     comma,
78     equal,
79     gt,
80     l_brace,
81     l_paren,
82     lt,
83     minus,
84     plus,
85     r_brace,
86     r_paren,
87     semicolon,
88     star,
89 
90     // Keywords.
91     kw_def,
92     FIRST_KEYWORD = kw_def,
93     kw_ods_def,
94     kw_floordiv,
95     kw_ceildiv,
96     kw_mod,
97     LAST_KEYWORD = kw_mod,
98 
99     // String valued tokens.
100     id,
101     integer,
102   };
103 
Token(Kind kind,StringRef spelling)104   Token(Kind kind, StringRef spelling) : kind(kind), spelling(spelling) {}
105 
106   /// Return the bytes that make up this token.
getSpelling() const107   StringRef getSpelling() const { return spelling; }
108 
109   /// Return the kind of this token.
getKind() const110   Kind getKind() const { return kind; }
111 
112   /// Return a location for this token.
getLoc() const113   llvm::SMLoc getLoc() const {
114     return llvm::SMLoc::getFromPointer(spelling.data());
115   }
116 
117   /// Return if this token is a keyword.
isKeyword() const118   bool isKeyword() const {
119     return kind >= Kind::FIRST_KEYWORD && kind <= Kind::LAST_KEYWORD;
120   }
is(Kind k) const121   bool is(Kind k) const { return kind == k; }
isNot(Kind k) const122   bool isNot(Kind k) const { return kind != k; }
123 
getUInt64IntegerValue() const124   Optional<uint64_t> getUInt64IntegerValue() const {
125     bool isHex = spelling.size() > 1 && spelling[1] == 'x';
126 
127     uint64_t result = 0;
128     if (spelling.getAsInteger(isHex ? 0 : 10, result))
129       return None;
130     return result;
131   }
132 
133 private:
134   /// Discriminator that indicates the kind of token this is.
135   Kind kind;
136 
137   /// A reference to the entire token contents; this is always a pointer into
138   /// a memory buffer owned by the source manager.
139   StringRef spelling;
140 };
141 
142 /// This class implements a simple lexer.
143 class Lexer {
144 public:
145   Lexer(llvm::SourceMgr &mgr);
146 
147   /// Lex the next token and return it.
148   Token lexToken();
149 
150   /// Emit an error to the lexer with the given location and message.
151   Token emitError(llvm::SMLoc loc, const Twine &msg);
152   Token emitError(const char *loc, const Twine &msg);
153 
154 private:
formToken(Token::Kind kind,const char * tokStart)155   Token formToken(Token::Kind kind, const char *tokStart) {
156     return Token(kind, StringRef(tokStart, curPtr - tokStart));
157   }
158 
159   /// Return the next character in the stream.
160   int getNextChar();
161 
162   /// Lex an identifier.
163   Token lexIdentifier(const char *tokStart);
164 
165   // Lex an integer.
166   Token lexInteger(const char *tokStart);
167 
168   // Skip a comment line, starting with a '//'.
169   void skipComment();
170 
171   llvm::SourceMgr &srcMgr;
172   StringRef curBuffer;
173   const char *curPtr;
174 };
175 } // end anonymous namespace
176 
Lexer(llvm::SourceMgr & mgr)177 Lexer::Lexer(llvm::SourceMgr &mgr) : srcMgr(mgr) {
178   curBuffer = srcMgr.getMemoryBuffer(mgr.getMainFileID())->getBuffer();
179   curPtr = curBuffer.begin();
180 }
181 
emitError(llvm::SMLoc loc,const Twine & msg)182 Token Lexer::emitError(llvm::SMLoc loc, const Twine &msg) {
183   srcMgr.PrintMessage(loc, llvm::SourceMgr::DK_Error, msg);
184   return formToken(Token::Kind::error, loc.getPointer());
185 }
emitError(const char * loc,const Twine & msg)186 Token Lexer::emitError(const char *loc, const Twine &msg) {
187   return emitError(llvm::SMLoc::getFromPointer(loc), msg);
188 }
189 
getNextChar()190 int Lexer::getNextChar() {
191   char curChar = *curPtr++;
192   switch (curChar) {
193   default:
194     return (unsigned char)curChar;
195   case 0: {
196     // A nul character in the stream is either the end of the current buffer
197     // or a random nul in the file. Disambiguate that here.
198     if (curPtr - 1 != curBuffer.end())
199       return 0;
200 
201     // Otherwise, return end of file.
202     --curPtr;
203     return EOF;
204   }
205   case '\n':
206   case '\r':
207     // Handle the newline character by ignoring it and incrementing the line
208     // count. However, be careful about 'dos style' files with \n\r in them.
209     // Only treat a \n\r or \r\n as a single line.
210     if ((*curPtr == '\n' || (*curPtr == '\r')) && *curPtr != curChar)
211       ++curPtr;
212     return '\n';
213   }
214 }
215 
lexToken()216 Token Lexer::lexToken() {
217   while (true) {
218     const char *tokStart = curPtr;
219 
220     // This always consumes at least one character.
221     int curChar = getNextChar();
222     switch (curChar) {
223     default:
224       // Handle identifiers: [a-zA-Z_]
225       if (isalpha(curChar) || curChar == '_')
226         return lexIdentifier(tokStart);
227 
228       // Handle integers: [0-9]
229       if (isdigit(curChar))
230         return lexInteger(tokStart);
231 
232       // Unknown character, emit an error.
233       return emitError(tokStart, "unexpected character");
234 
235     case EOF:
236       // Return EOF denoting the end of lexing.
237       return formToken(Token::Kind::eof, tokStart);
238 
239     // Lex punctuation.
240     case ':':
241       return formToken(Token::Kind::colon, tokStart);
242     case ',':
243       return formToken(Token::Kind::comma, tokStart);
244     case '=':
245       return formToken(Token::Kind::equal, tokStart);
246     case '{':
247       return formToken(Token::Kind::l_brace, tokStart);
248     case '(':
249       return formToken(Token::Kind::l_paren, tokStart);
250     case '}':
251       return formToken(Token::Kind::r_brace, tokStart);
252     case ')':
253       return formToken(Token::Kind::r_paren, tokStart);
254     case '<':
255       return formToken(Token::Kind::lt, tokStart);
256     case '>':
257       return formToken(Token::Kind::gt, tokStart);
258     case '+':
259       return formToken(Token::Kind::plus, tokStart);
260     case '-':
261       return formToken(Token::Kind::minus, tokStart);
262     case ';':
263       return formToken(Token::Kind::semicolon, tokStart);
264     case '*':
265       return formToken(Token::Kind::star, tokStart);
266     case '/':
267       if (*curPtr == '/') {
268         skipComment();
269         continue;
270       }
271       // Unknown character, emit an error.
272       return emitError(tokStart, "unexpected character: not a comment");
273 
274     // Ignore whitespace characters.
275     case 0:
276     case ' ':
277     case '\t':
278     case '\n':
279       return lexToken();
280     }
281   }
282 }
283 
lexIdentifier(const char * tokStart)284 Token Lexer::lexIdentifier(const char *tokStart) {
285   // Match the rest of the identifier regex: [0-9a-zA-Z_\-]*
286   while (isalnum(*curPtr) || *curPtr == '_' || *curPtr == '-')
287     ++curPtr;
288 
289   // Check to see if this identifier is a keyword.
290   StringRef str(tokStart, curPtr - tokStart);
291   Token::Kind kind = StringSwitch<Token::Kind>(str)
292                          .Case("def", Token::Kind::kw_def)
293                          .Case("ods_def", Token::Kind::kw_ods_def)
294                          .Case("floordiv", Token::Kind::kw_floordiv)
295                          .Case("ceildiv", Token::Kind::kw_ceildiv)
296                          .Case("mod", Token::Kind::kw_mod)
297                          .Default(Token::Kind::id);
298 
299   return Token(kind, str);
300 }
301 
lexInteger(const char * tokStart)302 Token Lexer::lexInteger(const char *tokStart) {
303   // Match the rest of the identifier regex: [0-9a-zA-Z_\-]*
304   while (isdigit(*curPtr))
305     ++curPtr;
306 
307   StringRef str(tokStart, curPtr - tokStart);
308   return Token(Token::Kind::integer, str);
309 }
310 
311 /// Skip a comment line, starting with a '//'.
skipComment()312 void Lexer::skipComment() {
313   // Advance over the second '/' in a '//' comment.
314   assert(*curPtr == '/');
315   ++curPtr;
316 
317   while (true) {
318     switch (*curPtr++) {
319     case '\n':
320     case '\r':
321       // Newline is end of comment.
322       return;
323     case 0:
324       // If this is the end of the buffer, end the comment.
325       if (curPtr - 1 == curBuffer.end()) {
326         --curPtr;
327         return;
328       }
329       LLVM_FALLTHROUGH;
330     default:
331       // Skip over other characters.
332       break;
333     }
334   }
335 }
336 
337 namespace {
338 
339 class Parser {
340 public:
Parser(llvm::SourceMgr & mgr,MLIRContext * ctx)341   Parser(llvm::SourceMgr &mgr, MLIRContext *ctx)
342       : lexer(mgr), curToken(lexer.lexToken()), context(ctx) {}
343 
344   //===--------------------------------------------------------------------===//
345   // Lexer Utilities
346   //===--------------------------------------------------------------------===//
347 
348   /// Advance the current lexer onto the next token.
consumeToken()349   void consumeToken() {
350     assert(curToken.getKind() != Token::Kind::eof &&
351            curToken.getKind() != Token::Kind::error &&
352            "shouldn't advance past EOF or errors");
353     curToken = lexer.lexToken();
354   }
consumeToken(Token::Kind kind)355   void consumeToken(Token::Kind kind) {
356     assert(curToken.getKind() == kind && "unexpected token");
357     curToken = lexer.lexToken();
358   }
parseToken(Token::Kind kind,const Twine & msg)359   LogicalResult parseToken(Token::Kind kind, const Twine &msg) {
360     if (curToken.getKind() != kind)
361       return emitError(curToken.getLoc(), msg);
362     consumeToken();
363     return success();
364   }
emitError(llvm::SMLoc loc,const Twine & msg)365   LogicalResult emitError(llvm::SMLoc loc, const Twine &msg) {
366     lexer.emitError(loc, msg);
367     return failure();
368   }
emitError(const Twine & msg)369   LogicalResult emitError(const Twine &msg) {
370     return emitError(curToken.getLoc(), msg);
371   }
consumeIf(Token::Kind kind)372   bool consumeIf(Token::Kind kind) {
373     if (curToken.isNot(kind))
374       return false;
375     consumeToken(kind);
376     return true;
377   }
378   LogicalResult
parseCommaSeparatedList(llvm::function_ref<ParseResult ()> parseElement)379   parseCommaSeparatedList(llvm::function_ref<ParseResult()> parseElement) {
380     // Non-empty case starts with an element.
381     if (parseElement())
382       return failure();
383 
384     // Otherwise we have a list of comma separated elements.
385     while (consumeIf(Token::Kind::comma)) {
386       if (parseElement())
387         return failure();
388     }
389     return success();
390   }
391   LogicalResult
parseCommaSeparatedListUntil(Token::Kind rightToken,llvm::function_ref<ParseResult ()> parseElement,bool allowEmptyList)392   parseCommaSeparatedListUntil(Token::Kind rightToken,
393                                llvm::function_ref<ParseResult()> parseElement,
394                                bool allowEmptyList) {
395     // Handle the empty case.
396     if (curToken.is(rightToken)) {
397       if (!allowEmptyList)
398         return emitError("expected list element");
399       consumeToken(rightToken);
400       return success();
401     }
402 
403     if (failed(parseCommaSeparatedList(parseElement)) ||
404         failed(
405             parseToken(rightToken, "expected ',' or right-terminating token")))
406       return failure();
407 
408     return success();
409   }
410 
411   Lexer lexer;
412   Token curToken;
413   MLIRContext *context;
414 };
415 } // namespace
416 
417 //===----------------------------------------------------------------------===//
418 // Affine parsing.
419 //===----------------------------------------------------------------------===//
420 
421 namespace {
422 
423 /// Lower precedence ops (all at the same precedence level). LNoOp is false in
424 /// the boolean sense.
425 enum AffineLowPrecOp {
426   /// Null value.
427   LNoOp,
428   Add,
429   Sub
430 };
431 
432 /// Higher precedence ops - all at the same precedence level. HNoOp is false
433 /// in the boolean sense.
434 enum AffineHighPrecOp {
435   /// Null value.
436   HNoOp,
437   Mul,
438   FloorDiv,
439   CeilDiv,
440   Mod
441 };
442 
443 using AffineDimList = SmallVector<std::pair<StringRef, AffineExpr>, 4>;
444 using AffineSymbolList = SmallVector<std::pair<StringRef, AffineExpr>, 4>;
445 
446 /// This is a specialized parser for affine expressions.
447 class AffineParser {
448 public:
AffineParser(Parser & p,std::function<AffineExpr (StringRef)> bareIdParsingHook,AffineDimList & dimList,AffineSymbolList & symbolList)449   explicit AffineParser(Parser &p,
450                         std::function<AffineExpr(StringRef)> bareIdParsingHook,
451                         AffineDimList &dimList, AffineSymbolList &symbolList)
452       : parser(p), bareIdFallback(bareIdParsingHook), dims(dimList),
453         symbols(symbolList) {}
454 
455   /// Parse a comma-separated list of affine exprs.
456   SmallVector<AffineExpr, 4>
457   parseAffineExprs(Token::Kind lDelim = Token::Kind::l_paren,
458                    Token::Kind rDelim = Token::Kind::r_paren);
459 
460   /// Parse a single affine expr.`.
461   AffineExpr parseAffineExpr();
462 
463 private:
464   // Binary affine op parsing.
465   AffineLowPrecOp consumeIfLowPrecOp();
466   AffineHighPrecOp consumeIfHighPrecOp();
467 
468   // AffineExpr parsing.
469   AffineExpr parseParentheticalExpr();
470   AffineExpr parseNegateExpression(AffineExpr lhs);
471   AffineExpr parseIntegerExpr();
472   AffineExpr parseBareIdExpr();
473 
474   AffineExpr getAffineBinaryOpExpr(AffineHighPrecOp op, AffineExpr lhs,
475                                    AffineExpr rhs, SMLoc opLoc);
476   AffineExpr getAffineBinaryOpExpr(AffineLowPrecOp op, AffineExpr lhs,
477                                    AffineExpr rhs);
478   AffineExpr parseAffineOperandExpr(AffineExpr lhs);
479   AffineExpr parseAffineLowPrecOpExpr(AffineExpr llhs, AffineLowPrecOp llhsOp);
480   AffineExpr parseAffineHighPrecOpExpr(AffineExpr llhs, AffineHighPrecOp llhsOp,
481                                        SMLoc llhsOpLoc);
482 
483   Parser &parser;
484   std::function<AffineExpr(StringRef)> bareIdFallback;
485   AffineDimList &dims;
486   AffineSymbolList &symbols;
487 };
488 } // end anonymous namespace
489 
490 /// Create an affine binary high precedence op expression (mul's, div's, mod).
491 /// opLoc is the location of the op token to be used to report errors
492 /// for non-conforming expressions.
getAffineBinaryOpExpr(AffineHighPrecOp op,AffineExpr lhs,AffineExpr rhs,SMLoc opLoc)493 AffineExpr AffineParser::getAffineBinaryOpExpr(AffineHighPrecOp op,
494                                                AffineExpr lhs, AffineExpr rhs,
495                                                SMLoc opLoc) {
496   switch (op) {
497   case Mul:
498     if (!lhs.isSymbolicOrConstant() && !rhs.isSymbolicOrConstant()) {
499       parser.emitError(opLoc,
500                        "non-affine expression: at least one of the multiply "
501                        "operands has to be either a constant or symbolic");
502       return nullptr;
503     }
504     return lhs * rhs;
505   case FloorDiv:
506     if (!rhs.isSymbolicOrConstant()) {
507       parser.emitError(opLoc,
508                        "non-affine expression: right operand of floordiv "
509                        "has to be either a constant or symbolic");
510       return nullptr;
511     }
512     return lhs.floorDiv(rhs);
513   case CeilDiv:
514     if (!rhs.isSymbolicOrConstant()) {
515       parser.emitError(opLoc, "non-affine expression: right operand of ceildiv "
516                               "has to be either a constant or symbolic");
517       return nullptr;
518     }
519     return lhs.ceilDiv(rhs);
520   case Mod:
521     if (!rhs.isSymbolicOrConstant()) {
522       parser.emitError(opLoc, "non-affine expression: right operand of mod "
523                               "has to be either a constant or symbolic");
524       return nullptr;
525     }
526     return lhs % rhs;
527   case HNoOp:
528     llvm_unreachable("can't create affine expression for null high prec op");
529     return nullptr;
530   }
531   llvm_unreachable("Unknown AffineHighPrecOp");
532 }
533 
534 /// Create an affine binary low precedence op expression (add, sub).
getAffineBinaryOpExpr(AffineLowPrecOp op,AffineExpr lhs,AffineExpr rhs)535 AffineExpr AffineParser::getAffineBinaryOpExpr(AffineLowPrecOp op,
536                                                AffineExpr lhs, AffineExpr rhs) {
537   switch (op) {
538   case AffineLowPrecOp::Add:
539     return lhs + rhs;
540   case AffineLowPrecOp::Sub:
541     return lhs - rhs;
542   case AffineLowPrecOp::LNoOp:
543     llvm_unreachable("can't create affine expression for null low prec op");
544     return nullptr;
545   }
546   llvm_unreachable("Unknown AffineLowPrecOp");
547 }
548 
549 /// Consume this token if it is a lower precedence affine op (there are only
550 /// two precedence levels).
consumeIfLowPrecOp()551 AffineLowPrecOp AffineParser::consumeIfLowPrecOp() {
552   switch (parser.curToken.getKind()) {
553   case Token::Kind::plus:
554     parser.consumeToken();
555     return AffineLowPrecOp::Add;
556   case Token::Kind::minus:
557     parser.consumeToken();
558     return AffineLowPrecOp::Sub;
559   default:
560     return AffineLowPrecOp::LNoOp;
561   }
562 }
563 
564 /// Consume this token if it is a higher precedence affine op (there are only
565 /// two precedence levels)
consumeIfHighPrecOp()566 AffineHighPrecOp AffineParser::consumeIfHighPrecOp() {
567   switch (parser.curToken.getKind()) {
568   case Token::Kind::star:
569     parser.consumeToken(Token::Kind::star);
570     return Mul;
571   case Token::Kind::kw_floordiv:
572     parser.consumeToken(Token::Kind::kw_floordiv);
573     return FloorDiv;
574   case Token::Kind::kw_ceildiv:
575     parser.consumeToken(Token::Kind::kw_ceildiv);
576     return CeilDiv;
577   case Token::Kind::kw_mod:
578     parser.consumeToken(Token::Kind::kw_mod);
579     return Mod;
580   default:
581     return HNoOp;
582   }
583 }
584 
585 /// Parse a high precedence op expression list: mul, div, and mod are high
586 /// precedence binary ops, i.e., parse a
587 ///   expr_1 op_1 expr_2 op_2 ... expr_n
588 /// where op_1, op_2 are all a AffineHighPrecOp (mul, div, mod).
589 /// All affine binary ops are left associative.
590 /// Given llhs, returns (llhs llhsOp lhs) op rhs, or (lhs op rhs) if llhs is
591 /// null. If no rhs can be found, returns (llhs llhsOp lhs) or lhs if llhs is
592 /// null. llhsOpLoc is the location of the llhsOp token that will be used to
593 /// report an error for non-conforming expressions.
parseAffineHighPrecOpExpr(AffineExpr llhs,AffineHighPrecOp llhsOp,SMLoc llhsOpLoc)594 AffineExpr AffineParser::parseAffineHighPrecOpExpr(AffineExpr llhs,
595                                                    AffineHighPrecOp llhsOp,
596                                                    SMLoc llhsOpLoc) {
597   AffineExpr lhs = parseAffineOperandExpr(llhs);
598   if (!lhs)
599     return nullptr;
600 
601   // Found an LHS. Parse the remaining expression.
602   auto opLoc = parser.curToken.getLoc();
603   if (AffineHighPrecOp op = consumeIfHighPrecOp()) {
604     if (llhs) {
605       AffineExpr expr = getAffineBinaryOpExpr(llhsOp, llhs, lhs, opLoc);
606       if (!expr)
607         return nullptr;
608       return parseAffineHighPrecOpExpr(expr, op, opLoc);
609     }
610     // No LLHS, get RHS
611     return parseAffineHighPrecOpExpr(lhs, op, opLoc);
612   }
613 
614   // This is the last operand in this expression.
615   if (llhs)
616     return getAffineBinaryOpExpr(llhsOp, llhs, lhs, llhsOpLoc);
617 
618   // No llhs, 'lhs' itself is the expression.
619   return lhs;
620 }
621 
622 /// Parse an affine expression inside parentheses.
623 ///
624 ///   affine-expr ::= `(` affine-expr `)`
parseParentheticalExpr()625 AffineExpr AffineParser::parseParentheticalExpr() {
626   if (failed(parser.parseToken(Token::Kind::l_paren, "expected '('")))
627     return nullptr;
628   if (parser.curToken.is(Token::Kind::r_paren))
629     return (parser.emitError("no expression inside parentheses"), nullptr);
630 
631   auto expr = parseAffineExpr();
632   if (!expr)
633     return nullptr;
634   if (failed(parser.parseToken(Token::Kind::r_paren, "expected ')'")))
635     return nullptr;
636 
637   return expr;
638 }
639 
640 /// Parse the negation expression.
641 ///
642 ///   affine-expr ::= `-` affine-expr
parseNegateExpression(AffineExpr lhs)643 AffineExpr AffineParser::parseNegateExpression(AffineExpr lhs) {
644   if (failed(parser.parseToken(Token::Kind::minus, "expected '-'")))
645     return nullptr;
646 
647   AffineExpr operand = parseAffineOperandExpr(lhs);
648   // Since negation has the highest precedence of all ops (including high
649   // precedence ops) but lower than parentheses, we are only going to use
650   // parseAffineOperandExpr instead of parseAffineExpr here.
651   if (!operand)
652     // Extra error message although parseAffineOperandExpr would have
653     // complained. Leads to a better diagnostic.
654     return (parser.emitError("missing operand of negation"), nullptr);
655   return (-1) * operand;
656 }
657 
658 /// Parse a bare id that may appear in an affine expression.
659 ///
660 ///   affine-expr ::= bare-id
parseBareIdExpr()661 AffineExpr AffineParser::parseBareIdExpr() {
662   if (parser.curToken.isNot(Token::Kind::id))
663     return (parser.emitError("expected id"), nullptr);
664 
665   StringRef sRef = parser.curToken.getSpelling();
666   for (auto &list : {dims, symbols}) {
667     for (auto entry : list) {
668       if (entry.first == sRef) {
669         parser.consumeToken(Token::Kind::id);
670         return entry.second;
671       }
672     }
673   }
674 
675   // Not found, check fallback path.
676   AffineExpr expr = bareIdFallback(sRef);
677   if (expr) {
678     parser.consumeToken(Token::Kind::id);
679     return expr;
680   }
681 
682   return (parser.emitError("use of undeclared id"), nullptr);
683 }
684 
685 /// Parse a positive integral constant appearing in an affine expression.
686 ///
687 ///   affine-expr ::= integer-literal
parseIntegerExpr()688 AffineExpr AffineParser::parseIntegerExpr() {
689   auto val = parser.curToken.getUInt64IntegerValue();
690   if (!val.hasValue() || (int64_t)val.getValue() < 0)
691     return (parser.emitError("constant too large for index"), nullptr);
692 
693   parser.consumeToken(Token::Kind::integer);
694   return getAffineConstantExpr((int64_t)val.getValue(), parser.context);
695 }
696 
697 /// Parses an expression that can be a valid operand of an affine expression.
698 /// lhs: if non-null, lhs is an affine expression that is the lhs of a binary
699 /// operator, the rhs of which is being parsed. This is used to determine
700 /// whether an error should be emitted for a missing right operand.
701 //  Eg: for an expression without parentheses (like i + j + k + l), each
702 //  of the four identifiers is an operand. For i + j*k + l, j*k is not an
703 //  operand expression, it's an op expression and will be parsed via
704 //  parseAffineHighPrecOpExpression(). However, for i + (j*k) + -l, (j*k) and
705 //  -l are valid operands that will be parsed by this function.
parseAffineOperandExpr(AffineExpr lhs)706 AffineExpr AffineParser::parseAffineOperandExpr(AffineExpr lhs) {
707   switch (parser.curToken.getKind()) {
708   case Token::Kind::id:
709     return parseBareIdExpr();
710   case Token::Kind::integer:
711     return parseIntegerExpr();
712   case Token::Kind::l_paren:
713     return parseParentheticalExpr();
714   case Token::Kind::minus:
715     return parseNegateExpression(lhs);
716   case Token::Kind::kw_ceildiv:
717   case Token::Kind::kw_floordiv:
718   case Token::Kind::kw_mod:
719   case Token::Kind::plus:
720   case Token::Kind::star:
721     if (lhs)
722       parser.emitError("missing right operand of binary operator");
723     else
724       parser.emitError("missing left operand of binary operator");
725     return nullptr;
726   default:
727     if (lhs)
728       parser.emitError("missing right operand of binary operator");
729     else
730       parser.emitError("expected affine expression");
731     return nullptr;
732   }
733 }
734 
735 /// Parse affine expressions that are bare-id's, integer constants,
736 /// parenthetical affine expressions, and affine op expressions that are a
737 /// composition of those.
738 ///
739 /// All binary op's associate from left to right.
740 ///
741 /// {add, sub} have lower precedence than {mul, div, and mod}.
742 ///
743 /// Add, sub'are themselves at the same precedence level. Mul, floordiv,
744 /// ceildiv, and mod are at the same higher precedence level. Negation has
745 /// higher precedence than any binary op.
746 ///
747 /// llhs: the affine expression appearing on the left of the one being parsed.
748 /// This function will return ((llhs llhsOp lhs) op rhs) if llhs is non null,
749 /// and lhs op rhs otherwise; if there is no rhs, llhs llhsOp lhs is returned
750 /// if llhs is non-null; otherwise lhs is returned. This is to deal with left
751 /// associativity.
752 ///
753 /// Eg: when the expression is e1 + e2*e3 + e4, with e1 as llhs, this function
754 /// will return the affine expr equivalent of (e1 + (e2*e3)) + e4, where
755 /// (e2*e3) will be parsed using parseAffineHighPrecOpExpr().
parseAffineLowPrecOpExpr(AffineExpr llhs,AffineLowPrecOp llhsOp)756 AffineExpr AffineParser::parseAffineLowPrecOpExpr(AffineExpr llhs,
757                                                   AffineLowPrecOp llhsOp) {
758   AffineExpr lhs;
759   if (!(lhs = parseAffineOperandExpr(llhs)))
760     return nullptr;
761 
762   // Found an LHS. Deal with the ops.
763   if (AffineLowPrecOp lOp = consumeIfLowPrecOp()) {
764     if (llhs) {
765       AffineExpr sum = getAffineBinaryOpExpr(llhsOp, llhs, lhs);
766       return parseAffineLowPrecOpExpr(sum, lOp);
767     }
768     // No LLHS, get RHS and form the expression.
769     return parseAffineLowPrecOpExpr(lhs, lOp);
770   }
771   auto opLoc = parser.curToken.getLoc();
772   if (AffineHighPrecOp hOp = consumeIfHighPrecOp()) {
773     // We have a higher precedence op here. Get the rhs operand for the llhs
774     // through parseAffineHighPrecOpExpr.
775     AffineExpr highRes = parseAffineHighPrecOpExpr(lhs, hOp, opLoc);
776     if (!highRes)
777       return nullptr;
778 
779     // If llhs is null, the product forms the first operand of the yet to be
780     // found expression. If non-null, the op to associate with llhs is llhsOp.
781     AffineExpr expr =
782         llhs ? getAffineBinaryOpExpr(llhsOp, llhs, highRes) : highRes;
783 
784     // Recurse for subsequent low prec op's after the affine high prec op
785     // expression.
786     if (AffineLowPrecOp nextOp = consumeIfLowPrecOp())
787       return parseAffineLowPrecOpExpr(expr, nextOp);
788     return expr;
789   }
790   // Last operand in the expression list.
791   if (llhs)
792     return getAffineBinaryOpExpr(llhsOp, llhs, lhs);
793   // No llhs, 'lhs' itself is the expression.
794   return lhs;
795 }
796 
797 /// Parse an affine expression.
798 ///  affine-expr ::= `(` affine-expr `)`
799 ///                | `-` affine-expr
800 ///                | affine-expr `+` affine-expr
801 ///                | affine-expr `-` affine-expr
802 ///                | affine-expr `*` affine-expr
803 ///                | affine-expr `floordiv` affine-expr
804 ///                | affine-expr `ceildiv` affine-expr
805 ///                | affine-expr `mod` affine-expr
806 ///                | bare-id
807 ///                | integer-literal
808 ///
809 /// Additional conditions are checked depending on the production. For eg.,
810 /// one of the operands for `*` has to be either constant/symbolic; the second
811 /// operand for floordiv, ceildiv, and mod has to be a positive integer.
parseAffineExpr()812 AffineExpr AffineParser::parseAffineExpr() {
813   return parseAffineLowPrecOpExpr(nullptr, AffineLowPrecOp::LNoOp);
814 }
815 
parseAffineExprs(Token::Kind lDelim,Token::Kind rDelim)816 SmallVector<AffineExpr, 4> AffineParser::parseAffineExprs(Token::Kind lDelim,
817                                                           Token::Kind rDelim) {
818   parser.parseToken(lDelim, "expected lDelim at start of affine expr list");
819 
820   SmallVector<AffineExpr, 4> exprs;
821   auto parseElt = [&]() -> LogicalResult {
822     auto elt = parseAffineExpr();
823     exprs.push_back(elt);
824     return elt ? success() : failure();
825   };
826 
827   if (failed(parser.parseCommaSeparatedListUntil(rDelim, parseElt,
828                                                  /*allowEmptyList=*/true)))
829     llvm_unreachable("Failed AffineExpr parsing");
830 
831   return exprs;
832 }
833 
834 //===----------------------------------------------------------------------===//
835 // TC parsing.
836 //===----------------------------------------------------------------------===//
837 
838 namespace {
839 
840 /// Base class for expressions involved in TC parsing.
841 struct Expression {
842   enum class Kind {
843     Uninitialized = 0,
844     TensorExpr = 1,
845     TensorUse = 2,
846   };
847 
Expression__anon084c004d0511::Expression848   explicit Expression(Kind k = Kind::Uninitialized) : kind(k) {}
849   virtual ~Expression() = default;
850 
operator bool__anon084c004d0511::Expression851   operator bool() const { return kind != Kind::Uninitialized; }
852 
853   Kind kind;
854 };
855 
856 /// Encodes a tensor use of the form:
857 ///
858 ///   affine-expr-list ::= affine-expr (`,` affine-expr)*
859 ///   tensor-use ::= bare-id `(` `)`
860 ///                | bare-id `(` affine-expr-list `)`
861 ///
862 /// The affine-expr-list is stored as an AffineMap.
863 struct TensorUse : public Expression {
TensorUse__anon084c004d0511::TensorUse864   TensorUse() : TensorUse("", AffineMap()) {}
TensorUse__anon084c004d0511::TensorUse865   TensorUse(StringRef name, AffineMap map)
866       : Expression(Kind::TensorUse), tensorId(name), indexingMap(map) {}
867   TensorUse(const TensorUse &use) = default;
868 
classof__anon084c004d0511::TensorUse869   static bool classof(const Expression *e) {
870     return e->kind == Kind::TensorUse;
871   }
872 
operator ==__anon084c004d0511::TensorUse873   bool operator==(const TensorUse &other) const {
874     return tensorId == other.tensorId && indexingMap == other.indexingMap;
875   }
876 
877   /// Visitation function. Performs preorder or postorder traversal depending on
878   /// `PreOrder` and applies `callback` on each node.
879   template <typename Lambda, bool PreOrder>
880   void visit(Lambda callback) const;
881 
882   StringRef tensorId;
883   AffineMap indexingMap;
884 };
885 
886 /// Encodes a tensor expression of the form:
887 ///
888 ///   op-spec ::= bare-id `<` reduction-dims-list `>`
889 ///             | bare-id
890 ///   op-arg ::= tensor-expr
891 ///            | tensor-use
892 ///   op-arg-list ::= op-arg (`,` op-arg)*
893 ///   tensor-expr ::= op-spec `(` op-arg-list `)`
894 ///
895 /// Underlying op-arg are stored by unique_ptr to base class.
896 struct TensorExpr : public Expression {
TensorExpr__anon084c004d0511::TensorExpr897   TensorExpr(StringRef name,
898              SmallVectorImpl<std::unique_ptr<Expression>> &&exprs,
899              ArrayRef<unsigned> reductionDims)
900       : Expression(Kind::TensorExpr), operationName(name),
901         expressions(std::move(exprs)),
902         reductionDimensions(reductionDims.begin(), reductionDims.end()) {}
903 
classof__anon084c004d0511::TensorExpr904   static bool classof(const Expression *e) {
905     return e->kind == Kind::TensorExpr;
906   }
907 
operator ==__anon084c004d0511::TensorExpr908   bool operator==(const TensorExpr &other) const {
909     if (operationName != other.operationName)
910       return false;
911     if (expressions.size() != other.expressions.size())
912       return false;
913     for (unsigned i = 0, e = expressions.size(); i < e; ++i)
914       if (*expressions[i] != *other.expressions[i])
915         return false;
916     for (unsigned i = 0, e = reductionDimensions.size(); i < e; ++i)
917       if (reductionDimensions[i] != other.reductionDimensions[i])
918         return false;
919     return true;
920   }
921 
922   /// Visitation function. Performs preorder or postorder traversal depending on
923   /// `PreOrder` and applies `callback` on each node.
924   template <typename Lambda, bool PreOrder>
925   void visit(Lambda callback) const;
926 
927   StringRef operationName;
928   SmallVector<std::unique_ptr<Expression>, 4> expressions;
929   SetVector<unsigned> reductionDimensions;
930 };
931 
932 /// This is a specialized parser for a TCDef.
933 /// This maintains the dims it finds in an eager fashion.
934 class TCParser {
935   enum class EagerDiscoveryMode { None = 0, Symbols, Dimensions };
936 
937 public:
938   explicit TCParser(Parser &p);
939 
940   /// Uses the AffineParser to parse the affine exprs used in a tensor
941   /// definition. If `discoveryMode` is set to Symbols (resp. Dimensions), new
942   /// symbols (resp. dimensions) are added eagerly. Otherwise, an error is
943   /// emitted on new identifiers.
944   SmallVector<AffineExpr, 4>
945   parseAffineExprs(EagerDiscoveryMode discoveryMode, AffineDimList &dims,
946                    Token::Kind lDelim = Token::Kind::l_paren,
947                    Token::Kind rDelim = Token::Kind::r_paren);
948 
949   /// Parse the information for a tensor def.
950   /// All the affine-expr must be dimensionless (i.e. contain only expressions
951   /// involving symbols and constants), but can otherwise contain arbitrary
952   /// affine expressions.
953   LogicalResult parseTensorDef(bool isOutput);
954 
955   /// Parses a tensor use.
956   struct ComprehensionParsingState {
957     AffineDimList dims;
958     SmallVector<std::unique_ptr<Expression>, 4> expressions;
959     llvm::DenseMap<TensorUse, unsigned> orderedTensorArgs;
960   };
961   LogicalResult parseTensorUse(TensorUse &result,
962                                ComprehensionParsingState &state);
963 
964   /// Parses a tensor expression.
965   LogicalResult parseExpression(TensorUse currentDefinition,
966                                 std::unique_ptr<Expression> &result,
967                                 ComprehensionParsingState &state);
968 
969   /// Parse a single comprehension.
970   LogicalResult parseOneComprehension(StringRef cppOpName,
971                                       StringRef linalgOpName,
972                                       ComprehensionParsingState &state);
973 
974   /// Parse and print the information for a TC def.
975   /// When `gen-ods-decl` is used, this prints the ODS declaration for the TC.
976   /// When `gen-impl` is used, this prints the C++ implementation for the extra
977   /// methods defined in ODS (`iterator_types`, `indexing_maps` and
978   /// `regionBuilder`).
979   LogicalResult parseAndEmitODSDef(llvm::raw_ostream &os);
980 
981   /// Print the ODS class that defines a new `cppOpName` for a `linalgOpName`.
982   void printODS(llvm::raw_ostream &os, StringRef cppOpName,
983                 StringRef linalgOpName, ComprehensionParsingState &state);
984 
985   /// Print the C++ StructuredOpsInterface impl of `iterator_types`.
986   void printReferenceIterators(llvm::raw_ostream &os, StringRef cppOpName,
987                                ComprehensionParsingState &state);
988 
989   /// Print the C++ StructuredOpsInterface impl of `indexing_maps`.
990   void printReferenceIndexingMaps(llvm::raw_ostream &os, StringRef cppOpName,
991                                   ComprehensionParsingState &state);
992 
993   /// Print the C++ StructuredOpsInterface impl of `regionBuilder`.
994   void printRegionBuilder(llvm::raw_ostream &os, StringRef cppOpName,
995                           ComprehensionParsingState &state);
996 
997   /// Print the C++ impl for named ops canonicalizers and fodlers.
998   void printCanonicalizersAndFolders(llvm::raw_ostream &os,
999                                      StringRef cppOpName);
1000 
1001 private:
1002   //===--------------------------------------------------------------------===//
1003   // Internal bookkeeping of tensors.
1004   //===--------------------------------------------------------------------===//
1005   struct RegisteredTensor {
1006     StringRef type;
1007     AffineMap shape;
1008     bool isOutput;
1009     AffineMap indexingMap;
1010     unsigned index;
1011   };
1012 
1013   //===--------------------------------------------------------------------===//
1014   // Per-TC def state.
1015   //===--------------------------------------------------------------------===//
1016   /// Symbols are per TC def.
1017   AffineSymbolList symbols;
1018   /// Tensors are per TC def.
1019   llvm::StringMap<RegisteredTensor> registeredTensors;
1020   unsigned nextRegisteredTensorIndex;
1021 
1022   Parser &parser;
1023 };
1024 } // namespace
1025 
1026 namespace llvm {
1027 
1028 template <>
1029 struct DenseMapInfo<TensorUse> {
getEmptyKeyllvm::DenseMapInfo1030   static TensorUse getEmptyKey() { return TensorUse("", AffineMap()); }
getTombstoneKeyllvm::DenseMapInfo1031   static TensorUse getTombstoneKey() {
1032     return TensorUse(DenseMapInfo<StringRef>::getTombstoneKey(),
1033                      DenseMapInfo<AffineMap>::getTombstoneKey());
1034   }
getHashValuellvm::DenseMapInfo1035   static unsigned getHashValue(const TensorUse &val) {
1036     return ::llvm::hash_value(val.tensorId); // don't care about collisions.
1037   }
isEqualllvm::DenseMapInfo1038   static bool isEqual(const TensorUse &LHS, const TensorUse &RHS) {
1039     return LHS == RHS;
1040   }
1041 };
1042 
1043 } // namespace llvm
1044 
1045 //===----------------------------------------------------------------------===//
1046 // Visitation functions.
1047 //===----------------------------------------------------------------------===//
1048 
1049 template <typename Lambda, bool PreOrder>
visit(const Expression & expr,Lambda callback)1050 void visit(const Expression &expr, Lambda callback) {
1051   switch (expr.kind) {
1052   default:
1053     llvm_unreachable("Unexpected kind");
1054   case Expression::Kind::TensorExpr:
1055     static_cast<const TensorExpr &>(expr).visit<Lambda, PreOrder>(callback);
1056     break;
1057   case Expression::Kind::TensorUse:
1058     static_cast<const TensorUse &>(expr).visit<Lambda, PreOrder>(callback);
1059     break;
1060   }
1061 }
1062 
1063 template <typename Lambda>
visitPreorder(const Expression & expr,Lambda callback)1064 void visitPreorder(const Expression &expr, Lambda callback) {
1065   visit<Lambda, false>(expr, callback);
1066 }
1067 
1068 template <typename Lambda>
visitPostorder(Expression & expr,Lambda callback)1069 void visitPostorder(Expression &expr, Lambda callback) {
1070   visit<Lambda, true>(expr, callback);
1071 }
1072 
1073 template <typename Lambda, bool PreOrder>
visit(Lambda callback) const1074 void TensorExpr::visit(Lambda callback) const {
1075   if (!PreOrder)
1076     callback(*this);
1077   for (auto &e : expressions)
1078     ::visit<Lambda, PreOrder>(*e, callback);
1079   if (PreOrder)
1080     callback(*this);
1081 }
1082 
1083 template <typename Lambda, bool PreOrder>
visit(Lambda callback) const1084 void TensorUse::visit(Lambda callback) const {
1085   callback(*this);
1086 }
1087 
1088 //===----------------------------------------------------------------------===//
1089 // TC parsing functions.
1090 //===----------------------------------------------------------------------===//
TCParser(Parser & p)1091 TCParser::TCParser(Parser &p)
1092     : symbols(), registeredTensors(), nextRegisteredTensorIndex(0), parser(p) {}
1093 
1094 /// Uses the AffineParser to parse the affine exprs used in a tensor
1095 /// definition. All identifiers are interpreted as symbols, new symbols are
1096 /// added eagerly.
1097 SmallVector<AffineExpr, 4>
parseAffineExprs(EagerDiscoveryMode discoveryMode,AffineDimList & dims,Token::Kind lDelim,Token::Kind rDelim)1098 TCParser::parseAffineExprs(EagerDiscoveryMode discoveryMode,
1099                            AffineDimList &dims, Token::Kind lDelim,
1100                            Token::Kind rDelim) {
1101   AffineParser affineParser(
1102       parser,
1103       [&](StringRef sRef) {
1104         AffineExpr expr;
1105         if (discoveryMode == EagerDiscoveryMode::Symbols) {
1106           expr = getAffineSymbolExpr(symbols.size(), parser.context);
1107           symbols.emplace_back(sRef, expr);
1108         } else if (discoveryMode == EagerDiscoveryMode::Dimensions) {
1109           expr = getAffineDimExpr(dims.size(), parser.context);
1110           dims.emplace_back(sRef, expr);
1111         }
1112         return expr;
1113       },
1114       dims, symbols);
1115   return affineParser.parseAffineExprs(lDelim, rDelim);
1116 }
1117 
1118 /// Parse the information for a tensor def of the form:
1119 ///
1120 ///   affine-expr-list ::= affine-expr (`,` affine-expr )*
1121 ///   tensor-typedef ::= type `(` `)`
1122 ///                    | type `(` affine-expr-list `)`
1123 ///   tensor-def ::= bare-id `:` tensor-typedef
parseTensorDef(bool isOutput)1124 LogicalResult TCParser::parseTensorDef(bool isOutput) {
1125   StringRef tensorId = parser.curToken.getSpelling();
1126   if (failed(parser.parseToken(Token::Kind::id, "expected an id")) ||
1127       failed(parser.parseToken(Token::Kind::colon, "expected colon")))
1128     return failure();
1129 
1130   StringRef tensorType = parser.curToken.getSpelling();
1131   if (failed(parser.parseToken(Token::Kind::id, "expected an id")))
1132     return failure();
1133 
1134   AffineDimList emptyDims;
1135   auto exprs = parseAffineExprs(EagerDiscoveryMode::Symbols, emptyDims);
1136   assert(emptyDims.empty() && "Unexpected dimension in tensor def");
1137   AffineMap map =
1138       AffineMap::get(/*dimCount=*/0, symbols.size(), exprs, parser.context);
1139 
1140   auto iterBoolPair = registeredTensors.try_emplace(
1141       tensorId, RegisteredTensor{tensorType, map, isOutput, AffineMap(),
1142                                  nextRegisteredTensorIndex++});
1143   (void)iterBoolPair;
1144   assert(iterBoolPair.second && "Could not emplace tensor registration");
1145   LLVM_DEBUG(llvm::dbgs() << "Recorded: " << tensorId << " "
1146                           << "with typeString: " << tensorType << " "
1147                           << "and shape: " << map << "\n");
1148 
1149   return success();
1150 }
1151 
1152 /// Parses a tensor use of the form:
1153 ///
1154 ///   affine-expr-list ::= affine-expr (`,` affine-expr)*
1155 ///   tensor-use ::= bare-id `(` `)`
1156 ///                | bare-id `(` affine-expr-list `)`
parseTensorUse(TensorUse & result,ComprehensionParsingState & state)1157 LogicalResult TCParser::parseTensorUse(TensorUse &result,
1158                                        ComprehensionParsingState &state) {
1159   StringRef tensorId = parser.curToken.getSpelling();
1160   if (failed(parser.parseToken(Token::Kind::id, "expected an id")))
1161     return failure();
1162 
1163   auto exprs = parseAffineExprs(EagerDiscoveryMode::Dimensions, state.dims);
1164   AffineMap map =
1165       AffineMap::get(state.dims.size(), symbols.size(), exprs, parser.context);
1166   LLVM_DEBUG(llvm::dbgs() << "Use of tensor: " << tensorId << " map: " << map
1167                           << "\n");
1168 
1169   result = TensorUse(tensorId, map);
1170   return success();
1171 }
1172 
1173 /// Parses a tensor expression of the form:
1174 ///
1175 ///   op-spec ::= bare-id `<` reduction-dims-list `>`
1176 ///             | bare-id
1177 ///   op-arg ::= tensor-expr
1178 ///            | tensor-use
1179 ///   op-arg-list ::= op-arg (`,` op-arg)*
1180 ///   tensor-expr ::= op-spec `(` op-arg-list `)`
parseExpression(TensorUse currentDefinition,std::unique_ptr<Expression> & result,ComprehensionParsingState & state)1181 LogicalResult TCParser::parseExpression(TensorUse currentDefinition,
1182                                         std::unique_ptr<Expression> &result,
1183                                         ComprehensionParsingState &state) {
1184   StringRef opOrTensor = parser.curToken.getSpelling();
1185   if (registeredTensors.count(opOrTensor) > 0) {
1186     TensorUse use;
1187     auto res = parseTensorUse(use, state);
1188     if (failed(res))
1189       return res;
1190     result = std::make_unique<TensorUse>(use);
1191     return success();
1192   }
1193 
1194   if (failed(parser.parseToken(Token::Kind::id, "expected an operation")))
1195     return failure();
1196 
1197   // This is an op.
1198   SmallVector<unsigned, 4> reductionDims;
1199   SmallVector<std::unique_ptr<Expression>, 4> expressions;
1200 
1201   // Check if it has a reduction set, discover dimensions eagerly.
1202   if (parser.curToken.is(Token::Kind::lt)) {
1203     auto iters = parseAffineExprs(EagerDiscoveryMode::Dimensions, state.dims,
1204                                   Token::Kind::lt, Token::Kind::gt);
1205     for (auto iter : iters)
1206       reductionDims.push_back(iter.cast<AffineDimExpr>().getPosition());
1207   }
1208 
1209   // If this op is a reduction, it's first argument is the `currentDefinition`
1210   // tensor use.
1211   if (!reductionDims.empty())
1212     expressions.push_back(std::make_unique<TensorUse>(currentDefinition));
1213   LLVM_DEBUG(llvm::dbgs() << "op: " << opOrTensor << "\n");
1214 
1215   auto parseExpr = [&]() -> LogicalResult {
1216     std::unique_ptr<Expression> e;
1217     if (failed(parseExpression(currentDefinition, e, state)))
1218       return failure();
1219     expressions.push_back(std::move(e));
1220     return success();
1221   };
1222   if (failed(parser.parseToken(Token::Kind::l_paren, "expected '('")) ||
1223       failed(parser.parseCommaSeparatedListUntil(
1224           Token::Kind::r_paren, parseExpr, /*allowEmptyList=*/true)))
1225     return failure();
1226 
1227   result = std::make_unique<TensorExpr>(opOrTensor, std::move(expressions),
1228                                         reductionDims);
1229 
1230   return success();
1231 }
1232 
1233 //===----------------------------------------------------------------------===//
1234 // Parse and Emit functions.
1235 //===----------------------------------------------------------------------===//
1236 
1237 /// Parse the information for a single comprehension.
1238 ///
1239 ///   tensor-def-list ::= tensor-def (`,` tensor-def)*
1240 ///   tensor-expr-list ::= tensor-expr (`,` tensor-expr)*
1241 ///   comprehension ::= tensor-def-list `=` tensor-expr-list `;`
1242 LogicalResult
parseOneComprehension(StringRef cppOpName,StringRef linalgOpName,ComprehensionParsingState & state)1243 TCParser::parseOneComprehension(StringRef cppOpName, StringRef linalgOpName,
1244                                 ComprehensionParsingState &state) {
1245   // 1. Parse LHS of `=`, these become the definitions that appear as the output
1246   // tensors or read/write buffers.
1247   SmallVector<TensorUse, 4> definitions;
1248   auto parseUse = [&]() -> LogicalResult {
1249     TensorUse use;
1250     if (failed(parseTensorUse(use, state)))
1251       return failure();
1252     definitions.push_back(use);
1253     return success();
1254   };
1255   if (failed(parser.parseCommaSeparatedListUntil(Token::Kind::equal, parseUse,
1256                                                  /*allowEmptyList=*/true)))
1257     return failure();
1258 
1259   // 2. Parse RHS of `=`, this becomes the expressions from which we emit
1260   // computations.
1261   unsigned idx = 0;
1262   auto parseExpr = [&]() -> LogicalResult {
1263     std::unique_ptr<Expression> expr;
1264     if (idx >= definitions.size()) {
1265       parser.emitError("Fewer LHS definitions than RHS expressions");
1266       return failure();
1267     }
1268     if (failed(parseExpression(definitions[idx++], expr, state)))
1269       return failure();
1270     state.expressions.push_back(std::move(expr));
1271     return success();
1272   };
1273   if (failed(parser.parseCommaSeparatedListUntil(
1274           Token::Kind::semicolon, parseExpr, /*allowEmptyList=*/true)))
1275     return failure();
1276   if (idx != definitions.size()) {
1277     parser.emitError("Fewer RHS expressions than LHS definitions");
1278     return failure();
1279   }
1280 
1281   // 3. Postprocess.
1282   // 3.a. Normalize all maps to the proper state.dims and symbols counts.
1283   SmallVector<TensorUse, 4> allUses;
1284   allUses.reserve(registeredTensors.size());
1285   for (auto &def : definitions)
1286     allUses.push_back(def);
1287   for (auto &pExpr : state.expressions)
1288     visitPostorder(*pExpr, [&](const Expression &e) {
1289       if (auto *use = dyn_cast<TensorUse>(&e))
1290         allUses.push_back(*use);
1291     });
1292   for (auto &use : allUses)
1293     use.indexingMap =
1294         AffineMap::get(state.dims.size(), symbols.size(),
1295                        use.indexingMap.getResults(), parser.context);
1296 
1297   // 3.b. Traverse definitions
1298   llvm::DenseSet<StringRef> seenDefs;
1299   for (auto &def : definitions) {
1300     if (seenDefs.count(def.tensorId) > 0) {
1301       parser.emitError("Unexpected multi-write to a single tensor");
1302       return failure();
1303     }
1304     seenDefs.insert(def.tensorId);
1305     auto tensorIter = registeredTensors.find(def.tensorId);
1306     assert(tensorIter != registeredTensors.end() && "unregistered tensor");
1307     auto &tensor = tensorIter->getValue();
1308     tensor.indexingMap = def.indexingMap;
1309     state.orderedTensorArgs[def] = tensor.index;
1310   }
1311 
1312   bool failed = false;
1313   for (auto &pExpr : state.expressions)
1314     visitPostorder(*pExpr, [&](const Expression &e) {
1315       auto *pUse = dyn_cast<TensorUse>(&e);
1316       if (failed || !pUse)
1317         return;
1318       auto &use = *pUse;
1319       LLVM_DEBUG(llvm::dbgs()
1320                  << "\nuse: " << use.tensorId << " map: " << use.indexingMap);
1321       auto tensorIter = registeredTensors.find(use.tensorId);
1322       assert(tensorIter != registeredTensors.end() && "unregistered tensor");
1323       auto &tensor = tensorIter->getValue();
1324       if (tensor.indexingMap && state.orderedTensorArgs.count(use) == 0) {
1325         LLVM_DEBUG(llvm::dbgs() << "\nexisting: " << tensor.indexingMap);
1326         parser.emitError(
1327             "Unexpected multi-read of a tensor with different accesses");
1328         failed = true;
1329         return;
1330       }
1331       seenDefs.insert(use.tensorId);
1332       tensor.indexingMap = use.indexingMap;
1333       state.orderedTensorArgs[use] = tensor.index;
1334     });
1335   if (failed)
1336     return failure();
1337 
1338   return success();
1339 }
1340 
1341 /// Parse and print the information for a ODS def.
1342 ///
1343 ///   tensor-def-list ::= tensor-def (`,` tensor-def )*
1344 ///
1345 ///   comprehension-list ::= comprehension comprehension*
1346 ///
1347 ///   tc-def ::= `def` bare-id `(`tensor-def-list`)` `->` `(` tensor-def-list`)`
1348 ///     `{` comprehension-list `}`
1349 ///
1350 ///   ods-def ::= `ods_def` `<` bare-id `>` `:` tc-def
1351 ///
1352 /// All the affine-expr in a `tensor-typedef` must be dimensionless (i.e.
1353 /// contain only expressions involving symbols and constants), but can
1354 /// otherwise contain arbitrary affine expressions.
parseAndEmitODSDef(llvm::raw_ostream & os)1355 LogicalResult TCParser::parseAndEmitODSDef(llvm::raw_ostream &os) {
1356   if (failed(parser.parseToken(Token::Kind::kw_ods_def,
1357                                "expected 'ods_def' to define a TC ODS")) ||
1358       failed(parser.parseToken(Token::Kind::lt, "expected '<'")))
1359     return failure();
1360   StringRef cppOpName = parser.curToken.getSpelling();
1361   LLVM_DEBUG(llvm::dbgs() << "\n\nStart parsing ODS: " << cppOpName << "\n");
1362 
1363   if (failed(parser.parseToken(Token::Kind::id, "expected id")) ||
1364       failed(parser.parseToken(Token::Kind::gt, "expected '>'")) ||
1365       failed(parser.parseToken(Token::Kind::colon, "expected ':'")))
1366     return failure();
1367   if (failed(parser.parseToken(Token::Kind::kw_def,
1368                                "expected 'def' to define a TC")))
1369     return failure();
1370 
1371   StringRef tcName = parser.curToken.getSpelling();
1372   LLVM_DEBUG(llvm::dbgs() << "\n\nStart parsing TC: " << tcName << "\n");
1373   if (failed(parser.parseToken(Token::Kind::id, "expected id")) ||
1374       failed(parser.parseToken(Token::Kind::l_paren, "expected '('")))
1375     return failure();
1376 
1377   auto parseInputDef = [&]() -> LogicalResult {
1378     return parseTensorDef(/*isOutput=*/false);
1379   };
1380   if (failed(parser.parseCommaSeparatedListUntil(
1381           Token::Kind::r_paren, parseInputDef, /*allowEmptyList=*/false)))
1382     return failure();
1383 
1384   if (failed(parser.parseToken(Token::Kind::minus, "expected '-'")) ||
1385       failed(parser.parseToken(Token::Kind::gt, "expected '>'")) ||
1386       failed(parser.parseToken(Token::Kind::l_paren, "expected '('")))
1387     return failure();
1388   auto parseOutputDef = [&]() -> LogicalResult {
1389     return parseTensorDef(/*isOutput=*/true);
1390   };
1391   if (failed(parser.parseCommaSeparatedListUntil(
1392           Token::Kind::r_paren, parseOutputDef, /*allowEmptyList=*/false)))
1393     return failure();
1394 
1395   // Since we don't declare symbols separately, we discover them eagerly: each
1396   // newly encountered id in a tensor shape expression is treated as a new
1397   // symbolic. At this point, all tensors have been parsed and all the symbols
1398   // that could be discovered eagerly are now known. Resize all AffineMaps to
1399   // normalize the number of eagerly discovered symbols.
1400   for (auto &tensor : registeredTensors) {
1401     auto &map = tensor.getValue().shape;
1402     map = AffineMap::get(/*dimCount=*/0, symbols.size(), map.getResults(),
1403                          parser.context);
1404   }
1405 
1406   if (failed(parser.parseToken(Token::Kind::l_brace, "expected '{'")))
1407     return failure();
1408 
1409   SmallVector<ComprehensionParsingState, 4> perComprehensionStates;
1410   while (parser.curToken.isNot(Token::Kind::r_brace)) {
1411     perComprehensionStates.push_back(ComprehensionParsingState());
1412     if (failed(parseOneComprehension(cppOpName, tcName,
1413                                      perComprehensionStates.back())))
1414       return failure();
1415   };
1416   parser.parseToken(Token::Kind::r_brace, "expected '}'");
1417 
1418   // Print.
1419   auto nComprehensions = perComprehensionStates.size();
1420   if (nComprehensions != 1) {
1421     parser.emitError("only 1 comprehension supported for now, got: " +
1422                      llvm::Twine(nComprehensions));
1423     return failure();
1424   }
1425   if (genODSDecl) {
1426     auto &state = perComprehensionStates.back();
1427     printODS(os, cppOpName, tcName, state);
1428     os << "\n";
1429   }
1430   if (genODSImpl) {
1431     auto &state = perComprehensionStates.back();
1432     std::string extraMethods;
1433     llvm::raw_string_ostream ss(extraMethods);
1434     printReferenceIterators(ss, cppOpName, state);
1435     printReferenceIndexingMaps(ss, cppOpName, state);
1436     printRegionBuilder(ss, cppOpName, state);
1437     printCanonicalizersAndFolders(ss, cppOpName);
1438     ss.flush();
1439     os << extraMethods << "\n";
1440   }
1441 
1442   return success();
1443 }
1444 
1445 //===----------------------------------------------------------------------===//
1446 // Printing functions
1447 //===----------------------------------------------------------------------===//
1448 
1449 /// Print the ODS class that defines a new `cppOpName` for a `linalgOpName`.
printODS(llvm::raw_ostream & os,StringRef cppOpName,StringRef linalgOpName,ComprehensionParsingState & state)1450 void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
1451                         StringRef linalgOpName,
1452                         ComprehensionParsingState &state) {
1453   const char *header = R"FMT(  def {0} : LinalgStructuredBase_Op<"{1}", [
1454     AttrSizedOperandSegments,
1455     DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
1456     NamedStructuredOpTrait,
1457     SingleBlockImplicitTerminator<"YieldOp">]> {
1458       let arguments = (ins Variadic<AnyShaped>:$inputs,
1459                            Variadic<AnyMemRef>:$output_buffers,
1460                            Variadic<AnyRankedTensor>:$init_tensors);
1461       let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
1462       let regions = (region AnyRegion:$region);
1463 
1464       let skipDefaultBuilders = 1;
1465       let builders = [ OpBuilderDAG<
1466         (ins "ValueRange":$inputs, "ValueRange":$outputBuffers),
1467         [{{
1468           $_state.addOperands(inputs);
1469           $_state.addOperands(outputBuffers);
1470           $_state.addAttribute(
1471             "operand_segment_sizes",
1472             $_builder.getI32VectorAttr({{
1473               static_cast<int32_t>(inputs.size()),
1474               static_cast<int32_t>(outputBuffers.size()),
1475               static_cast<int32_t>(0)}));
1476           buildNamedStructuredOpRegionAndAttributes<{0}>(
1477             $_builder,
1478             $_state,
1479             TypeRange(inputs),
1480             TypeRange(outputBuffers),
1481             TypeRange(),
1482             TypeRange());
1483         }]>, OpBuilderDAG<
1484         (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
1485              "ValueRange":$outputBuffers, "ValueRange":$initTensors),
1486         [{{
1487           $_state.addOperands(inputs);
1488           $_state.addOperands(outputBuffers);
1489           $_state.addOperands(initTensors);
1490           $_state.addTypes(resultTensorTypes);
1491           $_state.addAttribute(
1492             "operand_segment_sizes",
1493             $_builder.getI32VectorAttr({{
1494               static_cast<int32_t>(inputs.size()),
1495               static_cast<int32_t>(outputBuffers.size()),
1496               static_cast<int32_t>(initTensors.size())}));
1497           buildNamedStructuredOpRegionAndAttributes<{0}>(
1498             $_builder,
1499             $_state,
1500             TypeRange(inputs),
1501             TypeRange(outputBuffers),
1502             TypeRange(initTensors),
1503             resultTensorTypes);
1504         }]>, OpBuilderDAG<
1505         (ins "TypeRange":$resultTensorTypes, "ValueRange":$operands,
1506              CArg<"ArrayRef<NamedAttribute>", "{{}">:$attributes),
1507         [{{
1508           $_state.addOperands(operands);
1509           $_state.addAttributes(attributes);
1510           $_state.addTypes(resultTensorTypes);
1511           (void)$_state.addRegion();
1512         }]>
1513       ];
1514       let printer = [{{ return ::printNamedStructuredOp(p, *this); }];
1515       let parser = [{{ return ::parseNamedStructuredOp<{0}>(parser, result); }];
1516       let verifier = [{{ return ::verifyNamedStructuredOp(*this); }];
1517       let hasFolder = 1;
1518       let hasCanonicalizer = 1;
1519 
1520       let extraClassDeclaration = [{{
1521         // Auto-generated.
1522         ArrayAttr iterator_types();
1523         ArrayAttr indexing_maps();
1524         static void regionBuilder(Block &block);
1525         static std::function<void(Block &)> getRegionBuilder() {{ return regionBuilder; }
1526 
1527         // Generic methods.
1528         static unsigned getNumRegionArgs() {{ return {4}; }
1529         std::string getLibraryCallName() {{
1530           return generateLibraryCallName(getOperation());
1531         }
1532       }];
1533   })FMT";
1534 
1535   unsigned nInputs = 0, nOutputs = 0;
1536   for (auto &t : registeredTensors) {
1537     if (t.getValue().isOutput)
1538       nOutputs++;
1539     else
1540       nInputs++;
1541   }
1542 
1543   os << llvm::formatv(header, cppOpName, linalgOpName, nInputs, nOutputs,
1544                       state.orderedTensorArgs.size());
1545 }
1546 
1547 /// Print the C++ StructuredOpsInterface impl of `iterator_types`.
1548 void TCParser::printReferenceIterators(llvm::raw_ostream &os,
1549                                        StringRef cppOpName,
1550                                        ComprehensionParsingState &state) {
1551   const char *referenceReferenceIteratorsFmt =
1552       R"FMT(
1553     ArrayAttr {0}::iterator_types() {
1554       return Builder(getContext()).getStrArrayAttr(SmallVector<StringRef, 8>{{ {1} });
1555     })FMT";
1556 
1557   std::string iteratorsStr;
1558   llvm::raw_string_ostream ss(iteratorsStr);
1559   unsigned pos = 0;
1560   llvm::interleaveComma(
1561       state.dims, ss, [&](std::pair<StringRef, AffineExpr> p) {
1562         bool reduction = false;
1563         for (auto &expr : state.expressions) {
1564           visitPostorder(*expr, [&](const Expression &e) {
1565             if (auto *pTensorExpr = dyn_cast<TensorExpr>(&e)) {
1566               if (pTensorExpr->reductionDimensions.count(pos) > 0)
1567                 reduction = true;
1568             }
1569           });
1570           if (reduction)
1571             break;
1572         }
1573         ss << (reduction ? "getReductionIteratorTypeName()"
1574                          : "getParallelIteratorTypeName()");
1575         pos++;
1576       });
1577   ss.flush();
1578 
1579   os << llvm::formatv(referenceReferenceIteratorsFmt, cppOpName, iteratorsStr);
1580 }
1581 
1582 void TCParser::printCanonicalizersAndFolders(llvm::raw_ostream &os,
1583                                              StringRef cppOpName) {
1584   const char *canonicalizersAndFoldersFmt = R"FMT(
1585     void {0}::getCanonicalizationPatterns(
1586         OwningRewritePatternList &results,
1587         MLIRContext *context) {{
1588       results.insert<EraseDeadLinalgOp>();
1589       results.insert<FoldTensorCastOp>();
1590     }
1591     LogicalResult {0}::fold(ArrayRef<Attribute>,
1592                             SmallVectorImpl<OpFoldResult> &) {{
1593       return foldMemRefCast(*this);
1594     }
1595     void {0}::getEffects(SmallVectorImpl<
1596         SideEffects::EffectInstance<MemoryEffects::Effect> >&effects) {{
1597       getGenericEffectsImpl(effects,
1598         getOperation()->getResults(), getInputBuffers(), getOutputBuffers());
1599     })FMT";
1600   os << llvm::formatv(canonicalizersAndFoldersFmt, cppOpName);
1601 }
1602 
1603 /// Print the C++ StructuredOpsInterface impl of `referenceIndexingMaps`.
1604 void TCParser::printReferenceIndexingMaps(llvm::raw_ostream &os,
1605                                           StringRef cppOpName,
1606                                           ComprehensionParsingState &state) {
1607   // 1. Generic string template for specifying reference indexing maps.
1608   const char *referenceIndexingMapsFmt =
1609       R"FMT(
1610   // This is temporary until we transition out of manually specified ops that
1611   // should be auto-generated with linalg-ods-gen.
1612   ArrayAttr {0}::indexing_maps() {
1613     MLIRContext *context = getContext();
1614     AffineExpr {1};
1615     bindDims(context, {1});
1616     return Builder(context).getAffineMapArrayAttr({ {2} });
1617   })FMT";
1618 
1619   // 2. Print a comma-separated list of identifiers for the AffineExpr in
1620   // `state.dims`. These will replace the `{1}` placeholder in both
1621   // `AffineExpr {1}` and `bindDims(context, {1})` ensuring the AffineExpr
1622   // identifiers are bound in the right order to the proper AffineDimExpr.
1623   std::string dimsStr;
1624   llvm::raw_string_ostream ss(dimsStr);
1625   llvm::interleaveComma(
1626       state.dims, ss,
1627       [&](std::pair<StringRef, AffineExpr> p) { ss << p.second; });
1628   ss.flush();
1629 
1630   // 3. Print a comma-separated list of AffineMap constructors that use the
1631   // identifiers from 1. The AffineExpr use the common arithmetic operators on
1632   // AffineExpr. These AffineMap constructors will replace the `{2}` placeholder
1633   // in return `SmallVector<AffineMap, 8>{{ {2} };`.
1634   std::string mapsStr;
1635   llvm::raw_string_ostream mapsStringStream(mapsStr);
1636   SmallVector<TensorUse, 4> orderedUses(state.orderedTensorArgs.size());
1637   for (const auto &it : state.orderedTensorArgs)
1638     orderedUses[it.second] = it.first;
1639   llvm::interleaveComma(orderedUses, mapsStringStream, [&](TensorUse u) {
1640     assert(u.indexingMap);
1641     const char *mapFmt = "\n\tAffineMap::get({0}, 0, {1}, context)";
1642     if (u.indexingMap.isEmpty()) {
1643       mapsStringStream << llvm::formatv(mapFmt, state.dims.size(), "context");
1644       return;
1645     }
1646 
1647     std::string exprsStr;
1648     llvm::raw_string_ostream exprsStringStream(exprsStr);
1649     exprsStringStream << "{";
1650     llvm::interleaveComma(u.indexingMap.getResults(), exprsStringStream);
1651     exprsStringStream << "}";
1652     exprsStringStream.flush();
1653 
1654     mapsStringStream << llvm::formatv(mapFmt, state.dims.size(), exprsStr);
1655   });
1656   mapsStringStream.flush();
1657 
1658   // 4. Apply format to 1. using 2. and 3.
1659   os << llvm::formatv(referenceIndexingMapsFmt, cppOpName, dimsStr, mapsStr);
1660 }
1661 
1662 /// Print the C++ StructuredOpsInterface impl of `regionBuilder`.
1663 void TCParser::printRegionBuilder(llvm::raw_ostream &os, StringRef cppOpName,
1664                                   ComprehensionParsingState &state) {
1665   unsigned count = state.orderedTensorArgs.size();
1666   llvm::DenseMap<const TensorExpr *, unsigned> subExprsMap;
1667   std::function<void(llvm::raw_ostream & os, const Expression &)> printExpr;
1668   printExpr = [&](llvm::raw_ostream &os, const Expression &e) -> void {
1669     if (auto *pUse = dyn_cast<TensorUse>(&e)) {
1670       os << "_" << state.orderedTensorArgs.find(*pUse)->second;
1671       return;
1672     }
1673     auto *pTensorExpr = cast<TensorExpr>(&e);
1674     if (subExprsMap.count(pTensorExpr) > 0) {
1675       os << "_" << subExprsMap[pTensorExpr];
1676     } else {
1677       std::string subExprs;
1678       llvm::raw_string_ostream subExprsStringStream(subExprs);
1679       llvm::interleaveComma(pTensorExpr->expressions, subExprsStringStream,
1680                             [&](const std::unique_ptr<Expression> &e) {
1681                               printExpr(subExprsStringStream, *e);
1682                             });
1683       subExprsStringStream.flush();
1684       const char *tensorExprFmt = "\n    Value _{0} = {1}({2});";
1685       os << llvm::formatv(tensorExprFmt, ++count, pTensorExpr->operationName,
1686                           subExprs);
1687       subExprsMap[pTensorExpr] = count;
1688     }
1689   };
1690 
1691   const char *regionBuilderFmt = R"FMT(
1692   void {0}::regionBuilder(Block &block) {
1693     using namespace edsc;
1694     using namespace intrinsics;
1695     auto args = block.getArguments();
1696     Value {1};
1697     {2}
1698     (linalg_yield(ValueRange{ {3} }));
1699   })FMT";
1700 
1701   unsigned idx = 0;
1702   std::string valueHandleStr;
1703   llvm::raw_string_ostream valueHandleStringStream(valueHandleStr);
1704   llvm::interleaveComma(
1705       state.orderedTensorArgs, valueHandleStringStream, [&](auto) {
1706         valueHandleStringStream << "_" << idx << "(args[" << idx << "])";
1707         idx++;
1708       });
1709 
1710   std::string expressionsStr;
1711   llvm::raw_string_ostream expressionStringStream(expressionsStr);
1712   for (auto &expr : state.expressions)
1713     visitPostorder(*expr, [&](const Expression &e) {
1714       if (e.kind == Expression::Kind::TensorExpr)
1715         printExpr(expressionStringStream, e);
1716     });
1717 
1718   std::string yieldStr;
1719   llvm::raw_string_ostream yieldStringStream(yieldStr);
1720   llvm::interleaveComma(state.expressions, yieldStringStream,
1721                         [&](const std::unique_ptr<Expression> &e) {
1722                           printExpr(yieldStringStream, *e);
1723                         });
1724 
1725   valueHandleStringStream.flush();
1726   expressionStringStream.flush();
1727   yieldStringStream.flush();
1728 
1729   os << llvm::formatv(regionBuilderFmt, cppOpName, valueHandleStr,
1730                       expressionsStr, yieldStr);
1731 }
1732 
1733 /// Iterate over each Tensor Comprehension def.
1734 LogicalResult parseAndEmitAllTensorComprehensions(llvm::raw_ostream &os,
1735                                                   Parser &parser) {
1736   while (parser.curToken.getKind() != Token::Kind::eof) {
1737     TCParser tcParser(parser);
1738     if (failed(tcParser.parseAndEmitODSDef(os)))
1739       return failure();
1740   }
1741   return success();
1742 }
1743 
1744 int main(int argc, char **argv) {
1745   llvm::cl::ParseCommandLineOptions(argc, argv, "Linalg ODS Gen");
1746 
1747   // Set up the input file.
1748   std::string errorMessage;
1749   std::unique_ptr<llvm::MemoryBuffer> file =
1750       mlir::openInputFile(inputFilename, &errorMessage);
1751   if (!file) {
1752     llvm::errs() << errorMessage << "\n";
1753     return 1;
1754   }
1755 
1756   std::unique_ptr<llvm::ToolOutputFile> output =
1757       openOutputFile(outputFilename, &errorMessage);
1758   if (!output) {
1759     llvm::errs() << errorMessage << "\n";
1760     exit(1);
1761   }
1762 
1763   // Include the proper Linalg header for end-to-end tblgen testing without
1764   // resorting to non-portable shell manipulations.
1765   if (testEmitIncludeTdHeader)
1766     output->os() << "include \"mlir/Dialect/Linalg/IR/LinalgStructuredOps.td\"";
1767 
1768   MLIRContext context;
1769   llvm::SourceMgr mgr;
1770   mgr.AddNewSourceBuffer(std::move(file), llvm::SMLoc());
1771   Parser parser(mgr, &context);
1772   parseAndEmitAllTensorComprehensions(output->os(), parser);
1773   output->keep();
1774 
1775   return 0;
1776 }
1777