1 //===- Operator.cpp - Operator class --------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Operator wrapper to simplify using TableGen Record defining a MLIR Op.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "mlir/TableGen/Operator.h"
14 #include "mlir/TableGen/OpTrait.h"
15 #include "mlir/TableGen/Predicate.h"
16 #include "mlir/TableGen/Type.h"
17 #include "llvm/ADT/EquivalenceClasses.h"
18 #include "llvm/ADT/STLExtras.h"
19 #include "llvm/ADT/Sequence.h"
20 #include "llvm/ADT/SmallPtrSet.h"
21 #include "llvm/ADT/StringExtras.h"
22 #include "llvm/ADT/TypeSwitch.h"
23 #include "llvm/Support/Debug.h"
24 #include "llvm/Support/FormatVariadic.h"
25 #include "llvm/TableGen/Error.h"
26 #include "llvm/TableGen/Record.h"
27
28 #define DEBUG_TYPE "mlir-tblgen-operator"
29
30 using namespace mlir;
31 using namespace mlir::tblgen;
32
33 using llvm::DagInit;
34 using llvm::DefInit;
35 using llvm::Record;
36
Operator(const llvm::Record & def)37 Operator::Operator(const llvm::Record &def)
38 : dialect(def.getValueAsDef("opDialect")), def(def) {
39 // The first `_` in the op's TableGen def name is treated as separating the
40 // dialect prefix and the op class name. The dialect prefix will be ignored if
41 // not empty. Otherwise, if def name starts with a `_`, the `_` is considered
42 // as part of the class name.
43 StringRef prefix;
44 std::tie(prefix, cppClassName) = def.getName().split('_');
45 if (prefix.empty()) {
46 // Class name with a leading underscore and without dialect prefix
47 cppClassName = def.getName();
48 } else if (cppClassName.empty()) {
49 // Class name without dialect prefix
50 cppClassName = prefix;
51 }
52
53 populateOpStructure();
54 }
55
getOperationName() const56 std::string Operator::getOperationName() const {
57 auto prefix = dialect.getName();
58 auto opName = def.getValueAsString("opName");
59 if (prefix.empty())
60 return std::string(opName);
61 return std::string(llvm::formatv("{0}.{1}", prefix, opName));
62 }
63
getAdaptorName() const64 std::string Operator::getAdaptorName() const {
65 return std::string(llvm::formatv("{0}Adaptor", getCppClassName()));
66 }
67
getDialectName() const68 StringRef Operator::getDialectName() const { return dialect.getName(); }
69
getCppClassName() const70 StringRef Operator::getCppClassName() const { return cppClassName; }
71
getQualCppClassName() const72 std::string Operator::getQualCppClassName() const {
73 auto prefix = dialect.getCppNamespace();
74 if (prefix.empty())
75 return std::string(cppClassName);
76 return std::string(llvm::formatv("{0}::{1}", prefix, cppClassName));
77 }
78
getNumResults() const79 int Operator::getNumResults() const {
80 DagInit *results = def.getValueAsDag("results");
81 return results->getNumArgs();
82 }
83
getExtraClassDeclaration() const84 StringRef Operator::getExtraClassDeclaration() const {
85 constexpr auto attr = "extraClassDeclaration";
86 if (def.isValueUnset(attr))
87 return {};
88 return def.getValueAsString(attr);
89 }
90
getDef() const91 const llvm::Record &Operator::getDef() const { return def; }
92
skipDefaultBuilders() const93 bool Operator::skipDefaultBuilders() const {
94 return def.getValueAsBit("skipDefaultBuilders");
95 }
96
result_begin()97 auto Operator::result_begin() -> value_iterator { return results.begin(); }
98
result_end()99 auto Operator::result_end() -> value_iterator { return results.end(); }
100
getResults()101 auto Operator::getResults() -> value_range {
102 return {result_begin(), result_end()};
103 }
104
getResultTypeConstraint(int index) const105 TypeConstraint Operator::getResultTypeConstraint(int index) const {
106 DagInit *results = def.getValueAsDag("results");
107 return TypeConstraint(cast<DefInit>(results->getArg(index)));
108 }
109
getResultName(int index) const110 StringRef Operator::getResultName(int index) const {
111 DagInit *results = def.getValueAsDag("results");
112 return results->getArgNameStr(index);
113 }
114
getResultDecorators(int index) const115 auto Operator::getResultDecorators(int index) const -> var_decorator_range {
116 Record *result =
117 cast<DefInit>(def.getValueAsDag("results")->getArg(index))->getDef();
118 if (!result->isSubClassOf("OpVariable"))
119 return var_decorator_range(nullptr, nullptr);
120 return *result->getValueAsListInit("decorators");
121 }
122
getNumVariableLengthResults() const123 unsigned Operator::getNumVariableLengthResults() const {
124 return llvm::count_if(results, [](const NamedTypeConstraint &c) {
125 return c.constraint.isVariableLength();
126 });
127 }
128
getNumVariableLengthOperands() const129 unsigned Operator::getNumVariableLengthOperands() const {
130 return llvm::count_if(operands, [](const NamedTypeConstraint &c) {
131 return c.constraint.isVariableLength();
132 });
133 }
134
hasSingleVariadicArg() const135 bool Operator::hasSingleVariadicArg() const {
136 return getNumArgs() == 1 && getArg(0).is<NamedTypeConstraint *>() &&
137 getOperand(0).isVariadic();
138 }
139
arg_begin() const140 Operator::arg_iterator Operator::arg_begin() const { return arguments.begin(); }
141
arg_end() const142 Operator::arg_iterator Operator::arg_end() const { return arguments.end(); }
143
getArgs() const144 Operator::arg_range Operator::getArgs() const {
145 return {arg_begin(), arg_end()};
146 }
147
getArgName(int index) const148 StringRef Operator::getArgName(int index) const {
149 DagInit *argumentValues = def.getValueAsDag("arguments");
150 return argumentValues->getArgNameStr(index);
151 }
152
getArgDecorators(int index) const153 auto Operator::getArgDecorators(int index) const -> var_decorator_range {
154 Record *arg =
155 cast<DefInit>(def.getValueAsDag("arguments")->getArg(index))->getDef();
156 if (!arg->isSubClassOf("OpVariable"))
157 return var_decorator_range(nullptr, nullptr);
158 return *arg->getValueAsListInit("decorators");
159 }
160
getTrait(StringRef trait) const161 const OpTrait *Operator::getTrait(StringRef trait) const {
162 for (const auto &t : traits) {
163 if (const auto *opTrait = dyn_cast<NativeOpTrait>(&t)) {
164 if (opTrait->getTrait() == trait)
165 return opTrait;
166 } else if (const auto *opTrait = dyn_cast<InternalOpTrait>(&t)) {
167 if (opTrait->getTrait() == trait)
168 return opTrait;
169 } else if (const auto *opTrait = dyn_cast<InterfaceOpTrait>(&t)) {
170 if (opTrait->getTrait() == trait)
171 return opTrait;
172 }
173 }
174 return nullptr;
175 }
176
region_begin() const177 auto Operator::region_begin() const -> const_region_iterator {
178 return regions.begin();
179 }
region_end() const180 auto Operator::region_end() const -> const_region_iterator {
181 return regions.end();
182 }
getRegions() const183 auto Operator::getRegions() const
184 -> llvm::iterator_range<const_region_iterator> {
185 return {region_begin(), region_end()};
186 }
187
getNumRegions() const188 unsigned Operator::getNumRegions() const { return regions.size(); }
189
getRegion(unsigned index) const190 const NamedRegion &Operator::getRegion(unsigned index) const {
191 return regions[index];
192 }
193
getNumVariadicRegions() const194 unsigned Operator::getNumVariadicRegions() const {
195 return llvm::count_if(regions,
196 [](const NamedRegion &c) { return c.isVariadic(); });
197 }
198
successor_begin() const199 auto Operator::successor_begin() const -> const_successor_iterator {
200 return successors.begin();
201 }
successor_end() const202 auto Operator::successor_end() const -> const_successor_iterator {
203 return successors.end();
204 }
getSuccessors() const205 auto Operator::getSuccessors() const
206 -> llvm::iterator_range<const_successor_iterator> {
207 return {successor_begin(), successor_end()};
208 }
209
getNumSuccessors() const210 unsigned Operator::getNumSuccessors() const { return successors.size(); }
211
getSuccessor(unsigned index) const212 const NamedSuccessor &Operator::getSuccessor(unsigned index) const {
213 return successors[index];
214 }
215
getNumVariadicSuccessors() const216 unsigned Operator::getNumVariadicSuccessors() const {
217 return llvm::count_if(successors,
218 [](const NamedSuccessor &c) { return c.isVariadic(); });
219 }
220
trait_begin() const221 auto Operator::trait_begin() const -> const_trait_iterator {
222 return traits.begin();
223 }
trait_end() const224 auto Operator::trait_end() const -> const_trait_iterator {
225 return traits.end();
226 }
getTraits() const227 auto Operator::getTraits() const -> llvm::iterator_range<const_trait_iterator> {
228 return {trait_begin(), trait_end()};
229 }
230
attribute_begin() const231 auto Operator::attribute_begin() const -> attribute_iterator {
232 return attributes.begin();
233 }
attribute_end() const234 auto Operator::attribute_end() const -> attribute_iterator {
235 return attributes.end();
236 }
getAttributes() const237 auto Operator::getAttributes() const
238 -> llvm::iterator_range<attribute_iterator> {
239 return {attribute_begin(), attribute_end()};
240 }
241
operand_begin()242 auto Operator::operand_begin() -> value_iterator { return operands.begin(); }
operand_end()243 auto Operator::operand_end() -> value_iterator { return operands.end(); }
getOperands()244 auto Operator::getOperands() -> value_range {
245 return {operand_begin(), operand_end()};
246 }
247
getArg(int index) const248 auto Operator::getArg(int index) const -> Argument { return arguments[index]; }
249
250 // Mapping from result index to combined argument and result index. Arguments
251 // are indexed to match getArg index, while the result indexes are mapped to
252 // avoid overlap.
resultIndex(int i)253 static int resultIndex(int i) { return -1 - i; }
254
isVariadic() const255 bool Operator::isVariadic() const {
256 return any_of(llvm::concat<const NamedTypeConstraint>(operands, results),
257 [](const NamedTypeConstraint &op) { return op.isVariadic(); });
258 }
259
populateTypeInferenceInfo(const llvm::StringMap<int> & argumentsAndResultsIndex)260 void Operator::populateTypeInferenceInfo(
261 const llvm::StringMap<int> &argumentsAndResultsIndex) {
262 // If the type inference op interface is not registered, then do not attempt
263 // to determine if the result types an be inferred.
264 auto &recordKeeper = def.getRecords();
265 auto *inferTrait = recordKeeper.getDef(inferTypeOpInterface);
266 allResultsHaveKnownTypes = false;
267 if (!inferTrait)
268 return;
269
270 // If there are no results, the skip this else the build method generated
271 // overlaps with another autogenerated builder.
272 if (getNumResults() == 0)
273 return;
274
275 // Skip for ops with variadic operands/results.
276 // TODO: This can be relaxed.
277 if (isVariadic())
278 return;
279
280 // Skip cases currently being custom generated.
281 // TODO: Remove special cases.
282 if (getTrait("::mlir::OpTrait::SameOperandsAndResultType"))
283 return;
284
285 // We create equivalence classes of argument/result types where arguments
286 // and results are mapped into the same index space and indices corresponding
287 // to the same type are in the same equivalence class.
288 llvm::EquivalenceClasses<int> ecs;
289 resultTypeMapping.resize(getNumResults());
290 // Captures the argument whose type matches a given result type. Preference
291 // towards capturing operands first before attributes.
292 auto captureMapping = [&](int i) {
293 bool found = false;
294 ecs.insert(resultIndex(i));
295 auto mi = ecs.findLeader(resultIndex(i));
296 for (auto me = ecs.member_end(); mi != me; ++mi) {
297 if (*mi < 0) {
298 auto tc = getResultTypeConstraint(i);
299 if (tc.getBuilderCall().hasValue()) {
300 resultTypeMapping[i].emplace_back(tc);
301 found = true;
302 }
303 continue;
304 }
305
306 if (getArg(*mi).is<NamedAttribute *>()) {
307 // TODO: Handle attributes.
308 continue;
309 } else {
310 resultTypeMapping[i].emplace_back(*mi);
311 found = true;
312 }
313 }
314 return found;
315 };
316
317 for (const OpTrait &trait : traits) {
318 const llvm::Record &def = trait.getDef();
319 // If the infer type op interface was manually added, then treat it as
320 // intention that the op needs special handling.
321 // TODO: Reconsider whether to always generate, this is more conservative
322 // and keeps existing behavior so starting that way for now.
323 if (def.isSubClassOf(
324 llvm::formatv("{0}::Trait", inferTypeOpInterface).str()))
325 return;
326 if (const auto *opTrait = dyn_cast<InterfaceOpTrait>(&trait))
327 if (&opTrait->getDef() == inferTrait)
328 return;
329
330 if (!def.isSubClassOf("AllTypesMatch"))
331 continue;
332
333 auto values = def.getValueAsListOfStrings("values");
334 auto root = argumentsAndResultsIndex.lookup(values.front());
335 for (StringRef str : values)
336 ecs.unionSets(argumentsAndResultsIndex.lookup(str), root);
337 }
338
339 // Verifies that all output types have a corresponding known input type
340 // and chooses matching operand or attribute (in that order) that
341 // matches it.
342 allResultsHaveKnownTypes =
343 all_of(llvm::seq<int>(0, getNumResults()), captureMapping);
344
345 // If the types could be computed, then add type inference trait.
346 if (allResultsHaveKnownTypes)
347 traits.push_back(OpTrait::create(inferTrait->getDefInit()));
348 }
349
populateOpStructure()350 void Operator::populateOpStructure() {
351 auto &recordKeeper = def.getRecords();
352 auto *typeConstraintClass = recordKeeper.getClass("TypeConstraint");
353 auto *attrClass = recordKeeper.getClass("Attr");
354 auto *derivedAttrClass = recordKeeper.getClass("DerivedAttr");
355 auto *opVarClass = recordKeeper.getClass("OpVariable");
356 numNativeAttributes = 0;
357
358 DagInit *argumentValues = def.getValueAsDag("arguments");
359 unsigned numArgs = argumentValues->getNumArgs();
360
361 // Mapping from name of to argument or result index. Arguments are indexed
362 // to match getArg index, while the results are negatively indexed.
363 llvm::StringMap<int> argumentsAndResultsIndex;
364
365 // Handle operands and native attributes.
366 for (unsigned i = 0; i != numArgs; ++i) {
367 auto *arg = argumentValues->getArg(i);
368 auto givenName = argumentValues->getArgNameStr(i);
369 auto *argDefInit = dyn_cast<DefInit>(arg);
370 if (!argDefInit)
371 PrintFatalError(def.getLoc(),
372 Twine("undefined type for argument #") + Twine(i));
373 Record *argDef = argDefInit->getDef();
374 if (argDef->isSubClassOf(opVarClass))
375 argDef = argDef->getValueAsDef("constraint");
376
377 if (argDef->isSubClassOf(typeConstraintClass)) {
378 operands.push_back(
379 NamedTypeConstraint{givenName, TypeConstraint(argDef)});
380 } else if (argDef->isSubClassOf(attrClass)) {
381 if (givenName.empty())
382 PrintFatalError(argDef->getLoc(), "attributes must be named");
383 if (argDef->isSubClassOf(derivedAttrClass))
384 PrintFatalError(argDef->getLoc(),
385 "derived attributes not allowed in argument list");
386 attributes.push_back({givenName, Attribute(argDef)});
387 ++numNativeAttributes;
388 } else {
389 PrintFatalError(def.getLoc(), "unexpected def type; only defs deriving "
390 "from TypeConstraint or Attr are allowed");
391 }
392 if (!givenName.empty())
393 argumentsAndResultsIndex[givenName] = i;
394 }
395
396 // Handle derived attributes.
397 for (const auto &val : def.getValues()) {
398 if (auto *record = dyn_cast<llvm::RecordRecTy>(val.getType())) {
399 if (!record->isSubClassOf(attrClass))
400 continue;
401 if (!record->isSubClassOf(derivedAttrClass))
402 PrintFatalError(def.getLoc(),
403 "unexpected Attr where only DerivedAttr is allowed");
404
405 if (record->getClasses().size() != 1) {
406 PrintFatalError(
407 def.getLoc(),
408 "unsupported attribute modelling, only single class expected");
409 }
410 attributes.push_back(
411 {cast<llvm::StringInit>(val.getNameInit())->getValue(),
412 Attribute(cast<DefInit>(val.getValue()))});
413 }
414 }
415
416 // Populate `arguments`. This must happen after we've finalized `operands` and
417 // `attributes` because we will put their elements' pointers in `arguments`.
418 // SmallVector may perform re-allocation under the hood when adding new
419 // elements.
420 int operandIndex = 0, attrIndex = 0;
421 for (unsigned i = 0; i != numArgs; ++i) {
422 Record *argDef = dyn_cast<DefInit>(argumentValues->getArg(i))->getDef();
423 if (argDef->isSubClassOf(opVarClass))
424 argDef = argDef->getValueAsDef("constraint");
425
426 if (argDef->isSubClassOf(typeConstraintClass)) {
427 attrOrOperandMapping.push_back(
428 {OperandOrAttribute::Kind::Operand, operandIndex});
429 arguments.emplace_back(&operands[operandIndex++]);
430 } else {
431 assert(argDef->isSubClassOf(attrClass));
432 attrOrOperandMapping.push_back(
433 {OperandOrAttribute::Kind::Attribute, attrIndex});
434 arguments.emplace_back(&attributes[attrIndex++]);
435 }
436 }
437
438 auto *resultsDag = def.getValueAsDag("results");
439 auto *outsOp = dyn_cast<DefInit>(resultsDag->getOperator());
440 if (!outsOp || outsOp->getDef()->getName() != "outs") {
441 PrintFatalError(def.getLoc(), "'results' must have 'outs' directive");
442 }
443
444 // Handle results.
445 for (unsigned i = 0, e = resultsDag->getNumArgs(); i < e; ++i) {
446 auto name = resultsDag->getArgNameStr(i);
447 auto *resultInit = dyn_cast<DefInit>(resultsDag->getArg(i));
448 if (!resultInit) {
449 PrintFatalError(def.getLoc(),
450 Twine("undefined type for result #") + Twine(i));
451 }
452 auto *resultDef = resultInit->getDef();
453 if (resultDef->isSubClassOf(opVarClass))
454 resultDef = resultDef->getValueAsDef("constraint");
455 results.push_back({name, TypeConstraint(resultDef)});
456 if (!name.empty())
457 argumentsAndResultsIndex[name] = resultIndex(i);
458 }
459
460 // Handle successors
461 auto *successorsDag = def.getValueAsDag("successors");
462 auto *successorsOp = dyn_cast<DefInit>(successorsDag->getOperator());
463 if (!successorsOp || successorsOp->getDef()->getName() != "successor") {
464 PrintFatalError(def.getLoc(),
465 "'successors' must have 'successor' directive");
466 }
467
468 for (unsigned i = 0, e = successorsDag->getNumArgs(); i < e; ++i) {
469 auto name = successorsDag->getArgNameStr(i);
470 auto *successorInit = dyn_cast<DefInit>(successorsDag->getArg(i));
471 if (!successorInit) {
472 PrintFatalError(def.getLoc(),
473 Twine("undefined kind for successor #") + Twine(i));
474 }
475 Successor successor(successorInit->getDef());
476
477 // Only support variadic successors if it is the last one for now.
478 if (i != e - 1 && successor.isVariadic())
479 PrintFatalError(def.getLoc(), "only the last successor can be variadic");
480 successors.push_back({name, successor});
481 }
482
483 // Create list of traits, skipping over duplicates: appending to lists in
484 // tablegen is easy, making them unique less so, so dedupe here.
485 if (auto *traitList = def.getValueAsListInit("traits")) {
486 // This is uniquing based on pointers of the trait.
487 SmallPtrSet<const llvm::Init *, 32> traitSet;
488 traits.reserve(traitSet.size());
489 for (auto *traitInit : *traitList) {
490 // Keep traits in the same order while skipping over duplicates.
491 if (traitSet.insert(traitInit).second)
492 traits.push_back(OpTrait::create(traitInit));
493 }
494 }
495
496 populateTypeInferenceInfo(argumentsAndResultsIndex);
497
498 // Handle regions
499 auto *regionsDag = def.getValueAsDag("regions");
500 auto *regionsOp = dyn_cast<DefInit>(regionsDag->getOperator());
501 if (!regionsOp || regionsOp->getDef()->getName() != "region") {
502 PrintFatalError(def.getLoc(), "'regions' must have 'region' directive");
503 }
504
505 for (unsigned i = 0, e = regionsDag->getNumArgs(); i < e; ++i) {
506 auto name = regionsDag->getArgNameStr(i);
507 auto *regionInit = dyn_cast<DefInit>(regionsDag->getArg(i));
508 if (!regionInit) {
509 PrintFatalError(def.getLoc(),
510 Twine("undefined kind for region #") + Twine(i));
511 }
512 Region region(regionInit->getDef());
513 if (region.isVariadic()) {
514 // Only support variadic regions if it is the last one for now.
515 if (i != e - 1)
516 PrintFatalError(def.getLoc(), "only the last region can be variadic");
517 if (name.empty())
518 PrintFatalError(def.getLoc(), "variadic regions must be named");
519 }
520
521 regions.push_back({name, region});
522 }
523
524 LLVM_DEBUG(print(llvm::dbgs()));
525 }
526
getSameTypeAsResult(int index) const527 auto Operator::getSameTypeAsResult(int index) const -> ArrayRef<ArgOrType> {
528 assert(allResultTypesKnown());
529 return resultTypeMapping[index];
530 }
531
getLoc() const532 ArrayRef<llvm::SMLoc> Operator::getLoc() const { return def.getLoc(); }
533
hasDescription() const534 bool Operator::hasDescription() const {
535 return def.getValue("description") != nullptr;
536 }
537
getDescription() const538 StringRef Operator::getDescription() const {
539 return def.getValueAsString("description");
540 }
541
hasSummary() const542 bool Operator::hasSummary() const { return def.getValue("summary") != nullptr; }
543
getSummary() const544 StringRef Operator::getSummary() const {
545 return def.getValueAsString("summary");
546 }
547
hasAssemblyFormat() const548 bool Operator::hasAssemblyFormat() const {
549 auto *valueInit = def.getValueInit("assemblyFormat");
550 return isa<llvm::StringInit>(valueInit);
551 }
552
getAssemblyFormat() const553 StringRef Operator::getAssemblyFormat() const {
554 return TypeSwitch<llvm::Init *, StringRef>(def.getValueInit("assemblyFormat"))
555 .Case<llvm::StringInit>(
556 [&](auto *init) { return init->getValue(); });
557 }
558
print(llvm::raw_ostream & os) const559 void Operator::print(llvm::raw_ostream &os) const {
560 os << "op '" << getOperationName() << "'\n";
561 for (Argument arg : arguments) {
562 if (auto *attr = arg.dyn_cast<NamedAttribute *>())
563 os << "[attribute] " << attr->name << '\n';
564 else
565 os << "[operand] " << arg.get<NamedTypeConstraint *>()->name << '\n';
566 }
567 }
568
unwrap(llvm::Init * init)569 auto Operator::VariableDecoratorIterator::unwrap(llvm::Init *init)
570 -> VariableDecorator {
571 return VariableDecorator(cast<llvm::DefInit>(init)->getDef());
572 }
573
getArgToOperandOrAttribute(int index) const574 auto Operator::getArgToOperandOrAttribute(int index) const
575 -> OperandOrAttribute {
576 return attrOrOperandMapping[index];
577 }
578