1 //===- OpDefinitionsGen.cpp - MLIR op definitions generator ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // OpDefinitionsGen uses the description of operations to generate C++
10 // definitions for ops.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "OpFormatGen.h"
15 #include "mlir/TableGen/CodeGenHelpers.h"
16 #include "mlir/TableGen/Format.h"
17 #include "mlir/TableGen/GenInfo.h"
18 #include "mlir/TableGen/Interfaces.h"
19 #include "mlir/TableGen/OpClass.h"
20 #include "mlir/TableGen/OpTrait.h"
21 #include "mlir/TableGen/Operator.h"
22 #include "mlir/TableGen/SideEffects.h"
23 #include "llvm/ADT/Sequence.h"
24 #include "llvm/ADT/StringExtras.h"
25 #include "llvm/Support/CommandLine.h"
26 #include "llvm/Support/Regex.h"
27 #include "llvm/Support/Signals.h"
28 #include "llvm/TableGen/Error.h"
29 #include "llvm/TableGen/Record.h"
30 #include "llvm/TableGen/TableGenBackend.h"
31
32 #define DEBUG_TYPE "mlir-tblgen-opdefgen"
33
34 using namespace llvm;
35 using namespace mlir;
36 using namespace mlir::tblgen;
37
38 cl::OptionCategory opDefGenCat("Options for -gen-op-defs and -gen-op-decls");
39
40 static cl::opt<std::string> opIncFilter(
41 "op-include-regex",
42 cl::desc("Regex of name of op's to include (no filter if empty)"),
43 cl::cat(opDefGenCat));
44 static cl::opt<std::string> opExcFilter(
45 "op-exclude-regex",
46 cl::desc("Regex of name of op's to exclude (no filter if empty)"),
47 cl::cat(opDefGenCat));
48
49 static const char *const tblgenNamePrefix = "tblgen_";
50 static const char *const generatedArgName = "odsArg";
51 static const char *const builder = "odsBuilder";
52 static const char *const builderOpState = "odsState";
53
54 // The logic to calculate the actual value range for a declared operand/result
55 // of an op with variadic operands/results. Note that this logic is not for
56 // general use; it assumes all variadic operands/results must have the same
57 // number of values.
58 //
59 // {0}: The list of whether each declared operand/result is variadic.
60 // {1}: The total number of non-variadic operands/results.
61 // {2}: The total number of variadic operands/results.
62 // {3}: The total number of actual values.
63 // {4}: "operand" or "result".
64 const char *sameVariadicSizeValueRangeCalcCode = R"(
65 bool isVariadic[] = {{{0}};
66 int prevVariadicCount = 0;
67 for (unsigned i = 0; i < index; ++i)
68 if (isVariadic[i]) ++prevVariadicCount;
69
70 // Calculate how many dynamic values a static variadic {4} corresponds to.
71 // This assumes all static variadic {4}s have the same dynamic value count.
72 int variadicSize = ({3} - {1}) / {2};
73 // `index` passed in as the parameter is the static index which counts each
74 // {4} (variadic or not) as size 1. So here for each previous static variadic
75 // {4}, we need to offset by (variadicSize - 1) to get where the dynamic
76 // value pack for this static {4} starts.
77 int start = index + (variadicSize - 1) * prevVariadicCount;
78 int size = isVariadic[index] ? variadicSize : 1;
79 return {{start, size};
80 )";
81
82 // The logic to calculate the actual value range for a declared operand/result
83 // of an op with variadic operands/results. Note that this logic is assumes
84 // the op has an attribute specifying the size of each operand/result segment
85 // (variadic or not).
86 //
87 // {0}: The name of the attribute specifying the segment sizes.
88 const char *adapterSegmentSizeAttrInitCode = R"(
89 assert(odsAttrs && "missing segment size attribute for op");
90 auto sizeAttr = odsAttrs.get("{0}").cast<::mlir::DenseIntElementsAttr>();
91 )";
92 const char *opSegmentSizeAttrInitCode = R"(
93 auto sizeAttr =
94 getOperation()->getAttrOfType<::mlir::DenseIntElementsAttr>("{0}");
95 )";
96 const char *attrSizedSegmentValueRangeCalcCode = R"(
97 unsigned start = 0;
98 for (unsigned i = 0; i < index; ++i)
99 start += (*(sizeAttr.begin() + i)).getZExtValue();
100 unsigned size = (*(sizeAttr.begin() + index)).getZExtValue();
101 return {start, size};
102 )";
103
104 // The logic to build a range of either operand or result values.
105 //
106 // {0}: The begin iterator of the actual values.
107 // {1}: The call to generate the start and length of the value range.
108 const char *valueRangeReturnCode = R"(
109 auto valueRange = {1};
110 return {{std::next({0}, valueRange.first),
111 std::next({0}, valueRange.first + valueRange.second)};
112 )";
113
114 static const char *const opCommentHeader = R"(
115 //===----------------------------------------------------------------------===//
116 // {0} {1}
117 //===----------------------------------------------------------------------===//
118
119 )";
120
121 //===----------------------------------------------------------------------===//
122 // Utility structs and functions
123 //===----------------------------------------------------------------------===//
124
125 // Replaces all occurrences of `match` in `str` with `substitute`.
replaceAllSubstrs(std::string str,const std::string & match,const std::string & substitute)126 static std::string replaceAllSubstrs(std::string str, const std::string &match,
127 const std::string &substitute) {
128 std::string::size_type scanLoc = 0, matchLoc = std::string::npos;
129 while ((matchLoc = str.find(match, scanLoc)) != std::string::npos) {
130 str = str.replace(matchLoc, match.size(), substitute);
131 scanLoc = matchLoc + substitute.size();
132 }
133 return str;
134 }
135
136 // Returns whether the record has a value of the given name that can be returned
137 // via getValueAsString.
hasStringAttribute(const Record & record,StringRef fieldName)138 static inline bool hasStringAttribute(const Record &record,
139 StringRef fieldName) {
140 auto valueInit = record.getValueInit(fieldName);
141 return isa<StringInit>(valueInit);
142 }
143
getArgumentName(const Operator & op,int index)144 static std::string getArgumentName(const Operator &op, int index) {
145 const auto &operand = op.getOperand(index);
146 if (!operand.name.empty())
147 return std::string(operand.name);
148 else
149 return std::string(formatv("{0}_{1}", generatedArgName, index));
150 }
151
152 // Returns true if we can use unwrapped value for the given `attr` in builders.
canUseUnwrappedRawValue(const tblgen::Attribute & attr)153 static bool canUseUnwrappedRawValue(const tblgen::Attribute &attr) {
154 return attr.getReturnType() != attr.getStorageType() &&
155 // We need to wrap the raw value into an attribute in the builder impl
156 // so we need to make sure that the attribute specifies how to do that.
157 !attr.getConstBuilderTemplate().empty();
158 }
159
160 //===----------------------------------------------------------------------===//
161 // Op emitter
162 //===----------------------------------------------------------------------===//
163
164 namespace {
165 // Helper class to emit a record into the given output stream.
166 class OpEmitter {
167 public:
168 static void emitDecl(const Operator &op, raw_ostream &os);
169 static void emitDef(const Operator &op, raw_ostream &os);
170
171 private:
172 OpEmitter(const Operator &op);
173
174 void emitDecl(raw_ostream &os);
175 void emitDef(raw_ostream &os);
176
177 // Generates the OpAsmOpInterface for this operation if possible.
178 void genOpAsmInterface();
179
180 // Generates the `getOperationName` method for this op.
181 void genOpNameGetter();
182
183 // Generates getters for the attributes.
184 void genAttrGetters();
185
186 // Generates setter for the attributes.
187 void genAttrSetters();
188
189 // Generates getters for named operands.
190 void genNamedOperandGetters();
191
192 // Generates setters for named operands.
193 void genNamedOperandSetters();
194
195 // Generates getters for named results.
196 void genNamedResultGetters();
197
198 // Generates getters for named regions.
199 void genNamedRegionGetters();
200
201 // Generates getters for named successors.
202 void genNamedSuccessorGetters();
203
204 // Generates builder methods for the operation.
205 void genBuilder();
206
207 // Generates the build() method that takes each operand/attribute
208 // as a stand-alone parameter.
209 void genSeparateArgParamBuilder();
210
211 // Generates the build() method that takes each operand/attribute as a
212 // stand-alone parameter. The generated build() method uses first operand's
213 // type as all results' types.
214 void genUseOperandAsResultTypeSeparateParamBuilder();
215
216 // Generates the build() method that takes all operands/attributes
217 // collectively as one parameter. The generated build() method uses first
218 // operand's type as all results' types.
219 void genUseOperandAsResultTypeCollectiveParamBuilder();
220
221 // Generates the build() method that takes aggregate operands/attributes
222 // parameters. This build() method uses inferred types as result types.
223 // Requires: The type needs to be inferable via InferTypeOpInterface.
224 void genInferredTypeCollectiveParamBuilder();
225
226 // Generates the build() method that takes each operand/attribute as a
227 // stand-alone parameter. The generated build() method uses first attribute's
228 // type as all result's types.
229 void genUseAttrAsResultTypeBuilder();
230
231 // Generates the build() method that takes all result types collectively as
232 // one parameter. Similarly for operands and attributes.
233 void genCollectiveParamBuilder();
234
235 // The kind of parameter to generate for result types in builders.
236 enum class TypeParamKind {
237 None, // No result type in parameter list.
238 Separate, // A separate parameter for each result type.
239 Collective, // An ArrayRef<Type> for all result types.
240 };
241
242 // The kind of parameter to generate for attributes in builders.
243 enum class AttrParamKind {
244 WrappedAttr, // A wrapped MLIR Attribute instance.
245 UnwrappedValue, // A raw value without MLIR Attribute wrapper.
246 };
247
248 // Builds the parameter list for build() method of this op. This method writes
249 // to `paramList` the comma-separated parameter list and updates
250 // `resultTypeNames` with the names for parameters for specifying result
251 // types. The given `typeParamKind` and `attrParamKind` controls how result
252 // types and attributes are placed in the parameter list.
253 void buildParamList(llvm::SmallVectorImpl<OpMethodParameter> ¶mList,
254 SmallVectorImpl<std::string> &resultTypeNames,
255 TypeParamKind typeParamKind,
256 AttrParamKind attrParamKind = AttrParamKind::WrappedAttr);
257
258 // Adds op arguments and regions into operation state for build() methods.
259 void genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body,
260 bool isRawValueAttr = false);
261
262 // Generates canonicalizer declaration for the operation.
263 void genCanonicalizerDecls();
264
265 // Generates the folder declaration for the operation.
266 void genFolderDecls();
267
268 // Generates the parser for the operation.
269 void genParser();
270
271 // Generates the printer for the operation.
272 void genPrinter();
273
274 // Generates verify method for the operation.
275 void genVerifier();
276
277 // Generates verify statements for operands and results in the operation.
278 // The generated code will be attached to `body`.
279 void genOperandResultVerifier(OpMethodBody &body,
280 Operator::value_range values,
281 StringRef valueKind);
282
283 // Generates verify statements for regions in the operation.
284 // The generated code will be attached to `body`.
285 void genRegionVerifier(OpMethodBody &body);
286
287 // Generates verify statements for successors in the operation.
288 // The generated code will be attached to `body`.
289 void genSuccessorVerifier(OpMethodBody &body);
290
291 // Generates the traits used by the object.
292 void genTraits();
293
294 // Generate the OpInterface methods for all interfaces.
295 void genOpInterfaceMethods();
296
297 // Generate op interface methods for the given interface.
298 void genOpInterfaceMethods(const tblgen::InterfaceOpTrait *trait);
299
300 // Generate op interface method for the given interface method. If
301 // 'declaration' is true, generates a declaration, else a definition.
302 OpMethod *genOpInterfaceMethod(const tblgen::InterfaceMethod &method,
303 bool declaration = true);
304
305 // Generate the side effect interface methods.
306 void genSideEffectInterfaceMethods();
307
308 // Generate the type inference interface methods.
309 void genTypeInterfaceMethods();
310
311 private:
312 // The TableGen record for this op.
313 // TODO: OpEmitter should not have a Record directly,
314 // it should rather go through the Operator for better abstraction.
315 const Record &def;
316
317 // The wrapper operator class for querying information from this op.
318 Operator op;
319
320 // The C++ code builder for this op
321 OpClass opClass;
322
323 // The format context for verification code generation.
324 FmtContext verifyCtx;
325 };
326 } // end anonymous namespace
327
328 // Populate the format context `ctx` with substitutions of attributes, operands
329 // and results.
330 // - attrGet corresponds to the name of the function to call to get value of
331 // attribute (the generated function call returns an Attribute);
332 // - operandGet corresponds to the name of the function with which to retrieve
333 // an operand (the generated function call returns an OperandRange);
334 // - resultGet corresponds to the name of the function to get an result (the
335 // generated function call returns a ValueRange);
populateSubstitutions(const Operator & op,const char * attrGet,const char * operandGet,const char * resultGet,FmtContext & ctx)336 static void populateSubstitutions(const Operator &op, const char *attrGet,
337 const char *operandGet, const char *resultGet,
338 FmtContext &ctx) {
339 // Populate substitutions for attributes and named operands.
340 for (const auto &namedAttr : op.getAttributes())
341 ctx.addSubst(namedAttr.name,
342 formatv("{0}(\"{1}\")", attrGet, namedAttr.name));
343 for (int i = 0, e = op.getNumOperands(); i < e; ++i) {
344 auto &value = op.getOperand(i);
345 if (value.name.empty())
346 continue;
347
348 if (value.isVariadic())
349 ctx.addSubst(value.name, formatv("{0}({1})", operandGet, i));
350 else
351 ctx.addSubst(value.name, formatv("(*{0}({1}).begin())", operandGet, i));
352 }
353
354 // Populate substitutions for results.
355 for (int i = 0, e = op.getNumResults(); i < e; ++i) {
356 auto &value = op.getResult(i);
357 if (value.name.empty())
358 continue;
359
360 if (value.isVariadic())
361 ctx.addSubst(value.name, formatv("{0}({1})", resultGet, i));
362 else
363 ctx.addSubst(value.name, formatv("(*{0}({1}).begin())", resultGet, i));
364 }
365 }
366
367 // Generate attribute verification. If emitVerificationRequiringOp is set then
368 // only verification for attributes whose value depend on op being known are
369 // emitted, else only verification that doesn't depend on the op being known are
370 // generated.
371 // - emitErrorPrefix is the prefix for the error emitting call which consists
372 // of the entire function call up to start of error message fragment;
373 // - emitVerificationRequiringOp specifies whether verification should be
374 // emitted for verification that require the op to exist;
genAttributeVerifier(const Operator & op,const char * attrGet,const Twine & emitErrorPrefix,bool emitVerificationRequiringOp,FmtContext & ctx,OpMethodBody & body)375 static void genAttributeVerifier(const Operator &op, const char *attrGet,
376 const Twine &emitErrorPrefix,
377 bool emitVerificationRequiringOp,
378 FmtContext &ctx, OpMethodBody &body) {
379 for (const auto &namedAttr : op.getAttributes()) {
380 const auto &attr = namedAttr.attr;
381 if (attr.isDerivedAttr())
382 continue;
383
384 auto attrName = namedAttr.name;
385 bool allowMissingAttr = attr.hasDefaultValue() || attr.isOptional();
386 auto attrPred = attr.getPredicate();
387 auto condition = attrPred.isNull() ? "" : attrPred.getCondition();
388 // There is a condition to emit only if the use of $_op and whether to
389 // emit verifications for op matches.
390 bool hasConditionToEmit = (!(condition.find("$_op") != StringRef::npos) ^
391 emitVerificationRequiringOp);
392
393 // Prefix with `tblgen_` to avoid hiding the attribute accessor.
394 auto varName = tblgenNamePrefix + attrName;
395
396 // If the attribute is
397 // 1. Required (not allowed missing) and not in op verification, or
398 // 2. Has a condition that will get verified
399 // then the variable will be used.
400 //
401 // Therefore, for optional attributes whose verification requires that an
402 // op already exists for verification/emitVerificationRequiringOp is set
403 // has nothing that can be verified here.
404 if ((allowMissingAttr || emitVerificationRequiringOp) &&
405 !hasConditionToEmit)
406 continue;
407
408 body << formatv(" {\n auto {0} = {1}(\"{2}\");\n", varName, attrGet,
409 attrName);
410
411 if (!emitVerificationRequiringOp && !allowMissingAttr) {
412 body << " if (!" << varName << ") return " << emitErrorPrefix
413 << "\"requires attribute '" << attrName << "'\");\n";
414 }
415
416 if (!hasConditionToEmit) {
417 body << " }\n";
418 continue;
419 }
420
421 if (allowMissingAttr) {
422 // If the attribute has a default value, then only verify the predicate if
423 // set. This does effectively assume that the default value is valid.
424 // TODO: verify the debug value is valid (perhaps in debug mode only).
425 body << " if (" << varName << ") {\n";
426 }
427
428 body << tgfmt(" if (!($0)) return $1\"attribute '$2' "
429 "failed to satisfy constraint: $3\");\n",
430 /*ctx=*/nullptr, tgfmt(condition, &ctx.withSelf(varName)),
431 emitErrorPrefix, attrName, attr.getDescription());
432 if (allowMissingAttr)
433 body << " }\n";
434 body << " }\n";
435 }
436 }
437
OpEmitter(const Operator & op)438 OpEmitter::OpEmitter(const Operator &op)
439 : def(op.getDef()), op(op),
440 opClass(op.getCppClassName(), op.getExtraClassDeclaration()) {
441 verifyCtx.withOp("(*this->getOperation())");
442
443 genTraits();
444
445 // Generate C++ code for various op methods. The order here determines the
446 // methods in the generated file.
447 genOpAsmInterface();
448 genOpNameGetter();
449 genNamedOperandGetters();
450 genNamedOperandSetters();
451 genNamedResultGetters();
452 genNamedRegionGetters();
453 genNamedSuccessorGetters();
454 genAttrGetters();
455 genAttrSetters();
456 genBuilder();
457 genParser();
458 genPrinter();
459 genVerifier();
460 genCanonicalizerDecls();
461 genFolderDecls();
462 genTypeInterfaceMethods();
463 genOpInterfaceMethods();
464 generateOpFormat(op, opClass);
465 genSideEffectInterfaceMethods();
466 }
467
emitDecl(const Operator & op,raw_ostream & os)468 void OpEmitter::emitDecl(const Operator &op, raw_ostream &os) {
469 OpEmitter(op).emitDecl(os);
470 }
471
emitDef(const Operator & op,raw_ostream & os)472 void OpEmitter::emitDef(const Operator &op, raw_ostream &os) {
473 OpEmitter(op).emitDef(os);
474 }
475
emitDecl(raw_ostream & os)476 void OpEmitter::emitDecl(raw_ostream &os) { opClass.writeDeclTo(os); }
477
emitDef(raw_ostream & os)478 void OpEmitter::emitDef(raw_ostream &os) { opClass.writeDefTo(os); }
479
genAttrGetters()480 void OpEmitter::genAttrGetters() {
481 FmtContext fctx;
482 fctx.withBuilder("::mlir::Builder(this->getContext())");
483
484 Dialect opDialect = op.getDialect();
485 // Emit the derived attribute body.
486 auto emitDerivedAttr = [&](StringRef name, Attribute attr) {
487 auto *method = opClass.addMethodAndPrune(attr.getReturnType(), name);
488 if (!method)
489 return;
490 auto &body = method->body();
491 body << " " << attr.getDerivedCodeBody() << "\n";
492 };
493
494 // Emit with return type specified.
495 auto emitAttrWithReturnType = [&](StringRef name, Attribute attr) {
496 auto *method = opClass.addMethodAndPrune(attr.getReturnType(), name);
497 auto &body = method->body();
498 body << " auto attr = " << name << "Attr();\n";
499 if (attr.hasDefaultValue()) {
500 // Returns the default value if not set.
501 // TODO: this is inefficient, we are recreating the attribute for every
502 // call. This should be set instead.
503 std::string defaultValue = std::string(
504 tgfmt(attr.getConstBuilderTemplate(), &fctx, attr.getDefaultValue()));
505 body << " if (!attr)\n return "
506 << tgfmt(attr.getConvertFromStorageCall(),
507 &fctx.withSelf(defaultValue))
508 << ";\n";
509 }
510 body << " return "
511 << tgfmt(attr.getConvertFromStorageCall(), &fctx.withSelf("attr"))
512 << ";\n";
513 };
514
515 // Generate raw named accessor type. This is a wrapper class that allows
516 // referring to the attributes via accessors instead of having to use
517 // the string interface for better compile time verification.
518 auto emitAttrWithStorageType = [&](StringRef name, Attribute attr) {
519 auto *method =
520 opClass.addMethodAndPrune(attr.getStorageType(), (name + "Attr").str());
521 if (!method)
522 return;
523 auto &body = method->body();
524 body << " return this->getAttr(\"" << name << "\").";
525 if (attr.isOptional() || attr.hasDefaultValue())
526 body << "dyn_cast_or_null<";
527 else
528 body << "cast<";
529 body << attr.getStorageType() << ">();";
530 };
531
532 for (auto &namedAttr : op.getAttributes()) {
533 const auto &name = namedAttr.name;
534 const auto &attr = namedAttr.attr;
535 if (attr.isDerivedAttr()) {
536 emitDerivedAttr(name, attr);
537 } else {
538 emitAttrWithStorageType(name, attr);
539 emitAttrWithReturnType(name, attr);
540 }
541 }
542
543 auto derivedAttrs = make_filter_range(op.getAttributes(),
544 [](const NamedAttribute &namedAttr) {
545 return namedAttr.attr.isDerivedAttr();
546 });
547 if (!derivedAttrs.empty()) {
548 opClass.addTrait("::mlir::DerivedAttributeOpInterface::Trait");
549 // Generate helper method to query whether a named attribute is a derived
550 // attribute. This enables, for example, avoiding adding an attribute that
551 // overlaps with a derived attribute.
552 {
553 auto *method = opClass.addMethodAndPrune("bool", "isDerivedAttribute",
554 OpMethod::MP_Static,
555 "::llvm::StringRef", "name");
556 auto &body = method->body();
557 for (auto namedAttr : derivedAttrs)
558 body << " if (name == \"" << namedAttr.name << "\") return true;\n";
559 body << " return false;";
560 }
561 // Generate method to materialize derived attributes as a DictionaryAttr.
562 {
563 auto *method = opClass.addMethodAndPrune("::mlir::DictionaryAttr",
564 "materializeDerivedAttributes");
565 auto &body = method->body();
566
567 auto nonMaterializable =
568 make_filter_range(derivedAttrs, [](const NamedAttribute &namedAttr) {
569 return namedAttr.attr.getConvertFromStorageCall().empty();
570 });
571 if (!nonMaterializable.empty()) {
572 std::string attrs;
573 llvm::raw_string_ostream os(attrs);
574 interleaveComma(nonMaterializable, os,
575 [&](const NamedAttribute &attr) { os << attr.name; });
576 PrintWarning(
577 op.getLoc(),
578 formatv(
579 "op has non-materializable derived attributes '{0}', skipping",
580 os.str()));
581 body << formatv(" emitOpError(\"op has non-materializable derived "
582 "attributes '{0}'\");\n",
583 attrs);
584 body << " return nullptr;";
585 return;
586 }
587
588 body << " ::mlir::MLIRContext* ctx = getContext();\n";
589 body << " ::mlir::Builder odsBuilder(ctx); (void)odsBuilder;\n";
590 body << " return ::mlir::DictionaryAttr::get({\n";
591 interleave(
592 derivedAttrs, body,
593 [&](const NamedAttribute &namedAttr) {
594 auto tmpl = namedAttr.attr.getConvertFromStorageCall();
595 body << " {::mlir::Identifier::get(\"" << namedAttr.name
596 << "\", ctx),\n"
597 << tgfmt(tmpl, &fctx.withSelf(namedAttr.name + "()")
598 .withBuilder("odsBuilder")
599 .addSubst("_ctx", "ctx"))
600 << "}";
601 },
602 ",\n");
603 body << "\n }, ctx);";
604 }
605 }
606 }
607
genAttrSetters()608 void OpEmitter::genAttrSetters() {
609 // Generate raw named setter type. This is a wrapper class that allows setting
610 // to the attributes via setters instead of having to use the string interface
611 // for better compile time verification.
612 auto emitAttrWithStorageType = [&](StringRef name, Attribute attr) {
613 auto *method = opClass.addMethodAndPrune("void", (name + "Attr").str(),
614 attr.getStorageType(), "attr");
615 if (!method)
616 return;
617 auto &body = method->body();
618 body << " (*this)->setAttr(\"" << name << "\", attr);";
619 };
620
621 for (auto &namedAttr : op.getAttributes()) {
622 const auto &name = namedAttr.name;
623 const auto &attr = namedAttr.attr;
624 if (!attr.isDerivedAttr())
625 emitAttrWithStorageType(name, attr);
626 }
627 }
628
629 // Generates the code to compute the start and end index of an operand or result
630 // range.
631 template <typename RangeT>
632 static void
generateValueRangeStartAndEnd(Class & opClass,StringRef methodName,int numVariadic,int numNonVariadic,StringRef rangeSizeCall,bool hasAttrSegmentSize,StringRef sizeAttrInit,RangeT && odsValues)633 generateValueRangeStartAndEnd(Class &opClass, StringRef methodName,
634 int numVariadic, int numNonVariadic,
635 StringRef rangeSizeCall, bool hasAttrSegmentSize,
636 StringRef sizeAttrInit, RangeT &&odsValues) {
637 auto *method = opClass.addMethodAndPrune("std::pair<unsigned, unsigned>",
638 methodName, "unsigned", "index");
639 if (!method)
640 return;
641 auto &body = method->body();
642 if (numVariadic == 0) {
643 body << " return {index, 1};\n";
644 } else if (hasAttrSegmentSize) {
645 body << sizeAttrInit << attrSizedSegmentValueRangeCalcCode;
646 } else {
647 // Because the op can have arbitrarily interleaved variadic and non-variadic
648 // operands, we need to embed a list in the "sink" getter method for
649 // calculation at run-time.
650 llvm::SmallVector<StringRef, 4> isVariadic;
651 isVariadic.reserve(llvm::size(odsValues));
652 for (auto &it : odsValues)
653 isVariadic.push_back(it.isVariableLength() ? "true" : "false");
654 std::string isVariadicList = llvm::join(isVariadic, ", ");
655 body << formatv(sameVariadicSizeValueRangeCalcCode, isVariadicList,
656 numNonVariadic, numVariadic, rangeSizeCall, "operand");
657 }
658 }
659
660 // Generates the named operand getter methods for the given Operator `op` and
661 // puts them in `opClass`. Uses `rangeType` as the return type of getters that
662 // return a range of operands (individual operands are `Value ` and each
663 // element in the range must also be `Value `); use `rangeBeginCall` to get
664 // an iterator to the beginning of the operand range; use `rangeSizeCall` to
665 // obtain the number of operands. `getOperandCallPattern` contains the code
666 // necessary to obtain a single operand whose position will be substituted
667 // instead of
668 // "{0}" marker in the pattern. Note that the pattern should work for any kind
669 // of ops, in particular for one-operand ops that may not have the
670 // `getOperand(unsigned)` method.
generateNamedOperandGetters(const Operator & op,Class & opClass,StringRef sizeAttrInit,StringRef rangeType,StringRef rangeBeginCall,StringRef rangeSizeCall,StringRef getOperandCallPattern)671 static void generateNamedOperandGetters(const Operator &op, Class &opClass,
672 StringRef sizeAttrInit,
673 StringRef rangeType,
674 StringRef rangeBeginCall,
675 StringRef rangeSizeCall,
676 StringRef getOperandCallPattern) {
677 const int numOperands = op.getNumOperands();
678 const int numVariadicOperands = op.getNumVariableLengthOperands();
679 const int numNormalOperands = numOperands - numVariadicOperands;
680
681 const auto *sameVariadicSize =
682 op.getTrait("::mlir::OpTrait::SameVariadicOperandSize");
683 const auto *attrSizedOperands =
684 op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments");
685
686 if (numVariadicOperands > 1 && !sameVariadicSize && !attrSizedOperands) {
687 PrintFatalError(op.getLoc(), "op has multiple variadic operands but no "
688 "specification over their sizes");
689 }
690
691 if (numVariadicOperands < 2 && attrSizedOperands) {
692 PrintFatalError(op.getLoc(), "op must have at least two variadic operands "
693 "to use 'AttrSizedOperandSegments' trait");
694 }
695
696 if (attrSizedOperands && sameVariadicSize) {
697 PrintFatalError(op.getLoc(),
698 "op cannot have both 'AttrSizedOperandSegments' and "
699 "'SameVariadicOperandSize' traits");
700 }
701
702 // First emit a few "sink" getter methods upon which we layer all nicer named
703 // getter methods.
704 generateValueRangeStartAndEnd(opClass, "getODSOperandIndexAndLength",
705 numVariadicOperands, numNormalOperands,
706 rangeSizeCall, attrSizedOperands, sizeAttrInit,
707 const_cast<Operator &>(op).getOperands());
708
709 auto *m = opClass.addMethodAndPrune(rangeType, "getODSOperands", "unsigned",
710 "index");
711 auto &body = m->body();
712 body << formatv(valueRangeReturnCode, rangeBeginCall,
713 "getODSOperandIndexAndLength(index)");
714
715 // Then we emit nicer named getter methods by redirecting to the "sink" getter
716 // method.
717 for (int i = 0; i != numOperands; ++i) {
718 const auto &operand = op.getOperand(i);
719 if (operand.name.empty())
720 continue;
721
722 if (operand.isOptional()) {
723 m = opClass.addMethodAndPrune("::mlir::Value", operand.name);
724 m->body() << " auto operands = getODSOperands(" << i << ");\n"
725 << " return operands.empty() ? Value() : *operands.begin();";
726 } else if (operand.isVariadic()) {
727 m = opClass.addMethodAndPrune(rangeType, operand.name);
728 m->body() << " return getODSOperands(" << i << ");";
729 } else {
730 m = opClass.addMethodAndPrune("::mlir::Value", operand.name);
731 m->body() << " return *getODSOperands(" << i << ").begin();";
732 }
733 }
734 }
735
genNamedOperandGetters()736 void OpEmitter::genNamedOperandGetters() {
737 generateNamedOperandGetters(
738 op, opClass,
739 /*sizeAttrInit=*/
740 formatv(opSegmentSizeAttrInitCode, "operand_segment_sizes").str(),
741 /*rangeType=*/"::mlir::Operation::operand_range",
742 /*rangeBeginCall=*/"getOperation()->operand_begin()",
743 /*rangeSizeCall=*/"getOperation()->getNumOperands()",
744 /*getOperandCallPattern=*/"getOperation()->getOperand({0})");
745 }
746
genNamedOperandSetters()747 void OpEmitter::genNamedOperandSetters() {
748 auto *attrSizedOperands =
749 op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments");
750 for (int i = 0, e = op.getNumOperands(); i != e; ++i) {
751 const auto &operand = op.getOperand(i);
752 if (operand.name.empty())
753 continue;
754 auto *m = opClass.addMethodAndPrune("::mlir::MutableOperandRange",
755 (operand.name + "Mutable").str());
756 auto &body = m->body();
757 body << " auto range = getODSOperandIndexAndLength(" << i << ");\n"
758 << " return ::mlir::MutableOperandRange(getOperation(), "
759 "range.first, range.second";
760 if (attrSizedOperands)
761 body << ", ::mlir::MutableOperandRange::OperandSegment(" << i
762 << "u, *getOperation()->getMutableAttrDict().getNamed("
763 "\"operand_segment_sizes\"))";
764 body << ");\n";
765 }
766 }
767
genNamedResultGetters()768 void OpEmitter::genNamedResultGetters() {
769 const int numResults = op.getNumResults();
770 const int numVariadicResults = op.getNumVariableLengthResults();
771 const int numNormalResults = numResults - numVariadicResults;
772
773 // If we have more than one variadic results, we need more complicated logic
774 // to calculate the value range for each result.
775
776 const auto *sameVariadicSize =
777 op.getTrait("::mlir::OpTrait::SameVariadicResultSize");
778 const auto *attrSizedResults =
779 op.getTrait("::mlir::OpTrait::AttrSizedResultSegments");
780
781 if (numVariadicResults > 1 && !sameVariadicSize && !attrSizedResults) {
782 PrintFatalError(op.getLoc(), "op has multiple variadic results but no "
783 "specification over their sizes");
784 }
785
786 if (numVariadicResults < 2 && attrSizedResults) {
787 PrintFatalError(op.getLoc(), "op must have at least two variadic results "
788 "to use 'AttrSizedResultSegments' trait");
789 }
790
791 if (attrSizedResults && sameVariadicSize) {
792 PrintFatalError(op.getLoc(),
793 "op cannot have both 'AttrSizedResultSegments' and "
794 "'SameVariadicResultSize' traits");
795 }
796
797 generateValueRangeStartAndEnd(
798 opClass, "getODSResultIndexAndLength", numVariadicResults,
799 numNormalResults, "getOperation()->getNumResults()", attrSizedResults,
800 formatv(opSegmentSizeAttrInitCode, "result_segment_sizes").str(),
801 op.getResults());
802
803 auto *m = opClass.addMethodAndPrune("::mlir::Operation::result_range",
804 "getODSResults", "unsigned", "index");
805 m->body() << formatv(valueRangeReturnCode, "getOperation()->result_begin()",
806 "getODSResultIndexAndLength(index)");
807
808 for (int i = 0; i != numResults; ++i) {
809 const auto &result = op.getResult(i);
810 if (result.name.empty())
811 continue;
812
813 if (result.isOptional()) {
814 m = opClass.addMethodAndPrune("::mlir::Value", result.name);
815 m->body()
816 << " auto results = getODSResults(" << i << ");\n"
817 << " return results.empty() ? ::mlir::Value() : *results.begin();";
818 } else if (result.isVariadic()) {
819 m = opClass.addMethodAndPrune("::mlir::Operation::result_range",
820 result.name);
821 m->body() << " return getODSResults(" << i << ");";
822 } else {
823 m = opClass.addMethodAndPrune("::mlir::Value", result.name);
824 m->body() << " return *getODSResults(" << i << ").begin();";
825 }
826 }
827 }
828
genNamedRegionGetters()829 void OpEmitter::genNamedRegionGetters() {
830 unsigned numRegions = op.getNumRegions();
831 for (unsigned i = 0; i < numRegions; ++i) {
832 const auto ®ion = op.getRegion(i);
833 if (region.name.empty())
834 continue;
835
836 // Generate the accessors for a varidiadic region.
837 if (region.isVariadic()) {
838 auto *m = opClass.addMethodAndPrune("::mlir::MutableArrayRef<Region>",
839 region.name);
840 m->body() << formatv(
841 " return this->getOperation()->getRegions().drop_front({0});", i);
842 continue;
843 }
844
845 auto *m = opClass.addMethodAndPrune("::mlir::Region &", region.name);
846 m->body() << formatv(" return this->getOperation()->getRegion({0});", i);
847 }
848 }
849
genNamedSuccessorGetters()850 void OpEmitter::genNamedSuccessorGetters() {
851 unsigned numSuccessors = op.getNumSuccessors();
852 for (unsigned i = 0; i < numSuccessors; ++i) {
853 const NamedSuccessor &successor = op.getSuccessor(i);
854 if (successor.name.empty())
855 continue;
856
857 // Generate the accessors for a variadic successor list.
858 if (successor.isVariadic()) {
859 auto *m =
860 opClass.addMethodAndPrune("::mlir::SuccessorRange", successor.name);
861 m->body() << formatv(
862 " return {std::next(this->getOperation()->successor_begin(), {0}), "
863 "this->getOperation()->successor_end()};",
864 i);
865 continue;
866 }
867
868 auto *m = opClass.addMethodAndPrune("::mlir::Block *", successor.name);
869 m->body() << formatv(" return this->getOperation()->getSuccessor({0});",
870 i);
871 }
872 }
873
canGenerateUnwrappedBuilder(Operator & op)874 static bool canGenerateUnwrappedBuilder(Operator &op) {
875 // If this op does not have native attributes at all, return directly to avoid
876 // redefining builders.
877 if (op.getNumNativeAttributes() == 0)
878 return false;
879
880 bool canGenerate = false;
881 // We are generating builders that take raw values for attributes. We need to
882 // make sure the native attributes have a meaningful "unwrapped" value type
883 // different from the wrapped mlir::Attribute type to avoid redefining
884 // builders. This checks for the op has at least one such native attribute.
885 for (int i = 0, e = op.getNumNativeAttributes(); i < e; ++i) {
886 NamedAttribute &namedAttr = op.getAttribute(i);
887 if (canUseUnwrappedRawValue(namedAttr.attr)) {
888 canGenerate = true;
889 break;
890 }
891 }
892 return canGenerate;
893 }
894
canInferType(Operator & op)895 static bool canInferType(Operator &op) {
896 return op.getTrait("::mlir::InferTypeOpInterface::Trait") &&
897 op.getNumRegions() == 0;
898 }
899
genSeparateArgParamBuilder()900 void OpEmitter::genSeparateArgParamBuilder() {
901 SmallVector<AttrParamKind, 2> attrBuilderType;
902 attrBuilderType.push_back(AttrParamKind::WrappedAttr);
903 if (canGenerateUnwrappedBuilder(op))
904 attrBuilderType.push_back(AttrParamKind::UnwrappedValue);
905
906 // Emit with separate builders with or without unwrapped attributes and/or
907 // inferring result type.
908 auto emit = [&](AttrParamKind attrType, TypeParamKind paramKind,
909 bool inferType) {
910 llvm::SmallVector<OpMethodParameter, 4> paramList;
911 llvm::SmallVector<std::string, 4> resultNames;
912 buildParamList(paramList, resultNames, paramKind, attrType);
913
914 auto *m = opClass.addMethodAndPrune("void", "build", OpMethod::MP_Static,
915 std::move(paramList));
916 // If the builder is redundant, skip generating the method.
917 if (!m)
918 return;
919 auto &body = m->body();
920 genCodeForAddingArgAndRegionForBuilder(
921 body, /*isRawValueAttr=*/attrType == AttrParamKind::UnwrappedValue);
922
923 // Push all result types to the operation state
924
925 if (inferType) {
926 // Generate builder that infers type too.
927 // TODO: Subsume this with general checking if type can be
928 // inferred automatically.
929 // TODO: Expand to handle regions.
930 body << formatv(R"(
931 ::llvm::SmallVector<::mlir::Type, 2> inferredReturnTypes;
932 if (succeeded({0}::inferReturnTypes(odsBuilder.getContext(),
933 {1}.location, {1}.operands,
934 {1}.attributes.getDictionary({1}.getContext()),
935 /*regions=*/{{}, inferredReturnTypes)))
936 {1}.addTypes(inferredReturnTypes);
937 else
938 ::llvm::report_fatal_error("Failed to infer result type(s).");)",
939 opClass.getClassName(), builderOpState);
940 return;
941 }
942
943 switch (paramKind) {
944 case TypeParamKind::None:
945 return;
946 case TypeParamKind::Separate:
947 for (int i = 0, e = op.getNumResults(); i < e; ++i) {
948 if (op.getResult(i).isOptional())
949 body << " if (" << resultNames[i] << ")\n ";
950 body << " " << builderOpState << ".addTypes(" << resultNames[i]
951 << ");\n";
952 }
953 return;
954 case TypeParamKind::Collective: {
955 int numResults = op.getNumResults();
956 int numVariadicResults = op.getNumVariableLengthResults();
957 int numNonVariadicResults = numResults - numVariadicResults;
958 bool hasVariadicResult = numVariadicResults != 0;
959
960 // Avoid emitting "resultTypes.size() >= 0u" which is always true.
961 if (!(hasVariadicResult && numNonVariadicResults == 0))
962 body << " "
963 << "assert(resultTypes.size() "
964 << (hasVariadicResult ? ">=" : "==") << " "
965 << numNonVariadicResults
966 << "u && \"mismatched number of results\");\n";
967 body << " " << builderOpState << ".addTypes(resultTypes);\n";
968 }
969 return;
970 }
971 llvm_unreachable("unhandled TypeParamKind");
972 };
973
974 // Some of the build methods generated here may be ambiguous, but TableGen's
975 // ambiguous function detection will elide those ones.
976 for (auto attrType : attrBuilderType) {
977 emit(attrType, TypeParamKind::Separate, /*inferType=*/false);
978 if (canInferType(op))
979 emit(attrType, TypeParamKind::None, /*inferType=*/true);
980 emit(attrType, TypeParamKind::Collective, /*inferType=*/false);
981 }
982 }
983
genUseOperandAsResultTypeCollectiveParamBuilder()984 void OpEmitter::genUseOperandAsResultTypeCollectiveParamBuilder() {
985 int numResults = op.getNumResults();
986
987 // Signature
988 llvm::SmallVector<OpMethodParameter, 4> paramList;
989 paramList.emplace_back("::mlir::OpBuilder &", "odsBuilder");
990 paramList.emplace_back("::mlir::OperationState &", builderOpState);
991 paramList.emplace_back("::mlir::ValueRange", "operands");
992 // Provide default value for `attributes` when its the last parameter
993 StringRef attributesDefaultValue = op.getNumVariadicRegions() ? "" : "{}";
994 paramList.emplace_back("::llvm::ArrayRef<::mlir::NamedAttribute>",
995 "attributes", attributesDefaultValue);
996 if (op.getNumVariadicRegions())
997 paramList.emplace_back("unsigned", "numRegions");
998
999 auto *m = opClass.addMethodAndPrune("void", "build", OpMethod::MP_Static,
1000 std::move(paramList));
1001 // If the builder is redundant, skip generating the method
1002 if (!m)
1003 return;
1004 auto &body = m->body();
1005
1006 // Operands
1007 body << " " << builderOpState << ".addOperands(operands);\n";
1008
1009 // Attributes
1010 body << " " << builderOpState << ".addAttributes(attributes);\n";
1011
1012 // Create the correct number of regions
1013 if (int numRegions = op.getNumRegions()) {
1014 body << llvm::formatv(
1015 " for (unsigned i = 0; i != {0}; ++i)\n",
1016 (op.getNumVariadicRegions() ? "numRegions" : Twine(numRegions)));
1017 body << " (void)" << builderOpState << ".addRegion();\n";
1018 }
1019
1020 // Result types
1021 SmallVector<std::string, 2> resultTypes(numResults, "operands[0].getType()");
1022 body << " " << builderOpState << ".addTypes({"
1023 << llvm::join(resultTypes, ", ") << "});\n\n";
1024 }
1025
genInferredTypeCollectiveParamBuilder()1026 void OpEmitter::genInferredTypeCollectiveParamBuilder() {
1027 // TODO: Expand to support regions.
1028 SmallVector<OpMethodParameter, 4> paramList;
1029 paramList.emplace_back("::mlir::OpBuilder &", "odsBuilder");
1030 paramList.emplace_back("::mlir::OperationState &", builderOpState);
1031 paramList.emplace_back("::mlir::ValueRange", "operands");
1032 paramList.emplace_back("::llvm::ArrayRef<::mlir::NamedAttribute>",
1033 "attributes", "{}");
1034 auto *m = opClass.addMethodAndPrune("void", "build", OpMethod::MP_Static,
1035 std::move(paramList));
1036 // If the builder is redundant, skip generating the method
1037 if (!m)
1038 return;
1039 auto &body = m->body();
1040
1041 int numResults = op.getNumResults();
1042 int numVariadicResults = op.getNumVariableLengthResults();
1043 int numNonVariadicResults = numResults - numVariadicResults;
1044
1045 int numOperands = op.getNumOperands();
1046 int numVariadicOperands = op.getNumVariableLengthOperands();
1047 int numNonVariadicOperands = numOperands - numVariadicOperands;
1048
1049 // Operands
1050 if (numVariadicOperands == 0 || numNonVariadicOperands != 0)
1051 body << " assert(operands.size()"
1052 << (numVariadicOperands != 0 ? " >= " : " == ")
1053 << numNonVariadicOperands
1054 << "u && \"mismatched number of parameters\");\n";
1055 body << " " << builderOpState << ".addOperands(operands);\n";
1056 body << " " << builderOpState << ".addAttributes(attributes);\n";
1057
1058 // Create the correct number of regions
1059 if (int numRegions = op.getNumRegions()) {
1060 body << llvm::formatv(
1061 " for (unsigned i = 0; i != {0}; ++i)\n",
1062 (op.getNumVariadicRegions() ? "numRegions" : Twine(numRegions)));
1063 body << " (void)" << builderOpState << ".addRegion();\n";
1064 }
1065
1066 // Result types
1067 body << formatv(R"(
1068 ::mlir::SmallVector<::mlir::Type, 2> inferredReturnTypes;
1069 if (succeeded({0}::inferReturnTypes(odsBuilder.getContext(),
1070 {1}.location, operands,
1071 {1}.attributes.getDictionary({1}.getContext()),
1072 /*regions=*/{{}, inferredReturnTypes))) {{)",
1073 opClass.getClassName(), builderOpState);
1074 if (numVariadicResults == 0 || numNonVariadicResults != 0)
1075 body << " assert(inferredReturnTypes.size()"
1076 << (numVariadicResults != 0 ? " >= " : " == ") << numNonVariadicResults
1077 << "u && \"mismatched number of return types\");\n";
1078 body << " " << builderOpState << ".addTypes(inferredReturnTypes);";
1079
1080 body << formatv(R"(
1081 } else
1082 ::llvm::report_fatal_error("Failed to infer result type(s).");)",
1083 opClass.getClassName(), builderOpState);
1084 }
1085
genUseOperandAsResultTypeSeparateParamBuilder()1086 void OpEmitter::genUseOperandAsResultTypeSeparateParamBuilder() {
1087 llvm::SmallVector<OpMethodParameter, 4> paramList;
1088 llvm::SmallVector<std::string, 4> resultNames;
1089 buildParamList(paramList, resultNames, TypeParamKind::None);
1090
1091 auto *m = opClass.addMethodAndPrune("void", "build", OpMethod::MP_Static,
1092 std::move(paramList));
1093 // If the builder is redundant, skip generating the method
1094 if (!m)
1095 return;
1096 auto &body = m->body();
1097 genCodeForAddingArgAndRegionForBuilder(body);
1098
1099 auto numResults = op.getNumResults();
1100 if (numResults == 0)
1101 return;
1102
1103 // Push all result types to the operation state
1104 const char *index = op.getOperand(0).isVariadic() ? ".front()" : "";
1105 std::string resultType =
1106 formatv("{0}{1}.getType()", getArgumentName(op, 0), index).str();
1107 body << " " << builderOpState << ".addTypes({" << resultType;
1108 for (int i = 1; i != numResults; ++i)
1109 body << ", " << resultType;
1110 body << "});\n\n";
1111 }
1112
genUseAttrAsResultTypeBuilder()1113 void OpEmitter::genUseAttrAsResultTypeBuilder() {
1114 SmallVector<OpMethodParameter, 4> paramList;
1115 paramList.emplace_back("::mlir::OpBuilder &", "odsBuilder");
1116 paramList.emplace_back("::mlir::OperationState &", builderOpState);
1117 paramList.emplace_back("::mlir::ValueRange", "operands");
1118 paramList.emplace_back("::llvm::ArrayRef<::mlir::NamedAttribute>",
1119 "attributes", "{}");
1120 auto *m = opClass.addMethodAndPrune("void", "build", OpMethod::MP_Static,
1121 std::move(paramList));
1122 // If the builder is redundant, skip generating the method
1123 if (!m)
1124 return;
1125
1126 auto &body = m->body();
1127
1128 // Push all result types to the operation state
1129 std::string resultType;
1130 const auto &namedAttr = op.getAttribute(0);
1131
1132 body << " for (auto attr : attributes) {\n";
1133 body << " if (attr.first != \"" << namedAttr.name << "\") continue;\n";
1134 if (namedAttr.attr.isTypeAttr()) {
1135 resultType = "attr.second.cast<::mlir::TypeAttr>().getValue()";
1136 } else {
1137 resultType = "attr.second.getType()";
1138 }
1139
1140 // Operands
1141 body << " " << builderOpState << ".addOperands(operands);\n";
1142
1143 // Attributes
1144 body << " " << builderOpState << ".addAttributes(attributes);\n";
1145
1146 // Result types
1147 SmallVector<std::string, 2> resultTypes(op.getNumResults(), resultType);
1148 body << " " << builderOpState << ".addTypes({"
1149 << llvm::join(resultTypes, ", ") << "});\n";
1150 body << " }\n";
1151 }
1152
1153 /// Returns a signature of the builder as defined by a dag-typed initializer.
1154 /// Updates the context `fctx` to enable replacement of $_builder and $_state
1155 /// in the body. Reports errors at `loc`.
builderSignatureFromDAG(const DagInit * init,ArrayRef<llvm::SMLoc> loc,FmtContext & fctx)1156 static std::string builderSignatureFromDAG(const DagInit *init,
1157 ArrayRef<llvm::SMLoc> loc,
1158 FmtContext &fctx) {
1159 auto *defInit = dyn_cast<DefInit>(init->getOperator());
1160 if (!defInit || !defInit->getDef()->getName().equals("ins"))
1161 PrintFatalError(loc, "expected 'ins' in builders");
1162
1163 // Inject builder and state arguments.
1164 llvm::SmallVector<std::string, 8> arguments;
1165 arguments.reserve(init->getNumArgs() + 2);
1166 arguments.push_back(llvm::formatv("::mlir::OpBuilder &{0}", builder).str());
1167 arguments.push_back(
1168 llvm::formatv("::mlir::OperationState &{0}", builderOpState).str());
1169
1170 // Accept either a StringInit or a DefInit with two string values as dag
1171 // arguments. The former corresponds to the type, the latter to the type and
1172 // the default value. Similarly to C++, once an argument with a default value
1173 // is detected, the following arguments must have default values as well.
1174 bool seenDefaultValue = false;
1175 for (unsigned i = 0, e = init->getNumArgs(); i < e; ++i) {
1176 // If no name is provided, generate one.
1177 StringInit *argName = init->getArgName(i);
1178 std::string name =
1179 argName ? argName->getValue().str() : "odsArg" + std::to_string(i);
1180
1181 Init *argInit = init->getArg(i);
1182 StringRef type;
1183 std::string defaultValue;
1184 if (StringInit *strType = dyn_cast<StringInit>(argInit)) {
1185 type = strType->getValue();
1186 } else {
1187 const Record *typeAndDefaultValue = cast<DefInit>(argInit)->getDef();
1188 type = typeAndDefaultValue->getValueAsString("type");
1189 StringRef defaultValueRef =
1190 typeAndDefaultValue->getValueAsString("defaultValue");
1191 if (!defaultValueRef.empty()) {
1192 seenDefaultValue = true;
1193 defaultValue = llvm::formatv(" = {0}", defaultValueRef).str();
1194 }
1195 }
1196 if (seenDefaultValue && defaultValue.empty())
1197 PrintFatalError(loc,
1198 "expected an argument with default value after other "
1199 "arguments with default values");
1200 arguments.push_back(
1201 llvm::formatv("{0} {1}{2}", type, name, defaultValue).str());
1202 }
1203
1204 fctx.withBuilder(builder);
1205 fctx.addSubst("_state", builderOpState);
1206
1207 return llvm::join(arguments, ", ");
1208 }
1209
1210 // Returns a signature fo the builder as defined by a string initializer,
1211 // optionally injecting the builder and state arguments.
1212 // TODO: to be removed after the transition is complete.
builderSignatureFromString(StringRef params,FmtContext & fctx)1213 static std::string builderSignatureFromString(StringRef params,
1214 FmtContext &fctx) {
1215 bool skipParamGen = params.startswith("OpBuilder") ||
1216 params.startswith("mlir::OpBuilder") ||
1217 params.startswith("::mlir::OpBuilder");
1218 if (skipParamGen)
1219 return params.str();
1220
1221 fctx.withBuilder(builder);
1222 fctx.addSubst("_state", builderOpState);
1223 return std::string(llvm::formatv("::mlir::OpBuilder &{0}, "
1224 "::mlir::OperationState &{1}{2}{3}",
1225 builder, builderOpState,
1226 params.empty() ? "" : ", ", params));
1227 }
1228
genBuilder()1229 void OpEmitter::genBuilder() {
1230 // Handle custom builders if provided.
1231 // TODO: Create wrapper class for OpBuilder to hide the native
1232 // TableGen API calls here.
1233 {
1234 auto *listInit = dyn_cast_or_null<ListInit>(def.getValueInit("builders"));
1235 if (listInit) {
1236 for (Init *init : listInit->getValues()) {
1237 Record *builderDef = cast<DefInit>(init)->getDef();
1238 llvm::Optional<StringRef> params =
1239 builderDef->getValueAsOptionalString("params");
1240 FmtContext fctx;
1241 if (params.hasValue()) {
1242 PrintWarning(op.getLoc(),
1243 "Op uses a deprecated, string-based OpBuilder format; "
1244 "use OpBuilderDAG with '(ins <...>)' instead");
1245 }
1246 std::string paramStr =
1247 params.hasValue() ? builderSignatureFromString(params->trim(), fctx)
1248 : builderSignatureFromDAG(
1249 builderDef->getValueAsDag("dagParams"),
1250 op.getLoc(), fctx);
1251
1252 StringRef body = builderDef->getValueAsString("body");
1253 bool hasBody = !body.empty();
1254 OpMethod::Property properties =
1255 hasBody ? OpMethod::MP_Static : OpMethod::MP_StaticDeclaration;
1256 auto *method =
1257 opClass.addMethodAndPrune("void", "build", properties, paramStr);
1258 if (hasBody)
1259 method->body() << tgfmt(body, &fctx);
1260 }
1261 }
1262 if (op.skipDefaultBuilders()) {
1263 if (!listInit || listInit->empty())
1264 PrintFatalError(
1265 op.getLoc(),
1266 "default builders are skipped and no custom builders provided");
1267 return;
1268 }
1269 }
1270
1271 // Generate default builders that requires all result type, operands, and
1272 // attributes as parameters.
1273
1274 // We generate three classes of builders here:
1275 // 1. one having a stand-alone parameter for each operand / attribute, and
1276 genSeparateArgParamBuilder();
1277 // 2. one having an aggregated parameter for all result types / operands /
1278 // attributes, and
1279 genCollectiveParamBuilder();
1280 // 3. one having a stand-alone parameter for each operand and attribute,
1281 // use the first operand or attribute's type as all result types
1282 // to facilitate different call patterns.
1283 if (op.getNumVariableLengthResults() == 0) {
1284 if (op.getTrait("::mlir::OpTrait::SameOperandsAndResultType")) {
1285 genUseOperandAsResultTypeSeparateParamBuilder();
1286 genUseOperandAsResultTypeCollectiveParamBuilder();
1287 }
1288 if (op.getTrait("::mlir::OpTrait::FirstAttrDerivedResultType"))
1289 genUseAttrAsResultTypeBuilder();
1290 }
1291 }
1292
genCollectiveParamBuilder()1293 void OpEmitter::genCollectiveParamBuilder() {
1294 int numResults = op.getNumResults();
1295 int numVariadicResults = op.getNumVariableLengthResults();
1296 int numNonVariadicResults = numResults - numVariadicResults;
1297
1298 int numOperands = op.getNumOperands();
1299 int numVariadicOperands = op.getNumVariableLengthOperands();
1300 int numNonVariadicOperands = numOperands - numVariadicOperands;
1301
1302 SmallVector<OpMethodParameter, 4> paramList;
1303 paramList.emplace_back("::mlir::OpBuilder &", "");
1304 paramList.emplace_back("::mlir::OperationState &", builderOpState);
1305 paramList.emplace_back("::mlir::TypeRange", "resultTypes");
1306 paramList.emplace_back("::mlir::ValueRange", "operands");
1307 // Provide default value for `attributes` when its the last parameter
1308 StringRef attributesDefaultValue = op.getNumVariadicRegions() ? "" : "{}";
1309 paramList.emplace_back("::llvm::ArrayRef<::mlir::NamedAttribute>",
1310 "attributes", attributesDefaultValue);
1311 if (op.getNumVariadicRegions())
1312 paramList.emplace_back("unsigned", "numRegions");
1313
1314 auto *m = opClass.addMethodAndPrune("void", "build", OpMethod::MP_Static,
1315 std::move(paramList));
1316 // If the builder is redundant, skip generating the method
1317 if (!m)
1318 return;
1319 auto &body = m->body();
1320
1321 // Operands
1322 if (numVariadicOperands == 0 || numNonVariadicOperands != 0)
1323 body << " assert(operands.size()"
1324 << (numVariadicOperands != 0 ? " >= " : " == ")
1325 << numNonVariadicOperands
1326 << "u && \"mismatched number of parameters\");\n";
1327 body << " " << builderOpState << ".addOperands(operands);\n";
1328
1329 // Attributes
1330 body << " " << builderOpState << ".addAttributes(attributes);\n";
1331
1332 // Create the correct number of regions
1333 if (int numRegions = op.getNumRegions()) {
1334 body << llvm::formatv(
1335 " for (unsigned i = 0; i != {0}; ++i)\n",
1336 (op.getNumVariadicRegions() ? "numRegions" : Twine(numRegions)));
1337 body << " (void)" << builderOpState << ".addRegion();\n";
1338 }
1339
1340 // Result types
1341 if (numVariadicResults == 0 || numNonVariadicResults != 0)
1342 body << " assert(resultTypes.size()"
1343 << (numVariadicResults != 0 ? " >= " : " == ") << numNonVariadicResults
1344 << "u && \"mismatched number of return types\");\n";
1345 body << " " << builderOpState << ".addTypes(resultTypes);\n";
1346
1347 // Generate builder that infers type too.
1348 // TODO: Expand to handle regions and successors.
1349 if (canInferType(op) && op.getNumSuccessors() == 0)
1350 genInferredTypeCollectiveParamBuilder();
1351 }
1352
buildParamList(SmallVectorImpl<OpMethodParameter> & paramList,SmallVectorImpl<std::string> & resultTypeNames,TypeParamKind typeParamKind,AttrParamKind attrParamKind)1353 void OpEmitter::buildParamList(SmallVectorImpl<OpMethodParameter> ¶mList,
1354 SmallVectorImpl<std::string> &resultTypeNames,
1355 TypeParamKind typeParamKind,
1356 AttrParamKind attrParamKind) {
1357 resultTypeNames.clear();
1358 auto numResults = op.getNumResults();
1359 resultTypeNames.reserve(numResults);
1360
1361 paramList.emplace_back("::mlir::OpBuilder &", "odsBuilder");
1362 paramList.emplace_back("::mlir::OperationState &", builderOpState);
1363
1364 switch (typeParamKind) {
1365 case TypeParamKind::None:
1366 break;
1367 case TypeParamKind::Separate: {
1368 // Add parameters for all return types
1369 for (int i = 0; i < numResults; ++i) {
1370 const auto &result = op.getResult(i);
1371 std::string resultName = std::string(result.name);
1372 if (resultName.empty())
1373 resultName = std::string(formatv("resultType{0}", i));
1374
1375 StringRef type =
1376 result.isVariadic() ? "::mlir::TypeRange" : "::mlir::Type";
1377 OpMethodParameter::Property properties = OpMethodParameter::PP_None;
1378 if (result.isOptional())
1379 properties = OpMethodParameter::PP_Optional;
1380
1381 paramList.emplace_back(type, resultName, properties);
1382 resultTypeNames.emplace_back(std::move(resultName));
1383 }
1384 } break;
1385 case TypeParamKind::Collective: {
1386 paramList.emplace_back("::mlir::TypeRange", "resultTypes");
1387 resultTypeNames.push_back("resultTypes");
1388 } break;
1389 }
1390
1391 // Add parameters for all arguments (operands and attributes).
1392
1393 int numOperands = 0;
1394 int numAttrs = 0;
1395
1396 int defaultValuedAttrStartIndex = op.getNumArgs();
1397 if (attrParamKind == AttrParamKind::UnwrappedValue) {
1398 // Calculate the start index from which we can attach default values in the
1399 // builder declaration.
1400 for (int i = op.getNumArgs() - 1; i >= 0; --i) {
1401 auto *namedAttr = op.getArg(i).dyn_cast<tblgen::NamedAttribute *>();
1402 if (!namedAttr || !namedAttr->attr.hasDefaultValue())
1403 break;
1404
1405 if (!canUseUnwrappedRawValue(namedAttr->attr))
1406 break;
1407
1408 // Creating an APInt requires us to provide bitwidth, value, and
1409 // signedness, which is complicated compared to others. Similarly
1410 // for APFloat.
1411 // TODO: Adjust the 'returnType' field of such attributes
1412 // to support them.
1413 StringRef retType = namedAttr->attr.getReturnType();
1414 if (retType == "::llvm::APInt" || retType == "::llvm::APFloat")
1415 break;
1416
1417 defaultValuedAttrStartIndex = i;
1418 }
1419 }
1420
1421 for (int i = 0, e = op.getNumArgs(); i < e; ++i) {
1422 auto argument = op.getArg(i);
1423 if (argument.is<tblgen::NamedTypeConstraint *>()) {
1424 const auto &operand = op.getOperand(numOperands);
1425 StringRef type =
1426 operand.isVariadic() ? "::mlir::ValueRange" : "::mlir::Value";
1427 OpMethodParameter::Property properties = OpMethodParameter::PP_None;
1428 if (operand.isOptional())
1429 properties = OpMethodParameter::PP_Optional;
1430
1431 paramList.emplace_back(type, getArgumentName(op, numOperands),
1432 properties);
1433 ++numOperands;
1434 } else {
1435 const auto &namedAttr = op.getAttribute(numAttrs);
1436 const auto &attr = namedAttr.attr;
1437
1438 OpMethodParameter::Property properties = OpMethodParameter::PP_None;
1439 if (attr.isOptional())
1440 properties = OpMethodParameter::PP_Optional;
1441
1442 StringRef type;
1443 switch (attrParamKind) {
1444 case AttrParamKind::WrappedAttr:
1445 type = attr.getStorageType();
1446 break;
1447 case AttrParamKind::UnwrappedValue:
1448 if (canUseUnwrappedRawValue(attr))
1449 type = attr.getReturnType();
1450 else
1451 type = attr.getStorageType();
1452 break;
1453 }
1454
1455 std::string defaultValue;
1456 // Attach default value if requested and possible.
1457 if (attrParamKind == AttrParamKind::UnwrappedValue &&
1458 i >= defaultValuedAttrStartIndex) {
1459 bool isString = attr.getReturnType() == "::llvm::StringRef";
1460 if (isString)
1461 defaultValue.append("\"");
1462 defaultValue += attr.getDefaultValue();
1463 if (isString)
1464 defaultValue.append("\"");
1465 }
1466 paramList.emplace_back(type, namedAttr.name, defaultValue, properties);
1467 ++numAttrs;
1468 }
1469 }
1470
1471 /// Insert parameters for each successor.
1472 for (const NamedSuccessor &succ : op.getSuccessors()) {
1473 StringRef type =
1474 succ.isVariadic() ? "::mlir::BlockRange" : "::mlir::Block *";
1475 paramList.emplace_back(type, succ.name);
1476 }
1477
1478 /// Insert parameters for variadic regions.
1479 for (const NamedRegion ®ion : op.getRegions())
1480 if (region.isVariadic())
1481 paramList.emplace_back("unsigned",
1482 llvm::formatv("{0}Count", region.name).str());
1483 }
1484
genCodeForAddingArgAndRegionForBuilder(OpMethodBody & body,bool isRawValueAttr)1485 void OpEmitter::genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body,
1486 bool isRawValueAttr) {
1487 // Push all operands to the result.
1488 for (int i = 0, e = op.getNumOperands(); i < e; ++i) {
1489 std::string argName = getArgumentName(op, i);
1490 if (op.getOperand(i).isOptional())
1491 body << " if (" << argName << ")\n ";
1492 body << " " << builderOpState << ".addOperands(" << argName << ");\n";
1493 }
1494
1495 // If the operation has the operand segment size attribute, add it here.
1496 if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) {
1497 body << " " << builderOpState
1498 << ".addAttribute(\"operand_segment_sizes\", "
1499 "odsBuilder.getI32VectorAttr({";
1500 interleaveComma(llvm::seq<int>(0, op.getNumOperands()), body, [&](int i) {
1501 if (op.getOperand(i).isOptional())
1502 body << "(" << getArgumentName(op, i) << " ? 1 : 0)";
1503 else if (op.getOperand(i).isVariadic())
1504 body << "static_cast<int32_t>(" << getArgumentName(op, i) << ".size())";
1505 else
1506 body << "1";
1507 });
1508 body << "}));\n";
1509 }
1510
1511 // Push all attributes to the result.
1512 for (const auto &namedAttr : op.getAttributes()) {
1513 auto &attr = namedAttr.attr;
1514 if (!attr.isDerivedAttr()) {
1515 bool emitNotNullCheck = attr.isOptional();
1516 if (emitNotNullCheck) {
1517 body << formatv(" if ({0}) ", namedAttr.name) << "{\n";
1518 }
1519 if (isRawValueAttr && canUseUnwrappedRawValue(attr)) {
1520 // If this is a raw value, then we need to wrap it in an Attribute
1521 // instance.
1522 FmtContext fctx;
1523 fctx.withBuilder("odsBuilder");
1524
1525 std::string builderTemplate =
1526 std::string(attr.getConstBuilderTemplate());
1527
1528 // For StringAttr, its constant builder call will wrap the input in
1529 // quotes, which is correct for normal string literals, but incorrect
1530 // here given we use function arguments. So we need to strip the
1531 // wrapping quotes.
1532 if (StringRef(builderTemplate).contains("\"$0\""))
1533 builderTemplate = replaceAllSubstrs(builderTemplate, "\"$0\"", "$0");
1534
1535 std::string value =
1536 std::string(tgfmt(builderTemplate, &fctx, namedAttr.name));
1537 body << formatv(" {0}.addAttribute(\"{1}\", {2});\n", builderOpState,
1538 namedAttr.name, value);
1539 } else {
1540 body << formatv(" {0}.addAttribute(\"{1}\", {1});\n", builderOpState,
1541 namedAttr.name);
1542 }
1543 if (emitNotNullCheck) {
1544 body << " }\n";
1545 }
1546 }
1547 }
1548
1549 // Create the correct number of regions.
1550 for (const NamedRegion ®ion : op.getRegions()) {
1551 if (region.isVariadic())
1552 body << formatv(" for (unsigned i = 0; i < {0}Count; ++i)\n ",
1553 region.name);
1554
1555 body << " (void)" << builderOpState << ".addRegion();\n";
1556 }
1557
1558 // Push all successors to the result.
1559 for (const NamedSuccessor &namedSuccessor : op.getSuccessors()) {
1560 body << formatv(" {0}.addSuccessors({1});\n", builderOpState,
1561 namedSuccessor.name);
1562 }
1563 }
1564
genCanonicalizerDecls()1565 void OpEmitter::genCanonicalizerDecls() {
1566 if (!def.getValueAsBit("hasCanonicalizer"))
1567 return;
1568
1569 SmallVector<OpMethodParameter, 2> paramList;
1570 paramList.emplace_back("::mlir::OwningRewritePatternList &", "results");
1571 paramList.emplace_back("::mlir::MLIRContext *", "context");
1572 opClass.addMethodAndPrune("void", "getCanonicalizationPatterns",
1573 OpMethod::MP_StaticDeclaration,
1574 std::move(paramList));
1575 }
1576
genFolderDecls()1577 void OpEmitter::genFolderDecls() {
1578 bool hasSingleResult =
1579 op.getNumResults() == 1 && op.getNumVariableLengthResults() == 0;
1580
1581 if (def.getValueAsBit("hasFolder")) {
1582 if (hasSingleResult) {
1583 opClass.addMethodAndPrune(
1584 "::mlir::OpFoldResult", "fold", OpMethod::MP_Declaration,
1585 "::llvm::ArrayRef<::mlir::Attribute>", "operands");
1586 } else {
1587 SmallVector<OpMethodParameter, 2> paramList;
1588 paramList.emplace_back("::llvm::ArrayRef<::mlir::Attribute>", "operands");
1589 paramList.emplace_back("::llvm::SmallVectorImpl<::mlir::OpFoldResult> &",
1590 "results");
1591 opClass.addMethodAndPrune("::mlir::LogicalResult", "fold",
1592 OpMethod::MP_Declaration, std::move(paramList));
1593 }
1594 }
1595 }
1596
genOpInterfaceMethods(const tblgen::InterfaceOpTrait * opTrait)1597 void OpEmitter::genOpInterfaceMethods(const tblgen::InterfaceOpTrait *opTrait) {
1598 auto interface = opTrait->getOpInterface();
1599
1600 // Get the set of methods that should always be declared.
1601 auto alwaysDeclaredMethodsVec = opTrait->getAlwaysDeclaredMethods();
1602 llvm::StringSet<> alwaysDeclaredMethods;
1603 alwaysDeclaredMethods.insert(alwaysDeclaredMethodsVec.begin(),
1604 alwaysDeclaredMethodsVec.end());
1605
1606 for (const InterfaceMethod &method : interface.getMethods()) {
1607 // Don't declare if the method has a body.
1608 if (method.getBody())
1609 continue;
1610 // Don't declare if the method has a default implementation and the op
1611 // didn't request that it always be declared.
1612 if (method.getDefaultImplementation() &&
1613 !alwaysDeclaredMethods.count(method.getName()))
1614 continue;
1615 genOpInterfaceMethod(method);
1616 }
1617 }
1618
genOpInterfaceMethod(const InterfaceMethod & method,bool declaration)1619 OpMethod *OpEmitter::genOpInterfaceMethod(const InterfaceMethod &method,
1620 bool declaration) {
1621 SmallVector<OpMethodParameter, 4> paramList;
1622 for (const InterfaceMethod::Argument &arg : method.getArguments())
1623 paramList.emplace_back(arg.type, arg.name);
1624
1625 auto properties = method.isStatic() ? OpMethod::MP_Static : OpMethod::MP_None;
1626 if (declaration)
1627 properties =
1628 static_cast<OpMethod::Property>(properties | OpMethod::MP_Declaration);
1629 return opClass.addMethodAndPrune(method.getReturnType(), method.getName(),
1630 properties, std::move(paramList));
1631 }
1632
genOpInterfaceMethods()1633 void OpEmitter::genOpInterfaceMethods() {
1634 for (const auto &trait : op.getTraits()) {
1635 if (const auto *opTrait = dyn_cast<tblgen::InterfaceOpTrait>(&trait))
1636 if (opTrait->shouldDeclareMethods())
1637 genOpInterfaceMethods(opTrait);
1638 }
1639 }
1640
genSideEffectInterfaceMethods()1641 void OpEmitter::genSideEffectInterfaceMethods() {
1642 enum EffectKind { Operand, Result, Symbol, Static };
1643 struct EffectLocation {
1644 /// The effect applied.
1645 SideEffect effect;
1646
1647 /// The index if the kind is not static.
1648 unsigned index : 30;
1649
1650 /// The kind of the location.
1651 unsigned kind : 2;
1652 };
1653
1654 StringMap<SmallVector<EffectLocation, 1>> interfaceEffects;
1655 auto resolveDecorators = [&](Operator::var_decorator_range decorators,
1656 unsigned index, unsigned kind) {
1657 for (auto decorator : decorators)
1658 if (SideEffect *effect = dyn_cast<SideEffect>(&decorator)) {
1659 opClass.addTrait(effect->getInterfaceTrait());
1660 interfaceEffects[effect->getBaseEffectName()].push_back(
1661 EffectLocation{*effect, index, kind});
1662 }
1663 };
1664
1665 // Collect effects that were specified via:
1666 /// Traits.
1667 for (const auto &trait : op.getTraits()) {
1668 const auto *opTrait = dyn_cast<tblgen::SideEffectTrait>(&trait);
1669 if (!opTrait)
1670 continue;
1671 auto &effects = interfaceEffects[opTrait->getBaseEffectName()];
1672 for (auto decorator : opTrait->getEffects())
1673 effects.push_back(EffectLocation{cast<SideEffect>(decorator),
1674 /*index=*/0, EffectKind::Static});
1675 }
1676 /// Attributes and Operands.
1677 for (unsigned i = 0, operandIt = 0, e = op.getNumArgs(); i != e; ++i) {
1678 Argument arg = op.getArg(i);
1679 if (arg.is<NamedTypeConstraint *>()) {
1680 resolveDecorators(op.getArgDecorators(i), operandIt, EffectKind::Operand);
1681 ++operandIt;
1682 continue;
1683 }
1684 const NamedAttribute *attr = arg.get<NamedAttribute *>();
1685 if (attr->attr.getBaseAttr().isSymbolRefAttr())
1686 resolveDecorators(op.getArgDecorators(i), i, EffectKind::Symbol);
1687 }
1688 /// Results.
1689 for (unsigned i = 0, e = op.getNumResults(); i != e; ++i)
1690 resolveDecorators(op.getResultDecorators(i), i, EffectKind::Result);
1691
1692 // The code used to add an effect instance.
1693 // {0}: The effect class.
1694 // {1}: Optional value or symbol reference.
1695 // {1}: The resource class.
1696 const char *addEffectCode =
1697 " effects.emplace_back({0}::get(), {1}{2}::get());\n";
1698
1699 for (auto &it : interfaceEffects) {
1700 // Generate the 'getEffects' method.
1701 std::string type = llvm::formatv("::mlir::SmallVectorImpl<::mlir::"
1702 "SideEffects::EffectInstance<{0}>> &",
1703 it.first())
1704 .str();
1705 auto *getEffects =
1706 opClass.addMethodAndPrune("void", "getEffects", type, "effects");
1707 auto &body = getEffects->body();
1708
1709 // Add effect instances for each of the locations marked on the operation.
1710 for (auto &location : it.second) {
1711 StringRef effect = location.effect.getName();
1712 StringRef resource = location.effect.getResource();
1713 if (location.kind == EffectKind::Static) {
1714 // A static instance has no attached value.
1715 body << llvm::formatv(addEffectCode, effect, "", resource).str();
1716 } else if (location.kind == EffectKind::Symbol) {
1717 // A symbol reference requires adding the proper attribute.
1718 const auto *attr = op.getArg(location.index).get<NamedAttribute *>();
1719 if (attr->attr.isOptional()) {
1720 body << " if (auto symbolRef = " << attr->name << "Attr())\n "
1721 << llvm::formatv(addEffectCode, effect, "symbolRef, ", resource)
1722 .str();
1723 } else {
1724 body << llvm::formatv(addEffectCode, effect, attr->name + "(), ",
1725 resource)
1726 .str();
1727 }
1728 } else {
1729 // Otherwise this is an operand/result, so we need to attach the Value.
1730 body << " for (::mlir::Value value : getODS"
1731 << (location.kind == EffectKind::Operand ? "Operands" : "Results")
1732 << "(" << location.index << "))\n "
1733 << llvm::formatv(addEffectCode, effect, "value, ", resource).str();
1734 }
1735 }
1736 }
1737 }
1738
genTypeInterfaceMethods()1739 void OpEmitter::genTypeInterfaceMethods() {
1740 if (!op.allResultTypesKnown())
1741 return;
1742 // Generate 'inferReturnTypes' method declaration using the interface method
1743 // declared in 'InferTypeOpInterface' op interface.
1744 const auto *trait = dyn_cast<InterfaceOpTrait>(
1745 op.getTrait("::mlir::InferTypeOpInterface::Trait"));
1746 auto interface = trait->getOpInterface();
1747 OpMethod *method = [&]() -> OpMethod * {
1748 for (const InterfaceMethod &interfaceMethod : interface.getMethods()) {
1749 if (interfaceMethod.getName() == "inferReturnTypes") {
1750 return genOpInterfaceMethod(interfaceMethod, /*declaration=*/false);
1751 }
1752 }
1753 assert(0 && "unable to find inferReturnTypes interface method");
1754 return nullptr;
1755 }();
1756 auto &body = method->body();
1757 body << " inferredReturnTypes.resize(" << op.getNumResults() << ");\n";
1758
1759 FmtContext fctx;
1760 fctx.withBuilder("odsBuilder");
1761 body << " ::mlir::Builder odsBuilder(context);\n";
1762
1763 auto emitType =
1764 [&](const tblgen::Operator::ArgOrType &type) -> OpMethodBody & {
1765 if (type.isArg()) {
1766 auto argIndex = type.getArg();
1767 assert(!op.getArg(argIndex).is<NamedAttribute *>());
1768 auto arg = op.getArgToOperandOrAttribute(argIndex);
1769 if (arg.kind() == Operator::OperandOrAttribute::Kind::Operand)
1770 return body << "operands[" << arg.operandOrAttributeIndex()
1771 << "].getType()";
1772 return body << "attributes[" << arg.operandOrAttributeIndex()
1773 << "].getType()";
1774 } else {
1775 return body << tgfmt(*type.getType().getBuilderCall(), &fctx);
1776 }
1777 };
1778
1779 for (int i = 0, e = op.getNumResults(); i != e; ++i) {
1780 body << " inferredReturnTypes[" << i << "] = ";
1781 auto types = op.getSameTypeAsResult(i);
1782 emitType(types[0]) << ";\n";
1783 if (types.size() == 1)
1784 continue;
1785 // TODO: We could verify equality here, but skipping that for verification.
1786 }
1787 body << " return ::mlir::success();";
1788 }
1789
genParser()1790 void OpEmitter::genParser() {
1791 if (!hasStringAttribute(def, "parser") ||
1792 hasStringAttribute(def, "assemblyFormat"))
1793 return;
1794
1795 SmallVector<OpMethodParameter, 2> paramList;
1796 paramList.emplace_back("::mlir::OpAsmParser &", "parser");
1797 paramList.emplace_back("::mlir::OperationState &", "result");
1798 auto *method =
1799 opClass.addMethodAndPrune("::mlir::ParseResult", "parse",
1800 OpMethod::MP_Static, std::move(paramList));
1801
1802 FmtContext fctx;
1803 fctx.addSubst("cppClass", opClass.getClassName());
1804 auto parser = def.getValueAsString("parser").ltrim().rtrim(" \t\v\f\r");
1805 method->body() << " " << tgfmt(parser, &fctx);
1806 }
1807
genPrinter()1808 void OpEmitter::genPrinter() {
1809 if (hasStringAttribute(def, "assemblyFormat"))
1810 return;
1811
1812 auto valueInit = def.getValueInit("printer");
1813 StringInit *stringInit = dyn_cast<StringInit>(valueInit);
1814 if (!stringInit)
1815 return;
1816
1817 auto *method =
1818 opClass.addMethodAndPrune("void", "print", "::mlir::OpAsmPrinter &", "p");
1819 FmtContext fctx;
1820 fctx.addSubst("cppClass", opClass.getClassName());
1821 auto printer = stringInit->getValue().ltrim().rtrim(" \t\v\f\r");
1822 method->body() << " " << tgfmt(printer, &fctx);
1823 }
1824
genVerifier()1825 void OpEmitter::genVerifier() {
1826 auto *method = opClass.addMethodAndPrune("::mlir::LogicalResult", "verify");
1827 auto &body = method->body();
1828 body << " if (failed(" << op.getAdaptorName()
1829 << "(*this).verify(this->getLoc()))) "
1830 << "return ::mlir::failure();\n";
1831
1832 auto *valueInit = def.getValueInit("verifier");
1833 StringInit *stringInit = dyn_cast<StringInit>(valueInit);
1834 bool hasCustomVerify = stringInit && !stringInit->getValue().empty();
1835 populateSubstitutions(op, "this->getAttr", "this->getODSOperands",
1836 "this->getODSResults", verifyCtx);
1837
1838 genAttributeVerifier(op, "this->getAttr", "emitOpError(",
1839 /*emitVerificationRequiringOp=*/true, verifyCtx, body);
1840 genOperandResultVerifier(body, op.getOperands(), "operand");
1841 genOperandResultVerifier(body, op.getResults(), "result");
1842
1843 for (auto &trait : op.getTraits()) {
1844 if (auto *t = dyn_cast<tblgen::PredOpTrait>(&trait)) {
1845 body << tgfmt(" if (!($0))\n "
1846 "return emitOpError(\"failed to verify that $1\");\n",
1847 &verifyCtx, tgfmt(t->getPredTemplate(), &verifyCtx),
1848 t->getDescription());
1849 }
1850 }
1851
1852 genRegionVerifier(body);
1853 genSuccessorVerifier(body);
1854
1855 if (hasCustomVerify) {
1856 FmtContext fctx;
1857 fctx.addSubst("cppClass", opClass.getClassName());
1858 auto printer = stringInit->getValue().ltrim().rtrim(" \t\v\f\r");
1859 body << " " << tgfmt(printer, &fctx);
1860 } else {
1861 body << " return ::mlir::success();\n";
1862 }
1863 }
1864
genOperandResultVerifier(OpMethodBody & body,Operator::value_range values,StringRef valueKind)1865 void OpEmitter::genOperandResultVerifier(OpMethodBody &body,
1866 Operator::value_range values,
1867 StringRef valueKind) {
1868 FmtContext fctx;
1869
1870 body << " {\n";
1871 body << " unsigned index = 0; (void)index;\n";
1872
1873 for (auto staticValue : llvm::enumerate(values)) {
1874 bool hasPredicate = staticValue.value().hasPredicate();
1875 bool isOptional = staticValue.value().isOptional();
1876 if (!hasPredicate && !isOptional)
1877 continue;
1878 body << formatv(" auto valueGroup{2} = getODS{0}{1}s({2});\n",
1879 // Capitalize the first letter to match the function name
1880 valueKind.substr(0, 1).upper(), valueKind.substr(1),
1881 staticValue.index());
1882
1883 // If the constraint is optional check that the value group has at most 1
1884 // value.
1885 if (isOptional) {
1886 body << formatv(" if (valueGroup{0}.size() > 1)\n"
1887 " return emitOpError(\"{1} group starting at #\") "
1888 "<< index << \" requires 0 or 1 element, but found \" << "
1889 "valueGroup{0}.size();\n",
1890 staticValue.index(), valueKind);
1891 }
1892
1893 // Otherwise, if there is no predicate there is nothing left to do.
1894 if (!hasPredicate)
1895 continue;
1896
1897 // Emit a loop to check all the dynamic values in the pack.
1898 body << " for (::mlir::Value v : valueGroup" << staticValue.index()
1899 << ") {\n";
1900
1901 auto constraint = staticValue.value().constraint;
1902 body << " (void)v;\n"
1903 << " if (!("
1904 << tgfmt(constraint.getConditionTemplate(),
1905 &fctx.withSelf("v.getType()"))
1906 << ")) {\n"
1907 << formatv(" return emitOpError(\"{0} #\") << index "
1908 "<< \" must be {1}, but got \" << v.getType();\n",
1909 valueKind, constraint.getDescription())
1910 << " }\n" // if
1911 << " ++index;\n"
1912 << " }\n"; // for
1913 }
1914
1915 body << " }\n";
1916 }
1917
genRegionVerifier(OpMethodBody & body)1918 void OpEmitter::genRegionVerifier(OpMethodBody &body) {
1919 // If we have no regions, there is nothing more to do.
1920 unsigned numRegions = op.getNumRegions();
1921 if (numRegions == 0)
1922 return;
1923
1924 body << "{\n";
1925 body << " unsigned index = 0; (void)index;\n";
1926
1927 for (unsigned i = 0; i < numRegions; ++i) {
1928 const auto ®ion = op.getRegion(i);
1929 if (region.constraint.getPredicate().isNull())
1930 continue;
1931
1932 body << " for (::mlir::Region ®ion : ";
1933 body << formatv(region.isVariadic()
1934 ? "{0}()"
1935 : "::mlir::MutableArrayRef<::mlir::Region>(this->"
1936 "getOperation()->getRegion({1}))",
1937 region.name, i);
1938 body << ") {\n";
1939 auto constraint = tgfmt(region.constraint.getConditionTemplate(),
1940 &verifyCtx.withSelf("region"))
1941 .str();
1942
1943 body << formatv(" (void)region;\n"
1944 " if (!({0})) {\n "
1945 "return emitOpError(\"region #\") << index << \" {1}"
1946 "failed to "
1947 "verify constraint: {2}\";\n }\n",
1948 constraint,
1949 region.name.empty() ? "" : "('" + region.name + "') ",
1950 region.constraint.getDescription())
1951 << " ++index;\n"
1952 << " }\n";
1953 }
1954 body << " }\n";
1955 }
1956
genSuccessorVerifier(OpMethodBody & body)1957 void OpEmitter::genSuccessorVerifier(OpMethodBody &body) {
1958 // If we have no successors, there is nothing more to do.
1959 unsigned numSuccessors = op.getNumSuccessors();
1960 if (numSuccessors == 0)
1961 return;
1962
1963 body << "{\n";
1964 body << " unsigned index = 0; (void)index;\n";
1965
1966 for (unsigned i = 0; i < numSuccessors; ++i) {
1967 const auto &successor = op.getSuccessor(i);
1968 if (successor.constraint.getPredicate().isNull())
1969 continue;
1970
1971 if (successor.isVariadic()) {
1972 body << formatv(" for (::mlir::Block *successor : {0}()) {\n",
1973 successor.name);
1974 } else {
1975 body << " {\n";
1976 body << formatv(" ::mlir::Block *successor = {0}();\n",
1977 successor.name);
1978 }
1979 auto constraint = tgfmt(successor.constraint.getConditionTemplate(),
1980 &verifyCtx.withSelf("successor"))
1981 .str();
1982
1983 body << formatv(" (void)successor;\n"
1984 " if (!({0})) {\n "
1985 "return emitOpError(\"successor #\") << index << \"('{1}') "
1986 "failed to "
1987 "verify constraint: {2}\";\n }\n",
1988 constraint, successor.name,
1989 successor.constraint.getDescription())
1990 << " ++index;\n"
1991 << " }\n";
1992 }
1993 body << " }\n";
1994 }
1995
1996 /// Add a size count trait to the given operation class.
addSizeCountTrait(OpClass & opClass,StringRef traitKind,int numTotal,int numVariadic)1997 static void addSizeCountTrait(OpClass &opClass, StringRef traitKind,
1998 int numTotal, int numVariadic) {
1999 if (numVariadic != 0) {
2000 if (numTotal == numVariadic)
2001 opClass.addTrait("::mlir::OpTrait::Variadic" + traitKind + "s");
2002 else
2003 opClass.addTrait("::mlir::OpTrait::AtLeastN" + traitKind + "s<" +
2004 Twine(numTotal - numVariadic) + ">::Impl");
2005 return;
2006 }
2007 switch (numTotal) {
2008 case 0:
2009 opClass.addTrait("::mlir::OpTrait::Zero" + traitKind);
2010 break;
2011 case 1:
2012 opClass.addTrait("::mlir::OpTrait::One" + traitKind);
2013 break;
2014 default:
2015 opClass.addTrait("::mlir::OpTrait::N" + traitKind + "s<" + Twine(numTotal) +
2016 ">::Impl");
2017 break;
2018 }
2019 }
2020
genTraits()2021 void OpEmitter::genTraits() {
2022 // Add region size trait.
2023 unsigned numRegions = op.getNumRegions();
2024 unsigned numVariadicRegions = op.getNumVariadicRegions();
2025 addSizeCountTrait(opClass, "Region", numRegions, numVariadicRegions);
2026
2027 // Add result size trait.
2028 int numResults = op.getNumResults();
2029 int numVariadicResults = op.getNumVariableLengthResults();
2030 addSizeCountTrait(opClass, "Result", numResults, numVariadicResults);
2031
2032 // Add successor size trait.
2033 unsigned numSuccessors = op.getNumSuccessors();
2034 unsigned numVariadicSuccessors = op.getNumVariadicSuccessors();
2035 addSizeCountTrait(opClass, "Successor", numSuccessors, numVariadicSuccessors);
2036
2037 // Add variadic size trait and normal op traits.
2038 int numOperands = op.getNumOperands();
2039 int numVariadicOperands = op.getNumVariableLengthOperands();
2040
2041 // Add operand size trait.
2042 if (numVariadicOperands != 0) {
2043 if (numOperands == numVariadicOperands)
2044 opClass.addTrait("::mlir::OpTrait::VariadicOperands");
2045 else
2046 opClass.addTrait("::mlir::OpTrait::AtLeastNOperands<" +
2047 Twine(numOperands - numVariadicOperands) + ">::Impl");
2048 } else {
2049 switch (numOperands) {
2050 case 0:
2051 opClass.addTrait("::mlir::OpTrait::ZeroOperands");
2052 break;
2053 case 1:
2054 opClass.addTrait("::mlir::OpTrait::OneOperand");
2055 break;
2056 default:
2057 opClass.addTrait("::mlir::OpTrait::NOperands<" + Twine(numOperands) +
2058 ">::Impl");
2059 break;
2060 }
2061 }
2062
2063 // Add the native and interface traits.
2064 for (const auto &trait : op.getTraits()) {
2065 if (auto opTrait = dyn_cast<tblgen::NativeOpTrait>(&trait))
2066 opClass.addTrait(opTrait->getTrait());
2067 else if (auto opTrait = dyn_cast<tblgen::InterfaceOpTrait>(&trait))
2068 opClass.addTrait(opTrait->getTrait());
2069 }
2070 }
2071
genOpNameGetter()2072 void OpEmitter::genOpNameGetter() {
2073 auto *method = opClass.addMethodAndPrune(
2074 "::llvm::StringRef", "getOperationName", OpMethod::MP_Static);
2075 method->body() << " return \"" << op.getOperationName() << "\";\n";
2076 }
2077
genOpAsmInterface()2078 void OpEmitter::genOpAsmInterface() {
2079 // If the user only has one results or specifically added the Asm trait,
2080 // then don't generate it for them. We specifically only handle multi result
2081 // operations, because the name of a single result in the common case is not
2082 // interesting(generally 'result'/'output'/etc.).
2083 // TODO: We could also add a flag to allow operations to opt in to this
2084 // generation, even if they only have a single operation.
2085 int numResults = op.getNumResults();
2086 if (numResults <= 1 || op.getTrait("::mlir::OpAsmOpInterface::Trait"))
2087 return;
2088
2089 SmallVector<StringRef, 4> resultNames(numResults);
2090 for (int i = 0; i != numResults; ++i)
2091 resultNames[i] = op.getResultName(i);
2092
2093 // Don't add the trait if none of the results have a valid name.
2094 if (llvm::all_of(resultNames, [](StringRef name) { return name.empty(); }))
2095 return;
2096 opClass.addTrait("::mlir::OpAsmOpInterface::Trait");
2097
2098 // Generate the right accessor for the number of results.
2099 auto *method = opClass.addMethodAndPrune(
2100 "void", "getAsmResultNames", "::mlir::OpAsmSetValueNameFn", "setNameFn");
2101 auto &body = method->body();
2102 for (int i = 0; i != numResults; ++i) {
2103 body << " auto resultGroup" << i << " = getODSResults(" << i << ");\n"
2104 << " if (!llvm::empty(resultGroup" << i << "))\n"
2105 << " setNameFn(*resultGroup" << i << ".begin(), \""
2106 << resultNames[i] << "\");\n";
2107 }
2108 }
2109
2110 //===----------------------------------------------------------------------===//
2111 // OpOperandAdaptor emitter
2112 //===----------------------------------------------------------------------===//
2113
2114 namespace {
2115 // Helper class to emit Op operand adaptors to an output stream. Operand
2116 // adaptors are wrappers around ArrayRef<Value> that provide named operand
2117 // getters identical to those defined in the Op.
2118 class OpOperandAdaptorEmitter {
2119 public:
2120 static void emitDecl(const Operator &op, raw_ostream &os);
2121 static void emitDef(const Operator &op, raw_ostream &os);
2122
2123 private:
2124 explicit OpOperandAdaptorEmitter(const Operator &op);
2125
2126 // Add verification function. This generates a verify method for the adaptor
2127 // which verifies all the op-independent attribute constraints.
2128 void addVerification();
2129
2130 const Operator &op;
2131 Class adaptor;
2132 };
2133 } // end namespace
2134
OpOperandAdaptorEmitter(const Operator & op)2135 OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(const Operator &op)
2136 : op(op), adaptor(op.getAdaptorName()) {
2137 adaptor.newField("::mlir::ValueRange", "odsOperands");
2138 adaptor.newField("::mlir::DictionaryAttr", "odsAttrs");
2139 const auto *attrSizedOperands =
2140 op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments");
2141 {
2142 SmallVector<OpMethodParameter, 2> paramList;
2143 paramList.emplace_back("::mlir::ValueRange", "values");
2144 paramList.emplace_back("::mlir::DictionaryAttr", "attrs",
2145 attrSizedOperands ? "" : "nullptr");
2146 auto *constructor = adaptor.addConstructorAndPrune(std::move(paramList));
2147
2148 constructor->addMemberInitializer("odsOperands", "values");
2149 constructor->addMemberInitializer("odsAttrs", "attrs");
2150 }
2151
2152 {
2153 auto *constructor = adaptor.addConstructorAndPrune(
2154 llvm::formatv("{0}&", op.getCppClassName()).str(), "op");
2155 constructor->addMemberInitializer("odsOperands", "op->getOperands()");
2156 constructor->addMemberInitializer("odsAttrs", "op->getAttrDictionary()");
2157 }
2158
2159 std::string sizeAttrInit =
2160 formatv(adapterSegmentSizeAttrInitCode, "operand_segment_sizes");
2161 generateNamedOperandGetters(op, adaptor, sizeAttrInit,
2162 /*rangeType=*/"::mlir::ValueRange",
2163 /*rangeBeginCall=*/"odsOperands.begin()",
2164 /*rangeSizeCall=*/"odsOperands.size()",
2165 /*getOperandCallPattern=*/"odsOperands[{0}]");
2166
2167 FmtContext fctx;
2168 fctx.withBuilder("::mlir::Builder(odsAttrs.getContext())");
2169
2170 auto emitAttr = [&](StringRef name, Attribute attr) {
2171 auto &body = adaptor.addMethodAndPrune(attr.getStorageType(), name)->body();
2172 body << " assert(odsAttrs && \"no attributes when constructing adapter\");"
2173 << "\n " << attr.getStorageType() << " attr = "
2174 << "odsAttrs.get(\"" << name << "\").";
2175 if (attr.hasDefaultValue() || attr.isOptional())
2176 body << "dyn_cast_or_null<";
2177 else
2178 body << "cast<";
2179 body << attr.getStorageType() << ">();\n";
2180
2181 if (attr.hasDefaultValue()) {
2182 // Use the default value if attribute is not set.
2183 // TODO: this is inefficient, we are recreating the attribute for every
2184 // call. This should be set instead.
2185 std::string defaultValue = std::string(
2186 tgfmt(attr.getConstBuilderTemplate(), &fctx, attr.getDefaultValue()));
2187 body << " if (!attr)\n attr = " << defaultValue << ";\n";
2188 }
2189 body << " return attr;\n";
2190 };
2191
2192 for (auto &namedAttr : op.getAttributes()) {
2193 const auto &name = namedAttr.name;
2194 const auto &attr = namedAttr.attr;
2195 if (!attr.isDerivedAttr())
2196 emitAttr(name, attr);
2197 }
2198
2199 // Add verification function.
2200 addVerification();
2201 }
2202
addVerification()2203 void OpOperandAdaptorEmitter::addVerification() {
2204 auto *method = adaptor.addMethodAndPrune("::mlir::LogicalResult", "verify",
2205 "::mlir::Location", "loc");
2206 auto &body = method->body();
2207
2208 const char *checkAttrSizedValueSegmentsCode = R"(
2209 {
2210 auto sizeAttr = odsAttrs.get("{0}").cast<::mlir::DenseIntElementsAttr>();
2211 auto numElements = sizeAttr.getType().cast<::mlir::ShapedType>().getNumElements();
2212 if (numElements != {1})
2213 return emitError(loc, "'{0}' attribute for specifying {2} segments "
2214 "must have {1} elements");
2215 }
2216 )";
2217
2218 // Verify a few traits first so that we can use
2219 // getODSOperands()/getODSResults() in the rest of the verifier.
2220 for (auto &trait : op.getTraits()) {
2221 if (auto *t = dyn_cast<tblgen::NativeOpTrait>(&trait)) {
2222 if (t->getTrait() == "::mlir::OpTrait::AttrSizedOperandSegments") {
2223 body << formatv(checkAttrSizedValueSegmentsCode,
2224 "operand_segment_sizes", op.getNumOperands(),
2225 "operand");
2226 } else if (t->getTrait() == "::mlir::OpTrait::AttrSizedResultSegments") {
2227 body << formatv(checkAttrSizedValueSegmentsCode, "result_segment_sizes",
2228 op.getNumResults(), "result");
2229 }
2230 }
2231 }
2232
2233 FmtContext verifyCtx;
2234 populateSubstitutions(op, "odsAttrs.get", "getODSOperands",
2235 "<no results should be genarated>", verifyCtx);
2236 genAttributeVerifier(op, "odsAttrs.get",
2237 Twine("emitError(loc, \"'") + op.getOperationName() +
2238 "' op \"",
2239 /*emitVerificationRequiringOp*/ false, verifyCtx, body);
2240
2241 body << " return ::mlir::success();";
2242 }
2243
emitDecl(const Operator & op,raw_ostream & os)2244 void OpOperandAdaptorEmitter::emitDecl(const Operator &op, raw_ostream &os) {
2245 OpOperandAdaptorEmitter(op).adaptor.writeDeclTo(os);
2246 }
2247
emitDef(const Operator & op,raw_ostream & os)2248 void OpOperandAdaptorEmitter::emitDef(const Operator &op, raw_ostream &os) {
2249 OpOperandAdaptorEmitter(op).adaptor.writeDefTo(os);
2250 }
2251
2252 // Emits the opcode enum and op classes.
emitOpClasses(const std::vector<Record * > & defs,raw_ostream & os,bool emitDecl)2253 static void emitOpClasses(const std::vector<Record *> &defs, raw_ostream &os,
2254 bool emitDecl) {
2255 // First emit forward declaration for each class, this allows them to refer
2256 // to each others in traits for example.
2257 if (emitDecl) {
2258 os << "#if defined(GET_OP_CLASSES) || defined(GET_OP_FWD_DEFINES)\n";
2259 os << "#undef GET_OP_FWD_DEFINES\n";
2260 for (auto *def : defs) {
2261 Operator op(*def);
2262 NamespaceEmitter emitter(os, op.getDialect());
2263 os << "class " << op.getCppClassName() << ";\n";
2264 }
2265 os << "#endif\n\n";
2266 }
2267
2268 IfDefScope scope("GET_OP_CLASSES", os);
2269 for (auto *def : defs) {
2270 Operator op(*def);
2271 NamespaceEmitter emitter(os, op.getDialect());
2272 if (emitDecl) {
2273 os << formatv(opCommentHeader, op.getQualCppClassName(), "declarations");
2274 OpOperandAdaptorEmitter::emitDecl(op, os);
2275 OpEmitter::emitDecl(op, os);
2276 } else {
2277 os << formatv(opCommentHeader, op.getQualCppClassName(), "definitions");
2278 OpOperandAdaptorEmitter::emitDef(op, os);
2279 OpEmitter::emitDef(op, os);
2280 }
2281 }
2282 }
2283
2284 // Emits a comma-separated list of the ops.
emitOpList(const std::vector<Record * > & defs,raw_ostream & os)2285 static void emitOpList(const std::vector<Record *> &defs, raw_ostream &os) {
2286 IfDefScope scope("GET_OP_LIST", os);
2287
2288 interleave(
2289 // TODO: We are constructing the Operator wrapper instance just for
2290 // getting it's qualified class name here. Reduce the overhead by having a
2291 // lightweight version of Operator class just for that purpose.
2292 defs, [&os](Record *def) { os << Operator(def).getQualCppClassName(); },
2293 [&os]() { os << ",\n"; });
2294 }
2295
getOperationName(const Record & def)2296 static std::string getOperationName(const Record &def) {
2297 auto prefix = def.getValueAsDef("opDialect")->getValueAsString("name");
2298 auto opName = def.getValueAsString("opName");
2299 if (prefix.empty())
2300 return std::string(opName);
2301 return std::string(llvm::formatv("{0}.{1}", prefix, opName));
2302 }
2303
2304 static std::vector<Record *>
getAllDerivedDefinitions(const RecordKeeper & recordKeeper,StringRef className)2305 getAllDerivedDefinitions(const RecordKeeper &recordKeeper,
2306 StringRef className) {
2307 Record *classDef = recordKeeper.getClass(className);
2308 if (!classDef)
2309 PrintFatalError("ERROR: Couldn't find the `" + className + "' class!\n");
2310
2311 llvm::Regex includeRegex(opIncFilter), excludeRegex(opExcFilter);
2312 std::vector<Record *> defs;
2313 for (const auto &def : recordKeeper.getDefs()) {
2314 if (!def.second->isSubClassOf(classDef))
2315 continue;
2316 // Include if no include filter or include filter matches.
2317 if (!opIncFilter.empty() &&
2318 !includeRegex.match(getOperationName(*def.second)))
2319 continue;
2320 // Unless there is an exclude filter and it matches.
2321 if (!opExcFilter.empty() &&
2322 excludeRegex.match(getOperationName(*def.second)))
2323 continue;
2324 defs.push_back(def.second.get());
2325 }
2326
2327 return defs;
2328 }
2329
emitOpDecls(const RecordKeeper & recordKeeper,raw_ostream & os)2330 static bool emitOpDecls(const RecordKeeper &recordKeeper, raw_ostream &os) {
2331 emitSourceFileHeader("Op Declarations", os);
2332
2333 const auto &defs = getAllDerivedDefinitions(recordKeeper, "Op");
2334 emitOpClasses(defs, os, /*emitDecl=*/true);
2335
2336 return false;
2337 }
2338
emitOpDefs(const RecordKeeper & recordKeeper,raw_ostream & os)2339 static bool emitOpDefs(const RecordKeeper &recordKeeper, raw_ostream &os) {
2340 emitSourceFileHeader("Op Definitions", os);
2341
2342 const auto &defs = getAllDerivedDefinitions(recordKeeper, "Op");
2343 emitOpList(defs, os);
2344 emitOpClasses(defs, os, /*emitDecl=*/false);
2345
2346 return false;
2347 }
2348
2349 static mlir::GenRegistration
2350 genOpDecls("gen-op-decls", "Generate op declarations",
__anonec93f51b1402(const RecordKeeper &records, raw_ostream &os) 2351 [](const RecordKeeper &records, raw_ostream &os) {
2352 return emitOpDecls(records, os);
2353 });
2354
2355 static mlir::GenRegistration genOpDefs("gen-op-defs", "Generate op definitions",
2356 [](const RecordKeeper &records,
__anonec93f51b1502(const RecordKeeper &records, raw_ostream &os) 2357 raw_ostream &os) {
2358 return emitOpDefs(records, os);
2359 });
2360