1 //===- RewriterGen.cpp - MLIR pattern rewriter generator ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // RewriterGen uses pattern rewrite definitions to generate rewriter matchers.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "mlir/Support/IndentedOstream.h"
14 #include "mlir/TableGen/Attribute.h"
15 #include "mlir/TableGen/Format.h"
16 #include "mlir/TableGen/GenInfo.h"
17 #include "mlir/TableGen/Operator.h"
18 #include "mlir/TableGen/Pattern.h"
19 #include "mlir/TableGen/Predicate.h"
20 #include "mlir/TableGen/Type.h"
21 #include "llvm/ADT/StringExtras.h"
22 #include "llvm/ADT/StringSet.h"
23 #include "llvm/Support/CommandLine.h"
24 #include "llvm/Support/Debug.h"
25 #include "llvm/Support/FormatAdapters.h"
26 #include "llvm/Support/PrettyStackTrace.h"
27 #include "llvm/Support/Signals.h"
28 #include "llvm/TableGen/Error.h"
29 #include "llvm/TableGen/Main.h"
30 #include "llvm/TableGen/Record.h"
31 #include "llvm/TableGen/TableGenBackend.h"
32
33 using namespace mlir;
34 using namespace mlir::tblgen;
35
36 using llvm::formatv;
37 using llvm::Record;
38 using llvm::RecordKeeper;
39
40 #define DEBUG_TYPE "mlir-tblgen-rewritergen"
41
42 namespace llvm {
43 template <>
44 struct format_provider<mlir::tblgen::Pattern::IdentifierLine> {
formatllvm::format_provider45 static void format(const mlir::tblgen::Pattern::IdentifierLine &v,
46 raw_ostream &os, StringRef style) {
47 os << v.first << ":" << v.second;
48 }
49 };
50 } // end namespace llvm
51
52 //===----------------------------------------------------------------------===//
53 // PatternEmitter
54 //===----------------------------------------------------------------------===//
55
56 namespace {
57 class PatternEmitter {
58 public:
59 PatternEmitter(Record *pat, RecordOperatorMap *mapper, raw_ostream &os);
60
61 // Emits the mlir::RewritePattern struct named `rewriteName`.
62 void emit(StringRef rewriteName);
63
64 private:
65 // Emits the code for matching ops.
66 void emitMatchLogic(DagNode tree, StringRef opName);
67
68 // Emits the code for rewriting ops.
69 void emitRewriteLogic();
70
71 //===--------------------------------------------------------------------===//
72 // Match utilities
73 //===--------------------------------------------------------------------===//
74
75 // Emits C++ statements for matching the DAG structure.
76 void emitMatch(DagNode tree, StringRef name, int depth);
77
78 // Emits C++ statements for matching using a native code call.
79 void emitNativeCodeMatch(DagNode tree, StringRef name, int depth);
80
81 // Emits C++ statements for matching the op constrained by the given DAG
82 // `tree` returning the op's variable name.
83 void emitOpMatch(DagNode tree, StringRef opName, int depth);
84
85 // Emits C++ statements for matching the `argIndex`-th argument of the given
86 // DAG `tree` as an operand.
87 void emitOperandMatch(DagNode tree, StringRef opName, int argIndex,
88 int depth);
89
90 // Emits C++ statements for matching the `argIndex`-th argument of the given
91 // DAG `tree` as an attribute.
92 void emitAttributeMatch(DagNode tree, StringRef opName, int argIndex,
93 int depth);
94
95 // Emits C++ for checking a match with a corresponding match failure
96 // diagnostic.
97 void emitMatchCheck(StringRef opName, const FmtObjectBase &matchFmt,
98 const llvm::formatv_object_base &failureFmt);
99
100 // Emits C++ for checking a match with a corresponding match failure
101 // diagnostics.
102 void emitMatchCheck(StringRef opName, const std::string &matchStr,
103 const std::string &failureStr);
104
105 //===--------------------------------------------------------------------===//
106 // Rewrite utilities
107 //===--------------------------------------------------------------------===//
108
109 // The entry point for handling a result pattern rooted at `resultTree`. This
110 // method dispatches to concrete handlers according to `resultTree`'s kind and
111 // returns a symbol representing the whole value pack. Callers are expected to
112 // further resolve the symbol according to the specific use case.
113 //
114 // `depth` is the nesting level of `resultTree`; 0 means top-level result
115 // pattern. For top-level result pattern, `resultIndex` indicates which result
116 // of the matched root op this pattern is intended to replace, which can be
117 // used to deduce the result type of the op generated from this result
118 // pattern.
119 std::string handleResultPattern(DagNode resultTree, int resultIndex,
120 int depth);
121
122 // Emits the C++ statement to replace the matched DAG with a value built via
123 // calling native C++ code.
124 std::string handleReplaceWithNativeCodeCall(DagNode resultTree, int depth);
125
126 // Returns the symbol of the old value serving as the replacement.
127 StringRef handleReplaceWithValue(DagNode tree);
128
129 // Returns the location value to use.
130 std::pair<bool, std::string> getLocation(DagNode tree);
131
132 // Returns the location value to use.
133 std::string handleLocationDirective(DagNode tree);
134
135 // Emits the C++ statement to build a new op out of the given DAG `tree` and
136 // returns the variable name that this op is assigned to. If the root op in
137 // DAG `tree` has a specified name, the created op will be assigned to a
138 // variable of the given name. Otherwise, a unique name will be used as the
139 // result value name.
140 std::string handleOpCreation(DagNode tree, int resultIndex, int depth);
141
142 using ChildNodeIndexNameMap = DenseMap<unsigned, std::string>;
143
144 // Emits a local variable for each value and attribute to be used for creating
145 // an op.
146 void createSeparateLocalVarsForOpArgs(DagNode node,
147 ChildNodeIndexNameMap &childNodeNames);
148
149 // Emits the concrete arguments used to call an op's builder.
150 void supplyValuesForOpArgs(DagNode node,
151 const ChildNodeIndexNameMap &childNodeNames,
152 int depth);
153
154 // Emits the local variables for holding all values as a whole and all named
155 // attributes as a whole to be used for creating an op.
156 void createAggregateLocalVarsForOpArgs(
157 DagNode node, const ChildNodeIndexNameMap &childNodeNames, int depth);
158
159 // Returns the C++ expression to construct a constant attribute of the given
160 // `value` for the given attribute kind `attr`.
161 std::string handleConstantAttr(Attribute attr, StringRef value);
162
163 // Returns the C++ expression to build an argument from the given DAG `leaf`.
164 // `patArgName` is used to bound the argument to the source pattern.
165 std::string handleOpArgument(DagLeaf leaf, StringRef patArgName);
166
167 //===--------------------------------------------------------------------===//
168 // General utilities
169 //===--------------------------------------------------------------------===//
170
171 // Collects all of the operations within the given dag tree.
172 void collectOps(DagNode tree, llvm::SmallPtrSetImpl<const Operator *> &ops);
173
174 // Returns a unique symbol for a local variable of the given `op`.
175 std::string getUniqueSymbol(const Operator *op);
176
177 //===--------------------------------------------------------------------===//
178 // Symbol utilities
179 //===--------------------------------------------------------------------===//
180
181 // Returns how many static values the given DAG `node` correspond to.
182 int getNodeValueCount(DagNode node);
183
184 private:
185 // Pattern instantiation location followed by the location of multiclass
186 // prototypes used. This is intended to be used as a whole to
187 // PrintFatalError() on errors.
188 ArrayRef<llvm::SMLoc> loc;
189
190 // Op's TableGen Record to wrapper object.
191 RecordOperatorMap *opMap;
192
193 // Handy wrapper for pattern being emitted.
194 Pattern pattern;
195
196 // Map for all bound symbols' info.
197 SymbolInfoMap symbolInfoMap;
198
199 // The next unused ID for newly created values.
200 unsigned nextValueId;
201
202 raw_indented_ostream os;
203
204 // Format contexts containing placeholder substitutions.
205 FmtContext fmtCtx;
206
207 // Number of op processed.
208 int opCounter = 0;
209 };
210 } // end anonymous namespace
211
PatternEmitter(Record * pat,RecordOperatorMap * mapper,raw_ostream & os)212 PatternEmitter::PatternEmitter(Record *pat, RecordOperatorMap *mapper,
213 raw_ostream &os)
214 : loc(pat->getLoc()), opMap(mapper), pattern(pat, mapper),
215 symbolInfoMap(pat->getLoc()), nextValueId(0), os(os) {
216 fmtCtx.withBuilder("rewriter");
217 }
218
handleConstantAttr(Attribute attr,StringRef value)219 std::string PatternEmitter::handleConstantAttr(Attribute attr,
220 StringRef value) {
221 if (!attr.isConstBuildable())
222 PrintFatalError(loc, "Attribute " + attr.getAttrDefName() +
223 " does not have the 'constBuilderCall' field");
224
225 // TODO: Verify the constants here
226 return std::string(tgfmt(attr.getConstBuilderTemplate(), &fmtCtx, value));
227 }
228
229 // Helper function to match patterns.
emitMatch(DagNode tree,StringRef name,int depth)230 void PatternEmitter::emitMatch(DagNode tree, StringRef name, int depth) {
231 if (tree.isNativeCodeCall()) {
232 emitNativeCodeMatch(tree, name, depth);
233 return;
234 }
235
236 if (tree.isOperation()) {
237 emitOpMatch(tree, name, depth);
238 return;
239 }
240
241 PrintFatalError(loc, "encountered non-op, non-NativeCodeCall match.");
242 }
243
244 // Helper function to match patterns.
emitNativeCodeMatch(DagNode tree,StringRef opName,int depth)245 void PatternEmitter::emitNativeCodeMatch(DagNode tree, StringRef opName,
246 int depth) {
247 LLVM_DEBUG(llvm::dbgs() << "handle NativeCodeCall matcher pattern: ");
248 LLVM_DEBUG(tree.print(llvm::dbgs()));
249 LLVM_DEBUG(llvm::dbgs() << '\n');
250
251 // TODO(suderman): iterate through arguments, determine their types, output
252 // names.
253 SmallVector<std::string, 8> capture(8);
254 if (tree.getNumArgs() > 8) {
255 PrintFatalError(loc,
256 "unsupported NativeCodeCall matcher argument numbers: " +
257 Twine(tree.getNumArgs()));
258 }
259
260 raw_indented_ostream::DelimitedScope scope(os);
261
262 os << "if(!" << opName << ") return failure();\n";
263 for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
264 std::string argName = formatv("arg{0}_{1}", depth, i);
265 if (DagNode argTree = tree.getArgAsNestedDag(i)) {
266 os << "Value " << argName << ";\n";
267 } else {
268 auto leaf = tree.getArgAsLeaf(i);
269 if (leaf.isAttrMatcher() || leaf.isConstantAttr()) {
270 os << "Attribute " << argName << ";\n";
271 } else if (leaf.isOperandMatcher()) {
272 os << "Operation " << argName << ";\n";
273 }
274 }
275
276 capture[i] = std::move(argName);
277 }
278
279 bool hasLocationDirective;
280 std::string locToUse;
281 std::tie(hasLocationDirective, locToUse) = getLocation(tree);
282
283 auto fmt = tree.getNativeCodeTemplate();
284 auto nativeCodeCall = std::string(tgfmt(
285 fmt, &fmtCtx.addSubst("_loc", locToUse), opName, capture[0], capture[1],
286 capture[2], capture[3], capture[4], capture[5], capture[6], capture[7]));
287
288 os << "if (failed(" << nativeCodeCall << ")) return failure();\n";
289
290 for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
291 auto name = tree.getArgName(i);
292 if (!name.empty() && name != "_") {
293 os << formatv("{0} = {1};\n", name, capture[i]);
294 }
295 }
296
297 for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
298 std::string argName = capture[i];
299
300 // Handle nested DAG construct first
301 if (DagNode argTree = tree.getArgAsNestedDag(i)) {
302 PrintFatalError(
303 loc, formatv("Matching nested tree in NativeCodecall not support for "
304 "{0} as arg {1}",
305 argName, i));
306 }
307
308 DagLeaf leaf = tree.getArgAsLeaf(i);
309 auto constraint = leaf.getAsConstraint();
310
311 auto self = formatv("{0}", argName);
312 emitMatchCheck(
313 opName,
314 tgfmt(constraint.getConditionTemplate(), &fmtCtx.withSelf(self)),
315 formatv("\"operand {0} of native code call '{1}' failed to satisfy "
316 "constraint: "
317 "'{2}'\"",
318 i, tree.getNativeCodeTemplate(), constraint.getDescription()));
319 }
320
321 LLVM_DEBUG(llvm::dbgs() << "done emitting match for native code call\n");
322 }
323
324 // Helper function to match patterns.
emitOpMatch(DagNode tree,StringRef opName,int depth)325 void PatternEmitter::emitOpMatch(DagNode tree, StringRef opName, int depth) {
326 Operator &op = tree.getDialectOp(opMap);
327 LLVM_DEBUG(llvm::dbgs() << "start emitting match for op '"
328 << op.getOperationName() << "' at depth " << depth
329 << '\n');
330
331 std::string castedName = formatv("castedOp{0}", depth);
332 os << formatv("auto {0} = ::llvm::dyn_cast_or_null<{2}>({1}); "
333 "(void){0};\n",
334 castedName, opName, op.getQualCppClassName());
335 // Skip the operand matching at depth 0 as the pattern rewriter already does.
336 if (depth != 0) {
337 // Skip if there is no defining operation (e.g., arguments to function).
338 os << formatv("if (!{0}) return failure();\n", castedName);
339 }
340 if (tree.getNumArgs() != op.getNumArgs()) {
341 PrintFatalError(loc, formatv("op '{0}' argument number mismatch: {1} in "
342 "pattern vs. {2} in definition",
343 op.getOperationName(), tree.getNumArgs(),
344 op.getNumArgs()));
345 }
346
347 // If the operand's name is set, set to that variable.
348 auto name = tree.getSymbol();
349 if (!name.empty())
350 os << formatv("{0} = {1};\n", name, castedName);
351
352 for (int i = 0, e = tree.getNumArgs(), nextOperand = 0; i != e; ++i) {
353 auto opArg = op.getArg(i);
354 std::string argName = formatv("op{0}", depth + 1);
355
356 // Handle nested DAG construct first
357 if (DagNode argTree = tree.getArgAsNestedDag(i)) {
358 if (auto *operand = opArg.dyn_cast<NamedTypeConstraint *>()) {
359 if (operand->isVariableLength()) {
360 auto error = formatv("use nested DAG construct to match op {0}'s "
361 "variadic operand #{1} unsupported now",
362 op.getOperationName(), i);
363 PrintFatalError(loc, error);
364 }
365 }
366 os << "{\n";
367
368 // Attributes don't count for getODSOperands.
369 os.indent() << formatv(
370 "auto *{0} = "
371 "(*{1}.getODSOperands({2}).begin()).getDefiningOp();\n",
372 argName, castedName, nextOperand++);
373 emitMatch(argTree, argName, depth + 1);
374 os << formatv("tblgen_ops[{0}] = {1};\n", ++opCounter, argName);
375 os.unindent() << "}\n";
376 continue;
377 }
378
379 // Next handle DAG leaf: operand or attribute
380 if (opArg.is<NamedTypeConstraint *>()) {
381 // emitOperandMatch's argument indexing counts attributes.
382 emitOperandMatch(tree, castedName, i, depth);
383 ++nextOperand;
384 } else if (opArg.is<NamedAttribute *>()) {
385 emitAttributeMatch(tree, opName, i, depth);
386 } else {
387 PrintFatalError(loc, "unhandled case when matching op");
388 }
389 }
390 LLVM_DEBUG(llvm::dbgs() << "done emitting match for op '"
391 << op.getOperationName() << "' at depth " << depth
392 << '\n');
393 }
394
emitOperandMatch(DagNode tree,StringRef opName,int argIndex,int depth)395 void PatternEmitter::emitOperandMatch(DagNode tree, StringRef opName,
396 int argIndex, int depth) {
397 Operator &op = tree.getDialectOp(opMap);
398 auto *operand = op.getArg(argIndex).get<NamedTypeConstraint *>();
399 auto matcher = tree.getArgAsLeaf(argIndex);
400
401 // If a constraint is specified, we need to generate C++ statements to
402 // check the constraint.
403 if (!matcher.isUnspecified()) {
404 if (!matcher.isOperandMatcher()) {
405 PrintFatalError(
406 loc, formatv("the {1}-th argument of op '{0}' should be an operand",
407 op.getOperationName(), argIndex + 1));
408 }
409
410 // Only need to verify if the matcher's type is different from the one
411 // of op definition.
412 Constraint constraint = matcher.getAsConstraint();
413 if (operand->constraint != constraint) {
414 if (operand->isVariableLength()) {
415 auto error = formatv(
416 "further constrain op {0}'s variadic operand #{1} unsupported now",
417 op.getOperationName(), argIndex);
418 PrintFatalError(loc, error);
419 }
420 auto self = formatv("(*{0}.getODSOperands({1}).begin()).getType()",
421 opName, argIndex);
422 emitMatchCheck(
423 opName,
424 tgfmt(constraint.getConditionTemplate(), &fmtCtx.withSelf(self)),
425 formatv("\"operand {0} of op '{1}' failed to satisfy constraint: "
426 "'{2}'\"",
427 operand - op.operand_begin(), op.getOperationName(),
428 constraint.getDescription()));
429 }
430 }
431
432 // Capture the value
433 auto name = tree.getArgName(argIndex);
434 // `$_` is a special symbol to ignore op argument matching.
435 if (!name.empty() && name != "_") {
436 // We need to subtract the number of attributes before this operand to get
437 // the index in the operand list.
438 auto numPrevAttrs = std::count_if(
439 op.arg_begin(), op.arg_begin() + argIndex,
440 [](const Argument &arg) { return arg.is<NamedAttribute *>(); });
441
442 auto res = symbolInfoMap.findBoundSymbol(name, op, argIndex);
443 os << formatv("{0} = {1}.getODSOperands({2});\n",
444 res->second.getVarName(name), opName,
445 argIndex - numPrevAttrs);
446 }
447 }
448
emitAttributeMatch(DagNode tree,StringRef opName,int argIndex,int depth)449 void PatternEmitter::emitAttributeMatch(DagNode tree, StringRef opName,
450 int argIndex, int depth) {
451 Operator &op = tree.getDialectOp(opMap);
452 auto *namedAttr = op.getArg(argIndex).get<NamedAttribute *>();
453 const auto &attr = namedAttr->attr;
454
455 os << "{\n";
456 os.indent() << formatv("auto tblgen_attr = {0}->getAttrOfType<{1}>(\"{2}\");"
457 "(void)tblgen_attr;\n",
458 opName, attr.getStorageType(), namedAttr->name);
459
460 // TODO: This should use getter method to avoid duplication.
461 if (attr.hasDefaultValue()) {
462 os << "if (!tblgen_attr) tblgen_attr = "
463 << std::string(tgfmt(attr.getConstBuilderTemplate(), &fmtCtx,
464 attr.getDefaultValue()))
465 << ";\n";
466 } else if (attr.isOptional()) {
467 // For a missing attribute that is optional according to definition, we
468 // should just capture a mlir::Attribute() to signal the missing state.
469 // That is precisely what getAttr() returns on missing attributes.
470 } else {
471 emitMatchCheck(opName, tgfmt("tblgen_attr", &fmtCtx),
472 formatv("\"expected op '{0}' to have attribute '{1}' "
473 "of type '{2}'\"",
474 op.getOperationName(), namedAttr->name,
475 attr.getStorageType()));
476 }
477
478 auto matcher = tree.getArgAsLeaf(argIndex);
479 if (!matcher.isUnspecified()) {
480 if (!matcher.isAttrMatcher()) {
481 PrintFatalError(
482 loc, formatv("the {1}-th argument of op '{0}' should be an attribute",
483 op.getOperationName(), argIndex + 1));
484 }
485
486 // If a constraint is specified, we need to generate C++ statements to
487 // check the constraint.
488 emitMatchCheck(
489 opName,
490 tgfmt(matcher.getConditionTemplate(), &fmtCtx.withSelf("tblgen_attr")),
491 formatv("\"op '{0}' attribute '{1}' failed to satisfy constraint: "
492 "{2}\"",
493 op.getOperationName(), namedAttr->name,
494 matcher.getAsConstraint().getDescription()));
495 }
496
497 // Capture the value
498 auto name = tree.getArgName(argIndex);
499 // `$_` is a special symbol to ignore op argument matching.
500 if (!name.empty() && name != "_") {
501 os << formatv("{0} = tblgen_attr;\n", name);
502 }
503
504 os.unindent() << "}\n";
505 }
506
emitMatchCheck(StringRef opName,const FmtObjectBase & matchFmt,const llvm::formatv_object_base & failureFmt)507 void PatternEmitter::emitMatchCheck(
508 StringRef opName, const FmtObjectBase &matchFmt,
509 const llvm::formatv_object_base &failureFmt) {
510 emitMatchCheck(opName, matchFmt.str(), failureFmt.str());
511 }
512
emitMatchCheck(StringRef opName,const std::string & matchStr,const std::string & failureStr)513 void PatternEmitter::emitMatchCheck(StringRef opName,
514 const std::string &matchStr,
515 const std::string &failureStr) {
516
517 os << "if (!(" << matchStr << "))";
518 os.scope("{\n", "\n}\n").os << "return rewriter.notifyMatchFailure(" << opName
519 << ", [&](::mlir::Diagnostic &diag) {\n diag << "
520 << failureStr << ";\n});";
521 }
522
emitMatchLogic(DagNode tree,StringRef opName)523 void PatternEmitter::emitMatchLogic(DagNode tree, StringRef opName) {
524 LLVM_DEBUG(llvm::dbgs() << "--- start emitting match logic ---\n");
525 int depth = 0;
526 emitMatch(tree, opName, depth);
527
528 for (auto &appliedConstraint : pattern.getConstraints()) {
529 auto &constraint = appliedConstraint.constraint;
530 auto &entities = appliedConstraint.entities;
531
532 auto condition = constraint.getConditionTemplate();
533 if (isa<TypeConstraint>(constraint)) {
534 auto self = formatv("({0}.getType())",
535 symbolInfoMap.getValueAndRangeUse(entities.front()));
536 emitMatchCheck(
537 opName, tgfmt(condition, &fmtCtx.withSelf(self.str())),
538 formatv("\"value entity '{0}' failed to satisfy constraint: {1}\"",
539 entities.front(), constraint.getDescription()));
540
541 } else if (isa<AttrConstraint>(constraint)) {
542 PrintFatalError(
543 loc, "cannot use AttrConstraint in Pattern multi-entity constraints");
544 } else {
545 // TODO: replace formatv arguments with the exact specified
546 // args.
547 if (entities.size() > 4) {
548 PrintFatalError(loc, "only support up to 4-entity constraints now");
549 }
550 SmallVector<std::string, 4> names;
551 int i = 0;
552 for (int e = entities.size(); i < e; ++i)
553 names.push_back(symbolInfoMap.getValueAndRangeUse(entities[i]));
554 std::string self = appliedConstraint.self;
555 if (!self.empty())
556 self = symbolInfoMap.getValueAndRangeUse(self);
557 for (; i < 4; ++i)
558 names.push_back("<unused>");
559 emitMatchCheck(opName,
560 tgfmt(condition, &fmtCtx.withSelf(self), names[0],
561 names[1], names[2], names[3]),
562 formatv("\"entities '{0}' failed to satisfy constraint: "
563 "{1}\"",
564 llvm::join(entities, ", "),
565 constraint.getDescription()));
566 }
567 }
568
569 // Some of the operands could be bound to the same symbol name, we need
570 // to enforce equality constraint on those.
571 // TODO: we should be able to emit equality checks early
572 // and short circuit unnecessary work if vars are not equal.
573 for (auto symbolInfoIt = symbolInfoMap.begin();
574 symbolInfoIt != symbolInfoMap.end();) {
575 auto range = symbolInfoMap.getRangeOfEqualElements(symbolInfoIt->first);
576 auto startRange = range.first;
577 auto endRange = range.second;
578
579 auto firstOperand = symbolInfoIt->second.getVarName(symbolInfoIt->first);
580 for (++startRange; startRange != endRange; ++startRange) {
581 auto secondOperand = startRange->second.getVarName(symbolInfoIt->first);
582 emitMatchCheck(
583 opName,
584 formatv("*{0}.begin() == *{1}.begin()", firstOperand, secondOperand),
585 formatv("\"Operands '{0}' and '{1}' must be equal\"", firstOperand,
586 secondOperand));
587 }
588
589 symbolInfoIt = endRange;
590 }
591
592 LLVM_DEBUG(llvm::dbgs() << "--- done emitting match logic ---\n");
593 }
594
collectOps(DagNode tree,llvm::SmallPtrSetImpl<const Operator * > & ops)595 void PatternEmitter::collectOps(DagNode tree,
596 llvm::SmallPtrSetImpl<const Operator *> &ops) {
597 // Check if this tree is an operation.
598 if (tree.isOperation()) {
599 const Operator &op = tree.getDialectOp(opMap);
600 LLVM_DEBUG(llvm::dbgs()
601 << "found operation " << op.getOperationName() << '\n');
602 ops.insert(&op);
603 }
604
605 // Recurse the arguments of the tree.
606 for (unsigned i = 0, e = tree.getNumArgs(); i != e; ++i)
607 if (auto child = tree.getArgAsNestedDag(i))
608 collectOps(child, ops);
609 }
610
emit(StringRef rewriteName)611 void PatternEmitter::emit(StringRef rewriteName) {
612 // Get the DAG tree for the source pattern.
613 DagNode sourceTree = pattern.getSourcePattern();
614
615 const Operator &rootOp = pattern.getSourceRootOp();
616 auto rootName = rootOp.getOperationName();
617
618 // Collect the set of result operations.
619 llvm::SmallPtrSet<const Operator *, 4> resultOps;
620 LLVM_DEBUG(llvm::dbgs() << "start collecting ops used in result patterns\n");
621 for (unsigned i = 0, e = pattern.getNumResultPatterns(); i != e; ++i) {
622 collectOps(pattern.getResultPattern(i), resultOps);
623 }
624 LLVM_DEBUG(llvm::dbgs() << "done collecting ops used in result patterns\n");
625
626 // Emit RewritePattern for Pattern.
627 auto locs = pattern.getLocation();
628 os << formatv("/* Generated from:\n {0:$[ instantiating\n ]}\n*/\n",
629 make_range(locs.rbegin(), locs.rend()));
630 os << formatv(R"(struct {0} : public ::mlir::RewritePattern {
631 {0}(::mlir::MLIRContext *context)
632 : ::mlir::RewritePattern("{1}", {{)",
633 rewriteName, rootName);
634 // Sort result operators by name.
635 llvm::SmallVector<const Operator *, 4> sortedResultOps(resultOps.begin(),
636 resultOps.end());
637 llvm::sort(sortedResultOps, [&](const Operator *lhs, const Operator *rhs) {
638 return lhs->getOperationName() < rhs->getOperationName();
639 });
640 llvm::interleaveComma(sortedResultOps, os, [&](const Operator *op) {
641 os << '"' << op->getOperationName() << '"';
642 });
643 os << formatv(R"(}, {0}, context) {{})", pattern.getBenefit()) << "\n";
644
645 // Emit matchAndRewrite() function.
646 {
647 auto classScope = os.scope();
648 os.reindent(R"(
649 ::mlir::LogicalResult matchAndRewrite(::mlir::Operation *op0,
650 ::mlir::PatternRewriter &rewriter) const override {)")
651 << '\n';
652 {
653 auto functionScope = os.scope();
654
655 // Register all symbols bound in the source pattern.
656 pattern.collectSourcePatternBoundSymbols(symbolInfoMap);
657
658 LLVM_DEBUG(llvm::dbgs()
659 << "start creating local variables for capturing matches\n");
660 os << "// Variables for capturing values and attributes used while "
661 "creating ops\n";
662 // Create local variables for storing the arguments and results bound
663 // to symbols.
664 for (const auto &symbolInfoPair : symbolInfoMap) {
665 const auto &symbol = symbolInfoPair.first;
666 const auto &info = symbolInfoPair.second;
667
668 os << info.getVarDecl(symbol);
669 }
670 // TODO: capture ops with consistent numbering so that it can be
671 // reused for fused loc.
672 os << formatv("::mlir::Operation *tblgen_ops[{0}];\n\n",
673 pattern.getSourcePattern().getNumOps());
674 LLVM_DEBUG(llvm::dbgs()
675 << "done creating local variables for capturing matches\n");
676
677 os << "// Match\n";
678 os << "tblgen_ops[0] = op0;\n";
679 emitMatchLogic(sourceTree, "op0");
680
681 os << "\n// Rewrite\n";
682 emitRewriteLogic();
683
684 os << "return ::mlir::success();\n";
685 }
686 os << "};\n";
687 }
688 os << "};\n\n";
689 }
690
emitRewriteLogic()691 void PatternEmitter::emitRewriteLogic() {
692 LLVM_DEBUG(llvm::dbgs() << "--- start emitting rewrite logic ---\n");
693 const Operator &rootOp = pattern.getSourceRootOp();
694 int numExpectedResults = rootOp.getNumResults();
695 int numResultPatterns = pattern.getNumResultPatterns();
696
697 // First register all symbols bound to ops generated in result patterns.
698 pattern.collectResultPatternBoundSymbols(symbolInfoMap);
699
700 // Only the last N static values generated are used to replace the matched
701 // root N-result op. We need to calculate the starting index (of the results
702 // of the matched op) each result pattern is to replace.
703 SmallVector<int, 4> offsets(numResultPatterns + 1, numExpectedResults);
704 // If we don't need to replace any value at all, set the replacement starting
705 // index as the number of result patterns so we skip all of them when trying
706 // to replace the matched op's results.
707 int replStartIndex = numExpectedResults == 0 ? numResultPatterns : -1;
708 for (int i = numResultPatterns - 1; i >= 0; --i) {
709 auto numValues = getNodeValueCount(pattern.getResultPattern(i));
710 offsets[i] = offsets[i + 1] - numValues;
711 if (offsets[i] == 0) {
712 if (replStartIndex == -1)
713 replStartIndex = i;
714 } else if (offsets[i] < 0 && offsets[i + 1] > 0) {
715 auto error = formatv(
716 "cannot use the same multi-result op '{0}' to generate both "
717 "auxiliary values and values to be used for replacing the matched op",
718 pattern.getResultPattern(i).getSymbol());
719 PrintFatalError(loc, error);
720 }
721 }
722
723 if (offsets.front() > 0) {
724 const char error[] = "no enough values generated to replace the matched op";
725 PrintFatalError(loc, error);
726 }
727
728 os << "auto odsLoc = rewriter.getFusedLoc({";
729 for (int i = 0, e = pattern.getSourcePattern().getNumOps(); i != e; ++i) {
730 os << (i ? ", " : "") << "tblgen_ops[" << i << "]->getLoc()";
731 }
732 os << "}); (void)odsLoc;\n";
733
734 // Process auxiliary result patterns.
735 for (int i = 0; i < replStartIndex; ++i) {
736 DagNode resultTree = pattern.getResultPattern(i);
737 auto val = handleResultPattern(resultTree, offsets[i], 0);
738 // Normal op creation will be streamed to `os` by the above call; but
739 // NativeCodeCall will only be materialized to `os` if it is used. Here
740 // we are handling auxiliary patterns so we want the side effect even if
741 // NativeCodeCall is not replacing matched root op's results.
742 if (resultTree.isNativeCodeCall())
743 os << val << ";\n";
744 }
745
746 if (numExpectedResults == 0) {
747 assert(replStartIndex >= numResultPatterns &&
748 "invalid auxiliary vs. replacement pattern division!");
749 // No result to replace. Just erase the op.
750 os << "rewriter.eraseOp(op0);\n";
751 } else {
752 // Process replacement result patterns.
753 os << "::llvm::SmallVector<::mlir::Value, 4> tblgen_repl_values;\n";
754 for (int i = replStartIndex; i < numResultPatterns; ++i) {
755 DagNode resultTree = pattern.getResultPattern(i);
756 auto val = handleResultPattern(resultTree, offsets[i], 0);
757 os << "\n";
758 // Resolve each symbol for all range use so that we can loop over them.
759 // We need an explicit cast to `SmallVector` to capture the cases where
760 // `{0}` resolves to an `Operation::result_range` as well as cases that
761 // are not iterable (e.g. vector that gets wrapped in additional braces by
762 // RewriterGen).
763 // TODO: Revisit the need for materializing a vector.
764 os << symbolInfoMap.getAllRangeUse(
765 val,
766 "for (auto v: ::llvm::SmallVector<::mlir::Value, 4>{ {0} }) {{\n"
767 " tblgen_repl_values.push_back(v);\n}\n",
768 "\n");
769 }
770 os << "\nrewriter.replaceOp(op0, tblgen_repl_values);\n";
771 }
772
773 LLVM_DEBUG(llvm::dbgs() << "--- done emitting rewrite logic ---\n");
774 }
775
getUniqueSymbol(const Operator * op)776 std::string PatternEmitter::getUniqueSymbol(const Operator *op) {
777 return std::string(
778 formatv("tblgen_{0}_{1}", op->getCppClassName(), nextValueId++));
779 }
780
handleResultPattern(DagNode resultTree,int resultIndex,int depth)781 std::string PatternEmitter::handleResultPattern(DagNode resultTree,
782 int resultIndex, int depth) {
783 LLVM_DEBUG(llvm::dbgs() << "handle result pattern: ");
784 LLVM_DEBUG(resultTree.print(llvm::dbgs()));
785 LLVM_DEBUG(llvm::dbgs() << '\n');
786
787 if (resultTree.isLocationDirective()) {
788 PrintFatalError(loc,
789 "location directive can only be used with op creation");
790 }
791
792 if (resultTree.isNativeCodeCall()) {
793 auto symbol = handleReplaceWithNativeCodeCall(resultTree, depth);
794 symbolInfoMap.bindValue(symbol);
795 return symbol;
796 }
797
798 if (resultTree.isReplaceWithValue())
799 return handleReplaceWithValue(resultTree).str();
800
801 // Normal op creation.
802 auto symbol = handleOpCreation(resultTree, resultIndex, depth);
803 if (resultTree.getSymbol().empty()) {
804 // This is an op not explicitly bound to a symbol in the rewrite rule.
805 // Register the auto-generated symbol for it.
806 symbolInfoMap.bindOpResult(symbol, pattern.getDialectOp(resultTree));
807 }
808 return symbol;
809 }
810
handleReplaceWithValue(DagNode tree)811 StringRef PatternEmitter::handleReplaceWithValue(DagNode tree) {
812 assert(tree.isReplaceWithValue());
813
814 if (tree.getNumArgs() != 1) {
815 PrintFatalError(
816 loc, "replaceWithValue directive must take exactly one argument");
817 }
818
819 if (!tree.getSymbol().empty()) {
820 PrintFatalError(loc, "cannot bind symbol to replaceWithValue");
821 }
822
823 return tree.getArgName(0);
824 }
825
handleLocationDirective(DagNode tree)826 std::string PatternEmitter::handleLocationDirective(DagNode tree) {
827 assert(tree.isLocationDirective());
828 auto lookUpArgLoc = [this, &tree](int idx) {
829 const auto *const lookupFmt = "(*{0}.begin()).getLoc()";
830 return symbolInfoMap.getAllRangeUse(tree.getArgName(idx), lookupFmt);
831 };
832
833 if (tree.getNumArgs() == 0)
834 llvm::PrintFatalError(
835 "At least one argument to location directive required");
836
837 if (!tree.getSymbol().empty())
838 PrintFatalError(loc, "cannot bind symbol to location");
839
840 if (tree.getNumArgs() == 1) {
841 DagLeaf leaf = tree.getArgAsLeaf(0);
842 if (leaf.isStringAttr())
843 return formatv("::mlir::NameLoc::get(rewriter.getIdentifier(\"{0}\"), "
844 "rewriter.getContext())",
845 leaf.getStringAttr())
846 .str();
847 return lookUpArgLoc(0);
848 }
849
850 std::string ret;
851 llvm::raw_string_ostream os(ret);
852 std::string strAttr;
853 os << "rewriter.getFusedLoc({";
854 bool first = true;
855 for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
856 DagLeaf leaf = tree.getArgAsLeaf(i);
857 // Handle the optional string value.
858 if (leaf.isStringAttr()) {
859 if (!strAttr.empty())
860 llvm::PrintFatalError("Only one string attribute may be specified");
861 strAttr = leaf.getStringAttr();
862 continue;
863 }
864 os << (first ? "" : ", ") << lookUpArgLoc(i);
865 first = false;
866 }
867 os << "}";
868 if (!strAttr.empty()) {
869 os << ", rewriter.getStringAttr(\"" << strAttr << "\")";
870 }
871 os << ")";
872 return os.str();
873 }
874
handleOpArgument(DagLeaf leaf,StringRef patArgName)875 std::string PatternEmitter::handleOpArgument(DagLeaf leaf,
876 StringRef patArgName) {
877 if (leaf.isStringAttr())
878 PrintFatalError(loc, "raw string not supported as argument");
879 if (leaf.isConstantAttr()) {
880 auto constAttr = leaf.getAsConstantAttr();
881 return handleConstantAttr(constAttr.getAttribute(),
882 constAttr.getConstantValue());
883 }
884 if (leaf.isEnumAttrCase()) {
885 auto enumCase = leaf.getAsEnumAttrCase();
886 if (enumCase.isStrCase())
887 return handleConstantAttr(enumCase, enumCase.getSymbol());
888 // This is an enum case backed by an IntegerAttr. We need to get its value
889 // to build the constant.
890 std::string val = std::to_string(enumCase.getValue());
891 return handleConstantAttr(enumCase, val);
892 }
893
894 LLVM_DEBUG(llvm::dbgs() << "handle argument '" << patArgName << "'\n");
895 auto argName = symbolInfoMap.getValueAndRangeUse(patArgName);
896 if (leaf.isUnspecified() || leaf.isOperandMatcher()) {
897 LLVM_DEBUG(llvm::dbgs() << "replace " << patArgName << " with '" << argName
898 << "' (via symbol ref)\n");
899 return argName;
900 }
901 if (leaf.isNativeCodeCall()) {
902 auto repl = tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(argName));
903 LLVM_DEBUG(llvm::dbgs() << "replace " << patArgName << " with '" << repl
904 << "' (via NativeCodeCall)\n");
905 return std::string(repl);
906 }
907 PrintFatalError(loc, "unhandled case when rewriting op");
908 }
909
handleReplaceWithNativeCodeCall(DagNode tree,int depth)910 std::string PatternEmitter::handleReplaceWithNativeCodeCall(DagNode tree,
911 int depth) {
912 LLVM_DEBUG(llvm::dbgs() << "handle NativeCodeCall pattern: ");
913 LLVM_DEBUG(tree.print(llvm::dbgs()));
914 LLVM_DEBUG(llvm::dbgs() << '\n');
915
916 auto fmt = tree.getNativeCodeTemplate();
917 // TODO: replace formatv arguments with the exact specified args.
918 SmallVector<std::string, 8> attrs(8);
919 if (tree.getNumArgs() > 8) {
920 PrintFatalError(loc,
921 "unsupported NativeCodeCall replace argument numbers: " +
922 Twine(tree.getNumArgs()));
923 }
924 bool hasLocationDirective;
925 std::string locToUse;
926 std::tie(hasLocationDirective, locToUse) = getLocation(tree);
927
928 for (int i = 0, e = tree.getNumArgs() - hasLocationDirective; i != e; ++i) {
929 if (tree.isNestedDagArg(i)) {
930 attrs[i] = handleResultPattern(tree.getArgAsNestedDag(i), i, depth + 1);
931 } else {
932 attrs[i] = handleOpArgument(tree.getArgAsLeaf(i), tree.getArgName(i));
933 }
934 LLVM_DEBUG(llvm::dbgs() << "NativeCodeCall argument #" << i
935 << " replacement: " << attrs[i] << "\n");
936 }
937 return std::string(tgfmt(fmt, &fmtCtx.addSubst("_loc", locToUse), attrs[0],
938 attrs[1], attrs[2], attrs[3], attrs[4], attrs[5],
939 attrs[6], attrs[7]));
940 }
941
getNodeValueCount(DagNode node)942 int PatternEmitter::getNodeValueCount(DagNode node) {
943 if (node.isOperation()) {
944 // If the op is bound to a symbol in the rewrite rule, query its result
945 // count from the symbol info map.
946 auto symbol = node.getSymbol();
947 if (!symbol.empty()) {
948 return symbolInfoMap.getStaticValueCount(symbol);
949 }
950 // Otherwise this is an unbound op; we will use all its results.
951 return pattern.getDialectOp(node).getNumResults();
952 }
953 // TODO: This considers all NativeCodeCall as returning one
954 // value. Enhance if multi-value ones are needed.
955 return 1;
956 }
957
getLocation(DagNode tree)958 std::pair<bool, std::string> PatternEmitter::getLocation(DagNode tree) {
959 auto numPatArgs = tree.getNumArgs();
960
961 if (numPatArgs != 0) {
962 if (auto lastArg = tree.getArgAsNestedDag(numPatArgs - 1))
963 if (lastArg.isLocationDirective()) {
964 return std::make_pair(true, handleLocationDirective(lastArg));
965 }
966 }
967
968 // If no explicit location is given, use the default, all fused, location.
969 return std::make_pair(false, "odsLoc");
970 }
971
handleOpCreation(DagNode tree,int resultIndex,int depth)972 std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
973 int depth) {
974 LLVM_DEBUG(llvm::dbgs() << "create op for pattern: ");
975 LLVM_DEBUG(tree.print(llvm::dbgs()));
976 LLVM_DEBUG(llvm::dbgs() << '\n');
977
978 Operator &resultOp = tree.getDialectOp(opMap);
979 auto numOpArgs = resultOp.getNumArgs();
980 auto numPatArgs = tree.getNumArgs();
981
982 bool hasLocationDirective;
983 std::string locToUse;
984 std::tie(hasLocationDirective, locToUse) = getLocation(tree);
985
986 auto inPattern = numPatArgs - hasLocationDirective;
987 if (numOpArgs != inPattern) {
988 PrintFatalError(loc,
989 formatv("resultant op '{0}' argument number mismatch: "
990 "{1} in pattern vs. {2} in definition",
991 resultOp.getOperationName(), inPattern, numOpArgs));
992 }
993
994 // A map to collect all nested DAG child nodes' names, with operand index as
995 // the key. This includes both bound and unbound child nodes.
996 ChildNodeIndexNameMap childNodeNames;
997
998 // First go through all the child nodes who are nested DAG constructs to
999 // create ops for them and remember the symbol names for them, so that we can
1000 // use the results in the current node. This happens in a recursive manner.
1001 for (int i = 0, e = resultOp.getNumOperands(); i != e; ++i) {
1002 if (auto child = tree.getArgAsNestedDag(i))
1003 childNodeNames[i] = handleResultPattern(child, i, depth + 1);
1004 }
1005
1006 // The name of the local variable holding this op.
1007 std::string valuePackName;
1008 // The symbol for holding the result of this pattern. Note that the result of
1009 // this pattern is not necessarily the same as the variable created by this
1010 // pattern because we can use `__N` suffix to refer only a specific result if
1011 // the generated op is a multi-result op.
1012 std::string resultValue;
1013 if (tree.getSymbol().empty()) {
1014 // No symbol is explicitly bound to this op in the pattern. Generate a
1015 // unique name.
1016 valuePackName = resultValue = getUniqueSymbol(&resultOp);
1017 } else {
1018 resultValue = std::string(tree.getSymbol());
1019 // Strip the index to get the name for the value pack and use it to name the
1020 // local variable for the op.
1021 valuePackName = std::string(SymbolInfoMap::getValuePackName(resultValue));
1022 }
1023
1024 // Create the local variable for this op.
1025 os << formatv("{0} {1};\n{{\n", resultOp.getQualCppClassName(),
1026 valuePackName);
1027
1028 // Right now ODS don't have general type inference support. Except a few
1029 // special cases listed below, DRR needs to supply types for all results
1030 // when building an op.
1031 bool isSameOperandsAndResultType =
1032 resultOp.getTrait("::mlir::OpTrait::SameOperandsAndResultType");
1033 bool useFirstAttr =
1034 resultOp.getTrait("::mlir::OpTrait::FirstAttrDerivedResultType");
1035
1036 if (isSameOperandsAndResultType || useFirstAttr) {
1037 // We know how to deduce the result type for ops with these traits and we've
1038 // generated builders taking aggregate parameters. Use those builders to
1039 // create the ops.
1040
1041 // First prepare local variables for op arguments used in builder call.
1042 createAggregateLocalVarsForOpArgs(tree, childNodeNames, depth);
1043
1044 // Then create the op.
1045 os.scope("", "\n}\n").os << formatv(
1046 "{0} = rewriter.create<{1}>({2}, tblgen_values, tblgen_attrs);",
1047 valuePackName, resultOp.getQualCppClassName(), locToUse);
1048 return resultValue;
1049 }
1050
1051 bool usePartialResults = valuePackName != resultValue;
1052
1053 if (usePartialResults || depth > 0 || resultIndex < 0) {
1054 // For these cases (broadcastable ops, op results used both as auxiliary
1055 // values and replacement values, ops in nested patterns, auxiliary ops), we
1056 // still need to supply the result types when building the op. But because
1057 // we don't generate a builder automatically with ODS for them, it's the
1058 // developer's responsibility to make sure such a builder (with result type
1059 // deduction ability) exists. We go through the separate-parameter builder
1060 // here given that it's easier for developers to write compared to
1061 // aggregate-parameter builders.
1062 createSeparateLocalVarsForOpArgs(tree, childNodeNames);
1063
1064 os.scope().os << formatv("{0} = rewriter.create<{1}>({2}", valuePackName,
1065 resultOp.getQualCppClassName(), locToUse);
1066 supplyValuesForOpArgs(tree, childNodeNames, depth);
1067 os << "\n );\n}\n";
1068 return resultValue;
1069 }
1070
1071 // If depth == 0 and resultIndex >= 0, it means we are replacing the values
1072 // generated from the source pattern root op. Then we can use the source
1073 // pattern's value types to determine the value type of the generated op
1074 // here.
1075
1076 // First prepare local variables for op arguments used in builder call.
1077 createAggregateLocalVarsForOpArgs(tree, childNodeNames, depth);
1078
1079 // Then prepare the result types. We need to specify the types for all
1080 // results.
1081 os.indent() << formatv("::mlir::SmallVector<::mlir::Type, 4> tblgen_types; "
1082 "(void)tblgen_types;\n");
1083 int numResults = resultOp.getNumResults();
1084 if (numResults != 0) {
1085 for (int i = 0; i < numResults; ++i)
1086 os << formatv("for (auto v: castedOp0.getODSResults({0})) {{\n"
1087 " tblgen_types.push_back(v.getType());\n}\n",
1088 resultIndex + i);
1089 }
1090 os << formatv("{0} = rewriter.create<{1}>({2}, tblgen_types, "
1091 "tblgen_values, tblgen_attrs);\n",
1092 valuePackName, resultOp.getQualCppClassName(), locToUse);
1093 os.unindent() << "}\n";
1094 return resultValue;
1095 }
1096
createSeparateLocalVarsForOpArgs(DagNode node,ChildNodeIndexNameMap & childNodeNames)1097 void PatternEmitter::createSeparateLocalVarsForOpArgs(
1098 DagNode node, ChildNodeIndexNameMap &childNodeNames) {
1099 Operator &resultOp = node.getDialectOp(opMap);
1100
1101 // Now prepare operands used for building this op:
1102 // * If the operand is non-variadic, we create a `Value` local variable.
1103 // * If the operand is variadic, we create a `SmallVector<Value>` local
1104 // variable.
1105
1106 int valueIndex = 0; // An index for uniquing local variable names.
1107 for (int argIndex = 0, e = resultOp.getNumArgs(); argIndex < e; ++argIndex) {
1108 const auto *operand =
1109 resultOp.getArg(argIndex).dyn_cast<NamedTypeConstraint *>();
1110 // We do not need special handling for attributes.
1111 if (!operand)
1112 continue;
1113
1114 raw_indented_ostream::DelimitedScope scope(os);
1115 std::string varName;
1116 if (operand->isVariadic()) {
1117 varName = std::string(formatv("tblgen_values_{0}", valueIndex++));
1118 os << formatv("::mlir::SmallVector<::mlir::Value, 4> {0};\n", varName);
1119 std::string range;
1120 if (node.isNestedDagArg(argIndex)) {
1121 range = childNodeNames[argIndex];
1122 } else {
1123 range = std::string(node.getArgName(argIndex));
1124 }
1125 // Resolve the symbol for all range use so that we have a uniform way of
1126 // capturing the values.
1127 range = symbolInfoMap.getValueAndRangeUse(range);
1128 os << formatv("for (auto v: {0}) {{\n {1}.push_back(v);\n}\n", range,
1129 varName);
1130 } else {
1131 varName = std::string(formatv("tblgen_value_{0}", valueIndex++));
1132 os << formatv("::mlir::Value {0} = ", varName);
1133 if (node.isNestedDagArg(argIndex)) {
1134 os << symbolInfoMap.getValueAndRangeUse(childNodeNames[argIndex]);
1135 } else {
1136 DagLeaf leaf = node.getArgAsLeaf(argIndex);
1137 auto symbol =
1138 symbolInfoMap.getValueAndRangeUse(node.getArgName(argIndex));
1139 if (leaf.isNativeCodeCall()) {
1140 os << std::string(
1141 tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(symbol)));
1142 } else {
1143 os << symbol;
1144 }
1145 }
1146 os << ";\n";
1147 }
1148
1149 // Update to use the newly created local variable for building the op later.
1150 childNodeNames[argIndex] = varName;
1151 }
1152 }
1153
supplyValuesForOpArgs(DagNode node,const ChildNodeIndexNameMap & childNodeNames,int depth)1154 void PatternEmitter::supplyValuesForOpArgs(
1155 DagNode node, const ChildNodeIndexNameMap &childNodeNames, int depth) {
1156 Operator &resultOp = node.getDialectOp(opMap);
1157 for (int argIndex = 0, numOpArgs = resultOp.getNumArgs();
1158 argIndex != numOpArgs; ++argIndex) {
1159 // Start each argument on its own line.
1160 os << ",\n ";
1161
1162 Argument opArg = resultOp.getArg(argIndex);
1163 // Handle the case of operand first.
1164 if (auto *operand = opArg.dyn_cast<NamedTypeConstraint *>()) {
1165 if (!operand->name.empty())
1166 os << "/*" << operand->name << "=*/";
1167 os << childNodeNames.lookup(argIndex);
1168 continue;
1169 }
1170
1171 // The argument in the op definition.
1172 auto opArgName = resultOp.getArgName(argIndex);
1173 if (auto subTree = node.getArgAsNestedDag(argIndex)) {
1174 if (!subTree.isNativeCodeCall())
1175 PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node "
1176 "for creating attribute");
1177 os << formatv("/*{0}=*/{1}", opArgName,
1178 handleReplaceWithNativeCodeCall(subTree, depth));
1179 } else {
1180 auto leaf = node.getArgAsLeaf(argIndex);
1181 // The argument in the result DAG pattern.
1182 auto patArgName = node.getArgName(argIndex);
1183 if (leaf.isConstantAttr() || leaf.isEnumAttrCase()) {
1184 // TODO: Refactor out into map to avoid recomputing these.
1185 if (!opArg.is<NamedAttribute *>())
1186 PrintFatalError(loc, Twine("expected attribute ") + Twine(argIndex));
1187 if (!patArgName.empty())
1188 os << "/*" << patArgName << "=*/";
1189 } else {
1190 os << "/*" << opArgName << "=*/";
1191 }
1192 os << handleOpArgument(leaf, patArgName);
1193 }
1194 }
1195 }
1196
createAggregateLocalVarsForOpArgs(DagNode node,const ChildNodeIndexNameMap & childNodeNames,int depth)1197 void PatternEmitter::createAggregateLocalVarsForOpArgs(
1198 DagNode node, const ChildNodeIndexNameMap &childNodeNames, int depth) {
1199 Operator &resultOp = node.getDialectOp(opMap);
1200
1201 auto scope = os.scope();
1202 os << formatv("::mlir::SmallVector<::mlir::Value, 4> "
1203 "tblgen_values; (void)tblgen_values;\n");
1204 os << formatv("::mlir::SmallVector<::mlir::NamedAttribute, 4> "
1205 "tblgen_attrs; (void)tblgen_attrs;\n");
1206
1207 const char *addAttrCmd =
1208 "if (auto tmpAttr = {1}) {\n"
1209 " tblgen_attrs.emplace_back(rewriter.getIdentifier(\"{0}\"), "
1210 "tmpAttr);\n}\n";
1211 for (int argIndex = 0, e = resultOp.getNumArgs(); argIndex < e; ++argIndex) {
1212 if (resultOp.getArg(argIndex).is<NamedAttribute *>()) {
1213 // The argument in the op definition.
1214 auto opArgName = resultOp.getArgName(argIndex);
1215 if (auto subTree = node.getArgAsNestedDag(argIndex)) {
1216 if (!subTree.isNativeCodeCall())
1217 PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node "
1218 "for creating attribute");
1219 os << formatv(addAttrCmd, opArgName,
1220 handleReplaceWithNativeCodeCall(subTree, depth + 1));
1221 } else {
1222 auto leaf = node.getArgAsLeaf(argIndex);
1223 // The argument in the result DAG pattern.
1224 auto patArgName = node.getArgName(argIndex);
1225 os << formatv(addAttrCmd, opArgName,
1226 handleOpArgument(leaf, patArgName));
1227 }
1228 continue;
1229 }
1230
1231 const auto *operand =
1232 resultOp.getArg(argIndex).get<NamedTypeConstraint *>();
1233 std::string varName;
1234 if (operand->isVariadic()) {
1235 std::string range;
1236 if (node.isNestedDagArg(argIndex)) {
1237 range = childNodeNames.lookup(argIndex);
1238 } else {
1239 range = std::string(node.getArgName(argIndex));
1240 }
1241 // Resolve the symbol for all range use so that we have a uniform way of
1242 // capturing the values.
1243 range = symbolInfoMap.getValueAndRangeUse(range);
1244 os << formatv("for (auto v: {0}) {{\n tblgen_values.push_back(v);\n}\n",
1245 range);
1246 } else {
1247 os << formatv("tblgen_values.push_back(");
1248 if (node.isNestedDagArg(argIndex)) {
1249 os << symbolInfoMap.getValueAndRangeUse(
1250 childNodeNames.lookup(argIndex));
1251 } else {
1252 DagLeaf leaf = node.getArgAsLeaf(argIndex);
1253 auto symbol =
1254 symbolInfoMap.getValueAndRangeUse(node.getArgName(argIndex));
1255 if (leaf.isNativeCodeCall()) {
1256 os << std::string(
1257 tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(symbol)));
1258 } else {
1259 os << symbol;
1260 }
1261 }
1262 os << ");\n";
1263 }
1264 }
1265 }
1266
emitRewriters(const RecordKeeper & recordKeeper,raw_ostream & os)1267 static void emitRewriters(const RecordKeeper &recordKeeper, raw_ostream &os) {
1268 emitSourceFileHeader("Rewriters", os);
1269
1270 const auto &patterns = recordKeeper.getAllDerivedDefinitions("Pattern");
1271 auto numPatterns = patterns.size();
1272
1273 // We put the map here because it can be shared among multiple patterns.
1274 RecordOperatorMap recordOpMap;
1275
1276 std::vector<std::string> rewriterNames;
1277 rewriterNames.reserve(numPatterns);
1278
1279 std::string baseRewriterName = "GeneratedConvert";
1280 int rewriterIndex = 0;
1281
1282 for (Record *p : patterns) {
1283 std::string name;
1284 if (p->isAnonymous()) {
1285 // If no name is provided, ensure unique rewriter names simply by
1286 // appending unique suffix.
1287 name = baseRewriterName + llvm::utostr(rewriterIndex++);
1288 } else {
1289 name = std::string(p->getName());
1290 }
1291 LLVM_DEBUG(llvm::dbgs()
1292 << "=== start generating pattern '" << name << "' ===\n");
1293 PatternEmitter(p, &recordOpMap, os).emit(name);
1294 LLVM_DEBUG(llvm::dbgs()
1295 << "=== done generating pattern '" << name << "' ===\n");
1296 rewriterNames.push_back(std::move(name));
1297 }
1298
1299 // Emit function to add the generated matchers to the pattern list.
1300 os << "void LLVM_ATTRIBUTE_UNUSED populateWithGenerated(::mlir::MLIRContext "
1301 "*context, ::mlir::OwningRewritePatternList &patterns) {\n";
1302 for (const auto &name : rewriterNames) {
1303 os << " patterns.insert<" << name << ">(context);\n";
1304 }
1305 os << "}\n";
1306 }
1307
1308 static mlir::GenRegistration
1309 genRewriters("gen-rewriters", "Generate pattern rewriters",
__anon08e8fbd40602(const RecordKeeper &records, raw_ostream &os) 1310 [](const RecordKeeper &records, raw_ostream &os) {
1311 emitRewriters(records, os);
1312 return false;
1313 });
1314