1 //===- Pattern.cpp - Pattern wrapper class --------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Pattern wrapper class to simplify using TableGen Record defining a MLIR
10 // Pattern.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "mlir/TableGen/Pattern.h"
15 #include "llvm/ADT/StringExtras.h"
16 #include "llvm/ADT/Twine.h"
17 #include "llvm/Support/Debug.h"
18 #include "llvm/Support/FormatVariadic.h"
19 #include "llvm/TableGen/Error.h"
20 #include "llvm/TableGen/Record.h"
21
22 #define DEBUG_TYPE "mlir-tblgen-pattern"
23
24 using namespace mlir;
25 using namespace tblgen;
26
27 using llvm::formatv;
28
29 //===----------------------------------------------------------------------===//
30 // DagLeaf
31 //===----------------------------------------------------------------------===//
32
isUnspecified() const33 bool DagLeaf::isUnspecified() const {
34 return dyn_cast_or_null<llvm::UnsetInit>(def);
35 }
36
isOperandMatcher() const37 bool DagLeaf::isOperandMatcher() const {
38 // Operand matchers specify a type constraint.
39 return isSubClassOf("TypeConstraint");
40 }
41
isAttrMatcher() const42 bool DagLeaf::isAttrMatcher() const {
43 // Attribute matchers specify an attribute constraint.
44 return isSubClassOf("AttrConstraint");
45 }
46
isNativeCodeCall() const47 bool DagLeaf::isNativeCodeCall() const {
48 return isSubClassOf("NativeCodeCall");
49 }
50
isConstantAttr() const51 bool DagLeaf::isConstantAttr() const { return isSubClassOf("ConstantAttr"); }
52
isEnumAttrCase() const53 bool DagLeaf::isEnumAttrCase() const {
54 return isSubClassOf("EnumAttrCaseInfo");
55 }
56
isStringAttr() const57 bool DagLeaf::isStringAttr() const {
58 return isa<llvm::StringInit>(def);
59 }
60
getAsConstraint() const61 Constraint DagLeaf::getAsConstraint() const {
62 assert((isOperandMatcher() || isAttrMatcher()) &&
63 "the DAG leaf must be operand or attribute");
64 return Constraint(cast<llvm::DefInit>(def)->getDef());
65 }
66
getAsConstantAttr() const67 ConstantAttr DagLeaf::getAsConstantAttr() const {
68 assert(isConstantAttr() && "the DAG leaf must be constant attribute");
69 return ConstantAttr(cast<llvm::DefInit>(def));
70 }
71
getAsEnumAttrCase() const72 EnumAttrCase DagLeaf::getAsEnumAttrCase() const {
73 assert(isEnumAttrCase() && "the DAG leaf must be an enum attribute case");
74 return EnumAttrCase(cast<llvm::DefInit>(def));
75 }
76
getConditionTemplate() const77 std::string DagLeaf::getConditionTemplate() const {
78 return getAsConstraint().getConditionTemplate();
79 }
80
getNativeCodeTemplate() const81 llvm::StringRef DagLeaf::getNativeCodeTemplate() const {
82 assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall");
83 return cast<llvm::DefInit>(def)->getDef()->getValueAsString("expression");
84 }
85
getStringAttr() const86 std::string DagLeaf::getStringAttr() const {
87 assert(isStringAttr() && "the DAG leaf must be string attribute");
88 return def->getAsUnquotedString();
89 }
isSubClassOf(StringRef superclass) const90 bool DagLeaf::isSubClassOf(StringRef superclass) const {
91 if (auto *defInit = dyn_cast_or_null<llvm::DefInit>(def))
92 return defInit->getDef()->isSubClassOf(superclass);
93 return false;
94 }
95
print(raw_ostream & os) const96 void DagLeaf::print(raw_ostream &os) const {
97 if (def)
98 def->print(os);
99 }
100
101 //===----------------------------------------------------------------------===//
102 // DagNode
103 //===----------------------------------------------------------------------===//
104
isNativeCodeCall() const105 bool DagNode::isNativeCodeCall() const {
106 if (auto *defInit = dyn_cast_or_null<llvm::DefInit>(node->getOperator()))
107 return defInit->getDef()->isSubClassOf("NativeCodeCall");
108 return false;
109 }
110
isOperation() const111 bool DagNode::isOperation() const {
112 return !isNativeCodeCall() && !isReplaceWithValue() && !isLocationDirective();
113 }
114
getNativeCodeTemplate() const115 llvm::StringRef DagNode::getNativeCodeTemplate() const {
116 assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall");
117 return cast<llvm::DefInit>(node->getOperator())
118 ->getDef()
119 ->getValueAsString("expression");
120 }
121
getSymbol() const122 llvm::StringRef DagNode::getSymbol() const { return node->getNameStr(); }
123
getDialectOp(RecordOperatorMap * mapper) const124 Operator &DagNode::getDialectOp(RecordOperatorMap *mapper) const {
125 llvm::Record *opDef = cast<llvm::DefInit>(node->getOperator())->getDef();
126 auto it = mapper->find(opDef);
127 if (it != mapper->end())
128 return *it->second;
129 return *mapper->try_emplace(opDef, std::make_unique<Operator>(opDef))
130 .first->second;
131 }
132
getNumOps() const133 int DagNode::getNumOps() const {
134 int count = isReplaceWithValue() ? 0 : 1;
135 for (int i = 0, e = getNumArgs(); i != e; ++i) {
136 if (auto child = getArgAsNestedDag(i))
137 count += child.getNumOps();
138 }
139 return count;
140 }
141
getNumArgs() const142 int DagNode::getNumArgs() const { return node->getNumArgs(); }
143
isNestedDagArg(unsigned index) const144 bool DagNode::isNestedDagArg(unsigned index) const {
145 return isa<llvm::DagInit>(node->getArg(index));
146 }
147
getArgAsNestedDag(unsigned index) const148 DagNode DagNode::getArgAsNestedDag(unsigned index) const {
149 return DagNode(dyn_cast_or_null<llvm::DagInit>(node->getArg(index)));
150 }
151
getArgAsLeaf(unsigned index) const152 DagLeaf DagNode::getArgAsLeaf(unsigned index) const {
153 assert(!isNestedDagArg(index));
154 return DagLeaf(node->getArg(index));
155 }
156
getArgName(unsigned index) const157 StringRef DagNode::getArgName(unsigned index) const {
158 return node->getArgNameStr(index);
159 }
160
isReplaceWithValue() const161 bool DagNode::isReplaceWithValue() const {
162 auto *dagOpDef = cast<llvm::DefInit>(node->getOperator())->getDef();
163 return dagOpDef->getName() == "replaceWithValue";
164 }
165
isLocationDirective() const166 bool DagNode::isLocationDirective() const {
167 auto *dagOpDef = cast<llvm::DefInit>(node->getOperator())->getDef();
168 return dagOpDef->getName() == "location";
169 }
170
print(raw_ostream & os) const171 void DagNode::print(raw_ostream &os) const {
172 if (node)
173 node->print(os);
174 }
175
176 //===----------------------------------------------------------------------===//
177 // SymbolInfoMap
178 //===----------------------------------------------------------------------===//
179
getValuePackName(StringRef symbol,int * index)180 StringRef SymbolInfoMap::getValuePackName(StringRef symbol, int *index) {
181 StringRef name, indexStr;
182 int idx = -1;
183 std::tie(name, indexStr) = symbol.rsplit("__");
184
185 if (indexStr.consumeInteger(10, idx)) {
186 // The second part is not an index; we return the whole symbol as-is.
187 return symbol;
188 }
189 if (index) {
190 *index = idx;
191 }
192 return name;
193 }
194
SymbolInfo(const Operator * op,SymbolInfo::Kind kind,Optional<int> index)195 SymbolInfoMap::SymbolInfo::SymbolInfo(const Operator *op, SymbolInfo::Kind kind,
196 Optional<int> index)
197 : op(op), kind(kind), argIndex(index) {}
198
getStaticValueCount() const199 int SymbolInfoMap::SymbolInfo::getStaticValueCount() const {
200 switch (kind) {
201 case Kind::Attr:
202 case Kind::Operand:
203 case Kind::Value:
204 return 1;
205 case Kind::Result:
206 return op->getNumResults();
207 }
208 llvm_unreachable("unknown kind");
209 }
210
getVarName(StringRef name) const211 std::string SymbolInfoMap::SymbolInfo::getVarName(StringRef name) const {
212 return alternativeName.hasValue() ? alternativeName.getValue() : name.str();
213 }
214
getVarDecl(StringRef name) const215 std::string SymbolInfoMap::SymbolInfo::getVarDecl(StringRef name) const {
216 LLVM_DEBUG(llvm::dbgs() << "getVarDecl for '" << name << "': ");
217 switch (kind) {
218 case Kind::Attr: {
219 if (op) {
220 auto type =
221 op->getArg(*argIndex).get<NamedAttribute *>()->attr.getStorageType();
222 return std::string(formatv("{0} {1};\n", type, name));
223 }
224 // TODO(suderman): Use a more exact type when available.
225 return std::string(formatv("Attribute {0};\n", name));
226 }
227 case Kind::Operand: {
228 // Use operand range for captured operands (to support potential variadic
229 // operands).
230 return std::string(
231 formatv("::mlir::Operation::operand_range {0}(op0->getOperands());\n",
232 getVarName(name)));
233 }
234 case Kind::Value: {
235 return std::string(formatv("::llvm::ArrayRef<::mlir::Value> {0};\n", name));
236 }
237 case Kind::Result: {
238 // Use the op itself for captured results.
239 return std::string(formatv("{0} {1};\n", op->getQualCppClassName(), name));
240 }
241 }
242 llvm_unreachable("unknown kind");
243 }
244
getValueAndRangeUse(StringRef name,int index,const char * fmt,const char * separator) const245 std::string SymbolInfoMap::SymbolInfo::getValueAndRangeUse(
246 StringRef name, int index, const char *fmt, const char *separator) const {
247 LLVM_DEBUG(llvm::dbgs() << "getValueAndRangeUse for '" << name << "': ");
248 switch (kind) {
249 case Kind::Attr: {
250 assert(index < 0);
251 auto repl = formatv(fmt, name);
252 LLVM_DEBUG(llvm::dbgs() << repl << " (Attr)\n");
253 return std::string(repl);
254 }
255 case Kind::Operand: {
256 assert(index < 0);
257 auto *operand = op->getArg(*argIndex).get<NamedTypeConstraint *>();
258 // If this operand is variadic, then return a range. Otherwise, return the
259 // value itself.
260 if (operand->isVariableLength()) {
261 auto repl = formatv(fmt, name);
262 LLVM_DEBUG(llvm::dbgs() << repl << " (VariadicOperand)\n");
263 return std::string(repl);
264 }
265 auto repl = formatv(fmt, formatv("(*{0}.begin())", name));
266 LLVM_DEBUG(llvm::dbgs() << repl << " (SingleOperand)\n");
267 return std::string(repl);
268 }
269 case Kind::Result: {
270 // If `index` is greater than zero, then we are referencing a specific
271 // result of a multi-result op. The result can still be variadic.
272 if (index >= 0) {
273 std::string v =
274 std::string(formatv("{0}.getODSResults({1})", name, index));
275 if (!op->getResult(index).isVariadic())
276 v = std::string(formatv("(*{0}.begin())", v));
277 auto repl = formatv(fmt, v);
278 LLVM_DEBUG(llvm::dbgs() << repl << " (SingleResult)\n");
279 return std::string(repl);
280 }
281
282 // If this op has no result at all but still we bind a symbol to it, it
283 // means we want to capture the op itself.
284 if (op->getNumResults() == 0) {
285 LLVM_DEBUG(llvm::dbgs() << name << " (Op)\n");
286 return std::string(name);
287 }
288
289 // We are referencing all results of the multi-result op. A specific result
290 // can either be a value or a range. Then join them with `separator`.
291 SmallVector<std::string, 4> values;
292 values.reserve(op->getNumResults());
293
294 for (int i = 0, e = op->getNumResults(); i < e; ++i) {
295 std::string v = std::string(formatv("{0}.getODSResults({1})", name, i));
296 if (!op->getResult(i).isVariadic()) {
297 v = std::string(formatv("(*{0}.begin())", v));
298 }
299 values.push_back(std::string(formatv(fmt, v)));
300 }
301 auto repl = llvm::join(values, separator);
302 LLVM_DEBUG(llvm::dbgs() << repl << " (VariadicResult)\n");
303 return repl;
304 }
305 case Kind::Value: {
306 assert(index < 0);
307 assert(op == nullptr);
308 auto repl = formatv(fmt, name);
309 LLVM_DEBUG(llvm::dbgs() << repl << " (Value)\n");
310 return std::string(repl);
311 }
312 }
313 llvm_unreachable("unknown kind");
314 }
315
getAllRangeUse(StringRef name,int index,const char * fmt,const char * separator) const316 std::string SymbolInfoMap::SymbolInfo::getAllRangeUse(
317 StringRef name, int index, const char *fmt, const char *separator) const {
318 LLVM_DEBUG(llvm::dbgs() << "getAllRangeUse for '" << name << "': ");
319 switch (kind) {
320 case Kind::Attr:
321 case Kind::Operand: {
322 assert(index < 0 && "only allowed for symbol bound to result");
323 auto repl = formatv(fmt, name);
324 LLVM_DEBUG(llvm::dbgs() << repl << " (Operand/Attr)\n");
325 return std::string(repl);
326 }
327 case Kind::Result: {
328 if (index >= 0) {
329 auto repl = formatv(fmt, formatv("{0}.getODSResults({1})", name, index));
330 LLVM_DEBUG(llvm::dbgs() << repl << " (SingleResult)\n");
331 return std::string(repl);
332 }
333
334 // We are referencing all results of the multi-result op. Each result should
335 // have a value range, and then join them with `separator`.
336 SmallVector<std::string, 4> values;
337 values.reserve(op->getNumResults());
338
339 for (int i = 0, e = op->getNumResults(); i < e; ++i) {
340 values.push_back(std::string(
341 formatv(fmt, formatv("{0}.getODSResults({1})", name, i))));
342 }
343 auto repl = llvm::join(values, separator);
344 LLVM_DEBUG(llvm::dbgs() << repl << " (VariadicResult)\n");
345 return repl;
346 }
347 case Kind::Value: {
348 assert(index < 0 && "only allowed for symbol bound to result");
349 assert(op == nullptr);
350 auto repl = formatv(fmt, formatv("{{{0}}", name));
351 LLVM_DEBUG(llvm::dbgs() << repl << " (Value)\n");
352 return std::string(repl);
353 }
354 }
355 llvm_unreachable("unknown kind");
356 }
357
bindOpArgument(StringRef symbol,const Operator & op,int argIndex)358 bool SymbolInfoMap::bindOpArgument(StringRef symbol, const Operator &op,
359 int argIndex) {
360 StringRef name = getValuePackName(symbol);
361 if (name != symbol) {
362 auto error = formatv(
363 "symbol '{0}' with trailing index cannot bind to op argument", symbol);
364 PrintFatalError(loc, error);
365 }
366
367 auto symInfo = op.getArg(argIndex).is<NamedAttribute *>()
368 ? SymbolInfo::getAttr(&op, argIndex)
369 : SymbolInfo::getOperand(&op, argIndex);
370
371 std::string key = symbol.str();
372 if (symbolInfoMap.count(key)) {
373 // Only non unique name for the operand is supported.
374 if (symInfo.kind != SymbolInfo::Kind::Operand) {
375 return false;
376 }
377
378 // Cannot add new operand if there is already non operand with the same
379 // name.
380 if (symbolInfoMap.find(key)->second.kind != SymbolInfo::Kind::Operand) {
381 return false;
382 }
383 }
384
385 symbolInfoMap.emplace(key, symInfo);
386 return true;
387 }
388
bindOpResult(StringRef symbol,const Operator & op)389 bool SymbolInfoMap::bindOpResult(StringRef symbol, const Operator &op) {
390 std::string name = getValuePackName(symbol).str();
391 auto inserted = symbolInfoMap.emplace(name, SymbolInfo::getResult(&op));
392
393 return symbolInfoMap.count(inserted->first) == 1;
394 }
395
bindValue(StringRef symbol)396 bool SymbolInfoMap::bindValue(StringRef symbol) {
397 auto inserted = symbolInfoMap.emplace(symbol.str(), SymbolInfo::getValue());
398 return symbolInfoMap.count(inserted->first) == 1;
399 }
400
bindAttr(StringRef symbol)401 bool SymbolInfoMap::bindAttr(StringRef symbol) {
402 auto inserted = symbolInfoMap.emplace(symbol.str(), SymbolInfo::getAttr());
403 return symbolInfoMap.count(inserted->first) == 1;
404 }
405
contains(StringRef symbol) const406 bool SymbolInfoMap::contains(StringRef symbol) const {
407 return find(symbol) != symbolInfoMap.end();
408 }
409
find(StringRef key) const410 SymbolInfoMap::const_iterator SymbolInfoMap::find(StringRef key) const {
411 std::string name = getValuePackName(key).str();
412
413 return symbolInfoMap.find(name);
414 }
415
416 SymbolInfoMap::const_iterator
findBoundSymbol(StringRef key,const Operator & op,int argIndex) const417 SymbolInfoMap::findBoundSymbol(StringRef key, const Operator &op,
418 int argIndex) const {
419 std::string name = getValuePackName(key).str();
420 auto range = symbolInfoMap.equal_range(name);
421
422 for (auto it = range.first; it != range.second; ++it) {
423 if (it->second.op == &op && it->second.argIndex == argIndex) {
424 return it;
425 }
426 }
427
428 return symbolInfoMap.end();
429 }
430
431 std::pair<SymbolInfoMap::iterator, SymbolInfoMap::iterator>
getRangeOfEqualElements(StringRef key)432 SymbolInfoMap::getRangeOfEqualElements(StringRef key) {
433 std::string name = getValuePackName(key).str();
434
435 return symbolInfoMap.equal_range(name);
436 }
437
count(StringRef key) const438 int SymbolInfoMap::count(StringRef key) const {
439 std::string name = getValuePackName(key).str();
440 return symbolInfoMap.count(name);
441 }
442
getStaticValueCount(StringRef symbol) const443 int SymbolInfoMap::getStaticValueCount(StringRef symbol) const {
444 StringRef name = getValuePackName(symbol);
445 if (name != symbol) {
446 // If there is a trailing index inside symbol, it references just one
447 // static value.
448 return 1;
449 }
450 // Otherwise, find how many it represents by querying the symbol's info.
451 return find(name)->second.getStaticValueCount();
452 }
453
getValueAndRangeUse(StringRef symbol,const char * fmt,const char * separator) const454 std::string SymbolInfoMap::getValueAndRangeUse(StringRef symbol,
455 const char *fmt,
456 const char *separator) const {
457 int index = -1;
458 StringRef name = getValuePackName(symbol, &index);
459
460 auto it = symbolInfoMap.find(name.str());
461 if (it == symbolInfoMap.end()) {
462 auto error = formatv("referencing unbound symbol '{0}'", symbol);
463 PrintFatalError(loc, error);
464 }
465
466 return it->second.getValueAndRangeUse(name, index, fmt, separator);
467 }
468
getAllRangeUse(StringRef symbol,const char * fmt,const char * separator) const469 std::string SymbolInfoMap::getAllRangeUse(StringRef symbol, const char *fmt,
470 const char *separator) const {
471 int index = -1;
472 StringRef name = getValuePackName(symbol, &index);
473
474 auto it = symbolInfoMap.find(name.str());
475 if (it == symbolInfoMap.end()) {
476 auto error = formatv("referencing unbound symbol '{0}'", symbol);
477 PrintFatalError(loc, error);
478 }
479
480 return it->second.getAllRangeUse(name, index, fmt, separator);
481 }
482
assignUniqueAlternativeNames()483 void SymbolInfoMap::assignUniqueAlternativeNames() {
484 llvm::StringSet<> usedNames;
485
486 for (auto symbolInfoIt = symbolInfoMap.begin();
487 symbolInfoIt != symbolInfoMap.end();) {
488 auto range = symbolInfoMap.equal_range(symbolInfoIt->first);
489 auto startRange = range.first;
490 auto endRange = range.second;
491
492 auto operandName = symbolInfoIt->first;
493 int startSearchIndex = 0;
494 for (++startRange; startRange != endRange; ++startRange) {
495 // Current operand name is not unique, find a unique one
496 // and set the alternative name.
497 for (int i = startSearchIndex;; ++i) {
498 std::string alternativeName = operandName + std::to_string(i);
499 if (!usedNames.contains(alternativeName) &&
500 symbolInfoMap.count(alternativeName) == 0) {
501 usedNames.insert(alternativeName);
502 startRange->second.alternativeName = alternativeName;
503 startSearchIndex = i + 1;
504
505 break;
506 }
507 }
508 }
509
510 symbolInfoIt = endRange;
511 }
512 }
513
514 //===----------------------------------------------------------------------===//
515 // Pattern
516 //==----------------------------------------------------------------------===//
517
Pattern(const llvm::Record * def,RecordOperatorMap * mapper)518 Pattern::Pattern(const llvm::Record *def, RecordOperatorMap *mapper)
519 : def(*def), recordOpMap(mapper) {}
520
getSourcePattern() const521 DagNode Pattern::getSourcePattern() const {
522 return DagNode(def.getValueAsDag("sourcePattern"));
523 }
524
getNumResultPatterns() const525 int Pattern::getNumResultPatterns() const {
526 auto *results = def.getValueAsListInit("resultPatterns");
527 return results->size();
528 }
529
getResultPattern(unsigned index) const530 DagNode Pattern::getResultPattern(unsigned index) const {
531 auto *results = def.getValueAsListInit("resultPatterns");
532 return DagNode(cast<llvm::DagInit>(results->getElement(index)));
533 }
534
collectSourcePatternBoundSymbols(SymbolInfoMap & infoMap)535 void Pattern::collectSourcePatternBoundSymbols(SymbolInfoMap &infoMap) {
536 LLVM_DEBUG(llvm::dbgs() << "start collecting source pattern bound symbols\n");
537 collectBoundSymbols(getSourcePattern(), infoMap, /*isSrcPattern=*/true);
538 LLVM_DEBUG(llvm::dbgs() << "done collecting source pattern bound symbols\n");
539
540 LLVM_DEBUG(llvm::dbgs() << "start assigning alternative names for symbols\n");
541 infoMap.assignUniqueAlternativeNames();
542 LLVM_DEBUG(llvm::dbgs() << "done assigning alternative names for symbols\n");
543 }
544
collectResultPatternBoundSymbols(SymbolInfoMap & infoMap)545 void Pattern::collectResultPatternBoundSymbols(SymbolInfoMap &infoMap) {
546 LLVM_DEBUG(llvm::dbgs() << "start collecting result pattern bound symbols\n");
547 for (int i = 0, e = getNumResultPatterns(); i < e; ++i) {
548 auto pattern = getResultPattern(i);
549 collectBoundSymbols(pattern, infoMap, /*isSrcPattern=*/false);
550 }
551 LLVM_DEBUG(llvm::dbgs() << "done collecting result pattern bound symbols\n");
552 }
553
getSourceRootOp()554 const Operator &Pattern::getSourceRootOp() {
555 return getSourcePattern().getDialectOp(recordOpMap);
556 }
557
getDialectOp(DagNode node)558 Operator &Pattern::getDialectOp(DagNode node) {
559 return node.getDialectOp(recordOpMap);
560 }
561
getConstraints() const562 std::vector<AppliedConstraint> Pattern::getConstraints() const {
563 auto *listInit = def.getValueAsListInit("constraints");
564 std::vector<AppliedConstraint> ret;
565 ret.reserve(listInit->size());
566
567 for (auto it : *listInit) {
568 auto *dagInit = dyn_cast<llvm::DagInit>(it);
569 if (!dagInit)
570 PrintFatalError(&def, "all elements in Pattern multi-entity "
571 "constraints should be DAG nodes");
572
573 std::vector<std::string> entities;
574 entities.reserve(dagInit->arg_size());
575 for (auto *argName : dagInit->getArgNames()) {
576 if (!argName) {
577 PrintFatalError(
578 &def,
579 "operands to additional constraints can only be symbol references");
580 }
581 entities.push_back(std::string(argName->getValue()));
582 }
583
584 ret.emplace_back(cast<llvm::DefInit>(dagInit->getOperator())->getDef(),
585 dagInit->getNameStr(), std::move(entities));
586 }
587 return ret;
588 }
589
getBenefit() const590 int Pattern::getBenefit() const {
591 // The initial benefit value is a heuristic with number of ops in the source
592 // pattern.
593 int initBenefit = getSourcePattern().getNumOps();
594 llvm::DagInit *delta = def.getValueAsDag("benefitDelta");
595 if (delta->getNumArgs() != 1 || !isa<llvm::IntInit>(delta->getArg(0))) {
596 PrintFatalError(&def,
597 "The 'addBenefit' takes and only takes one integer value");
598 }
599 return initBenefit + dyn_cast<llvm::IntInit>(delta->getArg(0))->getValue();
600 }
601
getLocation() const602 std::vector<Pattern::IdentifierLine> Pattern::getLocation() const {
603 std::vector<std::pair<StringRef, unsigned>> result;
604 result.reserve(def.getLoc().size());
605 for (auto loc : def.getLoc()) {
606 unsigned buf = llvm::SrcMgr.FindBufferContainingLoc(loc);
607 assert(buf && "invalid source location");
608 result.emplace_back(
609 llvm::SrcMgr.getBufferInfo(buf).Buffer->getBufferIdentifier(),
610 llvm::SrcMgr.getLineAndColumn(loc, buf).first);
611 }
612 return result;
613 }
614
verifyBind(bool result,StringRef symbolName)615 void Pattern::verifyBind(bool result, StringRef symbolName) {
616 if (!result) {
617 auto err = formatv("symbol '{0}' bound more than once", symbolName);
618 PrintFatalError(&def, err);
619 }
620 }
621
collectBoundSymbols(DagNode tree,SymbolInfoMap & infoMap,bool isSrcPattern)622 void Pattern::collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap,
623 bool isSrcPattern) {
624 auto treeName = tree.getSymbol();
625 auto numTreeArgs = tree.getNumArgs();
626
627 if (tree.isNativeCodeCall()) {
628 if (!treeName.empty()) {
629 PrintFatalError(
630 &def,
631 formatv(
632 "binding symbol '{0}' to native code call unsupported right now",
633 treeName));
634 }
635
636 for (int i = 0; i != numTreeArgs; ++i) {
637 if (auto treeArg = tree.getArgAsNestedDag(i)) {
638 // This DAG node argument is a DAG node itself. Go inside recursively.
639 collectBoundSymbols(treeArg, infoMap, isSrcPattern);
640 continue;
641 }
642
643 if (!isSrcPattern)
644 continue;
645
646 // We can only bind symbols to arguments in source pattern. Those
647 // symbols are referenced in result patterns.
648 auto treeArgName = tree.getArgName(i);
649
650 // `$_` is a special symbol meaning ignore the current argument.
651 if (!treeArgName.empty() && treeArgName != "_") {
652 if (tree.isNestedDagArg(i)) {
653 auto err = formatv("cannot bind '{0}' for nested native call arg",
654 treeArgName);
655 PrintFatalError(&def, err);
656 }
657
658 DagLeaf leaf = tree.getArgAsLeaf(i);
659 auto constraint = leaf.getAsConstraint();
660 bool isAttr = leaf.isAttrMatcher() || leaf.isEnumAttrCase() ||
661 leaf.isConstantAttr() ||
662 constraint.getKind() == Constraint::Kind::CK_Attr;
663
664 if (isAttr) {
665 verifyBind(infoMap.bindAttr(treeArgName), treeArgName);
666 continue;
667 }
668
669 verifyBind(infoMap.bindValue(treeArgName), treeArgName);
670 }
671 }
672
673 return;
674 }
675
676 if (tree.isOperation()) {
677 auto &op = getDialectOp(tree);
678 auto numOpArgs = op.getNumArgs();
679
680 // The pattern might have the last argument specifying the location.
681 bool hasLocDirective = false;
682 if (numTreeArgs != 0) {
683 if (auto lastArg = tree.getArgAsNestedDag(numTreeArgs - 1))
684 hasLocDirective = lastArg.isLocationDirective();
685 }
686
687 if (numOpArgs != numTreeArgs - hasLocDirective) {
688 auto err = formatv("op '{0}' argument number mismatch: "
689 "{1} in pattern vs. {2} in definition",
690 op.getOperationName(), numTreeArgs, numOpArgs);
691 PrintFatalError(&def, err);
692 }
693
694 // The name attached to the DAG node's operator is for representing the
695 // results generated from this op. It should be remembered as bound results.
696 if (!treeName.empty()) {
697 LLVM_DEBUG(llvm::dbgs()
698 << "found symbol bound to op result: " << treeName << '\n');
699 verifyBind(infoMap.bindOpResult(treeName, op), treeName);
700 }
701
702 for (int i = 0; i != numTreeArgs; ++i) {
703 if (auto treeArg = tree.getArgAsNestedDag(i)) {
704 // This DAG node argument is a DAG node itself. Go inside recursively.
705 collectBoundSymbols(treeArg, infoMap, isSrcPattern);
706 continue;
707 }
708
709 if (isSrcPattern) {
710 // We can only bind symbols to op arguments in source pattern. Those
711 // symbols are referenced in result patterns.
712 auto treeArgName = tree.getArgName(i);
713 // `$_` is a special symbol meaning ignore the current argument.
714 if (!treeArgName.empty() && treeArgName != "_") {
715 LLVM_DEBUG(llvm::dbgs() << "found symbol bound to op argument: "
716 << treeArgName << '\n');
717 verifyBind(infoMap.bindOpArgument(treeArgName, op, i), treeArgName);
718 }
719 }
720 }
721 return;
722 }
723
724 if (!treeName.empty()) {
725 PrintFatalError(
726 &def, formatv("binding symbol '{0}' to non-operation/native code call "
727 "unsupported right now",
728 treeName));
729 }
730 return;
731 }
732