• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //===- OpDefinitionsGen.cpp - MLIR op definitions generator ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // OpDefinitionsGen uses the description of operations to generate C++
10 // definitions for ops.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "OpFormatGen.h"
15 #include "mlir/TableGen/CodeGenHelpers.h"
16 #include "mlir/TableGen/Format.h"
17 #include "mlir/TableGen/GenInfo.h"
18 #include "mlir/TableGen/Interfaces.h"
19 #include "mlir/TableGen/OpClass.h"
20 #include "mlir/TableGen/OpTrait.h"
21 #include "mlir/TableGen/Operator.h"
22 #include "mlir/TableGen/SideEffects.h"
23 #include "llvm/ADT/Sequence.h"
24 #include "llvm/ADT/StringExtras.h"
25 #include "llvm/Support/CommandLine.h"
26 #include "llvm/Support/Regex.h"
27 #include "llvm/Support/Signals.h"
28 #include "llvm/TableGen/Error.h"
29 #include "llvm/TableGen/Record.h"
30 #include "llvm/TableGen/TableGenBackend.h"
31 
32 #define DEBUG_TYPE "mlir-tblgen-opdefgen"
33 
34 using namespace llvm;
35 using namespace mlir;
36 using namespace mlir::tblgen;
37 
38 cl::OptionCategory opDefGenCat("Options for -gen-op-defs and -gen-op-decls");
39 
40 static cl::opt<std::string> opIncFilter(
41     "op-include-regex",
42     cl::desc("Regex of name of op's to include (no filter if empty)"),
43     cl::cat(opDefGenCat));
44 static cl::opt<std::string> opExcFilter(
45     "op-exclude-regex",
46     cl::desc("Regex of name of op's to exclude (no filter if empty)"),
47     cl::cat(opDefGenCat));
48 
49 static const char *const tblgenNamePrefix = "tblgen_";
50 static const char *const generatedArgName = "odsArg";
51 static const char *const builder = "odsBuilder";
52 static const char *const builderOpState = "odsState";
53 
54 // The logic to calculate the actual value range for a declared operand/result
55 // of an op with variadic operands/results. Note that this logic is not for
56 // general use; it assumes all variadic operands/results must have the same
57 // number of values.
58 //
59 // {0}: The list of whether each declared operand/result is variadic.
60 // {1}: The total number of non-variadic operands/results.
61 // {2}: The total number of variadic operands/results.
62 // {3}: The total number of actual values.
63 // {4}: "operand" or "result".
64 const char *sameVariadicSizeValueRangeCalcCode = R"(
65   bool isVariadic[] = {{{0}};
66   int prevVariadicCount = 0;
67   for (unsigned i = 0; i < index; ++i)
68     if (isVariadic[i]) ++prevVariadicCount;
69 
70   // Calculate how many dynamic values a static variadic {4} corresponds to.
71   // This assumes all static variadic {4}s have the same dynamic value count.
72   int variadicSize = ({3} - {1}) / {2};
73   // `index` passed in as the parameter is the static index which counts each
74   // {4} (variadic or not) as size 1. So here for each previous static variadic
75   // {4}, we need to offset by (variadicSize - 1) to get where the dynamic
76   // value pack for this static {4} starts.
77   int start = index + (variadicSize - 1) * prevVariadicCount;
78   int size = isVariadic[index] ? variadicSize : 1;
79   return {{start, size};
80 )";
81 
82 // The logic to calculate the actual value range for a declared operand/result
83 // of an op with variadic operands/results. Note that this logic is assumes
84 // the op has an attribute specifying the size of each operand/result segment
85 // (variadic or not).
86 //
87 // {0}: The name of the attribute specifying the segment sizes.
88 const char *adapterSegmentSizeAttrInitCode = R"(
89   assert(odsAttrs && "missing segment size attribute for op");
90   auto sizeAttr = odsAttrs.get("{0}").cast<::mlir::DenseIntElementsAttr>();
91 )";
92 const char *opSegmentSizeAttrInitCode = R"(
93   auto sizeAttr =
94       getOperation()->getAttrOfType<::mlir::DenseIntElementsAttr>("{0}");
95 )";
96 const char *attrSizedSegmentValueRangeCalcCode = R"(
97   unsigned start = 0;
98   for (unsigned i = 0; i < index; ++i)
99     start += (*(sizeAttr.begin() + i)).getZExtValue();
100   unsigned size = (*(sizeAttr.begin() + index)).getZExtValue();
101   return {start, size};
102 )";
103 
104 // The logic to build a range of either operand or result values.
105 //
106 // {0}: The begin iterator of the actual values.
107 // {1}: The call to generate the start and length of the value range.
108 const char *valueRangeReturnCode = R"(
109   auto valueRange = {1};
110   return {{std::next({0}, valueRange.first),
111            std::next({0}, valueRange.first + valueRange.second)};
112 )";
113 
114 static const char *const opCommentHeader = R"(
115 //===----------------------------------------------------------------------===//
116 // {0} {1}
117 //===----------------------------------------------------------------------===//
118 
119 )";
120 
121 //===----------------------------------------------------------------------===//
122 // Utility structs and functions
123 //===----------------------------------------------------------------------===//
124 
125 // Replaces all occurrences of `match` in `str` with `substitute`.
replaceAllSubstrs(std::string str,const std::string & match,const std::string & substitute)126 static std::string replaceAllSubstrs(std::string str, const std::string &match,
127                                      const std::string &substitute) {
128   std::string::size_type scanLoc = 0, matchLoc = std::string::npos;
129   while ((matchLoc = str.find(match, scanLoc)) != std::string::npos) {
130     str = str.replace(matchLoc, match.size(), substitute);
131     scanLoc = matchLoc + substitute.size();
132   }
133   return str;
134 }
135 
136 // Returns whether the record has a value of the given name that can be returned
137 // via getValueAsString.
hasStringAttribute(const Record & record,StringRef fieldName)138 static inline bool hasStringAttribute(const Record &record,
139                                       StringRef fieldName) {
140   auto valueInit = record.getValueInit(fieldName);
141   return isa<StringInit>(valueInit);
142 }
143 
getArgumentName(const Operator & op,int index)144 static std::string getArgumentName(const Operator &op, int index) {
145   const auto &operand = op.getOperand(index);
146   if (!operand.name.empty())
147     return std::string(operand.name);
148   else
149     return std::string(formatv("{0}_{1}", generatedArgName, index));
150 }
151 
152 // Returns true if we can use unwrapped value for the given `attr` in builders.
canUseUnwrappedRawValue(const tblgen::Attribute & attr)153 static bool canUseUnwrappedRawValue(const tblgen::Attribute &attr) {
154   return attr.getReturnType() != attr.getStorageType() &&
155          // We need to wrap the raw value into an attribute in the builder impl
156          // so we need to make sure that the attribute specifies how to do that.
157          !attr.getConstBuilderTemplate().empty();
158 }
159 
160 //===----------------------------------------------------------------------===//
161 // Op emitter
162 //===----------------------------------------------------------------------===//
163 
164 namespace {
165 // Helper class to emit a record into the given output stream.
166 class OpEmitter {
167 public:
168   static void emitDecl(const Operator &op, raw_ostream &os);
169   static void emitDef(const Operator &op, raw_ostream &os);
170 
171 private:
172   OpEmitter(const Operator &op);
173 
174   void emitDecl(raw_ostream &os);
175   void emitDef(raw_ostream &os);
176 
177   // Generates the OpAsmOpInterface for this operation if possible.
178   void genOpAsmInterface();
179 
180   // Generates the `getOperationName` method for this op.
181   void genOpNameGetter();
182 
183   // Generates getters for the attributes.
184   void genAttrGetters();
185 
186   // Generates setter for the attributes.
187   void genAttrSetters();
188 
189   // Generates getters for named operands.
190   void genNamedOperandGetters();
191 
192   // Generates setters for named operands.
193   void genNamedOperandSetters();
194 
195   // Generates getters for named results.
196   void genNamedResultGetters();
197 
198   // Generates getters for named regions.
199   void genNamedRegionGetters();
200 
201   // Generates getters for named successors.
202   void genNamedSuccessorGetters();
203 
204   // Generates builder methods for the operation.
205   void genBuilder();
206 
207   // Generates the build() method that takes each operand/attribute
208   // as a stand-alone parameter.
209   void genSeparateArgParamBuilder();
210 
211   // Generates the build() method that takes each operand/attribute as a
212   // stand-alone parameter. The generated build() method uses first operand's
213   // type as all results' types.
214   void genUseOperandAsResultTypeSeparateParamBuilder();
215 
216   // Generates the build() method that takes all operands/attributes
217   // collectively as one parameter. The generated build() method uses first
218   // operand's type as all results' types.
219   void genUseOperandAsResultTypeCollectiveParamBuilder();
220 
221   // Generates the build() method that takes aggregate operands/attributes
222   // parameters. This build() method uses inferred types as result types.
223   // Requires: The type needs to be inferable via InferTypeOpInterface.
224   void genInferredTypeCollectiveParamBuilder();
225 
226   // Generates the build() method that takes each operand/attribute as a
227   // stand-alone parameter. The generated build() method uses first attribute's
228   // type as all result's types.
229   void genUseAttrAsResultTypeBuilder();
230 
231   // Generates the build() method that takes all result types collectively as
232   // one parameter. Similarly for operands and attributes.
233   void genCollectiveParamBuilder();
234 
235   // The kind of parameter to generate for result types in builders.
236   enum class TypeParamKind {
237     None,       // No result type in parameter list.
238     Separate,   // A separate parameter for each result type.
239     Collective, // An ArrayRef<Type> for all result types.
240   };
241 
242   // The kind of parameter to generate for attributes in builders.
243   enum class AttrParamKind {
244     WrappedAttr,    // A wrapped MLIR Attribute instance.
245     UnwrappedValue, // A raw value without MLIR Attribute wrapper.
246   };
247 
248   // Builds the parameter list for build() method of this op. This method writes
249   // to `paramList` the comma-separated parameter list and updates
250   // `resultTypeNames` with the names for parameters for specifying result
251   // types. The given `typeParamKind` and `attrParamKind` controls how result
252   // types and attributes are placed in the parameter list.
253   void buildParamList(llvm::SmallVectorImpl<OpMethodParameter> &paramList,
254                       SmallVectorImpl<std::string> &resultTypeNames,
255                       TypeParamKind typeParamKind,
256                       AttrParamKind attrParamKind = AttrParamKind::WrappedAttr);
257 
258   // Adds op arguments and regions into operation state for build() methods.
259   void genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body,
260                                               bool isRawValueAttr = false);
261 
262   // Generates canonicalizer declaration for the operation.
263   void genCanonicalizerDecls();
264 
265   // Generates the folder declaration for the operation.
266   void genFolderDecls();
267 
268   // Generates the parser for the operation.
269   void genParser();
270 
271   // Generates the printer for the operation.
272   void genPrinter();
273 
274   // Generates verify method for the operation.
275   void genVerifier();
276 
277   // Generates verify statements for operands and results in the operation.
278   // The generated code will be attached to `body`.
279   void genOperandResultVerifier(OpMethodBody &body,
280                                 Operator::value_range values,
281                                 StringRef valueKind);
282 
283   // Generates verify statements for regions in the operation.
284   // The generated code will be attached to `body`.
285   void genRegionVerifier(OpMethodBody &body);
286 
287   // Generates verify statements for successors in the operation.
288   // The generated code will be attached to `body`.
289   void genSuccessorVerifier(OpMethodBody &body);
290 
291   // Generates the traits used by the object.
292   void genTraits();
293 
294   // Generate the OpInterface methods for all interfaces.
295   void genOpInterfaceMethods();
296 
297   // Generate op interface methods for the given interface.
298   void genOpInterfaceMethods(const tblgen::InterfaceOpTrait *trait);
299 
300   // Generate op interface method for the given interface method. If
301   // 'declaration' is true, generates a declaration, else a definition.
302   OpMethod *genOpInterfaceMethod(const tblgen::InterfaceMethod &method,
303                                  bool declaration = true);
304 
305   // Generate the side effect interface methods.
306   void genSideEffectInterfaceMethods();
307 
308   // Generate the type inference interface methods.
309   void genTypeInterfaceMethods();
310 
311 private:
312   // The TableGen record for this op.
313   // TODO: OpEmitter should not have a Record directly,
314   // it should rather go through the Operator for better abstraction.
315   const Record &def;
316 
317   // The wrapper operator class for querying information from this op.
318   Operator op;
319 
320   // The C++ code builder for this op
321   OpClass opClass;
322 
323   // The format context for verification code generation.
324   FmtContext verifyCtx;
325 };
326 } // end anonymous namespace
327 
328 // Populate the format context `ctx` with substitutions of attributes, operands
329 // and results.
330 // - attrGet corresponds to the name of the function to call to get value of
331 //   attribute (the generated function call returns an Attribute);
332 // - operandGet corresponds to the name of the function with which to retrieve
333 //   an operand (the generated function call returns an OperandRange);
334 // - resultGet corresponds to the name of the function to get an result (the
335 //   generated function call returns a ValueRange);
populateSubstitutions(const Operator & op,const char * attrGet,const char * operandGet,const char * resultGet,FmtContext & ctx)336 static void populateSubstitutions(const Operator &op, const char *attrGet,
337                                   const char *operandGet, const char *resultGet,
338                                   FmtContext &ctx) {
339   // Populate substitutions for attributes and named operands.
340   for (const auto &namedAttr : op.getAttributes())
341     ctx.addSubst(namedAttr.name,
342                  formatv("{0}(\"{1}\")", attrGet, namedAttr.name));
343   for (int i = 0, e = op.getNumOperands(); i < e; ++i) {
344     auto &value = op.getOperand(i);
345     if (value.name.empty())
346       continue;
347 
348     if (value.isVariadic())
349       ctx.addSubst(value.name, formatv("{0}({1})", operandGet, i));
350     else
351       ctx.addSubst(value.name, formatv("(*{0}({1}).begin())", operandGet, i));
352   }
353 
354   // Populate substitutions for results.
355   for (int i = 0, e = op.getNumResults(); i < e; ++i) {
356     auto &value = op.getResult(i);
357     if (value.name.empty())
358       continue;
359 
360     if (value.isVariadic())
361       ctx.addSubst(value.name, formatv("{0}({1})", resultGet, i));
362     else
363       ctx.addSubst(value.name, formatv("(*{0}({1}).begin())", resultGet, i));
364   }
365 }
366 
367 // Generate attribute verification. If emitVerificationRequiringOp is set then
368 // only verification for attributes whose value depend on op being known are
369 // emitted, else only verification that doesn't depend on the op being known are
370 // generated.
371 // - emitErrorPrefix is the prefix for the error emitting call which consists
372 //   of the entire function call up to start of error message fragment;
373 // - emitVerificationRequiringOp specifies whether verification should be
374 //   emitted for verification that require the op to exist;
genAttributeVerifier(const Operator & op,const char * attrGet,const Twine & emitErrorPrefix,bool emitVerificationRequiringOp,FmtContext & ctx,OpMethodBody & body)375 static void genAttributeVerifier(const Operator &op, const char *attrGet,
376                                  const Twine &emitErrorPrefix,
377                                  bool emitVerificationRequiringOp,
378                                  FmtContext &ctx, OpMethodBody &body) {
379   for (const auto &namedAttr : op.getAttributes()) {
380     const auto &attr = namedAttr.attr;
381     if (attr.isDerivedAttr())
382       continue;
383 
384     auto attrName = namedAttr.name;
385     bool allowMissingAttr = attr.hasDefaultValue() || attr.isOptional();
386     auto attrPred = attr.getPredicate();
387     auto condition = attrPred.isNull() ? "" : attrPred.getCondition();
388     // There is a condition to emit only if the use of $_op and whether to
389     // emit verifications for op matches.
390     bool hasConditionToEmit = (!(condition.find("$_op") != StringRef::npos) ^
391                                emitVerificationRequiringOp);
392 
393     // Prefix with `tblgen_` to avoid hiding the attribute accessor.
394     auto varName = tblgenNamePrefix + attrName;
395 
396     // If the attribute is
397     //  1. Required (not allowed missing) and not in op verification, or
398     //  2. Has a condition that will get verified
399     // then the variable will be used.
400     //
401     // Therefore, for optional attributes whose verification requires that an
402     // op already exists for verification/emitVerificationRequiringOp is set
403     // has nothing that can be verified here.
404     if ((allowMissingAttr || emitVerificationRequiringOp) &&
405         !hasConditionToEmit)
406       continue;
407 
408     body << formatv("  {\n  auto {0} = {1}(\"{2}\");\n", varName, attrGet,
409                     attrName);
410 
411     if (!emitVerificationRequiringOp && !allowMissingAttr) {
412       body << "  if (!" << varName << ") return " << emitErrorPrefix
413            << "\"requires attribute '" << attrName << "'\");\n";
414     }
415 
416     if (!hasConditionToEmit) {
417       body << "  }\n";
418       continue;
419     }
420 
421     if (allowMissingAttr) {
422       // If the attribute has a default value, then only verify the predicate if
423       // set. This does effectively assume that the default value is valid.
424       // TODO: verify the debug value is valid (perhaps in debug mode only).
425       body << "  if (" << varName << ") {\n";
426     }
427 
428     body << tgfmt("    if (!($0)) return $1\"attribute '$2' "
429                   "failed to satisfy constraint: $3\");\n",
430                   /*ctx=*/nullptr, tgfmt(condition, &ctx.withSelf(varName)),
431                   emitErrorPrefix, attrName, attr.getDescription());
432     if (allowMissingAttr)
433       body << "  }\n";
434     body << "  }\n";
435   }
436 }
437 
OpEmitter(const Operator & op)438 OpEmitter::OpEmitter(const Operator &op)
439     : def(op.getDef()), op(op),
440       opClass(op.getCppClassName(), op.getExtraClassDeclaration()) {
441   verifyCtx.withOp("(*this->getOperation())");
442 
443   genTraits();
444 
445   // Generate C++ code for various op methods. The order here determines the
446   // methods in the generated file.
447   genOpAsmInterface();
448   genOpNameGetter();
449   genNamedOperandGetters();
450   genNamedOperandSetters();
451   genNamedResultGetters();
452   genNamedRegionGetters();
453   genNamedSuccessorGetters();
454   genAttrGetters();
455   genAttrSetters();
456   genBuilder();
457   genParser();
458   genPrinter();
459   genVerifier();
460   genCanonicalizerDecls();
461   genFolderDecls();
462   genTypeInterfaceMethods();
463   genOpInterfaceMethods();
464   generateOpFormat(op, opClass);
465   genSideEffectInterfaceMethods();
466 }
467 
emitDecl(const Operator & op,raw_ostream & os)468 void OpEmitter::emitDecl(const Operator &op, raw_ostream &os) {
469   OpEmitter(op).emitDecl(os);
470 }
471 
emitDef(const Operator & op,raw_ostream & os)472 void OpEmitter::emitDef(const Operator &op, raw_ostream &os) {
473   OpEmitter(op).emitDef(os);
474 }
475 
emitDecl(raw_ostream & os)476 void OpEmitter::emitDecl(raw_ostream &os) { opClass.writeDeclTo(os); }
477 
emitDef(raw_ostream & os)478 void OpEmitter::emitDef(raw_ostream &os) { opClass.writeDefTo(os); }
479 
genAttrGetters()480 void OpEmitter::genAttrGetters() {
481   FmtContext fctx;
482   fctx.withBuilder("::mlir::Builder(this->getContext())");
483 
484   Dialect opDialect = op.getDialect();
485   // Emit the derived attribute body.
486   auto emitDerivedAttr = [&](StringRef name, Attribute attr) {
487     auto *method = opClass.addMethodAndPrune(attr.getReturnType(), name);
488     if (!method)
489       return;
490     auto &body = method->body();
491     body << "  " << attr.getDerivedCodeBody() << "\n";
492   };
493 
494   // Emit with return type specified.
495   auto emitAttrWithReturnType = [&](StringRef name, Attribute attr) {
496     auto *method = opClass.addMethodAndPrune(attr.getReturnType(), name);
497     auto &body = method->body();
498     body << "  auto attr = " << name << "Attr();\n";
499     if (attr.hasDefaultValue()) {
500       // Returns the default value if not set.
501       // TODO: this is inefficient, we are recreating the attribute for every
502       // call. This should be set instead.
503       std::string defaultValue = std::string(
504           tgfmt(attr.getConstBuilderTemplate(), &fctx, attr.getDefaultValue()));
505       body << "    if (!attr)\n      return "
506            << tgfmt(attr.getConvertFromStorageCall(),
507                     &fctx.withSelf(defaultValue))
508            << ";\n";
509     }
510     body << "  return "
511          << tgfmt(attr.getConvertFromStorageCall(), &fctx.withSelf("attr"))
512          << ";\n";
513   };
514 
515   // Generate raw named accessor type. This is a wrapper class that allows
516   // referring to the attributes via accessors instead of having to use
517   // the string interface for better compile time verification.
518   auto emitAttrWithStorageType = [&](StringRef name, Attribute attr) {
519     auto *method =
520         opClass.addMethodAndPrune(attr.getStorageType(), (name + "Attr").str());
521     if (!method)
522       return;
523     auto &body = method->body();
524     body << "  return this->getAttr(\"" << name << "\").";
525     if (attr.isOptional() || attr.hasDefaultValue())
526       body << "dyn_cast_or_null<";
527     else
528       body << "cast<";
529     body << attr.getStorageType() << ">();";
530   };
531 
532   for (auto &namedAttr : op.getAttributes()) {
533     const auto &name = namedAttr.name;
534     const auto &attr = namedAttr.attr;
535     if (attr.isDerivedAttr()) {
536       emitDerivedAttr(name, attr);
537     } else {
538       emitAttrWithStorageType(name, attr);
539       emitAttrWithReturnType(name, attr);
540     }
541   }
542 
543   auto derivedAttrs = make_filter_range(op.getAttributes(),
544                                         [](const NamedAttribute &namedAttr) {
545                                           return namedAttr.attr.isDerivedAttr();
546                                         });
547   if (!derivedAttrs.empty()) {
548     opClass.addTrait("::mlir::DerivedAttributeOpInterface::Trait");
549     // Generate helper method to query whether a named attribute is a derived
550     // attribute. This enables, for example, avoiding adding an attribute that
551     // overlaps with a derived attribute.
552     {
553       auto *method = opClass.addMethodAndPrune("bool", "isDerivedAttribute",
554                                                OpMethod::MP_Static,
555                                                "::llvm::StringRef", "name");
556       auto &body = method->body();
557       for (auto namedAttr : derivedAttrs)
558         body << "  if (name == \"" << namedAttr.name << "\") return true;\n";
559       body << " return false;";
560     }
561     // Generate method to materialize derived attributes as a DictionaryAttr.
562     {
563       auto *method = opClass.addMethodAndPrune("::mlir::DictionaryAttr",
564                                                "materializeDerivedAttributes");
565       auto &body = method->body();
566 
567       auto nonMaterializable =
568           make_filter_range(derivedAttrs, [](const NamedAttribute &namedAttr) {
569             return namedAttr.attr.getConvertFromStorageCall().empty();
570           });
571       if (!nonMaterializable.empty()) {
572         std::string attrs;
573         llvm::raw_string_ostream os(attrs);
574         interleaveComma(nonMaterializable, os,
575                         [&](const NamedAttribute &attr) { os << attr.name; });
576         PrintWarning(
577             op.getLoc(),
578             formatv(
579                 "op has non-materializable derived attributes '{0}', skipping",
580                 os.str()));
581         body << formatv("  emitOpError(\"op has non-materializable derived "
582                         "attributes '{0}'\");\n",
583                         attrs);
584         body << "  return nullptr;";
585         return;
586       }
587 
588       body << "  ::mlir::MLIRContext* ctx = getContext();\n";
589       body << "  ::mlir::Builder odsBuilder(ctx); (void)odsBuilder;\n";
590       body << "  return ::mlir::DictionaryAttr::get({\n";
591       interleave(
592           derivedAttrs, body,
593           [&](const NamedAttribute &namedAttr) {
594             auto tmpl = namedAttr.attr.getConvertFromStorageCall();
595             body << "    {::mlir::Identifier::get(\"" << namedAttr.name
596                  << "\", ctx),\n"
597                  << tgfmt(tmpl, &fctx.withSelf(namedAttr.name + "()")
598                                      .withBuilder("odsBuilder")
599                                      .addSubst("_ctx", "ctx"))
600                  << "}";
601           },
602           ",\n");
603       body << "\n    }, ctx);";
604     }
605   }
606 }
607 
genAttrSetters()608 void OpEmitter::genAttrSetters() {
609   // Generate raw named setter type. This is a wrapper class that allows setting
610   // to the attributes via setters instead of having to use the string interface
611   // for better compile time verification.
612   auto emitAttrWithStorageType = [&](StringRef name, Attribute attr) {
613     auto *method = opClass.addMethodAndPrune("void", (name + "Attr").str(),
614                                              attr.getStorageType(), "attr");
615     if (!method)
616       return;
617     auto &body = method->body();
618     body << "  (*this)->setAttr(\"" << name << "\", attr);";
619   };
620 
621   for (auto &namedAttr : op.getAttributes()) {
622     const auto &name = namedAttr.name;
623     const auto &attr = namedAttr.attr;
624     if (!attr.isDerivedAttr())
625       emitAttrWithStorageType(name, attr);
626   }
627 }
628 
629 // Generates the code to compute the start and end index of an operand or result
630 // range.
631 template <typename RangeT>
632 static void
generateValueRangeStartAndEnd(Class & opClass,StringRef methodName,int numVariadic,int numNonVariadic,StringRef rangeSizeCall,bool hasAttrSegmentSize,StringRef sizeAttrInit,RangeT && odsValues)633 generateValueRangeStartAndEnd(Class &opClass, StringRef methodName,
634                               int numVariadic, int numNonVariadic,
635                               StringRef rangeSizeCall, bool hasAttrSegmentSize,
636                               StringRef sizeAttrInit, RangeT &&odsValues) {
637   auto *method = opClass.addMethodAndPrune("std::pair<unsigned, unsigned>",
638                                            methodName, "unsigned", "index");
639   if (!method)
640     return;
641   auto &body = method->body();
642   if (numVariadic == 0) {
643     body << "  return {index, 1};\n";
644   } else if (hasAttrSegmentSize) {
645     body << sizeAttrInit << attrSizedSegmentValueRangeCalcCode;
646   } else {
647     // Because the op can have arbitrarily interleaved variadic and non-variadic
648     // operands, we need to embed a list in the "sink" getter method for
649     // calculation at run-time.
650     llvm::SmallVector<StringRef, 4> isVariadic;
651     isVariadic.reserve(llvm::size(odsValues));
652     for (auto &it : odsValues)
653       isVariadic.push_back(it.isVariableLength() ? "true" : "false");
654     std::string isVariadicList = llvm::join(isVariadic, ", ");
655     body << formatv(sameVariadicSizeValueRangeCalcCode, isVariadicList,
656                     numNonVariadic, numVariadic, rangeSizeCall, "operand");
657   }
658 }
659 
660 // Generates the named operand getter methods for the given Operator `op` and
661 // puts them in `opClass`.  Uses `rangeType` as the return type of getters that
662 // return a range of operands (individual operands are `Value ` and each
663 // element in the range must also be `Value `); use `rangeBeginCall` to get
664 // an iterator to the beginning of the operand range; use `rangeSizeCall` to
665 // obtain the number of operands. `getOperandCallPattern` contains the code
666 // necessary to obtain a single operand whose position will be substituted
667 // instead of
668 // "{0}" marker in the pattern.  Note that the pattern should work for any kind
669 // of ops, in particular for one-operand ops that may not have the
670 // `getOperand(unsigned)` method.
generateNamedOperandGetters(const Operator & op,Class & opClass,StringRef sizeAttrInit,StringRef rangeType,StringRef rangeBeginCall,StringRef rangeSizeCall,StringRef getOperandCallPattern)671 static void generateNamedOperandGetters(const Operator &op, Class &opClass,
672                                         StringRef sizeAttrInit,
673                                         StringRef rangeType,
674                                         StringRef rangeBeginCall,
675                                         StringRef rangeSizeCall,
676                                         StringRef getOperandCallPattern) {
677   const int numOperands = op.getNumOperands();
678   const int numVariadicOperands = op.getNumVariableLengthOperands();
679   const int numNormalOperands = numOperands - numVariadicOperands;
680 
681   const auto *sameVariadicSize =
682       op.getTrait("::mlir::OpTrait::SameVariadicOperandSize");
683   const auto *attrSizedOperands =
684       op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments");
685 
686   if (numVariadicOperands > 1 && !sameVariadicSize && !attrSizedOperands) {
687     PrintFatalError(op.getLoc(), "op has multiple variadic operands but no "
688                                  "specification over their sizes");
689   }
690 
691   if (numVariadicOperands < 2 && attrSizedOperands) {
692     PrintFatalError(op.getLoc(), "op must have at least two variadic operands "
693                                  "to use 'AttrSizedOperandSegments' trait");
694   }
695 
696   if (attrSizedOperands && sameVariadicSize) {
697     PrintFatalError(op.getLoc(),
698                     "op cannot have both 'AttrSizedOperandSegments' and "
699                     "'SameVariadicOperandSize' traits");
700   }
701 
702   // First emit a few "sink" getter methods upon which we layer all nicer named
703   // getter methods.
704   generateValueRangeStartAndEnd(opClass, "getODSOperandIndexAndLength",
705                                 numVariadicOperands, numNormalOperands,
706                                 rangeSizeCall, attrSizedOperands, sizeAttrInit,
707                                 const_cast<Operator &>(op).getOperands());
708 
709   auto *m = opClass.addMethodAndPrune(rangeType, "getODSOperands", "unsigned",
710                                       "index");
711   auto &body = m->body();
712   body << formatv(valueRangeReturnCode, rangeBeginCall,
713                   "getODSOperandIndexAndLength(index)");
714 
715   // Then we emit nicer named getter methods by redirecting to the "sink" getter
716   // method.
717   for (int i = 0; i != numOperands; ++i) {
718     const auto &operand = op.getOperand(i);
719     if (operand.name.empty())
720       continue;
721 
722     if (operand.isOptional()) {
723       m = opClass.addMethodAndPrune("::mlir::Value", operand.name);
724       m->body() << "  auto operands = getODSOperands(" << i << ");\n"
725                 << "  return operands.empty() ? Value() : *operands.begin();";
726     } else if (operand.isVariadic()) {
727       m = opClass.addMethodAndPrune(rangeType, operand.name);
728       m->body() << "  return getODSOperands(" << i << ");";
729     } else {
730       m = opClass.addMethodAndPrune("::mlir::Value", operand.name);
731       m->body() << "  return *getODSOperands(" << i << ").begin();";
732     }
733   }
734 }
735 
genNamedOperandGetters()736 void OpEmitter::genNamedOperandGetters() {
737   generateNamedOperandGetters(
738       op, opClass,
739       /*sizeAttrInit=*/
740       formatv(opSegmentSizeAttrInitCode, "operand_segment_sizes").str(),
741       /*rangeType=*/"::mlir::Operation::operand_range",
742       /*rangeBeginCall=*/"getOperation()->operand_begin()",
743       /*rangeSizeCall=*/"getOperation()->getNumOperands()",
744       /*getOperandCallPattern=*/"getOperation()->getOperand({0})");
745 }
746 
genNamedOperandSetters()747 void OpEmitter::genNamedOperandSetters() {
748   auto *attrSizedOperands =
749       op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments");
750   for (int i = 0, e = op.getNumOperands(); i != e; ++i) {
751     const auto &operand = op.getOperand(i);
752     if (operand.name.empty())
753       continue;
754     auto *m = opClass.addMethodAndPrune("::mlir::MutableOperandRange",
755                                         (operand.name + "Mutable").str());
756     auto &body = m->body();
757     body << "  auto range = getODSOperandIndexAndLength(" << i << ");\n"
758          << "  return ::mlir::MutableOperandRange(getOperation(), "
759             "range.first, range.second";
760     if (attrSizedOperands)
761       body << ", ::mlir::MutableOperandRange::OperandSegment(" << i
762            << "u, *getOperation()->getMutableAttrDict().getNamed("
763               "\"operand_segment_sizes\"))";
764     body << ");\n";
765   }
766 }
767 
genNamedResultGetters()768 void OpEmitter::genNamedResultGetters() {
769   const int numResults = op.getNumResults();
770   const int numVariadicResults = op.getNumVariableLengthResults();
771   const int numNormalResults = numResults - numVariadicResults;
772 
773   // If we have more than one variadic results, we need more complicated logic
774   // to calculate the value range for each result.
775 
776   const auto *sameVariadicSize =
777       op.getTrait("::mlir::OpTrait::SameVariadicResultSize");
778   const auto *attrSizedResults =
779       op.getTrait("::mlir::OpTrait::AttrSizedResultSegments");
780 
781   if (numVariadicResults > 1 && !sameVariadicSize && !attrSizedResults) {
782     PrintFatalError(op.getLoc(), "op has multiple variadic results but no "
783                                  "specification over their sizes");
784   }
785 
786   if (numVariadicResults < 2 && attrSizedResults) {
787     PrintFatalError(op.getLoc(), "op must have at least two variadic results "
788                                  "to use 'AttrSizedResultSegments' trait");
789   }
790 
791   if (attrSizedResults && sameVariadicSize) {
792     PrintFatalError(op.getLoc(),
793                     "op cannot have both 'AttrSizedResultSegments' and "
794                     "'SameVariadicResultSize' traits");
795   }
796 
797   generateValueRangeStartAndEnd(
798       opClass, "getODSResultIndexAndLength", numVariadicResults,
799       numNormalResults, "getOperation()->getNumResults()", attrSizedResults,
800       formatv(opSegmentSizeAttrInitCode, "result_segment_sizes").str(),
801       op.getResults());
802 
803   auto *m = opClass.addMethodAndPrune("::mlir::Operation::result_range",
804                                       "getODSResults", "unsigned", "index");
805   m->body() << formatv(valueRangeReturnCode, "getOperation()->result_begin()",
806                        "getODSResultIndexAndLength(index)");
807 
808   for (int i = 0; i != numResults; ++i) {
809     const auto &result = op.getResult(i);
810     if (result.name.empty())
811       continue;
812 
813     if (result.isOptional()) {
814       m = opClass.addMethodAndPrune("::mlir::Value", result.name);
815       m->body()
816           << "  auto results = getODSResults(" << i << ");\n"
817           << "  return results.empty() ? ::mlir::Value() : *results.begin();";
818     } else if (result.isVariadic()) {
819       m = opClass.addMethodAndPrune("::mlir::Operation::result_range",
820                                     result.name);
821       m->body() << "  return getODSResults(" << i << ");";
822     } else {
823       m = opClass.addMethodAndPrune("::mlir::Value", result.name);
824       m->body() << "  return *getODSResults(" << i << ").begin();";
825     }
826   }
827 }
828 
genNamedRegionGetters()829 void OpEmitter::genNamedRegionGetters() {
830   unsigned numRegions = op.getNumRegions();
831   for (unsigned i = 0; i < numRegions; ++i) {
832     const auto &region = op.getRegion(i);
833     if (region.name.empty())
834       continue;
835 
836     // Generate the accessors for a varidiadic region.
837     if (region.isVariadic()) {
838       auto *m = opClass.addMethodAndPrune("::mlir::MutableArrayRef<Region>",
839                                           region.name);
840       m->body() << formatv(
841           "  return this->getOperation()->getRegions().drop_front({0});", i);
842       continue;
843     }
844 
845     auto *m = opClass.addMethodAndPrune("::mlir::Region &", region.name);
846     m->body() << formatv("  return this->getOperation()->getRegion({0});", i);
847   }
848 }
849 
genNamedSuccessorGetters()850 void OpEmitter::genNamedSuccessorGetters() {
851   unsigned numSuccessors = op.getNumSuccessors();
852   for (unsigned i = 0; i < numSuccessors; ++i) {
853     const NamedSuccessor &successor = op.getSuccessor(i);
854     if (successor.name.empty())
855       continue;
856 
857     // Generate the accessors for a variadic successor list.
858     if (successor.isVariadic()) {
859       auto *m =
860           opClass.addMethodAndPrune("::mlir::SuccessorRange", successor.name);
861       m->body() << formatv(
862           "  return {std::next(this->getOperation()->successor_begin(), {0}), "
863           "this->getOperation()->successor_end()};",
864           i);
865       continue;
866     }
867 
868     auto *m = opClass.addMethodAndPrune("::mlir::Block *", successor.name);
869     m->body() << formatv("  return this->getOperation()->getSuccessor({0});",
870                          i);
871   }
872 }
873 
canGenerateUnwrappedBuilder(Operator & op)874 static bool canGenerateUnwrappedBuilder(Operator &op) {
875   // If this op does not have native attributes at all, return directly to avoid
876   // redefining builders.
877   if (op.getNumNativeAttributes() == 0)
878     return false;
879 
880   bool canGenerate = false;
881   // We are generating builders that take raw values for attributes. We need to
882   // make sure the native attributes have a meaningful "unwrapped" value type
883   // different from the wrapped mlir::Attribute type to avoid redefining
884   // builders. This checks for the op has at least one such native attribute.
885   for (int i = 0, e = op.getNumNativeAttributes(); i < e; ++i) {
886     NamedAttribute &namedAttr = op.getAttribute(i);
887     if (canUseUnwrappedRawValue(namedAttr.attr)) {
888       canGenerate = true;
889       break;
890     }
891   }
892   return canGenerate;
893 }
894 
canInferType(Operator & op)895 static bool canInferType(Operator &op) {
896   return op.getTrait("::mlir::InferTypeOpInterface::Trait") &&
897          op.getNumRegions() == 0;
898 }
899 
genSeparateArgParamBuilder()900 void OpEmitter::genSeparateArgParamBuilder() {
901   SmallVector<AttrParamKind, 2> attrBuilderType;
902   attrBuilderType.push_back(AttrParamKind::WrappedAttr);
903   if (canGenerateUnwrappedBuilder(op))
904     attrBuilderType.push_back(AttrParamKind::UnwrappedValue);
905 
906   // Emit with separate builders with or without unwrapped attributes and/or
907   // inferring result type.
908   auto emit = [&](AttrParamKind attrType, TypeParamKind paramKind,
909                   bool inferType) {
910     llvm::SmallVector<OpMethodParameter, 4> paramList;
911     llvm::SmallVector<std::string, 4> resultNames;
912     buildParamList(paramList, resultNames, paramKind, attrType);
913 
914     auto *m = opClass.addMethodAndPrune("void", "build", OpMethod::MP_Static,
915                                         std::move(paramList));
916     // If the builder is redundant, skip generating the method.
917     if (!m)
918       return;
919     auto &body = m->body();
920     genCodeForAddingArgAndRegionForBuilder(
921         body, /*isRawValueAttr=*/attrType == AttrParamKind::UnwrappedValue);
922 
923     // Push all result types to the operation state
924 
925     if (inferType) {
926       // Generate builder that infers type too.
927       // TODO: Subsume this with general checking if type can be
928       // inferred automatically.
929       // TODO: Expand to handle regions.
930       body << formatv(R"(
931         ::llvm::SmallVector<::mlir::Type, 2> inferredReturnTypes;
932         if (succeeded({0}::inferReturnTypes(odsBuilder.getContext(),
933                       {1}.location, {1}.operands,
934                       {1}.attributes.getDictionary({1}.getContext()),
935                       /*regions=*/{{}, inferredReturnTypes)))
936           {1}.addTypes(inferredReturnTypes);
937         else
938           ::llvm::report_fatal_error("Failed to infer result type(s).");)",
939                       opClass.getClassName(), builderOpState);
940       return;
941     }
942 
943     switch (paramKind) {
944     case TypeParamKind::None:
945       return;
946     case TypeParamKind::Separate:
947       for (int i = 0, e = op.getNumResults(); i < e; ++i) {
948         if (op.getResult(i).isOptional())
949           body << "  if (" << resultNames[i] << ")\n  ";
950         body << "  " << builderOpState << ".addTypes(" << resultNames[i]
951              << ");\n";
952       }
953       return;
954     case TypeParamKind::Collective: {
955       int numResults = op.getNumResults();
956       int numVariadicResults = op.getNumVariableLengthResults();
957       int numNonVariadicResults = numResults - numVariadicResults;
958       bool hasVariadicResult = numVariadicResults != 0;
959 
960       // Avoid emitting "resultTypes.size() >= 0u" which is always true.
961       if (!(hasVariadicResult && numNonVariadicResults == 0))
962         body << "  "
963              << "assert(resultTypes.size() "
964              << (hasVariadicResult ? ">=" : "==") << " "
965              << numNonVariadicResults
966              << "u && \"mismatched number of results\");\n";
967       body << "  " << builderOpState << ".addTypes(resultTypes);\n";
968     }
969       return;
970     }
971     llvm_unreachable("unhandled TypeParamKind");
972   };
973 
974   // Some of the build methods generated here may be ambiguous, but TableGen's
975   // ambiguous function detection will elide those ones.
976   for (auto attrType : attrBuilderType) {
977     emit(attrType, TypeParamKind::Separate, /*inferType=*/false);
978     if (canInferType(op))
979       emit(attrType, TypeParamKind::None, /*inferType=*/true);
980     emit(attrType, TypeParamKind::Collective, /*inferType=*/false);
981   }
982 }
983 
genUseOperandAsResultTypeCollectiveParamBuilder()984 void OpEmitter::genUseOperandAsResultTypeCollectiveParamBuilder() {
985   int numResults = op.getNumResults();
986 
987   // Signature
988   llvm::SmallVector<OpMethodParameter, 4> paramList;
989   paramList.emplace_back("::mlir::OpBuilder &", "odsBuilder");
990   paramList.emplace_back("::mlir::OperationState &", builderOpState);
991   paramList.emplace_back("::mlir::ValueRange", "operands");
992   // Provide default value for `attributes` when its the last parameter
993   StringRef attributesDefaultValue = op.getNumVariadicRegions() ? "" : "{}";
994   paramList.emplace_back("::llvm::ArrayRef<::mlir::NamedAttribute>",
995                          "attributes", attributesDefaultValue);
996   if (op.getNumVariadicRegions())
997     paramList.emplace_back("unsigned", "numRegions");
998 
999   auto *m = opClass.addMethodAndPrune("void", "build", OpMethod::MP_Static,
1000                                       std::move(paramList));
1001   // If the builder is redundant, skip generating the method
1002   if (!m)
1003     return;
1004   auto &body = m->body();
1005 
1006   // Operands
1007   body << "  " << builderOpState << ".addOperands(operands);\n";
1008 
1009   // Attributes
1010   body << "  " << builderOpState << ".addAttributes(attributes);\n";
1011 
1012   // Create the correct number of regions
1013   if (int numRegions = op.getNumRegions()) {
1014     body << llvm::formatv(
1015         "  for (unsigned i = 0; i != {0}; ++i)\n",
1016         (op.getNumVariadicRegions() ? "numRegions" : Twine(numRegions)));
1017     body << "    (void)" << builderOpState << ".addRegion();\n";
1018   }
1019 
1020   // Result types
1021   SmallVector<std::string, 2> resultTypes(numResults, "operands[0].getType()");
1022   body << "  " << builderOpState << ".addTypes({"
1023        << llvm::join(resultTypes, ", ") << "});\n\n";
1024 }
1025 
genInferredTypeCollectiveParamBuilder()1026 void OpEmitter::genInferredTypeCollectiveParamBuilder() {
1027   // TODO: Expand to support regions.
1028   SmallVector<OpMethodParameter, 4> paramList;
1029   paramList.emplace_back("::mlir::OpBuilder &", "odsBuilder");
1030   paramList.emplace_back("::mlir::OperationState &", builderOpState);
1031   paramList.emplace_back("::mlir::ValueRange", "operands");
1032   paramList.emplace_back("::llvm::ArrayRef<::mlir::NamedAttribute>",
1033                          "attributes", "{}");
1034   auto *m = opClass.addMethodAndPrune("void", "build", OpMethod::MP_Static,
1035                                       std::move(paramList));
1036   // If the builder is redundant, skip generating the method
1037   if (!m)
1038     return;
1039   auto &body = m->body();
1040 
1041   int numResults = op.getNumResults();
1042   int numVariadicResults = op.getNumVariableLengthResults();
1043   int numNonVariadicResults = numResults - numVariadicResults;
1044 
1045   int numOperands = op.getNumOperands();
1046   int numVariadicOperands = op.getNumVariableLengthOperands();
1047   int numNonVariadicOperands = numOperands - numVariadicOperands;
1048 
1049   // Operands
1050   if (numVariadicOperands == 0 || numNonVariadicOperands != 0)
1051     body << "  assert(operands.size()"
1052          << (numVariadicOperands != 0 ? " >= " : " == ")
1053          << numNonVariadicOperands
1054          << "u && \"mismatched number of parameters\");\n";
1055   body << "  " << builderOpState << ".addOperands(operands);\n";
1056   body << "  " << builderOpState << ".addAttributes(attributes);\n";
1057 
1058   // Create the correct number of regions
1059   if (int numRegions = op.getNumRegions()) {
1060     body << llvm::formatv(
1061         "  for (unsigned i = 0; i != {0}; ++i)\n",
1062         (op.getNumVariadicRegions() ? "numRegions" : Twine(numRegions)));
1063     body << "    (void)" << builderOpState << ".addRegion();\n";
1064   }
1065 
1066   // Result types
1067   body << formatv(R"(
1068     ::mlir::SmallVector<::mlir::Type, 2> inferredReturnTypes;
1069     if (succeeded({0}::inferReturnTypes(odsBuilder.getContext(),
1070                   {1}.location, operands,
1071                   {1}.attributes.getDictionary({1}.getContext()),
1072                   /*regions=*/{{}, inferredReturnTypes))) {{)",
1073                   opClass.getClassName(), builderOpState);
1074   if (numVariadicResults == 0 || numNonVariadicResults != 0)
1075     body << "  assert(inferredReturnTypes.size()"
1076          << (numVariadicResults != 0 ? " >= " : " == ") << numNonVariadicResults
1077          << "u && \"mismatched number of return types\");\n";
1078   body << "      " << builderOpState << ".addTypes(inferredReturnTypes);";
1079 
1080   body << formatv(R"(
1081     } else
1082       ::llvm::report_fatal_error("Failed to infer result type(s).");)",
1083                   opClass.getClassName(), builderOpState);
1084 }
1085 
genUseOperandAsResultTypeSeparateParamBuilder()1086 void OpEmitter::genUseOperandAsResultTypeSeparateParamBuilder() {
1087   llvm::SmallVector<OpMethodParameter, 4> paramList;
1088   llvm::SmallVector<std::string, 4> resultNames;
1089   buildParamList(paramList, resultNames, TypeParamKind::None);
1090 
1091   auto *m = opClass.addMethodAndPrune("void", "build", OpMethod::MP_Static,
1092                                       std::move(paramList));
1093   // If the builder is redundant, skip generating the method
1094   if (!m)
1095     return;
1096   auto &body = m->body();
1097   genCodeForAddingArgAndRegionForBuilder(body);
1098 
1099   auto numResults = op.getNumResults();
1100   if (numResults == 0)
1101     return;
1102 
1103   // Push all result types to the operation state
1104   const char *index = op.getOperand(0).isVariadic() ? ".front()" : "";
1105   std::string resultType =
1106       formatv("{0}{1}.getType()", getArgumentName(op, 0), index).str();
1107   body << "  " << builderOpState << ".addTypes({" << resultType;
1108   for (int i = 1; i != numResults; ++i)
1109     body << ", " << resultType;
1110   body << "});\n\n";
1111 }
1112 
genUseAttrAsResultTypeBuilder()1113 void OpEmitter::genUseAttrAsResultTypeBuilder() {
1114   SmallVector<OpMethodParameter, 4> paramList;
1115   paramList.emplace_back("::mlir::OpBuilder &", "odsBuilder");
1116   paramList.emplace_back("::mlir::OperationState &", builderOpState);
1117   paramList.emplace_back("::mlir::ValueRange", "operands");
1118   paramList.emplace_back("::llvm::ArrayRef<::mlir::NamedAttribute>",
1119                          "attributes", "{}");
1120   auto *m = opClass.addMethodAndPrune("void", "build", OpMethod::MP_Static,
1121                                       std::move(paramList));
1122   // If the builder is redundant, skip generating the method
1123   if (!m)
1124     return;
1125 
1126   auto &body = m->body();
1127 
1128   // Push all result types to the operation state
1129   std::string resultType;
1130   const auto &namedAttr = op.getAttribute(0);
1131 
1132   body << "  for (auto attr : attributes) {\n";
1133   body << "    if (attr.first != \"" << namedAttr.name << "\") continue;\n";
1134   if (namedAttr.attr.isTypeAttr()) {
1135     resultType = "attr.second.cast<::mlir::TypeAttr>().getValue()";
1136   } else {
1137     resultType = "attr.second.getType()";
1138   }
1139 
1140   // Operands
1141   body << "  " << builderOpState << ".addOperands(operands);\n";
1142 
1143   // Attributes
1144   body << "  " << builderOpState << ".addAttributes(attributes);\n";
1145 
1146   // Result types
1147   SmallVector<std::string, 2> resultTypes(op.getNumResults(), resultType);
1148   body << "    " << builderOpState << ".addTypes({"
1149        << llvm::join(resultTypes, ", ") << "});\n";
1150   body << "  }\n";
1151 }
1152 
1153 /// Returns a signature of the builder as defined by a dag-typed initializer.
1154 /// Updates the context `fctx` to enable replacement of $_builder and $_state
1155 /// in the body. Reports errors at `loc`.
builderSignatureFromDAG(const DagInit * init,ArrayRef<llvm::SMLoc> loc,FmtContext & fctx)1156 static std::string builderSignatureFromDAG(const DagInit *init,
1157                                            ArrayRef<llvm::SMLoc> loc,
1158                                            FmtContext &fctx) {
1159   auto *defInit = dyn_cast<DefInit>(init->getOperator());
1160   if (!defInit || !defInit->getDef()->getName().equals("ins"))
1161     PrintFatalError(loc, "expected 'ins' in builders");
1162 
1163   // Inject builder and state arguments.
1164   llvm::SmallVector<std::string, 8> arguments;
1165   arguments.reserve(init->getNumArgs() + 2);
1166   arguments.push_back(llvm::formatv("::mlir::OpBuilder &{0}", builder).str());
1167   arguments.push_back(
1168       llvm::formatv("::mlir::OperationState &{0}", builderOpState).str());
1169 
1170   // Accept either a StringInit or a DefInit with two string values as dag
1171   // arguments. The former corresponds to the type, the latter to the type and
1172   // the default value. Similarly to C++, once an argument with a default value
1173   // is detected, the following arguments must have default values as well.
1174   bool seenDefaultValue = false;
1175   for (unsigned i = 0, e = init->getNumArgs(); i < e; ++i) {
1176     // If no name is provided, generate one.
1177     StringInit *argName = init->getArgName(i);
1178     std::string name =
1179         argName ? argName->getValue().str() : "odsArg" + std::to_string(i);
1180 
1181     Init *argInit = init->getArg(i);
1182     StringRef type;
1183     std::string defaultValue;
1184     if (StringInit *strType = dyn_cast<StringInit>(argInit)) {
1185       type = strType->getValue();
1186     } else {
1187       const Record *typeAndDefaultValue = cast<DefInit>(argInit)->getDef();
1188       type = typeAndDefaultValue->getValueAsString("type");
1189       StringRef defaultValueRef =
1190           typeAndDefaultValue->getValueAsString("defaultValue");
1191       if (!defaultValueRef.empty()) {
1192         seenDefaultValue = true;
1193         defaultValue = llvm::formatv(" = {0}", defaultValueRef).str();
1194       }
1195     }
1196     if (seenDefaultValue && defaultValue.empty())
1197       PrintFatalError(loc,
1198                       "expected an argument with default value after other "
1199                       "arguments with default values");
1200     arguments.push_back(
1201         llvm::formatv("{0} {1}{2}", type, name, defaultValue).str());
1202   }
1203 
1204   fctx.withBuilder(builder);
1205   fctx.addSubst("_state", builderOpState);
1206 
1207   return llvm::join(arguments, ", ");
1208 }
1209 
1210 // Returns a signature fo the builder as defined by a string initializer,
1211 // optionally injecting the builder and state arguments.
1212 // TODO: to be removed after the transition is complete.
builderSignatureFromString(StringRef params,FmtContext & fctx)1213 static std::string builderSignatureFromString(StringRef params,
1214                                               FmtContext &fctx) {
1215   bool skipParamGen = params.startswith("OpBuilder") ||
1216                       params.startswith("mlir::OpBuilder") ||
1217                       params.startswith("::mlir::OpBuilder");
1218   if (skipParamGen)
1219     return params.str();
1220 
1221   fctx.withBuilder(builder);
1222   fctx.addSubst("_state", builderOpState);
1223   return std::string(llvm::formatv("::mlir::OpBuilder &{0}, "
1224                                    "::mlir::OperationState &{1}{2}{3}",
1225                                    builder, builderOpState,
1226                                    params.empty() ? "" : ", ", params));
1227 }
1228 
genBuilder()1229 void OpEmitter::genBuilder() {
1230   // Handle custom builders if provided.
1231   // TODO: Create wrapper class for OpBuilder to hide the native
1232   // TableGen API calls here.
1233   {
1234     auto *listInit = dyn_cast_or_null<ListInit>(def.getValueInit("builders"));
1235     if (listInit) {
1236       for (Init *init : listInit->getValues()) {
1237         Record *builderDef = cast<DefInit>(init)->getDef();
1238         llvm::Optional<StringRef> params =
1239             builderDef->getValueAsOptionalString("params");
1240         FmtContext fctx;
1241         if (params.hasValue()) {
1242           PrintWarning(op.getLoc(),
1243                        "Op uses a deprecated, string-based OpBuilder format; "
1244                        "use OpBuilderDAG with '(ins <...>)' instead");
1245         }
1246         std::string paramStr =
1247             params.hasValue() ? builderSignatureFromString(params->trim(), fctx)
1248                               : builderSignatureFromDAG(
1249                                     builderDef->getValueAsDag("dagParams"),
1250                                     op.getLoc(), fctx);
1251 
1252         StringRef body = builderDef->getValueAsString("body");
1253         bool hasBody = !body.empty();
1254         OpMethod::Property properties =
1255             hasBody ? OpMethod::MP_Static : OpMethod::MP_StaticDeclaration;
1256         auto *method =
1257             opClass.addMethodAndPrune("void", "build", properties, paramStr);
1258         if (hasBody)
1259           method->body() << tgfmt(body, &fctx);
1260       }
1261     }
1262     if (op.skipDefaultBuilders()) {
1263       if (!listInit || listInit->empty())
1264         PrintFatalError(
1265             op.getLoc(),
1266             "default builders are skipped and no custom builders provided");
1267       return;
1268     }
1269   }
1270 
1271   // Generate default builders that requires all result type, operands, and
1272   // attributes as parameters.
1273 
1274   // We generate three classes of builders here:
1275   // 1. one having a stand-alone parameter for each operand / attribute, and
1276   genSeparateArgParamBuilder();
1277   // 2. one having an aggregated parameter for all result types / operands /
1278   //    attributes, and
1279   genCollectiveParamBuilder();
1280   // 3. one having a stand-alone parameter for each operand and attribute,
1281   //    use the first operand or attribute's type as all result types
1282   //    to facilitate different call patterns.
1283   if (op.getNumVariableLengthResults() == 0) {
1284     if (op.getTrait("::mlir::OpTrait::SameOperandsAndResultType")) {
1285       genUseOperandAsResultTypeSeparateParamBuilder();
1286       genUseOperandAsResultTypeCollectiveParamBuilder();
1287     }
1288     if (op.getTrait("::mlir::OpTrait::FirstAttrDerivedResultType"))
1289       genUseAttrAsResultTypeBuilder();
1290   }
1291 }
1292 
genCollectiveParamBuilder()1293 void OpEmitter::genCollectiveParamBuilder() {
1294   int numResults = op.getNumResults();
1295   int numVariadicResults = op.getNumVariableLengthResults();
1296   int numNonVariadicResults = numResults - numVariadicResults;
1297 
1298   int numOperands = op.getNumOperands();
1299   int numVariadicOperands = op.getNumVariableLengthOperands();
1300   int numNonVariadicOperands = numOperands - numVariadicOperands;
1301 
1302   SmallVector<OpMethodParameter, 4> paramList;
1303   paramList.emplace_back("::mlir::OpBuilder &", "");
1304   paramList.emplace_back("::mlir::OperationState &", builderOpState);
1305   paramList.emplace_back("::mlir::TypeRange", "resultTypes");
1306   paramList.emplace_back("::mlir::ValueRange", "operands");
1307   // Provide default value for `attributes` when its the last parameter
1308   StringRef attributesDefaultValue = op.getNumVariadicRegions() ? "" : "{}";
1309   paramList.emplace_back("::llvm::ArrayRef<::mlir::NamedAttribute>",
1310                          "attributes", attributesDefaultValue);
1311   if (op.getNumVariadicRegions())
1312     paramList.emplace_back("unsigned", "numRegions");
1313 
1314   auto *m = opClass.addMethodAndPrune("void", "build", OpMethod::MP_Static,
1315                                       std::move(paramList));
1316   // If the builder is redundant, skip generating the method
1317   if (!m)
1318     return;
1319   auto &body = m->body();
1320 
1321   // Operands
1322   if (numVariadicOperands == 0 || numNonVariadicOperands != 0)
1323     body << "  assert(operands.size()"
1324          << (numVariadicOperands != 0 ? " >= " : " == ")
1325          << numNonVariadicOperands
1326          << "u && \"mismatched number of parameters\");\n";
1327   body << "  " << builderOpState << ".addOperands(operands);\n";
1328 
1329   // Attributes
1330   body << "  " << builderOpState << ".addAttributes(attributes);\n";
1331 
1332   // Create the correct number of regions
1333   if (int numRegions = op.getNumRegions()) {
1334     body << llvm::formatv(
1335         "  for (unsigned i = 0; i != {0}; ++i)\n",
1336         (op.getNumVariadicRegions() ? "numRegions" : Twine(numRegions)));
1337     body << "    (void)" << builderOpState << ".addRegion();\n";
1338   }
1339 
1340   // Result types
1341   if (numVariadicResults == 0 || numNonVariadicResults != 0)
1342     body << "  assert(resultTypes.size()"
1343          << (numVariadicResults != 0 ? " >= " : " == ") << numNonVariadicResults
1344          << "u && \"mismatched number of return types\");\n";
1345   body << "  " << builderOpState << ".addTypes(resultTypes);\n";
1346 
1347   // Generate builder that infers type too.
1348   // TODO: Expand to handle regions and successors.
1349   if (canInferType(op) && op.getNumSuccessors() == 0)
1350     genInferredTypeCollectiveParamBuilder();
1351 }
1352 
buildParamList(SmallVectorImpl<OpMethodParameter> & paramList,SmallVectorImpl<std::string> & resultTypeNames,TypeParamKind typeParamKind,AttrParamKind attrParamKind)1353 void OpEmitter::buildParamList(SmallVectorImpl<OpMethodParameter> &paramList,
1354                                SmallVectorImpl<std::string> &resultTypeNames,
1355                                TypeParamKind typeParamKind,
1356                                AttrParamKind attrParamKind) {
1357   resultTypeNames.clear();
1358   auto numResults = op.getNumResults();
1359   resultTypeNames.reserve(numResults);
1360 
1361   paramList.emplace_back("::mlir::OpBuilder &", "odsBuilder");
1362   paramList.emplace_back("::mlir::OperationState &", builderOpState);
1363 
1364   switch (typeParamKind) {
1365   case TypeParamKind::None:
1366     break;
1367   case TypeParamKind::Separate: {
1368     // Add parameters for all return types
1369     for (int i = 0; i < numResults; ++i) {
1370       const auto &result = op.getResult(i);
1371       std::string resultName = std::string(result.name);
1372       if (resultName.empty())
1373         resultName = std::string(formatv("resultType{0}", i));
1374 
1375       StringRef type =
1376           result.isVariadic() ? "::mlir::TypeRange" : "::mlir::Type";
1377       OpMethodParameter::Property properties = OpMethodParameter::PP_None;
1378       if (result.isOptional())
1379         properties = OpMethodParameter::PP_Optional;
1380 
1381       paramList.emplace_back(type, resultName, properties);
1382       resultTypeNames.emplace_back(std::move(resultName));
1383     }
1384   } break;
1385   case TypeParamKind::Collective: {
1386     paramList.emplace_back("::mlir::TypeRange", "resultTypes");
1387     resultTypeNames.push_back("resultTypes");
1388   } break;
1389   }
1390 
1391   // Add parameters for all arguments (operands and attributes).
1392 
1393   int numOperands = 0;
1394   int numAttrs = 0;
1395 
1396   int defaultValuedAttrStartIndex = op.getNumArgs();
1397   if (attrParamKind == AttrParamKind::UnwrappedValue) {
1398     // Calculate the start index from which we can attach default values in the
1399     // builder declaration.
1400     for (int i = op.getNumArgs() - 1; i >= 0; --i) {
1401       auto *namedAttr = op.getArg(i).dyn_cast<tblgen::NamedAttribute *>();
1402       if (!namedAttr || !namedAttr->attr.hasDefaultValue())
1403         break;
1404 
1405       if (!canUseUnwrappedRawValue(namedAttr->attr))
1406         break;
1407 
1408       // Creating an APInt requires us to provide bitwidth, value, and
1409       // signedness, which is complicated compared to others. Similarly
1410       // for APFloat.
1411       // TODO: Adjust the 'returnType' field of such attributes
1412       // to support them.
1413       StringRef retType = namedAttr->attr.getReturnType();
1414       if (retType == "::llvm::APInt" || retType == "::llvm::APFloat")
1415         break;
1416 
1417       defaultValuedAttrStartIndex = i;
1418     }
1419   }
1420 
1421   for (int i = 0, e = op.getNumArgs(); i < e; ++i) {
1422     auto argument = op.getArg(i);
1423     if (argument.is<tblgen::NamedTypeConstraint *>()) {
1424       const auto &operand = op.getOperand(numOperands);
1425       StringRef type =
1426           operand.isVariadic() ? "::mlir::ValueRange" : "::mlir::Value";
1427       OpMethodParameter::Property properties = OpMethodParameter::PP_None;
1428       if (operand.isOptional())
1429         properties = OpMethodParameter::PP_Optional;
1430 
1431       paramList.emplace_back(type, getArgumentName(op, numOperands),
1432                              properties);
1433       ++numOperands;
1434     } else {
1435       const auto &namedAttr = op.getAttribute(numAttrs);
1436       const auto &attr = namedAttr.attr;
1437 
1438       OpMethodParameter::Property properties = OpMethodParameter::PP_None;
1439       if (attr.isOptional())
1440         properties = OpMethodParameter::PP_Optional;
1441 
1442       StringRef type;
1443       switch (attrParamKind) {
1444       case AttrParamKind::WrappedAttr:
1445         type = attr.getStorageType();
1446         break;
1447       case AttrParamKind::UnwrappedValue:
1448         if (canUseUnwrappedRawValue(attr))
1449           type = attr.getReturnType();
1450         else
1451           type = attr.getStorageType();
1452         break;
1453       }
1454 
1455       std::string defaultValue;
1456       // Attach default value if requested and possible.
1457       if (attrParamKind == AttrParamKind::UnwrappedValue &&
1458           i >= defaultValuedAttrStartIndex) {
1459         bool isString = attr.getReturnType() == "::llvm::StringRef";
1460         if (isString)
1461           defaultValue.append("\"");
1462         defaultValue += attr.getDefaultValue();
1463         if (isString)
1464           defaultValue.append("\"");
1465       }
1466       paramList.emplace_back(type, namedAttr.name, defaultValue, properties);
1467       ++numAttrs;
1468     }
1469   }
1470 
1471   /// Insert parameters for each successor.
1472   for (const NamedSuccessor &succ : op.getSuccessors()) {
1473     StringRef type =
1474         succ.isVariadic() ? "::mlir::BlockRange" : "::mlir::Block *";
1475     paramList.emplace_back(type, succ.name);
1476   }
1477 
1478   /// Insert parameters for variadic regions.
1479   for (const NamedRegion &region : op.getRegions())
1480     if (region.isVariadic())
1481       paramList.emplace_back("unsigned",
1482                              llvm::formatv("{0}Count", region.name).str());
1483 }
1484 
genCodeForAddingArgAndRegionForBuilder(OpMethodBody & body,bool isRawValueAttr)1485 void OpEmitter::genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body,
1486                                                        bool isRawValueAttr) {
1487   // Push all operands to the result.
1488   for (int i = 0, e = op.getNumOperands(); i < e; ++i) {
1489     std::string argName = getArgumentName(op, i);
1490     if (op.getOperand(i).isOptional())
1491       body << "  if (" << argName << ")\n  ";
1492     body << "  " << builderOpState << ".addOperands(" << argName << ");\n";
1493   }
1494 
1495   // If the operation has the operand segment size attribute, add it here.
1496   if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) {
1497     body << "  " << builderOpState
1498          << ".addAttribute(\"operand_segment_sizes\", "
1499             "odsBuilder.getI32VectorAttr({";
1500     interleaveComma(llvm::seq<int>(0, op.getNumOperands()), body, [&](int i) {
1501       if (op.getOperand(i).isOptional())
1502         body << "(" << getArgumentName(op, i) << " ? 1 : 0)";
1503       else if (op.getOperand(i).isVariadic())
1504         body << "static_cast<int32_t>(" << getArgumentName(op, i) << ".size())";
1505       else
1506         body << "1";
1507     });
1508     body << "}));\n";
1509   }
1510 
1511   // Push all attributes to the result.
1512   for (const auto &namedAttr : op.getAttributes()) {
1513     auto &attr = namedAttr.attr;
1514     if (!attr.isDerivedAttr()) {
1515       bool emitNotNullCheck = attr.isOptional();
1516       if (emitNotNullCheck) {
1517         body << formatv("  if ({0}) ", namedAttr.name) << "{\n";
1518       }
1519       if (isRawValueAttr && canUseUnwrappedRawValue(attr)) {
1520         // If this is a raw value, then we need to wrap it in an Attribute
1521         // instance.
1522         FmtContext fctx;
1523         fctx.withBuilder("odsBuilder");
1524 
1525         std::string builderTemplate =
1526             std::string(attr.getConstBuilderTemplate());
1527 
1528         // For StringAttr, its constant builder call will wrap the input in
1529         // quotes, which is correct for normal string literals, but incorrect
1530         // here given we use function arguments. So we need to strip the
1531         // wrapping quotes.
1532         if (StringRef(builderTemplate).contains("\"$0\""))
1533           builderTemplate = replaceAllSubstrs(builderTemplate, "\"$0\"", "$0");
1534 
1535         std::string value =
1536             std::string(tgfmt(builderTemplate, &fctx, namedAttr.name));
1537         body << formatv("  {0}.addAttribute(\"{1}\", {2});\n", builderOpState,
1538                         namedAttr.name, value);
1539       } else {
1540         body << formatv("  {0}.addAttribute(\"{1}\", {1});\n", builderOpState,
1541                         namedAttr.name);
1542       }
1543       if (emitNotNullCheck) {
1544         body << "  }\n";
1545       }
1546     }
1547   }
1548 
1549   // Create the correct number of regions.
1550   for (const NamedRegion &region : op.getRegions()) {
1551     if (region.isVariadic())
1552       body << formatv("  for (unsigned i = 0; i < {0}Count; ++i)\n  ",
1553                       region.name);
1554 
1555     body << "  (void)" << builderOpState << ".addRegion();\n";
1556   }
1557 
1558   // Push all successors to the result.
1559   for (const NamedSuccessor &namedSuccessor : op.getSuccessors()) {
1560     body << formatv("  {0}.addSuccessors({1});\n", builderOpState,
1561                     namedSuccessor.name);
1562   }
1563 }
1564 
genCanonicalizerDecls()1565 void OpEmitter::genCanonicalizerDecls() {
1566   if (!def.getValueAsBit("hasCanonicalizer"))
1567     return;
1568 
1569   SmallVector<OpMethodParameter, 2> paramList;
1570   paramList.emplace_back("::mlir::OwningRewritePatternList &", "results");
1571   paramList.emplace_back("::mlir::MLIRContext *", "context");
1572   opClass.addMethodAndPrune("void", "getCanonicalizationPatterns",
1573                             OpMethod::MP_StaticDeclaration,
1574                             std::move(paramList));
1575 }
1576 
genFolderDecls()1577 void OpEmitter::genFolderDecls() {
1578   bool hasSingleResult =
1579       op.getNumResults() == 1 && op.getNumVariableLengthResults() == 0;
1580 
1581   if (def.getValueAsBit("hasFolder")) {
1582     if (hasSingleResult) {
1583       opClass.addMethodAndPrune(
1584           "::mlir::OpFoldResult", "fold", OpMethod::MP_Declaration,
1585           "::llvm::ArrayRef<::mlir::Attribute>", "operands");
1586     } else {
1587       SmallVector<OpMethodParameter, 2> paramList;
1588       paramList.emplace_back("::llvm::ArrayRef<::mlir::Attribute>", "operands");
1589       paramList.emplace_back("::llvm::SmallVectorImpl<::mlir::OpFoldResult> &",
1590                              "results");
1591       opClass.addMethodAndPrune("::mlir::LogicalResult", "fold",
1592                                 OpMethod::MP_Declaration, std::move(paramList));
1593     }
1594   }
1595 }
1596 
genOpInterfaceMethods(const tblgen::InterfaceOpTrait * opTrait)1597 void OpEmitter::genOpInterfaceMethods(const tblgen::InterfaceOpTrait *opTrait) {
1598   auto interface = opTrait->getOpInterface();
1599 
1600   // Get the set of methods that should always be declared.
1601   auto alwaysDeclaredMethodsVec = opTrait->getAlwaysDeclaredMethods();
1602   llvm::StringSet<> alwaysDeclaredMethods;
1603   alwaysDeclaredMethods.insert(alwaysDeclaredMethodsVec.begin(),
1604                                alwaysDeclaredMethodsVec.end());
1605 
1606   for (const InterfaceMethod &method : interface.getMethods()) {
1607     // Don't declare if the method has a body.
1608     if (method.getBody())
1609       continue;
1610     // Don't declare if the method has a default implementation and the op
1611     // didn't request that it always be declared.
1612     if (method.getDefaultImplementation() &&
1613         !alwaysDeclaredMethods.count(method.getName()))
1614       continue;
1615     genOpInterfaceMethod(method);
1616   }
1617 }
1618 
genOpInterfaceMethod(const InterfaceMethod & method,bool declaration)1619 OpMethod *OpEmitter::genOpInterfaceMethod(const InterfaceMethod &method,
1620                                           bool declaration) {
1621   SmallVector<OpMethodParameter, 4> paramList;
1622   for (const InterfaceMethod::Argument &arg : method.getArguments())
1623     paramList.emplace_back(arg.type, arg.name);
1624 
1625   auto properties = method.isStatic() ? OpMethod::MP_Static : OpMethod::MP_None;
1626   if (declaration)
1627     properties =
1628         static_cast<OpMethod::Property>(properties | OpMethod::MP_Declaration);
1629   return opClass.addMethodAndPrune(method.getReturnType(), method.getName(),
1630                                    properties, std::move(paramList));
1631 }
1632 
genOpInterfaceMethods()1633 void OpEmitter::genOpInterfaceMethods() {
1634   for (const auto &trait : op.getTraits()) {
1635     if (const auto *opTrait = dyn_cast<tblgen::InterfaceOpTrait>(&trait))
1636       if (opTrait->shouldDeclareMethods())
1637         genOpInterfaceMethods(opTrait);
1638   }
1639 }
1640 
genSideEffectInterfaceMethods()1641 void OpEmitter::genSideEffectInterfaceMethods() {
1642   enum EffectKind { Operand, Result, Symbol, Static };
1643   struct EffectLocation {
1644     /// The effect applied.
1645     SideEffect effect;
1646 
1647     /// The index if the kind is not static.
1648     unsigned index : 30;
1649 
1650     /// The kind of the location.
1651     unsigned kind : 2;
1652   };
1653 
1654   StringMap<SmallVector<EffectLocation, 1>> interfaceEffects;
1655   auto resolveDecorators = [&](Operator::var_decorator_range decorators,
1656                                unsigned index, unsigned kind) {
1657     for (auto decorator : decorators)
1658       if (SideEffect *effect = dyn_cast<SideEffect>(&decorator)) {
1659         opClass.addTrait(effect->getInterfaceTrait());
1660         interfaceEffects[effect->getBaseEffectName()].push_back(
1661             EffectLocation{*effect, index, kind});
1662       }
1663   };
1664 
1665   // Collect effects that were specified via:
1666   /// Traits.
1667   for (const auto &trait : op.getTraits()) {
1668     const auto *opTrait = dyn_cast<tblgen::SideEffectTrait>(&trait);
1669     if (!opTrait)
1670       continue;
1671     auto &effects = interfaceEffects[opTrait->getBaseEffectName()];
1672     for (auto decorator : opTrait->getEffects())
1673       effects.push_back(EffectLocation{cast<SideEffect>(decorator),
1674                                        /*index=*/0, EffectKind::Static});
1675   }
1676   /// Attributes and Operands.
1677   for (unsigned i = 0, operandIt = 0, e = op.getNumArgs(); i != e; ++i) {
1678     Argument arg = op.getArg(i);
1679     if (arg.is<NamedTypeConstraint *>()) {
1680       resolveDecorators(op.getArgDecorators(i), operandIt, EffectKind::Operand);
1681       ++operandIt;
1682       continue;
1683     }
1684     const NamedAttribute *attr = arg.get<NamedAttribute *>();
1685     if (attr->attr.getBaseAttr().isSymbolRefAttr())
1686       resolveDecorators(op.getArgDecorators(i), i, EffectKind::Symbol);
1687   }
1688   /// Results.
1689   for (unsigned i = 0, e = op.getNumResults(); i != e; ++i)
1690     resolveDecorators(op.getResultDecorators(i), i, EffectKind::Result);
1691 
1692   // The code used to add an effect instance.
1693   // {0}: The effect class.
1694   // {1}: Optional value or symbol reference.
1695   // {1}: The resource class.
1696   const char *addEffectCode =
1697       "  effects.emplace_back({0}::get(), {1}{2}::get());\n";
1698 
1699   for (auto &it : interfaceEffects) {
1700     // Generate the 'getEffects' method.
1701     std::string type = llvm::formatv("::mlir::SmallVectorImpl<::mlir::"
1702                                      "SideEffects::EffectInstance<{0}>> &",
1703                                      it.first())
1704                            .str();
1705     auto *getEffects =
1706         opClass.addMethodAndPrune("void", "getEffects", type, "effects");
1707     auto &body = getEffects->body();
1708 
1709     // Add effect instances for each of the locations marked on the operation.
1710     for (auto &location : it.second) {
1711       StringRef effect = location.effect.getName();
1712       StringRef resource = location.effect.getResource();
1713       if (location.kind == EffectKind::Static) {
1714         // A static instance has no attached value.
1715         body << llvm::formatv(addEffectCode, effect, "", resource).str();
1716       } else if (location.kind == EffectKind::Symbol) {
1717         // A symbol reference requires adding the proper attribute.
1718         const auto *attr = op.getArg(location.index).get<NamedAttribute *>();
1719         if (attr->attr.isOptional()) {
1720           body << "  if (auto symbolRef = " << attr->name << "Attr())\n  "
1721                << llvm::formatv(addEffectCode, effect, "symbolRef, ", resource)
1722                       .str();
1723         } else {
1724           body << llvm::formatv(addEffectCode, effect, attr->name + "(), ",
1725                                 resource)
1726                       .str();
1727         }
1728       } else {
1729         // Otherwise this is an operand/result, so we need to attach the Value.
1730         body << "  for (::mlir::Value value : getODS"
1731              << (location.kind == EffectKind::Operand ? "Operands" : "Results")
1732              << "(" << location.index << "))\n  "
1733              << llvm::formatv(addEffectCode, effect, "value, ", resource).str();
1734       }
1735     }
1736   }
1737 }
1738 
genTypeInterfaceMethods()1739 void OpEmitter::genTypeInterfaceMethods() {
1740   if (!op.allResultTypesKnown())
1741     return;
1742   // Generate 'inferReturnTypes' method declaration using the interface method
1743   // declared in 'InferTypeOpInterface' op interface.
1744   const auto *trait = dyn_cast<InterfaceOpTrait>(
1745       op.getTrait("::mlir::InferTypeOpInterface::Trait"));
1746   auto interface = trait->getOpInterface();
1747   OpMethod *method = [&]() -> OpMethod * {
1748     for (const InterfaceMethod &interfaceMethod : interface.getMethods()) {
1749       if (interfaceMethod.getName() == "inferReturnTypes") {
1750         return genOpInterfaceMethod(interfaceMethod, /*declaration=*/false);
1751       }
1752     }
1753     assert(0 && "unable to find inferReturnTypes interface method");
1754     return nullptr;
1755   }();
1756   auto &body = method->body();
1757   body << "  inferredReturnTypes.resize(" << op.getNumResults() << ");\n";
1758 
1759   FmtContext fctx;
1760   fctx.withBuilder("odsBuilder");
1761   body << "  ::mlir::Builder odsBuilder(context);\n";
1762 
1763   auto emitType =
1764       [&](const tblgen::Operator::ArgOrType &type) -> OpMethodBody & {
1765     if (type.isArg()) {
1766       auto argIndex = type.getArg();
1767       assert(!op.getArg(argIndex).is<NamedAttribute *>());
1768       auto arg = op.getArgToOperandOrAttribute(argIndex);
1769       if (arg.kind() == Operator::OperandOrAttribute::Kind::Operand)
1770         return body << "operands[" << arg.operandOrAttributeIndex()
1771                     << "].getType()";
1772       return body << "attributes[" << arg.operandOrAttributeIndex()
1773                   << "].getType()";
1774     } else {
1775       return body << tgfmt(*type.getType().getBuilderCall(), &fctx);
1776     }
1777   };
1778 
1779   for (int i = 0, e = op.getNumResults(); i != e; ++i) {
1780     body << "  inferredReturnTypes[" << i << "] = ";
1781     auto types = op.getSameTypeAsResult(i);
1782     emitType(types[0]) << ";\n";
1783     if (types.size() == 1)
1784       continue;
1785     // TODO: We could verify equality here, but skipping that for verification.
1786   }
1787   body << "  return ::mlir::success();";
1788 }
1789 
genParser()1790 void OpEmitter::genParser() {
1791   if (!hasStringAttribute(def, "parser") ||
1792       hasStringAttribute(def, "assemblyFormat"))
1793     return;
1794 
1795   SmallVector<OpMethodParameter, 2> paramList;
1796   paramList.emplace_back("::mlir::OpAsmParser &", "parser");
1797   paramList.emplace_back("::mlir::OperationState &", "result");
1798   auto *method =
1799       opClass.addMethodAndPrune("::mlir::ParseResult", "parse",
1800                                 OpMethod::MP_Static, std::move(paramList));
1801 
1802   FmtContext fctx;
1803   fctx.addSubst("cppClass", opClass.getClassName());
1804   auto parser = def.getValueAsString("parser").ltrim().rtrim(" \t\v\f\r");
1805   method->body() << "  " << tgfmt(parser, &fctx);
1806 }
1807 
genPrinter()1808 void OpEmitter::genPrinter() {
1809   if (hasStringAttribute(def, "assemblyFormat"))
1810     return;
1811 
1812   auto valueInit = def.getValueInit("printer");
1813   StringInit *stringInit = dyn_cast<StringInit>(valueInit);
1814   if (!stringInit)
1815     return;
1816 
1817   auto *method =
1818       opClass.addMethodAndPrune("void", "print", "::mlir::OpAsmPrinter &", "p");
1819   FmtContext fctx;
1820   fctx.addSubst("cppClass", opClass.getClassName());
1821   auto printer = stringInit->getValue().ltrim().rtrim(" \t\v\f\r");
1822   method->body() << "  " << tgfmt(printer, &fctx);
1823 }
1824 
genVerifier()1825 void OpEmitter::genVerifier() {
1826   auto *method = opClass.addMethodAndPrune("::mlir::LogicalResult", "verify");
1827   auto &body = method->body();
1828   body << "  if (failed(" << op.getAdaptorName()
1829        << "(*this).verify(this->getLoc()))) "
1830        << "return ::mlir::failure();\n";
1831 
1832   auto *valueInit = def.getValueInit("verifier");
1833   StringInit *stringInit = dyn_cast<StringInit>(valueInit);
1834   bool hasCustomVerify = stringInit && !stringInit->getValue().empty();
1835   populateSubstitutions(op, "this->getAttr", "this->getODSOperands",
1836                         "this->getODSResults", verifyCtx);
1837 
1838   genAttributeVerifier(op, "this->getAttr", "emitOpError(",
1839                        /*emitVerificationRequiringOp=*/true, verifyCtx, body);
1840   genOperandResultVerifier(body, op.getOperands(), "operand");
1841   genOperandResultVerifier(body, op.getResults(), "result");
1842 
1843   for (auto &trait : op.getTraits()) {
1844     if (auto *t = dyn_cast<tblgen::PredOpTrait>(&trait)) {
1845       body << tgfmt("  if (!($0))\n    "
1846                     "return emitOpError(\"failed to verify that $1\");\n",
1847                     &verifyCtx, tgfmt(t->getPredTemplate(), &verifyCtx),
1848                     t->getDescription());
1849     }
1850   }
1851 
1852   genRegionVerifier(body);
1853   genSuccessorVerifier(body);
1854 
1855   if (hasCustomVerify) {
1856     FmtContext fctx;
1857     fctx.addSubst("cppClass", opClass.getClassName());
1858     auto printer = stringInit->getValue().ltrim().rtrim(" \t\v\f\r");
1859     body << "  " << tgfmt(printer, &fctx);
1860   } else {
1861     body << "  return ::mlir::success();\n";
1862   }
1863 }
1864 
genOperandResultVerifier(OpMethodBody & body,Operator::value_range values,StringRef valueKind)1865 void OpEmitter::genOperandResultVerifier(OpMethodBody &body,
1866                                          Operator::value_range values,
1867                                          StringRef valueKind) {
1868   FmtContext fctx;
1869 
1870   body << "  {\n";
1871   body << "    unsigned index = 0; (void)index;\n";
1872 
1873   for (auto staticValue : llvm::enumerate(values)) {
1874     bool hasPredicate = staticValue.value().hasPredicate();
1875     bool isOptional = staticValue.value().isOptional();
1876     if (!hasPredicate && !isOptional)
1877       continue;
1878     body << formatv("    auto valueGroup{2} = getODS{0}{1}s({2});\n",
1879                     // Capitalize the first letter to match the function name
1880                     valueKind.substr(0, 1).upper(), valueKind.substr(1),
1881                     staticValue.index());
1882 
1883     // If the constraint is optional check that the value group has at most 1
1884     // value.
1885     if (isOptional) {
1886       body << formatv("    if (valueGroup{0}.size() > 1)\n"
1887                       "      return emitOpError(\"{1} group starting at #\") "
1888                       "<< index << \" requires 0 or 1 element, but found \" << "
1889                       "valueGroup{0}.size();\n",
1890                       staticValue.index(), valueKind);
1891     }
1892 
1893     // Otherwise, if there is no predicate there is nothing left to do.
1894     if (!hasPredicate)
1895       continue;
1896 
1897     // Emit a loop to check all the dynamic values in the pack.
1898     body << "    for (::mlir::Value v : valueGroup" << staticValue.index()
1899          << ") {\n";
1900 
1901     auto constraint = staticValue.value().constraint;
1902     body << "      (void)v;\n"
1903          << "      if (!("
1904          << tgfmt(constraint.getConditionTemplate(),
1905                   &fctx.withSelf("v.getType()"))
1906          << ")) {\n"
1907          << formatv("        return emitOpError(\"{0} #\") << index "
1908                     "<< \" must be {1}, but got \" << v.getType();\n",
1909                     valueKind, constraint.getDescription())
1910          << "      }\n" // if
1911          << "      ++index;\n"
1912          << "    }\n"; // for
1913   }
1914 
1915   body << "  }\n";
1916 }
1917 
genRegionVerifier(OpMethodBody & body)1918 void OpEmitter::genRegionVerifier(OpMethodBody &body) {
1919   // If we have no regions, there is nothing more to do.
1920   unsigned numRegions = op.getNumRegions();
1921   if (numRegions == 0)
1922     return;
1923 
1924   body << "{\n";
1925   body << "    unsigned index = 0; (void)index;\n";
1926 
1927   for (unsigned i = 0; i < numRegions; ++i) {
1928     const auto &region = op.getRegion(i);
1929     if (region.constraint.getPredicate().isNull())
1930       continue;
1931 
1932     body << "    for (::mlir::Region &region : ";
1933     body << formatv(region.isVariadic()
1934                         ? "{0}()"
1935                         : "::mlir::MutableArrayRef<::mlir::Region>(this->"
1936                           "getOperation()->getRegion({1}))",
1937                     region.name, i);
1938     body << ") {\n";
1939     auto constraint = tgfmt(region.constraint.getConditionTemplate(),
1940                             &verifyCtx.withSelf("region"))
1941                           .str();
1942 
1943     body << formatv("      (void)region;\n"
1944                     "      if (!({0})) {\n        "
1945                     "return emitOpError(\"region #\") << index << \" {1}"
1946                     "failed to "
1947                     "verify constraint: {2}\";\n      }\n",
1948                     constraint,
1949                     region.name.empty() ? "" : "('" + region.name + "') ",
1950                     region.constraint.getDescription())
1951          << "      ++index;\n"
1952          << "    }\n";
1953   }
1954   body << "  }\n";
1955 }
1956 
genSuccessorVerifier(OpMethodBody & body)1957 void OpEmitter::genSuccessorVerifier(OpMethodBody &body) {
1958   // If we have no successors, there is nothing more to do.
1959   unsigned numSuccessors = op.getNumSuccessors();
1960   if (numSuccessors == 0)
1961     return;
1962 
1963   body << "{\n";
1964   body << "    unsigned index = 0; (void)index;\n";
1965 
1966   for (unsigned i = 0; i < numSuccessors; ++i) {
1967     const auto &successor = op.getSuccessor(i);
1968     if (successor.constraint.getPredicate().isNull())
1969       continue;
1970 
1971     if (successor.isVariadic()) {
1972       body << formatv("    for (::mlir::Block *successor : {0}()) {\n",
1973                       successor.name);
1974     } else {
1975       body << "    {\n";
1976       body << formatv("      ::mlir::Block *successor = {0}();\n",
1977                       successor.name);
1978     }
1979     auto constraint = tgfmt(successor.constraint.getConditionTemplate(),
1980                             &verifyCtx.withSelf("successor"))
1981                           .str();
1982 
1983     body << formatv("      (void)successor;\n"
1984                     "      if (!({0})) {\n        "
1985                     "return emitOpError(\"successor #\") << index << \"('{1}') "
1986                     "failed to "
1987                     "verify constraint: {2}\";\n      }\n",
1988                     constraint, successor.name,
1989                     successor.constraint.getDescription())
1990          << "      ++index;\n"
1991          << "    }\n";
1992   }
1993   body << "  }\n";
1994 }
1995 
1996 /// Add a size count trait to the given operation class.
addSizeCountTrait(OpClass & opClass,StringRef traitKind,int numTotal,int numVariadic)1997 static void addSizeCountTrait(OpClass &opClass, StringRef traitKind,
1998                               int numTotal, int numVariadic) {
1999   if (numVariadic != 0) {
2000     if (numTotal == numVariadic)
2001       opClass.addTrait("::mlir::OpTrait::Variadic" + traitKind + "s");
2002     else
2003       opClass.addTrait("::mlir::OpTrait::AtLeastN" + traitKind + "s<" +
2004                        Twine(numTotal - numVariadic) + ">::Impl");
2005     return;
2006   }
2007   switch (numTotal) {
2008   case 0:
2009     opClass.addTrait("::mlir::OpTrait::Zero" + traitKind);
2010     break;
2011   case 1:
2012     opClass.addTrait("::mlir::OpTrait::One" + traitKind);
2013     break;
2014   default:
2015     opClass.addTrait("::mlir::OpTrait::N" + traitKind + "s<" + Twine(numTotal) +
2016                      ">::Impl");
2017     break;
2018   }
2019 }
2020 
genTraits()2021 void OpEmitter::genTraits() {
2022   // Add region size trait.
2023   unsigned numRegions = op.getNumRegions();
2024   unsigned numVariadicRegions = op.getNumVariadicRegions();
2025   addSizeCountTrait(opClass, "Region", numRegions, numVariadicRegions);
2026 
2027   // Add result size trait.
2028   int numResults = op.getNumResults();
2029   int numVariadicResults = op.getNumVariableLengthResults();
2030   addSizeCountTrait(opClass, "Result", numResults, numVariadicResults);
2031 
2032   // Add successor size trait.
2033   unsigned numSuccessors = op.getNumSuccessors();
2034   unsigned numVariadicSuccessors = op.getNumVariadicSuccessors();
2035   addSizeCountTrait(opClass, "Successor", numSuccessors, numVariadicSuccessors);
2036 
2037   // Add variadic size trait and normal op traits.
2038   int numOperands = op.getNumOperands();
2039   int numVariadicOperands = op.getNumVariableLengthOperands();
2040 
2041   // Add operand size trait.
2042   if (numVariadicOperands != 0) {
2043     if (numOperands == numVariadicOperands)
2044       opClass.addTrait("::mlir::OpTrait::VariadicOperands");
2045     else
2046       opClass.addTrait("::mlir::OpTrait::AtLeastNOperands<" +
2047                        Twine(numOperands - numVariadicOperands) + ">::Impl");
2048   } else {
2049     switch (numOperands) {
2050     case 0:
2051       opClass.addTrait("::mlir::OpTrait::ZeroOperands");
2052       break;
2053     case 1:
2054       opClass.addTrait("::mlir::OpTrait::OneOperand");
2055       break;
2056     default:
2057       opClass.addTrait("::mlir::OpTrait::NOperands<" + Twine(numOperands) +
2058                        ">::Impl");
2059       break;
2060     }
2061   }
2062 
2063   // Add the native and interface traits.
2064   for (const auto &trait : op.getTraits()) {
2065     if (auto opTrait = dyn_cast<tblgen::NativeOpTrait>(&trait))
2066       opClass.addTrait(opTrait->getTrait());
2067     else if (auto opTrait = dyn_cast<tblgen::InterfaceOpTrait>(&trait))
2068       opClass.addTrait(opTrait->getTrait());
2069   }
2070 }
2071 
genOpNameGetter()2072 void OpEmitter::genOpNameGetter() {
2073   auto *method = opClass.addMethodAndPrune(
2074       "::llvm::StringRef", "getOperationName", OpMethod::MP_Static);
2075   method->body() << "  return \"" << op.getOperationName() << "\";\n";
2076 }
2077 
genOpAsmInterface()2078 void OpEmitter::genOpAsmInterface() {
2079   // If the user only has one results or specifically added the Asm trait,
2080   // then don't generate it for them. We specifically only handle multi result
2081   // operations, because the name of a single result in the common case is not
2082   // interesting(generally 'result'/'output'/etc.).
2083   // TODO: We could also add a flag to allow operations to opt in to this
2084   // generation, even if they only have a single operation.
2085   int numResults = op.getNumResults();
2086   if (numResults <= 1 || op.getTrait("::mlir::OpAsmOpInterface::Trait"))
2087     return;
2088 
2089   SmallVector<StringRef, 4> resultNames(numResults);
2090   for (int i = 0; i != numResults; ++i)
2091     resultNames[i] = op.getResultName(i);
2092 
2093   // Don't add the trait if none of the results have a valid name.
2094   if (llvm::all_of(resultNames, [](StringRef name) { return name.empty(); }))
2095     return;
2096   opClass.addTrait("::mlir::OpAsmOpInterface::Trait");
2097 
2098   // Generate the right accessor for the number of results.
2099   auto *method = opClass.addMethodAndPrune(
2100       "void", "getAsmResultNames", "::mlir::OpAsmSetValueNameFn", "setNameFn");
2101   auto &body = method->body();
2102   for (int i = 0; i != numResults; ++i) {
2103     body << "  auto resultGroup" << i << " = getODSResults(" << i << ");\n"
2104          << "  if (!llvm::empty(resultGroup" << i << "))\n"
2105          << "    setNameFn(*resultGroup" << i << ".begin(), \""
2106          << resultNames[i] << "\");\n";
2107   }
2108 }
2109 
2110 //===----------------------------------------------------------------------===//
2111 // OpOperandAdaptor emitter
2112 //===----------------------------------------------------------------------===//
2113 
2114 namespace {
2115 // Helper class to emit Op operand adaptors to an output stream.  Operand
2116 // adaptors are wrappers around ArrayRef<Value> that provide named operand
2117 // getters identical to those defined in the Op.
2118 class OpOperandAdaptorEmitter {
2119 public:
2120   static void emitDecl(const Operator &op, raw_ostream &os);
2121   static void emitDef(const Operator &op, raw_ostream &os);
2122 
2123 private:
2124   explicit OpOperandAdaptorEmitter(const Operator &op);
2125 
2126   // Add verification function. This generates a verify method for the adaptor
2127   // which verifies all the op-independent attribute constraints.
2128   void addVerification();
2129 
2130   const Operator &op;
2131   Class adaptor;
2132 };
2133 } // end namespace
2134 
OpOperandAdaptorEmitter(const Operator & op)2135 OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(const Operator &op)
2136     : op(op), adaptor(op.getAdaptorName()) {
2137   adaptor.newField("::mlir::ValueRange", "odsOperands");
2138   adaptor.newField("::mlir::DictionaryAttr", "odsAttrs");
2139   const auto *attrSizedOperands =
2140       op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments");
2141   {
2142     SmallVector<OpMethodParameter, 2> paramList;
2143     paramList.emplace_back("::mlir::ValueRange", "values");
2144     paramList.emplace_back("::mlir::DictionaryAttr", "attrs",
2145                            attrSizedOperands ? "" : "nullptr");
2146     auto *constructor = adaptor.addConstructorAndPrune(std::move(paramList));
2147 
2148     constructor->addMemberInitializer("odsOperands", "values");
2149     constructor->addMemberInitializer("odsAttrs", "attrs");
2150   }
2151 
2152   {
2153     auto *constructor = adaptor.addConstructorAndPrune(
2154         llvm::formatv("{0}&", op.getCppClassName()).str(), "op");
2155     constructor->addMemberInitializer("odsOperands", "op->getOperands()");
2156     constructor->addMemberInitializer("odsAttrs", "op->getAttrDictionary()");
2157   }
2158 
2159   std::string sizeAttrInit =
2160       formatv(adapterSegmentSizeAttrInitCode, "operand_segment_sizes");
2161   generateNamedOperandGetters(op, adaptor, sizeAttrInit,
2162                               /*rangeType=*/"::mlir::ValueRange",
2163                               /*rangeBeginCall=*/"odsOperands.begin()",
2164                               /*rangeSizeCall=*/"odsOperands.size()",
2165                               /*getOperandCallPattern=*/"odsOperands[{0}]");
2166 
2167   FmtContext fctx;
2168   fctx.withBuilder("::mlir::Builder(odsAttrs.getContext())");
2169 
2170   auto emitAttr = [&](StringRef name, Attribute attr) {
2171     auto &body = adaptor.addMethodAndPrune(attr.getStorageType(), name)->body();
2172     body << "  assert(odsAttrs && \"no attributes when constructing adapter\");"
2173          << "\n  " << attr.getStorageType() << " attr = "
2174          << "odsAttrs.get(\"" << name << "\").";
2175     if (attr.hasDefaultValue() || attr.isOptional())
2176       body << "dyn_cast_or_null<";
2177     else
2178       body << "cast<";
2179     body << attr.getStorageType() << ">();\n";
2180 
2181     if (attr.hasDefaultValue()) {
2182       // Use the default value if attribute is not set.
2183       // TODO: this is inefficient, we are recreating the attribute for every
2184       // call. This should be set instead.
2185       std::string defaultValue = std::string(
2186           tgfmt(attr.getConstBuilderTemplate(), &fctx, attr.getDefaultValue()));
2187       body << "  if (!attr)\n    attr = " << defaultValue << ";\n";
2188     }
2189     body << "  return attr;\n";
2190   };
2191 
2192   for (auto &namedAttr : op.getAttributes()) {
2193     const auto &name = namedAttr.name;
2194     const auto &attr = namedAttr.attr;
2195     if (!attr.isDerivedAttr())
2196       emitAttr(name, attr);
2197   }
2198 
2199   // Add verification function.
2200   addVerification();
2201 }
2202 
addVerification()2203 void OpOperandAdaptorEmitter::addVerification() {
2204   auto *method = adaptor.addMethodAndPrune("::mlir::LogicalResult", "verify",
2205                                            "::mlir::Location", "loc");
2206   auto &body = method->body();
2207 
2208   const char *checkAttrSizedValueSegmentsCode = R"(
2209   {
2210     auto sizeAttr = odsAttrs.get("{0}").cast<::mlir::DenseIntElementsAttr>();
2211     auto numElements = sizeAttr.getType().cast<::mlir::ShapedType>().getNumElements();
2212     if (numElements != {1})
2213       return emitError(loc, "'{0}' attribute for specifying {2} segments "
2214                        "must have {1} elements");
2215   }
2216   )";
2217 
2218   // Verify a few traits first so that we can use
2219   // getODSOperands()/getODSResults() in the rest of the verifier.
2220   for (auto &trait : op.getTraits()) {
2221     if (auto *t = dyn_cast<tblgen::NativeOpTrait>(&trait)) {
2222       if (t->getTrait() == "::mlir::OpTrait::AttrSizedOperandSegments") {
2223         body << formatv(checkAttrSizedValueSegmentsCode,
2224                         "operand_segment_sizes", op.getNumOperands(),
2225                         "operand");
2226       } else if (t->getTrait() == "::mlir::OpTrait::AttrSizedResultSegments") {
2227         body << formatv(checkAttrSizedValueSegmentsCode, "result_segment_sizes",
2228                         op.getNumResults(), "result");
2229       }
2230     }
2231   }
2232 
2233   FmtContext verifyCtx;
2234   populateSubstitutions(op, "odsAttrs.get", "getODSOperands",
2235                         "<no results should be genarated>", verifyCtx);
2236   genAttributeVerifier(op, "odsAttrs.get",
2237                        Twine("emitError(loc, \"'") + op.getOperationName() +
2238                            "' op \"",
2239                        /*emitVerificationRequiringOp*/ false, verifyCtx, body);
2240 
2241   body << "  return ::mlir::success();";
2242 }
2243 
emitDecl(const Operator & op,raw_ostream & os)2244 void OpOperandAdaptorEmitter::emitDecl(const Operator &op, raw_ostream &os) {
2245   OpOperandAdaptorEmitter(op).adaptor.writeDeclTo(os);
2246 }
2247 
emitDef(const Operator & op,raw_ostream & os)2248 void OpOperandAdaptorEmitter::emitDef(const Operator &op, raw_ostream &os) {
2249   OpOperandAdaptorEmitter(op).adaptor.writeDefTo(os);
2250 }
2251 
2252 // Emits the opcode enum and op classes.
emitOpClasses(const std::vector<Record * > & defs,raw_ostream & os,bool emitDecl)2253 static void emitOpClasses(const std::vector<Record *> &defs, raw_ostream &os,
2254                           bool emitDecl) {
2255   // First emit forward declaration for each class, this allows them to refer
2256   // to each others in traits for example.
2257   if (emitDecl) {
2258     os << "#if defined(GET_OP_CLASSES) || defined(GET_OP_FWD_DEFINES)\n";
2259     os << "#undef GET_OP_FWD_DEFINES\n";
2260     for (auto *def : defs) {
2261       Operator op(*def);
2262       NamespaceEmitter emitter(os, op.getDialect());
2263       os << "class " << op.getCppClassName() << ";\n";
2264     }
2265     os << "#endif\n\n";
2266   }
2267 
2268   IfDefScope scope("GET_OP_CLASSES", os);
2269   for (auto *def : defs) {
2270     Operator op(*def);
2271     NamespaceEmitter emitter(os, op.getDialect());
2272     if (emitDecl) {
2273       os << formatv(opCommentHeader, op.getQualCppClassName(), "declarations");
2274       OpOperandAdaptorEmitter::emitDecl(op, os);
2275       OpEmitter::emitDecl(op, os);
2276     } else {
2277       os << formatv(opCommentHeader, op.getQualCppClassName(), "definitions");
2278       OpOperandAdaptorEmitter::emitDef(op, os);
2279       OpEmitter::emitDef(op, os);
2280     }
2281   }
2282 }
2283 
2284 // Emits a comma-separated list of the ops.
emitOpList(const std::vector<Record * > & defs,raw_ostream & os)2285 static void emitOpList(const std::vector<Record *> &defs, raw_ostream &os) {
2286   IfDefScope scope("GET_OP_LIST", os);
2287 
2288   interleave(
2289       // TODO: We are constructing the Operator wrapper instance just for
2290       // getting it's qualified class name here. Reduce the overhead by having a
2291       // lightweight version of Operator class just for that purpose.
2292       defs, [&os](Record *def) { os << Operator(def).getQualCppClassName(); },
2293       [&os]() { os << ",\n"; });
2294 }
2295 
getOperationName(const Record & def)2296 static std::string getOperationName(const Record &def) {
2297   auto prefix = def.getValueAsDef("opDialect")->getValueAsString("name");
2298   auto opName = def.getValueAsString("opName");
2299   if (prefix.empty())
2300     return std::string(opName);
2301   return std::string(llvm::formatv("{0}.{1}", prefix, opName));
2302 }
2303 
2304 static std::vector<Record *>
getAllDerivedDefinitions(const RecordKeeper & recordKeeper,StringRef className)2305 getAllDerivedDefinitions(const RecordKeeper &recordKeeper,
2306                          StringRef className) {
2307   Record *classDef = recordKeeper.getClass(className);
2308   if (!classDef)
2309     PrintFatalError("ERROR: Couldn't find the `" + className + "' class!\n");
2310 
2311   llvm::Regex includeRegex(opIncFilter), excludeRegex(opExcFilter);
2312   std::vector<Record *> defs;
2313   for (const auto &def : recordKeeper.getDefs()) {
2314     if (!def.second->isSubClassOf(classDef))
2315       continue;
2316     // Include if no include filter or include filter matches.
2317     if (!opIncFilter.empty() &&
2318         !includeRegex.match(getOperationName(*def.second)))
2319       continue;
2320     // Unless there is an exclude filter and it matches.
2321     if (!opExcFilter.empty() &&
2322         excludeRegex.match(getOperationName(*def.second)))
2323       continue;
2324     defs.push_back(def.second.get());
2325   }
2326 
2327   return defs;
2328 }
2329 
emitOpDecls(const RecordKeeper & recordKeeper,raw_ostream & os)2330 static bool emitOpDecls(const RecordKeeper &recordKeeper, raw_ostream &os) {
2331   emitSourceFileHeader("Op Declarations", os);
2332 
2333   const auto &defs = getAllDerivedDefinitions(recordKeeper, "Op");
2334   emitOpClasses(defs, os, /*emitDecl=*/true);
2335 
2336   return false;
2337 }
2338 
emitOpDefs(const RecordKeeper & recordKeeper,raw_ostream & os)2339 static bool emitOpDefs(const RecordKeeper &recordKeeper, raw_ostream &os) {
2340   emitSourceFileHeader("Op Definitions", os);
2341 
2342   const auto &defs = getAllDerivedDefinitions(recordKeeper, "Op");
2343   emitOpList(defs, os);
2344   emitOpClasses(defs, os, /*emitDecl=*/false);
2345 
2346   return false;
2347 }
2348 
2349 static mlir::GenRegistration
2350     genOpDecls("gen-op-decls", "Generate op declarations",
__anonec93f51b1402(const RecordKeeper &records, raw_ostream &os) 2351                [](const RecordKeeper &records, raw_ostream &os) {
2352                  return emitOpDecls(records, os);
2353                });
2354 
2355 static mlir::GenRegistration genOpDefs("gen-op-defs", "Generate op definitions",
2356                                        [](const RecordKeeper &records,
__anonec93f51b1502(const RecordKeeper &records, raw_ostream &os) 2357                                           raw_ostream &os) {
2358                                          return emitOpDefs(records, os);
2359                                        });
2360