1 //===- mlir-linalg-ods-gen.cpp - Linalg ODS generation from math form -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains the implementation for the Tensor Comprehension-inspired
10 // parser and ODS pretty-printer for specifying Linalg "named ops" from a
11 // mathematical form.
12 //
13 //===----------------------------------------------------------------------===//
14
15 #include "mlir/IR/AffineExpr.h"
16 #include "mlir/IR/AffineMap.h"
17 #include "mlir/IR/MLIRContext.h"
18 #include "mlir/IR/OpImplementation.h"
19 #include "mlir/Support/FileUtilities.h"
20 #include "mlir/Support/LLVM.h"
21 #include "mlir/Support/LogicalResult.h"
22 #include "llvm/ADT/SetVector.h"
23 #include "llvm/Support/Casting.h"
24 #include "llvm/Support/CommandLine.h"
25 #include "llvm/Support/FormatVariadic.h"
26 #include "llvm/Support/ToolOutputFile.h"
27
28 #define DEBUG_TYPE "linalg-ods-gen"
29
30 static llvm::cl::OptionCategory ODSGenCat("Linalg ODS Gen");
31
32 // Commandline options
33 static llvm::cl::opt<std::string>
34 inputFilename(llvm::cl::Positional, llvm::cl::desc("<input file>"),
35 llvm::cl::init("-"), llvm::cl::value_desc("filename"));
36
37 static llvm::cl::opt<std::string>
38 outputFilename("o", llvm::cl::desc("Output filename"),
39 llvm::cl::value_desc("filename"), llvm::cl::init("-"));
40
41 static llvm::cl::opt<bool>
42 genODSDecl("gen-ods-decl", llvm::cl::desc("Emit the ODS ops declarations."),
43 llvm::cl::cat(ODSGenCat));
44
45 static llvm::cl::opt<bool>
46 genODSImpl("gen-impl", llvm::cl::desc("Emit the ops implementations"),
47 llvm::cl::init(false), llvm::cl::cat(ODSGenCat));
48
49 static llvm::cl::opt<bool> testEmitIncludeTdHeader(
50 "test-emit-include-td-header",
51 llvm::cl::desc("Include LinalgStructuredOps.td for end-to-end "
52 "tblgen testing."),
53 llvm::cl::init(false), llvm::cl::cat(ODSGenCat));
54
55 using llvm::SetVector;
56 using llvm::SMLoc;
57 using llvm::StringRef;
58 using llvm::Twine;
59
60 using namespace mlir;
61
62 //===----------------------------------------------------------------------===//
63 // Lexer
64 //===----------------------------------------------------------------------===//
65
66 namespace {
67 /// This class represents a specific token in the input format.
68 class Token {
69 public:
70 enum class Kind {
71 // Markers.
72 eof,
73 error,
74
75 // Tokens with no info.
76 colon,
77 comma,
78 equal,
79 gt,
80 l_brace,
81 l_paren,
82 lt,
83 minus,
84 plus,
85 r_brace,
86 r_paren,
87 semicolon,
88 star,
89
90 // Keywords.
91 kw_def,
92 FIRST_KEYWORD = kw_def,
93 kw_ods_def,
94 kw_floordiv,
95 kw_ceildiv,
96 kw_mod,
97 LAST_KEYWORD = kw_mod,
98
99 // String valued tokens.
100 id,
101 integer,
102 };
103
Token(Kind kind,StringRef spelling)104 Token(Kind kind, StringRef spelling) : kind(kind), spelling(spelling) {}
105
106 /// Return the bytes that make up this token.
getSpelling() const107 StringRef getSpelling() const { return spelling; }
108
109 /// Return the kind of this token.
getKind() const110 Kind getKind() const { return kind; }
111
112 /// Return a location for this token.
getLoc() const113 llvm::SMLoc getLoc() const {
114 return llvm::SMLoc::getFromPointer(spelling.data());
115 }
116
117 /// Return if this token is a keyword.
isKeyword() const118 bool isKeyword() const {
119 return kind >= Kind::FIRST_KEYWORD && kind <= Kind::LAST_KEYWORD;
120 }
is(Kind k) const121 bool is(Kind k) const { return kind == k; }
isNot(Kind k) const122 bool isNot(Kind k) const { return kind != k; }
123
getUInt64IntegerValue() const124 Optional<uint64_t> getUInt64IntegerValue() const {
125 bool isHex = spelling.size() > 1 && spelling[1] == 'x';
126
127 uint64_t result = 0;
128 if (spelling.getAsInteger(isHex ? 0 : 10, result))
129 return None;
130 return result;
131 }
132
133 private:
134 /// Discriminator that indicates the kind of token this is.
135 Kind kind;
136
137 /// A reference to the entire token contents; this is always a pointer into
138 /// a memory buffer owned by the source manager.
139 StringRef spelling;
140 };
141
142 /// This class implements a simple lexer.
143 class Lexer {
144 public:
145 Lexer(llvm::SourceMgr &mgr);
146
147 /// Lex the next token and return it.
148 Token lexToken();
149
150 /// Emit an error to the lexer with the given location and message.
151 Token emitError(llvm::SMLoc loc, const Twine &msg);
152 Token emitError(const char *loc, const Twine &msg);
153
154 private:
formToken(Token::Kind kind,const char * tokStart)155 Token formToken(Token::Kind kind, const char *tokStart) {
156 return Token(kind, StringRef(tokStart, curPtr - tokStart));
157 }
158
159 /// Return the next character in the stream.
160 int getNextChar();
161
162 /// Lex an identifier.
163 Token lexIdentifier(const char *tokStart);
164
165 // Lex an integer.
166 Token lexInteger(const char *tokStart);
167
168 // Skip a comment line, starting with a '//'.
169 void skipComment();
170
171 llvm::SourceMgr &srcMgr;
172 StringRef curBuffer;
173 const char *curPtr;
174 };
175 } // end anonymous namespace
176
Lexer(llvm::SourceMgr & mgr)177 Lexer::Lexer(llvm::SourceMgr &mgr) : srcMgr(mgr) {
178 curBuffer = srcMgr.getMemoryBuffer(mgr.getMainFileID())->getBuffer();
179 curPtr = curBuffer.begin();
180 }
181
emitError(llvm::SMLoc loc,const Twine & msg)182 Token Lexer::emitError(llvm::SMLoc loc, const Twine &msg) {
183 srcMgr.PrintMessage(loc, llvm::SourceMgr::DK_Error, msg);
184 return formToken(Token::Kind::error, loc.getPointer());
185 }
emitError(const char * loc,const Twine & msg)186 Token Lexer::emitError(const char *loc, const Twine &msg) {
187 return emitError(llvm::SMLoc::getFromPointer(loc), msg);
188 }
189
getNextChar()190 int Lexer::getNextChar() {
191 char curChar = *curPtr++;
192 switch (curChar) {
193 default:
194 return (unsigned char)curChar;
195 case 0: {
196 // A nul character in the stream is either the end of the current buffer
197 // or a random nul in the file. Disambiguate that here.
198 if (curPtr - 1 != curBuffer.end())
199 return 0;
200
201 // Otherwise, return end of file.
202 --curPtr;
203 return EOF;
204 }
205 case '\n':
206 case '\r':
207 // Handle the newline character by ignoring it and incrementing the line
208 // count. However, be careful about 'dos style' files with \n\r in them.
209 // Only treat a \n\r or \r\n as a single line.
210 if ((*curPtr == '\n' || (*curPtr == '\r')) && *curPtr != curChar)
211 ++curPtr;
212 return '\n';
213 }
214 }
215
lexToken()216 Token Lexer::lexToken() {
217 while (true) {
218 const char *tokStart = curPtr;
219
220 // This always consumes at least one character.
221 int curChar = getNextChar();
222 switch (curChar) {
223 default:
224 // Handle identifiers: [a-zA-Z_]
225 if (isalpha(curChar) || curChar == '_')
226 return lexIdentifier(tokStart);
227
228 // Handle integers: [0-9]
229 if (isdigit(curChar))
230 return lexInteger(tokStart);
231
232 // Unknown character, emit an error.
233 return emitError(tokStart, "unexpected character");
234
235 case EOF:
236 // Return EOF denoting the end of lexing.
237 return formToken(Token::Kind::eof, tokStart);
238
239 // Lex punctuation.
240 case ':':
241 return formToken(Token::Kind::colon, tokStart);
242 case ',':
243 return formToken(Token::Kind::comma, tokStart);
244 case '=':
245 return formToken(Token::Kind::equal, tokStart);
246 case '{':
247 return formToken(Token::Kind::l_brace, tokStart);
248 case '(':
249 return formToken(Token::Kind::l_paren, tokStart);
250 case '}':
251 return formToken(Token::Kind::r_brace, tokStart);
252 case ')':
253 return formToken(Token::Kind::r_paren, tokStart);
254 case '<':
255 return formToken(Token::Kind::lt, tokStart);
256 case '>':
257 return formToken(Token::Kind::gt, tokStart);
258 case '+':
259 return formToken(Token::Kind::plus, tokStart);
260 case '-':
261 return formToken(Token::Kind::minus, tokStart);
262 case ';':
263 return formToken(Token::Kind::semicolon, tokStart);
264 case '*':
265 return formToken(Token::Kind::star, tokStart);
266 case '/':
267 if (*curPtr == '/') {
268 skipComment();
269 continue;
270 }
271 // Unknown character, emit an error.
272 return emitError(tokStart, "unexpected character: not a comment");
273
274 // Ignore whitespace characters.
275 case 0:
276 case ' ':
277 case '\t':
278 case '\n':
279 return lexToken();
280 }
281 }
282 }
283
lexIdentifier(const char * tokStart)284 Token Lexer::lexIdentifier(const char *tokStart) {
285 // Match the rest of the identifier regex: [0-9a-zA-Z_\-]*
286 while (isalnum(*curPtr) || *curPtr == '_' || *curPtr == '-')
287 ++curPtr;
288
289 // Check to see if this identifier is a keyword.
290 StringRef str(tokStart, curPtr - tokStart);
291 Token::Kind kind = StringSwitch<Token::Kind>(str)
292 .Case("def", Token::Kind::kw_def)
293 .Case("ods_def", Token::Kind::kw_ods_def)
294 .Case("floordiv", Token::Kind::kw_floordiv)
295 .Case("ceildiv", Token::Kind::kw_ceildiv)
296 .Case("mod", Token::Kind::kw_mod)
297 .Default(Token::Kind::id);
298
299 return Token(kind, str);
300 }
301
lexInteger(const char * tokStart)302 Token Lexer::lexInteger(const char *tokStart) {
303 // Match the rest of the identifier regex: [0-9a-zA-Z_\-]*
304 while (isdigit(*curPtr))
305 ++curPtr;
306
307 StringRef str(tokStart, curPtr - tokStart);
308 return Token(Token::Kind::integer, str);
309 }
310
311 /// Skip a comment line, starting with a '//'.
skipComment()312 void Lexer::skipComment() {
313 // Advance over the second '/' in a '//' comment.
314 assert(*curPtr == '/');
315 ++curPtr;
316
317 while (true) {
318 switch (*curPtr++) {
319 case '\n':
320 case '\r':
321 // Newline is end of comment.
322 return;
323 case 0:
324 // If this is the end of the buffer, end the comment.
325 if (curPtr - 1 == curBuffer.end()) {
326 --curPtr;
327 return;
328 }
329 LLVM_FALLTHROUGH;
330 default:
331 // Skip over other characters.
332 break;
333 }
334 }
335 }
336
337 namespace {
338
339 class Parser {
340 public:
Parser(llvm::SourceMgr & mgr,MLIRContext * ctx)341 Parser(llvm::SourceMgr &mgr, MLIRContext *ctx)
342 : lexer(mgr), curToken(lexer.lexToken()), context(ctx) {}
343
344 //===--------------------------------------------------------------------===//
345 // Lexer Utilities
346 //===--------------------------------------------------------------------===//
347
348 /// Advance the current lexer onto the next token.
consumeToken()349 void consumeToken() {
350 assert(curToken.getKind() != Token::Kind::eof &&
351 curToken.getKind() != Token::Kind::error &&
352 "shouldn't advance past EOF or errors");
353 curToken = lexer.lexToken();
354 }
consumeToken(Token::Kind kind)355 void consumeToken(Token::Kind kind) {
356 assert(curToken.getKind() == kind && "unexpected token");
357 curToken = lexer.lexToken();
358 }
parseToken(Token::Kind kind,const Twine & msg)359 LogicalResult parseToken(Token::Kind kind, const Twine &msg) {
360 if (curToken.getKind() != kind)
361 return emitError(curToken.getLoc(), msg);
362 consumeToken();
363 return success();
364 }
emitError(llvm::SMLoc loc,const Twine & msg)365 LogicalResult emitError(llvm::SMLoc loc, const Twine &msg) {
366 lexer.emitError(loc, msg);
367 return failure();
368 }
emitError(const Twine & msg)369 LogicalResult emitError(const Twine &msg) {
370 return emitError(curToken.getLoc(), msg);
371 }
consumeIf(Token::Kind kind)372 bool consumeIf(Token::Kind kind) {
373 if (curToken.isNot(kind))
374 return false;
375 consumeToken(kind);
376 return true;
377 }
378 LogicalResult
parseCommaSeparatedList(llvm::function_ref<ParseResult ()> parseElement)379 parseCommaSeparatedList(llvm::function_ref<ParseResult()> parseElement) {
380 // Non-empty case starts with an element.
381 if (parseElement())
382 return failure();
383
384 // Otherwise we have a list of comma separated elements.
385 while (consumeIf(Token::Kind::comma)) {
386 if (parseElement())
387 return failure();
388 }
389 return success();
390 }
391 LogicalResult
parseCommaSeparatedListUntil(Token::Kind rightToken,llvm::function_ref<ParseResult ()> parseElement,bool allowEmptyList)392 parseCommaSeparatedListUntil(Token::Kind rightToken,
393 llvm::function_ref<ParseResult()> parseElement,
394 bool allowEmptyList) {
395 // Handle the empty case.
396 if (curToken.is(rightToken)) {
397 if (!allowEmptyList)
398 return emitError("expected list element");
399 consumeToken(rightToken);
400 return success();
401 }
402
403 if (failed(parseCommaSeparatedList(parseElement)) ||
404 failed(
405 parseToken(rightToken, "expected ',' or right-terminating token")))
406 return failure();
407
408 return success();
409 }
410
411 Lexer lexer;
412 Token curToken;
413 MLIRContext *context;
414 };
415 } // namespace
416
417 //===----------------------------------------------------------------------===//
418 // Affine parsing.
419 //===----------------------------------------------------------------------===//
420
421 namespace {
422
423 /// Lower precedence ops (all at the same precedence level). LNoOp is false in
424 /// the boolean sense.
425 enum AffineLowPrecOp {
426 /// Null value.
427 LNoOp,
428 Add,
429 Sub
430 };
431
432 /// Higher precedence ops - all at the same precedence level. HNoOp is false
433 /// in the boolean sense.
434 enum AffineHighPrecOp {
435 /// Null value.
436 HNoOp,
437 Mul,
438 FloorDiv,
439 CeilDiv,
440 Mod
441 };
442
443 using AffineDimList = SmallVector<std::pair<StringRef, AffineExpr>, 4>;
444 using AffineSymbolList = SmallVector<std::pair<StringRef, AffineExpr>, 4>;
445
446 /// This is a specialized parser for affine expressions.
447 class AffineParser {
448 public:
AffineParser(Parser & p,std::function<AffineExpr (StringRef)> bareIdParsingHook,AffineDimList & dimList,AffineSymbolList & symbolList)449 explicit AffineParser(Parser &p,
450 std::function<AffineExpr(StringRef)> bareIdParsingHook,
451 AffineDimList &dimList, AffineSymbolList &symbolList)
452 : parser(p), bareIdFallback(bareIdParsingHook), dims(dimList),
453 symbols(symbolList) {}
454
455 /// Parse a comma-separated list of affine exprs.
456 SmallVector<AffineExpr, 4>
457 parseAffineExprs(Token::Kind lDelim = Token::Kind::l_paren,
458 Token::Kind rDelim = Token::Kind::r_paren);
459
460 /// Parse a single affine expr.`.
461 AffineExpr parseAffineExpr();
462
463 private:
464 // Binary affine op parsing.
465 AffineLowPrecOp consumeIfLowPrecOp();
466 AffineHighPrecOp consumeIfHighPrecOp();
467
468 // AffineExpr parsing.
469 AffineExpr parseParentheticalExpr();
470 AffineExpr parseNegateExpression(AffineExpr lhs);
471 AffineExpr parseIntegerExpr();
472 AffineExpr parseBareIdExpr();
473
474 AffineExpr getAffineBinaryOpExpr(AffineHighPrecOp op, AffineExpr lhs,
475 AffineExpr rhs, SMLoc opLoc);
476 AffineExpr getAffineBinaryOpExpr(AffineLowPrecOp op, AffineExpr lhs,
477 AffineExpr rhs);
478 AffineExpr parseAffineOperandExpr(AffineExpr lhs);
479 AffineExpr parseAffineLowPrecOpExpr(AffineExpr llhs, AffineLowPrecOp llhsOp);
480 AffineExpr parseAffineHighPrecOpExpr(AffineExpr llhs, AffineHighPrecOp llhsOp,
481 SMLoc llhsOpLoc);
482
483 Parser &parser;
484 std::function<AffineExpr(StringRef)> bareIdFallback;
485 AffineDimList &dims;
486 AffineSymbolList &symbols;
487 };
488 } // end anonymous namespace
489
490 /// Create an affine binary high precedence op expression (mul's, div's, mod).
491 /// opLoc is the location of the op token to be used to report errors
492 /// for non-conforming expressions.
getAffineBinaryOpExpr(AffineHighPrecOp op,AffineExpr lhs,AffineExpr rhs,SMLoc opLoc)493 AffineExpr AffineParser::getAffineBinaryOpExpr(AffineHighPrecOp op,
494 AffineExpr lhs, AffineExpr rhs,
495 SMLoc opLoc) {
496 switch (op) {
497 case Mul:
498 if (!lhs.isSymbolicOrConstant() && !rhs.isSymbolicOrConstant()) {
499 parser.emitError(opLoc,
500 "non-affine expression: at least one of the multiply "
501 "operands has to be either a constant or symbolic");
502 return nullptr;
503 }
504 return lhs * rhs;
505 case FloorDiv:
506 if (!rhs.isSymbolicOrConstant()) {
507 parser.emitError(opLoc,
508 "non-affine expression: right operand of floordiv "
509 "has to be either a constant or symbolic");
510 return nullptr;
511 }
512 return lhs.floorDiv(rhs);
513 case CeilDiv:
514 if (!rhs.isSymbolicOrConstant()) {
515 parser.emitError(opLoc, "non-affine expression: right operand of ceildiv "
516 "has to be either a constant or symbolic");
517 return nullptr;
518 }
519 return lhs.ceilDiv(rhs);
520 case Mod:
521 if (!rhs.isSymbolicOrConstant()) {
522 parser.emitError(opLoc, "non-affine expression: right operand of mod "
523 "has to be either a constant or symbolic");
524 return nullptr;
525 }
526 return lhs % rhs;
527 case HNoOp:
528 llvm_unreachable("can't create affine expression for null high prec op");
529 return nullptr;
530 }
531 llvm_unreachable("Unknown AffineHighPrecOp");
532 }
533
534 /// Create an affine binary low precedence op expression (add, sub).
getAffineBinaryOpExpr(AffineLowPrecOp op,AffineExpr lhs,AffineExpr rhs)535 AffineExpr AffineParser::getAffineBinaryOpExpr(AffineLowPrecOp op,
536 AffineExpr lhs, AffineExpr rhs) {
537 switch (op) {
538 case AffineLowPrecOp::Add:
539 return lhs + rhs;
540 case AffineLowPrecOp::Sub:
541 return lhs - rhs;
542 case AffineLowPrecOp::LNoOp:
543 llvm_unreachable("can't create affine expression for null low prec op");
544 return nullptr;
545 }
546 llvm_unreachable("Unknown AffineLowPrecOp");
547 }
548
549 /// Consume this token if it is a lower precedence affine op (there are only
550 /// two precedence levels).
consumeIfLowPrecOp()551 AffineLowPrecOp AffineParser::consumeIfLowPrecOp() {
552 switch (parser.curToken.getKind()) {
553 case Token::Kind::plus:
554 parser.consumeToken();
555 return AffineLowPrecOp::Add;
556 case Token::Kind::minus:
557 parser.consumeToken();
558 return AffineLowPrecOp::Sub;
559 default:
560 return AffineLowPrecOp::LNoOp;
561 }
562 }
563
564 /// Consume this token if it is a higher precedence affine op (there are only
565 /// two precedence levels)
consumeIfHighPrecOp()566 AffineHighPrecOp AffineParser::consumeIfHighPrecOp() {
567 switch (parser.curToken.getKind()) {
568 case Token::Kind::star:
569 parser.consumeToken(Token::Kind::star);
570 return Mul;
571 case Token::Kind::kw_floordiv:
572 parser.consumeToken(Token::Kind::kw_floordiv);
573 return FloorDiv;
574 case Token::Kind::kw_ceildiv:
575 parser.consumeToken(Token::Kind::kw_ceildiv);
576 return CeilDiv;
577 case Token::Kind::kw_mod:
578 parser.consumeToken(Token::Kind::kw_mod);
579 return Mod;
580 default:
581 return HNoOp;
582 }
583 }
584
585 /// Parse a high precedence op expression list: mul, div, and mod are high
586 /// precedence binary ops, i.e., parse a
587 /// expr_1 op_1 expr_2 op_2 ... expr_n
588 /// where op_1, op_2 are all a AffineHighPrecOp (mul, div, mod).
589 /// All affine binary ops are left associative.
590 /// Given llhs, returns (llhs llhsOp lhs) op rhs, or (lhs op rhs) if llhs is
591 /// null. If no rhs can be found, returns (llhs llhsOp lhs) or lhs if llhs is
592 /// null. llhsOpLoc is the location of the llhsOp token that will be used to
593 /// report an error for non-conforming expressions.
parseAffineHighPrecOpExpr(AffineExpr llhs,AffineHighPrecOp llhsOp,SMLoc llhsOpLoc)594 AffineExpr AffineParser::parseAffineHighPrecOpExpr(AffineExpr llhs,
595 AffineHighPrecOp llhsOp,
596 SMLoc llhsOpLoc) {
597 AffineExpr lhs = parseAffineOperandExpr(llhs);
598 if (!lhs)
599 return nullptr;
600
601 // Found an LHS. Parse the remaining expression.
602 auto opLoc = parser.curToken.getLoc();
603 if (AffineHighPrecOp op = consumeIfHighPrecOp()) {
604 if (llhs) {
605 AffineExpr expr = getAffineBinaryOpExpr(llhsOp, llhs, lhs, opLoc);
606 if (!expr)
607 return nullptr;
608 return parseAffineHighPrecOpExpr(expr, op, opLoc);
609 }
610 // No LLHS, get RHS
611 return parseAffineHighPrecOpExpr(lhs, op, opLoc);
612 }
613
614 // This is the last operand in this expression.
615 if (llhs)
616 return getAffineBinaryOpExpr(llhsOp, llhs, lhs, llhsOpLoc);
617
618 // No llhs, 'lhs' itself is the expression.
619 return lhs;
620 }
621
622 /// Parse an affine expression inside parentheses.
623 ///
624 /// affine-expr ::= `(` affine-expr `)`
parseParentheticalExpr()625 AffineExpr AffineParser::parseParentheticalExpr() {
626 if (failed(parser.parseToken(Token::Kind::l_paren, "expected '('")))
627 return nullptr;
628 if (parser.curToken.is(Token::Kind::r_paren))
629 return (parser.emitError("no expression inside parentheses"), nullptr);
630
631 auto expr = parseAffineExpr();
632 if (!expr)
633 return nullptr;
634 if (failed(parser.parseToken(Token::Kind::r_paren, "expected ')'")))
635 return nullptr;
636
637 return expr;
638 }
639
640 /// Parse the negation expression.
641 ///
642 /// affine-expr ::= `-` affine-expr
parseNegateExpression(AffineExpr lhs)643 AffineExpr AffineParser::parseNegateExpression(AffineExpr lhs) {
644 if (failed(parser.parseToken(Token::Kind::minus, "expected '-'")))
645 return nullptr;
646
647 AffineExpr operand = parseAffineOperandExpr(lhs);
648 // Since negation has the highest precedence of all ops (including high
649 // precedence ops) but lower than parentheses, we are only going to use
650 // parseAffineOperandExpr instead of parseAffineExpr here.
651 if (!operand)
652 // Extra error message although parseAffineOperandExpr would have
653 // complained. Leads to a better diagnostic.
654 return (parser.emitError("missing operand of negation"), nullptr);
655 return (-1) * operand;
656 }
657
658 /// Parse a bare id that may appear in an affine expression.
659 ///
660 /// affine-expr ::= bare-id
parseBareIdExpr()661 AffineExpr AffineParser::parseBareIdExpr() {
662 if (parser.curToken.isNot(Token::Kind::id))
663 return (parser.emitError("expected id"), nullptr);
664
665 StringRef sRef = parser.curToken.getSpelling();
666 for (auto &list : {dims, symbols}) {
667 for (auto entry : list) {
668 if (entry.first == sRef) {
669 parser.consumeToken(Token::Kind::id);
670 return entry.second;
671 }
672 }
673 }
674
675 // Not found, check fallback path.
676 AffineExpr expr = bareIdFallback(sRef);
677 if (expr) {
678 parser.consumeToken(Token::Kind::id);
679 return expr;
680 }
681
682 return (parser.emitError("use of undeclared id"), nullptr);
683 }
684
685 /// Parse a positive integral constant appearing in an affine expression.
686 ///
687 /// affine-expr ::= integer-literal
parseIntegerExpr()688 AffineExpr AffineParser::parseIntegerExpr() {
689 auto val = parser.curToken.getUInt64IntegerValue();
690 if (!val.hasValue() || (int64_t)val.getValue() < 0)
691 return (parser.emitError("constant too large for index"), nullptr);
692
693 parser.consumeToken(Token::Kind::integer);
694 return getAffineConstantExpr((int64_t)val.getValue(), parser.context);
695 }
696
697 /// Parses an expression that can be a valid operand of an affine expression.
698 /// lhs: if non-null, lhs is an affine expression that is the lhs of a binary
699 /// operator, the rhs of which is being parsed. This is used to determine
700 /// whether an error should be emitted for a missing right operand.
701 // Eg: for an expression without parentheses (like i + j + k + l), each
702 // of the four identifiers is an operand. For i + j*k + l, j*k is not an
703 // operand expression, it's an op expression and will be parsed via
704 // parseAffineHighPrecOpExpression(). However, for i + (j*k) + -l, (j*k) and
705 // -l are valid operands that will be parsed by this function.
parseAffineOperandExpr(AffineExpr lhs)706 AffineExpr AffineParser::parseAffineOperandExpr(AffineExpr lhs) {
707 switch (parser.curToken.getKind()) {
708 case Token::Kind::id:
709 return parseBareIdExpr();
710 case Token::Kind::integer:
711 return parseIntegerExpr();
712 case Token::Kind::l_paren:
713 return parseParentheticalExpr();
714 case Token::Kind::minus:
715 return parseNegateExpression(lhs);
716 case Token::Kind::kw_ceildiv:
717 case Token::Kind::kw_floordiv:
718 case Token::Kind::kw_mod:
719 case Token::Kind::plus:
720 case Token::Kind::star:
721 if (lhs)
722 parser.emitError("missing right operand of binary operator");
723 else
724 parser.emitError("missing left operand of binary operator");
725 return nullptr;
726 default:
727 if (lhs)
728 parser.emitError("missing right operand of binary operator");
729 else
730 parser.emitError("expected affine expression");
731 return nullptr;
732 }
733 }
734
735 /// Parse affine expressions that are bare-id's, integer constants,
736 /// parenthetical affine expressions, and affine op expressions that are a
737 /// composition of those.
738 ///
739 /// All binary op's associate from left to right.
740 ///
741 /// {add, sub} have lower precedence than {mul, div, and mod}.
742 ///
743 /// Add, sub'are themselves at the same precedence level. Mul, floordiv,
744 /// ceildiv, and mod are at the same higher precedence level. Negation has
745 /// higher precedence than any binary op.
746 ///
747 /// llhs: the affine expression appearing on the left of the one being parsed.
748 /// This function will return ((llhs llhsOp lhs) op rhs) if llhs is non null,
749 /// and lhs op rhs otherwise; if there is no rhs, llhs llhsOp lhs is returned
750 /// if llhs is non-null; otherwise lhs is returned. This is to deal with left
751 /// associativity.
752 ///
753 /// Eg: when the expression is e1 + e2*e3 + e4, with e1 as llhs, this function
754 /// will return the affine expr equivalent of (e1 + (e2*e3)) + e4, where
755 /// (e2*e3) will be parsed using parseAffineHighPrecOpExpr().
parseAffineLowPrecOpExpr(AffineExpr llhs,AffineLowPrecOp llhsOp)756 AffineExpr AffineParser::parseAffineLowPrecOpExpr(AffineExpr llhs,
757 AffineLowPrecOp llhsOp) {
758 AffineExpr lhs;
759 if (!(lhs = parseAffineOperandExpr(llhs)))
760 return nullptr;
761
762 // Found an LHS. Deal with the ops.
763 if (AffineLowPrecOp lOp = consumeIfLowPrecOp()) {
764 if (llhs) {
765 AffineExpr sum = getAffineBinaryOpExpr(llhsOp, llhs, lhs);
766 return parseAffineLowPrecOpExpr(sum, lOp);
767 }
768 // No LLHS, get RHS and form the expression.
769 return parseAffineLowPrecOpExpr(lhs, lOp);
770 }
771 auto opLoc = parser.curToken.getLoc();
772 if (AffineHighPrecOp hOp = consumeIfHighPrecOp()) {
773 // We have a higher precedence op here. Get the rhs operand for the llhs
774 // through parseAffineHighPrecOpExpr.
775 AffineExpr highRes = parseAffineHighPrecOpExpr(lhs, hOp, opLoc);
776 if (!highRes)
777 return nullptr;
778
779 // If llhs is null, the product forms the first operand of the yet to be
780 // found expression. If non-null, the op to associate with llhs is llhsOp.
781 AffineExpr expr =
782 llhs ? getAffineBinaryOpExpr(llhsOp, llhs, highRes) : highRes;
783
784 // Recurse for subsequent low prec op's after the affine high prec op
785 // expression.
786 if (AffineLowPrecOp nextOp = consumeIfLowPrecOp())
787 return parseAffineLowPrecOpExpr(expr, nextOp);
788 return expr;
789 }
790 // Last operand in the expression list.
791 if (llhs)
792 return getAffineBinaryOpExpr(llhsOp, llhs, lhs);
793 // No llhs, 'lhs' itself is the expression.
794 return lhs;
795 }
796
797 /// Parse an affine expression.
798 /// affine-expr ::= `(` affine-expr `)`
799 /// | `-` affine-expr
800 /// | affine-expr `+` affine-expr
801 /// | affine-expr `-` affine-expr
802 /// | affine-expr `*` affine-expr
803 /// | affine-expr `floordiv` affine-expr
804 /// | affine-expr `ceildiv` affine-expr
805 /// | affine-expr `mod` affine-expr
806 /// | bare-id
807 /// | integer-literal
808 ///
809 /// Additional conditions are checked depending on the production. For eg.,
810 /// one of the operands for `*` has to be either constant/symbolic; the second
811 /// operand for floordiv, ceildiv, and mod has to be a positive integer.
parseAffineExpr()812 AffineExpr AffineParser::parseAffineExpr() {
813 return parseAffineLowPrecOpExpr(nullptr, AffineLowPrecOp::LNoOp);
814 }
815
parseAffineExprs(Token::Kind lDelim,Token::Kind rDelim)816 SmallVector<AffineExpr, 4> AffineParser::parseAffineExprs(Token::Kind lDelim,
817 Token::Kind rDelim) {
818 parser.parseToken(lDelim, "expected lDelim at start of affine expr list");
819
820 SmallVector<AffineExpr, 4> exprs;
821 auto parseElt = [&]() -> LogicalResult {
822 auto elt = parseAffineExpr();
823 exprs.push_back(elt);
824 return elt ? success() : failure();
825 };
826
827 if (failed(parser.parseCommaSeparatedListUntil(rDelim, parseElt,
828 /*allowEmptyList=*/true)))
829 llvm_unreachable("Failed AffineExpr parsing");
830
831 return exprs;
832 }
833
834 //===----------------------------------------------------------------------===//
835 // TC parsing.
836 //===----------------------------------------------------------------------===//
837
838 namespace {
839
840 /// Base class for expressions involved in TC parsing.
841 struct Expression {
842 enum class Kind {
843 Uninitialized = 0,
844 TensorExpr = 1,
845 TensorUse = 2,
846 };
847
Expression__anon084c004d0511::Expression848 explicit Expression(Kind k = Kind::Uninitialized) : kind(k) {}
849 virtual ~Expression() = default;
850
operator bool__anon084c004d0511::Expression851 operator bool() const { return kind != Kind::Uninitialized; }
852
853 Kind kind;
854 };
855
856 /// Encodes a tensor use of the form:
857 ///
858 /// affine-expr-list ::= affine-expr (`,` affine-expr)*
859 /// tensor-use ::= bare-id `(` `)`
860 /// | bare-id `(` affine-expr-list `)`
861 ///
862 /// The affine-expr-list is stored as an AffineMap.
863 struct TensorUse : public Expression {
TensorUse__anon084c004d0511::TensorUse864 TensorUse() : TensorUse("", AffineMap()) {}
TensorUse__anon084c004d0511::TensorUse865 TensorUse(StringRef name, AffineMap map)
866 : Expression(Kind::TensorUse), tensorId(name), indexingMap(map) {}
867 TensorUse(const TensorUse &use) = default;
868
classof__anon084c004d0511::TensorUse869 static bool classof(const Expression *e) {
870 return e->kind == Kind::TensorUse;
871 }
872
operator ==__anon084c004d0511::TensorUse873 bool operator==(const TensorUse &other) const {
874 return tensorId == other.tensorId && indexingMap == other.indexingMap;
875 }
876
877 /// Visitation function. Performs preorder or postorder traversal depending on
878 /// `PreOrder` and applies `callback` on each node.
879 template <typename Lambda, bool PreOrder>
880 void visit(Lambda callback) const;
881
882 StringRef tensorId;
883 AffineMap indexingMap;
884 };
885
886 /// Encodes a tensor expression of the form:
887 ///
888 /// op-spec ::= bare-id `<` reduction-dims-list `>`
889 /// | bare-id
890 /// op-arg ::= tensor-expr
891 /// | tensor-use
892 /// op-arg-list ::= op-arg (`,` op-arg)*
893 /// tensor-expr ::= op-spec `(` op-arg-list `)`
894 ///
895 /// Underlying op-arg are stored by unique_ptr to base class.
896 struct TensorExpr : public Expression {
TensorExpr__anon084c004d0511::TensorExpr897 TensorExpr(StringRef name,
898 SmallVectorImpl<std::unique_ptr<Expression>> &&exprs,
899 ArrayRef<unsigned> reductionDims)
900 : Expression(Kind::TensorExpr), operationName(name),
901 expressions(std::move(exprs)),
902 reductionDimensions(reductionDims.begin(), reductionDims.end()) {}
903
classof__anon084c004d0511::TensorExpr904 static bool classof(const Expression *e) {
905 return e->kind == Kind::TensorExpr;
906 }
907
operator ==__anon084c004d0511::TensorExpr908 bool operator==(const TensorExpr &other) const {
909 if (operationName != other.operationName)
910 return false;
911 if (expressions.size() != other.expressions.size())
912 return false;
913 for (unsigned i = 0, e = expressions.size(); i < e; ++i)
914 if (*expressions[i] != *other.expressions[i])
915 return false;
916 for (unsigned i = 0, e = reductionDimensions.size(); i < e; ++i)
917 if (reductionDimensions[i] != other.reductionDimensions[i])
918 return false;
919 return true;
920 }
921
922 /// Visitation function. Performs preorder or postorder traversal depending on
923 /// `PreOrder` and applies `callback` on each node.
924 template <typename Lambda, bool PreOrder>
925 void visit(Lambda callback) const;
926
927 StringRef operationName;
928 SmallVector<std::unique_ptr<Expression>, 4> expressions;
929 SetVector<unsigned> reductionDimensions;
930 };
931
932 /// This is a specialized parser for a TCDef.
933 /// This maintains the dims it finds in an eager fashion.
934 class TCParser {
935 enum class EagerDiscoveryMode { None = 0, Symbols, Dimensions };
936
937 public:
938 explicit TCParser(Parser &p);
939
940 /// Uses the AffineParser to parse the affine exprs used in a tensor
941 /// definition. If `discoveryMode` is set to Symbols (resp. Dimensions), new
942 /// symbols (resp. dimensions) are added eagerly. Otherwise, an error is
943 /// emitted on new identifiers.
944 SmallVector<AffineExpr, 4>
945 parseAffineExprs(EagerDiscoveryMode discoveryMode, AffineDimList &dims,
946 Token::Kind lDelim = Token::Kind::l_paren,
947 Token::Kind rDelim = Token::Kind::r_paren);
948
949 /// Parse the information for a tensor def.
950 /// All the affine-expr must be dimensionless (i.e. contain only expressions
951 /// involving symbols and constants), but can otherwise contain arbitrary
952 /// affine expressions.
953 LogicalResult parseTensorDef(bool isOutput);
954
955 /// Parses a tensor use.
956 struct ComprehensionParsingState {
957 AffineDimList dims;
958 SmallVector<std::unique_ptr<Expression>, 4> expressions;
959 llvm::DenseMap<TensorUse, unsigned> orderedTensorArgs;
960 };
961 LogicalResult parseTensorUse(TensorUse &result,
962 ComprehensionParsingState &state);
963
964 /// Parses a tensor expression.
965 LogicalResult parseExpression(TensorUse currentDefinition,
966 std::unique_ptr<Expression> &result,
967 ComprehensionParsingState &state);
968
969 /// Parse a single comprehension.
970 LogicalResult parseOneComprehension(StringRef cppOpName,
971 StringRef linalgOpName,
972 ComprehensionParsingState &state);
973
974 /// Parse and print the information for a TC def.
975 /// When `gen-ods-decl` is used, this prints the ODS declaration for the TC.
976 /// When `gen-impl` is used, this prints the C++ implementation for the extra
977 /// methods defined in ODS (`iterator_types`, `indexing_maps` and
978 /// `regionBuilder`).
979 LogicalResult parseAndEmitODSDef(llvm::raw_ostream &os);
980
981 /// Print the ODS class that defines a new `cppOpName` for a `linalgOpName`.
982 void printODS(llvm::raw_ostream &os, StringRef cppOpName,
983 StringRef linalgOpName, ComprehensionParsingState &state);
984
985 /// Print the C++ StructuredOpsInterface impl of `iterator_types`.
986 void printReferenceIterators(llvm::raw_ostream &os, StringRef cppOpName,
987 ComprehensionParsingState &state);
988
989 /// Print the C++ StructuredOpsInterface impl of `indexing_maps`.
990 void printReferenceIndexingMaps(llvm::raw_ostream &os, StringRef cppOpName,
991 ComprehensionParsingState &state);
992
993 /// Print the C++ StructuredOpsInterface impl of `regionBuilder`.
994 void printRegionBuilder(llvm::raw_ostream &os, StringRef cppOpName,
995 ComprehensionParsingState &state);
996
997 /// Print the C++ impl for named ops canonicalizers and fodlers.
998 void printCanonicalizersAndFolders(llvm::raw_ostream &os,
999 StringRef cppOpName);
1000
1001 private:
1002 //===--------------------------------------------------------------------===//
1003 // Internal bookkeeping of tensors.
1004 //===--------------------------------------------------------------------===//
1005 struct RegisteredTensor {
1006 StringRef type;
1007 AffineMap shape;
1008 bool isOutput;
1009 AffineMap indexingMap;
1010 unsigned index;
1011 };
1012
1013 //===--------------------------------------------------------------------===//
1014 // Per-TC def state.
1015 //===--------------------------------------------------------------------===//
1016 /// Symbols are per TC def.
1017 AffineSymbolList symbols;
1018 /// Tensors are per TC def.
1019 llvm::StringMap<RegisteredTensor> registeredTensors;
1020 unsigned nextRegisteredTensorIndex;
1021
1022 Parser &parser;
1023 };
1024 } // namespace
1025
1026 namespace llvm {
1027
1028 template <>
1029 struct DenseMapInfo<TensorUse> {
getEmptyKeyllvm::DenseMapInfo1030 static TensorUse getEmptyKey() { return TensorUse("", AffineMap()); }
getTombstoneKeyllvm::DenseMapInfo1031 static TensorUse getTombstoneKey() {
1032 return TensorUse(DenseMapInfo<StringRef>::getTombstoneKey(),
1033 DenseMapInfo<AffineMap>::getTombstoneKey());
1034 }
getHashValuellvm::DenseMapInfo1035 static unsigned getHashValue(const TensorUse &val) {
1036 return ::llvm::hash_value(val.tensorId); // don't care about collisions.
1037 }
isEqualllvm::DenseMapInfo1038 static bool isEqual(const TensorUse &LHS, const TensorUse &RHS) {
1039 return LHS == RHS;
1040 }
1041 };
1042
1043 } // namespace llvm
1044
1045 //===----------------------------------------------------------------------===//
1046 // Visitation functions.
1047 //===----------------------------------------------------------------------===//
1048
1049 template <typename Lambda, bool PreOrder>
visit(const Expression & expr,Lambda callback)1050 void visit(const Expression &expr, Lambda callback) {
1051 switch (expr.kind) {
1052 default:
1053 llvm_unreachable("Unexpected kind");
1054 case Expression::Kind::TensorExpr:
1055 static_cast<const TensorExpr &>(expr).visit<Lambda, PreOrder>(callback);
1056 break;
1057 case Expression::Kind::TensorUse:
1058 static_cast<const TensorUse &>(expr).visit<Lambda, PreOrder>(callback);
1059 break;
1060 }
1061 }
1062
1063 template <typename Lambda>
visitPreorder(const Expression & expr,Lambda callback)1064 void visitPreorder(const Expression &expr, Lambda callback) {
1065 visit<Lambda, false>(expr, callback);
1066 }
1067
1068 template <typename Lambda>
visitPostorder(Expression & expr,Lambda callback)1069 void visitPostorder(Expression &expr, Lambda callback) {
1070 visit<Lambda, true>(expr, callback);
1071 }
1072
1073 template <typename Lambda, bool PreOrder>
visit(Lambda callback) const1074 void TensorExpr::visit(Lambda callback) const {
1075 if (!PreOrder)
1076 callback(*this);
1077 for (auto &e : expressions)
1078 ::visit<Lambda, PreOrder>(*e, callback);
1079 if (PreOrder)
1080 callback(*this);
1081 }
1082
1083 template <typename Lambda, bool PreOrder>
visit(Lambda callback) const1084 void TensorUse::visit(Lambda callback) const {
1085 callback(*this);
1086 }
1087
1088 //===----------------------------------------------------------------------===//
1089 // TC parsing functions.
1090 //===----------------------------------------------------------------------===//
TCParser(Parser & p)1091 TCParser::TCParser(Parser &p)
1092 : symbols(), registeredTensors(), nextRegisteredTensorIndex(0), parser(p) {}
1093
1094 /// Uses the AffineParser to parse the affine exprs used in a tensor
1095 /// definition. All identifiers are interpreted as symbols, new symbols are
1096 /// added eagerly.
1097 SmallVector<AffineExpr, 4>
parseAffineExprs(EagerDiscoveryMode discoveryMode,AffineDimList & dims,Token::Kind lDelim,Token::Kind rDelim)1098 TCParser::parseAffineExprs(EagerDiscoveryMode discoveryMode,
1099 AffineDimList &dims, Token::Kind lDelim,
1100 Token::Kind rDelim) {
1101 AffineParser affineParser(
1102 parser,
1103 [&](StringRef sRef) {
1104 AffineExpr expr;
1105 if (discoveryMode == EagerDiscoveryMode::Symbols) {
1106 expr = getAffineSymbolExpr(symbols.size(), parser.context);
1107 symbols.emplace_back(sRef, expr);
1108 } else if (discoveryMode == EagerDiscoveryMode::Dimensions) {
1109 expr = getAffineDimExpr(dims.size(), parser.context);
1110 dims.emplace_back(sRef, expr);
1111 }
1112 return expr;
1113 },
1114 dims, symbols);
1115 return affineParser.parseAffineExprs(lDelim, rDelim);
1116 }
1117
1118 /// Parse the information for a tensor def of the form:
1119 ///
1120 /// affine-expr-list ::= affine-expr (`,` affine-expr )*
1121 /// tensor-typedef ::= type `(` `)`
1122 /// | type `(` affine-expr-list `)`
1123 /// tensor-def ::= bare-id `:` tensor-typedef
parseTensorDef(bool isOutput)1124 LogicalResult TCParser::parseTensorDef(bool isOutput) {
1125 StringRef tensorId = parser.curToken.getSpelling();
1126 if (failed(parser.parseToken(Token::Kind::id, "expected an id")) ||
1127 failed(parser.parseToken(Token::Kind::colon, "expected colon")))
1128 return failure();
1129
1130 StringRef tensorType = parser.curToken.getSpelling();
1131 if (failed(parser.parseToken(Token::Kind::id, "expected an id")))
1132 return failure();
1133
1134 AffineDimList emptyDims;
1135 auto exprs = parseAffineExprs(EagerDiscoveryMode::Symbols, emptyDims);
1136 assert(emptyDims.empty() && "Unexpected dimension in tensor def");
1137 AffineMap map =
1138 AffineMap::get(/*dimCount=*/0, symbols.size(), exprs, parser.context);
1139
1140 auto iterBoolPair = registeredTensors.try_emplace(
1141 tensorId, RegisteredTensor{tensorType, map, isOutput, AffineMap(),
1142 nextRegisteredTensorIndex++});
1143 (void)iterBoolPair;
1144 assert(iterBoolPair.second && "Could not emplace tensor registration");
1145 LLVM_DEBUG(llvm::dbgs() << "Recorded: " << tensorId << " "
1146 << "with typeString: " << tensorType << " "
1147 << "and shape: " << map << "\n");
1148
1149 return success();
1150 }
1151
1152 /// Parses a tensor use of the form:
1153 ///
1154 /// affine-expr-list ::= affine-expr (`,` affine-expr)*
1155 /// tensor-use ::= bare-id `(` `)`
1156 /// | bare-id `(` affine-expr-list `)`
parseTensorUse(TensorUse & result,ComprehensionParsingState & state)1157 LogicalResult TCParser::parseTensorUse(TensorUse &result,
1158 ComprehensionParsingState &state) {
1159 StringRef tensorId = parser.curToken.getSpelling();
1160 if (failed(parser.parseToken(Token::Kind::id, "expected an id")))
1161 return failure();
1162
1163 auto exprs = parseAffineExprs(EagerDiscoveryMode::Dimensions, state.dims);
1164 AffineMap map =
1165 AffineMap::get(state.dims.size(), symbols.size(), exprs, parser.context);
1166 LLVM_DEBUG(llvm::dbgs() << "Use of tensor: " << tensorId << " map: " << map
1167 << "\n");
1168
1169 result = TensorUse(tensorId, map);
1170 return success();
1171 }
1172
1173 /// Parses a tensor expression of the form:
1174 ///
1175 /// op-spec ::= bare-id `<` reduction-dims-list `>`
1176 /// | bare-id
1177 /// op-arg ::= tensor-expr
1178 /// | tensor-use
1179 /// op-arg-list ::= op-arg (`,` op-arg)*
1180 /// tensor-expr ::= op-spec `(` op-arg-list `)`
parseExpression(TensorUse currentDefinition,std::unique_ptr<Expression> & result,ComprehensionParsingState & state)1181 LogicalResult TCParser::parseExpression(TensorUse currentDefinition,
1182 std::unique_ptr<Expression> &result,
1183 ComprehensionParsingState &state) {
1184 StringRef opOrTensor = parser.curToken.getSpelling();
1185 if (registeredTensors.count(opOrTensor) > 0) {
1186 TensorUse use;
1187 auto res = parseTensorUse(use, state);
1188 if (failed(res))
1189 return res;
1190 result = std::make_unique<TensorUse>(use);
1191 return success();
1192 }
1193
1194 if (failed(parser.parseToken(Token::Kind::id, "expected an operation")))
1195 return failure();
1196
1197 // This is an op.
1198 SmallVector<unsigned, 4> reductionDims;
1199 SmallVector<std::unique_ptr<Expression>, 4> expressions;
1200
1201 // Check if it has a reduction set, discover dimensions eagerly.
1202 if (parser.curToken.is(Token::Kind::lt)) {
1203 auto iters = parseAffineExprs(EagerDiscoveryMode::Dimensions, state.dims,
1204 Token::Kind::lt, Token::Kind::gt);
1205 for (auto iter : iters)
1206 reductionDims.push_back(iter.cast<AffineDimExpr>().getPosition());
1207 }
1208
1209 // If this op is a reduction, it's first argument is the `currentDefinition`
1210 // tensor use.
1211 if (!reductionDims.empty())
1212 expressions.push_back(std::make_unique<TensorUse>(currentDefinition));
1213 LLVM_DEBUG(llvm::dbgs() << "op: " << opOrTensor << "\n");
1214
1215 auto parseExpr = [&]() -> LogicalResult {
1216 std::unique_ptr<Expression> e;
1217 if (failed(parseExpression(currentDefinition, e, state)))
1218 return failure();
1219 expressions.push_back(std::move(e));
1220 return success();
1221 };
1222 if (failed(parser.parseToken(Token::Kind::l_paren, "expected '('")) ||
1223 failed(parser.parseCommaSeparatedListUntil(
1224 Token::Kind::r_paren, parseExpr, /*allowEmptyList=*/true)))
1225 return failure();
1226
1227 result = std::make_unique<TensorExpr>(opOrTensor, std::move(expressions),
1228 reductionDims);
1229
1230 return success();
1231 }
1232
1233 //===----------------------------------------------------------------------===//
1234 // Parse and Emit functions.
1235 //===----------------------------------------------------------------------===//
1236
1237 /// Parse the information for a single comprehension.
1238 ///
1239 /// tensor-def-list ::= tensor-def (`,` tensor-def)*
1240 /// tensor-expr-list ::= tensor-expr (`,` tensor-expr)*
1241 /// comprehension ::= tensor-def-list `=` tensor-expr-list `;`
1242 LogicalResult
parseOneComprehension(StringRef cppOpName,StringRef linalgOpName,ComprehensionParsingState & state)1243 TCParser::parseOneComprehension(StringRef cppOpName, StringRef linalgOpName,
1244 ComprehensionParsingState &state) {
1245 // 1. Parse LHS of `=`, these become the definitions that appear as the output
1246 // tensors or read/write buffers.
1247 SmallVector<TensorUse, 4> definitions;
1248 auto parseUse = [&]() -> LogicalResult {
1249 TensorUse use;
1250 if (failed(parseTensorUse(use, state)))
1251 return failure();
1252 definitions.push_back(use);
1253 return success();
1254 };
1255 if (failed(parser.parseCommaSeparatedListUntil(Token::Kind::equal, parseUse,
1256 /*allowEmptyList=*/true)))
1257 return failure();
1258
1259 // 2. Parse RHS of `=`, this becomes the expressions from which we emit
1260 // computations.
1261 unsigned idx = 0;
1262 auto parseExpr = [&]() -> LogicalResult {
1263 std::unique_ptr<Expression> expr;
1264 if (idx >= definitions.size()) {
1265 parser.emitError("Fewer LHS definitions than RHS expressions");
1266 return failure();
1267 }
1268 if (failed(parseExpression(definitions[idx++], expr, state)))
1269 return failure();
1270 state.expressions.push_back(std::move(expr));
1271 return success();
1272 };
1273 if (failed(parser.parseCommaSeparatedListUntil(
1274 Token::Kind::semicolon, parseExpr, /*allowEmptyList=*/true)))
1275 return failure();
1276 if (idx != definitions.size()) {
1277 parser.emitError("Fewer RHS expressions than LHS definitions");
1278 return failure();
1279 }
1280
1281 // 3. Postprocess.
1282 // 3.a. Normalize all maps to the proper state.dims and symbols counts.
1283 SmallVector<TensorUse, 4> allUses;
1284 allUses.reserve(registeredTensors.size());
1285 for (auto &def : definitions)
1286 allUses.push_back(def);
1287 for (auto &pExpr : state.expressions)
1288 visitPostorder(*pExpr, [&](const Expression &e) {
1289 if (auto *use = dyn_cast<TensorUse>(&e))
1290 allUses.push_back(*use);
1291 });
1292 for (auto &use : allUses)
1293 use.indexingMap =
1294 AffineMap::get(state.dims.size(), symbols.size(),
1295 use.indexingMap.getResults(), parser.context);
1296
1297 // 3.b. Traverse definitions
1298 llvm::DenseSet<StringRef> seenDefs;
1299 for (auto &def : definitions) {
1300 if (seenDefs.count(def.tensorId) > 0) {
1301 parser.emitError("Unexpected multi-write to a single tensor");
1302 return failure();
1303 }
1304 seenDefs.insert(def.tensorId);
1305 auto tensorIter = registeredTensors.find(def.tensorId);
1306 assert(tensorIter != registeredTensors.end() && "unregistered tensor");
1307 auto &tensor = tensorIter->getValue();
1308 tensor.indexingMap = def.indexingMap;
1309 state.orderedTensorArgs[def] = tensor.index;
1310 }
1311
1312 bool failed = false;
1313 for (auto &pExpr : state.expressions)
1314 visitPostorder(*pExpr, [&](const Expression &e) {
1315 auto *pUse = dyn_cast<TensorUse>(&e);
1316 if (failed || !pUse)
1317 return;
1318 auto &use = *pUse;
1319 LLVM_DEBUG(llvm::dbgs()
1320 << "\nuse: " << use.tensorId << " map: " << use.indexingMap);
1321 auto tensorIter = registeredTensors.find(use.tensorId);
1322 assert(tensorIter != registeredTensors.end() && "unregistered tensor");
1323 auto &tensor = tensorIter->getValue();
1324 if (tensor.indexingMap && state.orderedTensorArgs.count(use) == 0) {
1325 LLVM_DEBUG(llvm::dbgs() << "\nexisting: " << tensor.indexingMap);
1326 parser.emitError(
1327 "Unexpected multi-read of a tensor with different accesses");
1328 failed = true;
1329 return;
1330 }
1331 seenDefs.insert(use.tensorId);
1332 tensor.indexingMap = use.indexingMap;
1333 state.orderedTensorArgs[use] = tensor.index;
1334 });
1335 if (failed)
1336 return failure();
1337
1338 return success();
1339 }
1340
1341 /// Parse and print the information for a ODS def.
1342 ///
1343 /// tensor-def-list ::= tensor-def (`,` tensor-def )*
1344 ///
1345 /// comprehension-list ::= comprehension comprehension*
1346 ///
1347 /// tc-def ::= `def` bare-id `(`tensor-def-list`)` `->` `(` tensor-def-list`)`
1348 /// `{` comprehension-list `}`
1349 ///
1350 /// ods-def ::= `ods_def` `<` bare-id `>` `:` tc-def
1351 ///
1352 /// All the affine-expr in a `tensor-typedef` must be dimensionless (i.e.
1353 /// contain only expressions involving symbols and constants), but can
1354 /// otherwise contain arbitrary affine expressions.
parseAndEmitODSDef(llvm::raw_ostream & os)1355 LogicalResult TCParser::parseAndEmitODSDef(llvm::raw_ostream &os) {
1356 if (failed(parser.parseToken(Token::Kind::kw_ods_def,
1357 "expected 'ods_def' to define a TC ODS")) ||
1358 failed(parser.parseToken(Token::Kind::lt, "expected '<'")))
1359 return failure();
1360 StringRef cppOpName = parser.curToken.getSpelling();
1361 LLVM_DEBUG(llvm::dbgs() << "\n\nStart parsing ODS: " << cppOpName << "\n");
1362
1363 if (failed(parser.parseToken(Token::Kind::id, "expected id")) ||
1364 failed(parser.parseToken(Token::Kind::gt, "expected '>'")) ||
1365 failed(parser.parseToken(Token::Kind::colon, "expected ':'")))
1366 return failure();
1367 if (failed(parser.parseToken(Token::Kind::kw_def,
1368 "expected 'def' to define a TC")))
1369 return failure();
1370
1371 StringRef tcName = parser.curToken.getSpelling();
1372 LLVM_DEBUG(llvm::dbgs() << "\n\nStart parsing TC: " << tcName << "\n");
1373 if (failed(parser.parseToken(Token::Kind::id, "expected id")) ||
1374 failed(parser.parseToken(Token::Kind::l_paren, "expected '('")))
1375 return failure();
1376
1377 auto parseInputDef = [&]() -> LogicalResult {
1378 return parseTensorDef(/*isOutput=*/false);
1379 };
1380 if (failed(parser.parseCommaSeparatedListUntil(
1381 Token::Kind::r_paren, parseInputDef, /*allowEmptyList=*/false)))
1382 return failure();
1383
1384 if (failed(parser.parseToken(Token::Kind::minus, "expected '-'")) ||
1385 failed(parser.parseToken(Token::Kind::gt, "expected '>'")) ||
1386 failed(parser.parseToken(Token::Kind::l_paren, "expected '('")))
1387 return failure();
1388 auto parseOutputDef = [&]() -> LogicalResult {
1389 return parseTensorDef(/*isOutput=*/true);
1390 };
1391 if (failed(parser.parseCommaSeparatedListUntil(
1392 Token::Kind::r_paren, parseOutputDef, /*allowEmptyList=*/false)))
1393 return failure();
1394
1395 // Since we don't declare symbols separately, we discover them eagerly: each
1396 // newly encountered id in a tensor shape expression is treated as a new
1397 // symbolic. At this point, all tensors have been parsed and all the symbols
1398 // that could be discovered eagerly are now known. Resize all AffineMaps to
1399 // normalize the number of eagerly discovered symbols.
1400 for (auto &tensor : registeredTensors) {
1401 auto &map = tensor.getValue().shape;
1402 map = AffineMap::get(/*dimCount=*/0, symbols.size(), map.getResults(),
1403 parser.context);
1404 }
1405
1406 if (failed(parser.parseToken(Token::Kind::l_brace, "expected '{'")))
1407 return failure();
1408
1409 SmallVector<ComprehensionParsingState, 4> perComprehensionStates;
1410 while (parser.curToken.isNot(Token::Kind::r_brace)) {
1411 perComprehensionStates.push_back(ComprehensionParsingState());
1412 if (failed(parseOneComprehension(cppOpName, tcName,
1413 perComprehensionStates.back())))
1414 return failure();
1415 };
1416 parser.parseToken(Token::Kind::r_brace, "expected '}'");
1417
1418 // Print.
1419 auto nComprehensions = perComprehensionStates.size();
1420 if (nComprehensions != 1) {
1421 parser.emitError("only 1 comprehension supported for now, got: " +
1422 llvm::Twine(nComprehensions));
1423 return failure();
1424 }
1425 if (genODSDecl) {
1426 auto &state = perComprehensionStates.back();
1427 printODS(os, cppOpName, tcName, state);
1428 os << "\n";
1429 }
1430 if (genODSImpl) {
1431 auto &state = perComprehensionStates.back();
1432 std::string extraMethods;
1433 llvm::raw_string_ostream ss(extraMethods);
1434 printReferenceIterators(ss, cppOpName, state);
1435 printReferenceIndexingMaps(ss, cppOpName, state);
1436 printRegionBuilder(ss, cppOpName, state);
1437 printCanonicalizersAndFolders(ss, cppOpName);
1438 ss.flush();
1439 os << extraMethods << "\n";
1440 }
1441
1442 return success();
1443 }
1444
1445 //===----------------------------------------------------------------------===//
1446 // Printing functions
1447 //===----------------------------------------------------------------------===//
1448
1449 /// Print the ODS class that defines a new `cppOpName` for a `linalgOpName`.
printODS(llvm::raw_ostream & os,StringRef cppOpName,StringRef linalgOpName,ComprehensionParsingState & state)1450 void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
1451 StringRef linalgOpName,
1452 ComprehensionParsingState &state) {
1453 const char *header = R"FMT( def {0} : LinalgStructuredBase_Op<"{1}", [
1454 AttrSizedOperandSegments,
1455 DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
1456 NamedStructuredOpTrait,
1457 SingleBlockImplicitTerminator<"YieldOp">]> {
1458 let arguments = (ins Variadic<AnyShaped>:$inputs,
1459 Variadic<AnyMemRef>:$output_buffers,
1460 Variadic<AnyRankedTensor>:$init_tensors);
1461 let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
1462 let regions = (region AnyRegion:$region);
1463
1464 let skipDefaultBuilders = 1;
1465 let builders = [ OpBuilderDAG<
1466 (ins "ValueRange":$inputs, "ValueRange":$outputBuffers),
1467 [{{
1468 $_state.addOperands(inputs);
1469 $_state.addOperands(outputBuffers);
1470 $_state.addAttribute(
1471 "operand_segment_sizes",
1472 $_builder.getI32VectorAttr({{
1473 static_cast<int32_t>(inputs.size()),
1474 static_cast<int32_t>(outputBuffers.size()),
1475 static_cast<int32_t>(0)}));
1476 buildNamedStructuredOpRegionAndAttributes<{0}>(
1477 $_builder,
1478 $_state,
1479 TypeRange(inputs),
1480 TypeRange(outputBuffers),
1481 TypeRange(),
1482 TypeRange());
1483 }]>, OpBuilderDAG<
1484 (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
1485 "ValueRange":$outputBuffers, "ValueRange":$initTensors),
1486 [{{
1487 $_state.addOperands(inputs);
1488 $_state.addOperands(outputBuffers);
1489 $_state.addOperands(initTensors);
1490 $_state.addTypes(resultTensorTypes);
1491 $_state.addAttribute(
1492 "operand_segment_sizes",
1493 $_builder.getI32VectorAttr({{
1494 static_cast<int32_t>(inputs.size()),
1495 static_cast<int32_t>(outputBuffers.size()),
1496 static_cast<int32_t>(initTensors.size())}));
1497 buildNamedStructuredOpRegionAndAttributes<{0}>(
1498 $_builder,
1499 $_state,
1500 TypeRange(inputs),
1501 TypeRange(outputBuffers),
1502 TypeRange(initTensors),
1503 resultTensorTypes);
1504 }]>, OpBuilderDAG<
1505 (ins "TypeRange":$resultTensorTypes, "ValueRange":$operands,
1506 CArg<"ArrayRef<NamedAttribute>", "{{}">:$attributes),
1507 [{{
1508 $_state.addOperands(operands);
1509 $_state.addAttributes(attributes);
1510 $_state.addTypes(resultTensorTypes);
1511 (void)$_state.addRegion();
1512 }]>
1513 ];
1514 let printer = [{{ return ::printNamedStructuredOp(p, *this); }];
1515 let parser = [{{ return ::parseNamedStructuredOp<{0}>(parser, result); }];
1516 let verifier = [{{ return ::verifyNamedStructuredOp(*this); }];
1517 let hasFolder = 1;
1518 let hasCanonicalizer = 1;
1519
1520 let extraClassDeclaration = [{{
1521 // Auto-generated.
1522 ArrayAttr iterator_types();
1523 ArrayAttr indexing_maps();
1524 static void regionBuilder(Block &block);
1525 static std::function<void(Block &)> getRegionBuilder() {{ return regionBuilder; }
1526
1527 // Generic methods.
1528 static unsigned getNumRegionArgs() {{ return {4}; }
1529 std::string getLibraryCallName() {{
1530 return generateLibraryCallName(getOperation());
1531 }
1532 }];
1533 })FMT";
1534
1535 unsigned nInputs = 0, nOutputs = 0;
1536 for (auto &t : registeredTensors) {
1537 if (t.getValue().isOutput)
1538 nOutputs++;
1539 else
1540 nInputs++;
1541 }
1542
1543 os << llvm::formatv(header, cppOpName, linalgOpName, nInputs, nOutputs,
1544 state.orderedTensorArgs.size());
1545 }
1546
1547 /// Print the C++ StructuredOpsInterface impl of `iterator_types`.
1548 void TCParser::printReferenceIterators(llvm::raw_ostream &os,
1549 StringRef cppOpName,
1550 ComprehensionParsingState &state) {
1551 const char *referenceReferenceIteratorsFmt =
1552 R"FMT(
1553 ArrayAttr {0}::iterator_types() {
1554 return Builder(getContext()).getStrArrayAttr(SmallVector<StringRef, 8>{{ {1} });
1555 })FMT";
1556
1557 std::string iteratorsStr;
1558 llvm::raw_string_ostream ss(iteratorsStr);
1559 unsigned pos = 0;
1560 llvm::interleaveComma(
1561 state.dims, ss, [&](std::pair<StringRef, AffineExpr> p) {
1562 bool reduction = false;
1563 for (auto &expr : state.expressions) {
1564 visitPostorder(*expr, [&](const Expression &e) {
1565 if (auto *pTensorExpr = dyn_cast<TensorExpr>(&e)) {
1566 if (pTensorExpr->reductionDimensions.count(pos) > 0)
1567 reduction = true;
1568 }
1569 });
1570 if (reduction)
1571 break;
1572 }
1573 ss << (reduction ? "getReductionIteratorTypeName()"
1574 : "getParallelIteratorTypeName()");
1575 pos++;
1576 });
1577 ss.flush();
1578
1579 os << llvm::formatv(referenceReferenceIteratorsFmt, cppOpName, iteratorsStr);
1580 }
1581
1582 void TCParser::printCanonicalizersAndFolders(llvm::raw_ostream &os,
1583 StringRef cppOpName) {
1584 const char *canonicalizersAndFoldersFmt = R"FMT(
1585 void {0}::getCanonicalizationPatterns(
1586 OwningRewritePatternList &results,
1587 MLIRContext *context) {{
1588 results.insert<EraseDeadLinalgOp>();
1589 results.insert<FoldTensorCastOp>();
1590 }
1591 LogicalResult {0}::fold(ArrayRef<Attribute>,
1592 SmallVectorImpl<OpFoldResult> &) {{
1593 return foldMemRefCast(*this);
1594 }
1595 void {0}::getEffects(SmallVectorImpl<
1596 SideEffects::EffectInstance<MemoryEffects::Effect> >&effects) {{
1597 getGenericEffectsImpl(effects,
1598 getOperation()->getResults(), getInputBuffers(), getOutputBuffers());
1599 })FMT";
1600 os << llvm::formatv(canonicalizersAndFoldersFmt, cppOpName);
1601 }
1602
1603 /// Print the C++ StructuredOpsInterface impl of `referenceIndexingMaps`.
1604 void TCParser::printReferenceIndexingMaps(llvm::raw_ostream &os,
1605 StringRef cppOpName,
1606 ComprehensionParsingState &state) {
1607 // 1. Generic string template for specifying reference indexing maps.
1608 const char *referenceIndexingMapsFmt =
1609 R"FMT(
1610 // This is temporary until we transition out of manually specified ops that
1611 // should be auto-generated with linalg-ods-gen.
1612 ArrayAttr {0}::indexing_maps() {
1613 MLIRContext *context = getContext();
1614 AffineExpr {1};
1615 bindDims(context, {1});
1616 return Builder(context).getAffineMapArrayAttr({ {2} });
1617 })FMT";
1618
1619 // 2. Print a comma-separated list of identifiers for the AffineExpr in
1620 // `state.dims`. These will replace the `{1}` placeholder in both
1621 // `AffineExpr {1}` and `bindDims(context, {1})` ensuring the AffineExpr
1622 // identifiers are bound in the right order to the proper AffineDimExpr.
1623 std::string dimsStr;
1624 llvm::raw_string_ostream ss(dimsStr);
1625 llvm::interleaveComma(
1626 state.dims, ss,
1627 [&](std::pair<StringRef, AffineExpr> p) { ss << p.second; });
1628 ss.flush();
1629
1630 // 3. Print a comma-separated list of AffineMap constructors that use the
1631 // identifiers from 1. The AffineExpr use the common arithmetic operators on
1632 // AffineExpr. These AffineMap constructors will replace the `{2}` placeholder
1633 // in return `SmallVector<AffineMap, 8>{{ {2} };`.
1634 std::string mapsStr;
1635 llvm::raw_string_ostream mapsStringStream(mapsStr);
1636 SmallVector<TensorUse, 4> orderedUses(state.orderedTensorArgs.size());
1637 for (const auto &it : state.orderedTensorArgs)
1638 orderedUses[it.second] = it.first;
1639 llvm::interleaveComma(orderedUses, mapsStringStream, [&](TensorUse u) {
1640 assert(u.indexingMap);
1641 const char *mapFmt = "\n\tAffineMap::get({0}, 0, {1}, context)";
1642 if (u.indexingMap.isEmpty()) {
1643 mapsStringStream << llvm::formatv(mapFmt, state.dims.size(), "context");
1644 return;
1645 }
1646
1647 std::string exprsStr;
1648 llvm::raw_string_ostream exprsStringStream(exprsStr);
1649 exprsStringStream << "{";
1650 llvm::interleaveComma(u.indexingMap.getResults(), exprsStringStream);
1651 exprsStringStream << "}";
1652 exprsStringStream.flush();
1653
1654 mapsStringStream << llvm::formatv(mapFmt, state.dims.size(), exprsStr);
1655 });
1656 mapsStringStream.flush();
1657
1658 // 4. Apply format to 1. using 2. and 3.
1659 os << llvm::formatv(referenceIndexingMapsFmt, cppOpName, dimsStr, mapsStr);
1660 }
1661
1662 /// Print the C++ StructuredOpsInterface impl of `regionBuilder`.
1663 void TCParser::printRegionBuilder(llvm::raw_ostream &os, StringRef cppOpName,
1664 ComprehensionParsingState &state) {
1665 unsigned count = state.orderedTensorArgs.size();
1666 llvm::DenseMap<const TensorExpr *, unsigned> subExprsMap;
1667 std::function<void(llvm::raw_ostream & os, const Expression &)> printExpr;
1668 printExpr = [&](llvm::raw_ostream &os, const Expression &e) -> void {
1669 if (auto *pUse = dyn_cast<TensorUse>(&e)) {
1670 os << "_" << state.orderedTensorArgs.find(*pUse)->second;
1671 return;
1672 }
1673 auto *pTensorExpr = cast<TensorExpr>(&e);
1674 if (subExprsMap.count(pTensorExpr) > 0) {
1675 os << "_" << subExprsMap[pTensorExpr];
1676 } else {
1677 std::string subExprs;
1678 llvm::raw_string_ostream subExprsStringStream(subExprs);
1679 llvm::interleaveComma(pTensorExpr->expressions, subExprsStringStream,
1680 [&](const std::unique_ptr<Expression> &e) {
1681 printExpr(subExprsStringStream, *e);
1682 });
1683 subExprsStringStream.flush();
1684 const char *tensorExprFmt = "\n Value _{0} = {1}({2});";
1685 os << llvm::formatv(tensorExprFmt, ++count, pTensorExpr->operationName,
1686 subExprs);
1687 subExprsMap[pTensorExpr] = count;
1688 }
1689 };
1690
1691 const char *regionBuilderFmt = R"FMT(
1692 void {0}::regionBuilder(Block &block) {
1693 using namespace edsc;
1694 using namespace intrinsics;
1695 auto args = block.getArguments();
1696 Value {1};
1697 {2}
1698 (linalg_yield(ValueRange{ {3} }));
1699 })FMT";
1700
1701 unsigned idx = 0;
1702 std::string valueHandleStr;
1703 llvm::raw_string_ostream valueHandleStringStream(valueHandleStr);
1704 llvm::interleaveComma(
1705 state.orderedTensorArgs, valueHandleStringStream, [&](auto) {
1706 valueHandleStringStream << "_" << idx << "(args[" << idx << "])";
1707 idx++;
1708 });
1709
1710 std::string expressionsStr;
1711 llvm::raw_string_ostream expressionStringStream(expressionsStr);
1712 for (auto &expr : state.expressions)
1713 visitPostorder(*expr, [&](const Expression &e) {
1714 if (e.kind == Expression::Kind::TensorExpr)
1715 printExpr(expressionStringStream, e);
1716 });
1717
1718 std::string yieldStr;
1719 llvm::raw_string_ostream yieldStringStream(yieldStr);
1720 llvm::interleaveComma(state.expressions, yieldStringStream,
1721 [&](const std::unique_ptr<Expression> &e) {
1722 printExpr(yieldStringStream, *e);
1723 });
1724
1725 valueHandleStringStream.flush();
1726 expressionStringStream.flush();
1727 yieldStringStream.flush();
1728
1729 os << llvm::formatv(regionBuilderFmt, cppOpName, valueHandleStr,
1730 expressionsStr, yieldStr);
1731 }
1732
1733 /// Iterate over each Tensor Comprehension def.
1734 LogicalResult parseAndEmitAllTensorComprehensions(llvm::raw_ostream &os,
1735 Parser &parser) {
1736 while (parser.curToken.getKind() != Token::Kind::eof) {
1737 TCParser tcParser(parser);
1738 if (failed(tcParser.parseAndEmitODSDef(os)))
1739 return failure();
1740 }
1741 return success();
1742 }
1743
1744 int main(int argc, char **argv) {
1745 llvm::cl::ParseCommandLineOptions(argc, argv, "Linalg ODS Gen");
1746
1747 // Set up the input file.
1748 std::string errorMessage;
1749 std::unique_ptr<llvm::MemoryBuffer> file =
1750 mlir::openInputFile(inputFilename, &errorMessage);
1751 if (!file) {
1752 llvm::errs() << errorMessage << "\n";
1753 return 1;
1754 }
1755
1756 std::unique_ptr<llvm::ToolOutputFile> output =
1757 openOutputFile(outputFilename, &errorMessage);
1758 if (!output) {
1759 llvm::errs() << errorMessage << "\n";
1760 exit(1);
1761 }
1762
1763 // Include the proper Linalg header for end-to-end tblgen testing without
1764 // resorting to non-portable shell manipulations.
1765 if (testEmitIncludeTdHeader)
1766 output->os() << "include \"mlir/Dialect/Linalg/IR/LinalgStructuredOps.td\"";
1767
1768 MLIRContext context;
1769 llvm::SourceMgr mgr;
1770 mgr.AddNewSourceBuffer(std::move(file), llvm::SMLoc());
1771 Parser parser(mgr, &context);
1772 parseAndEmitAllTensorComprehensions(output->os(), parser);
1773 output->keep();
1774
1775 return 0;
1776 }
1777