• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //===- Pass.cpp - MLIR pass registration generator ------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // PassGen uses the description of passes to generate base classes for passes
10 // and command line registration.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/TableGen/GenInfo.h"
15 #include "mlir/TableGen/Pass.h"
16 #include "llvm/ADT/StringExtras.h"
17 #include "llvm/Support/CommandLine.h"
18 #include "llvm/Support/FormatVariadic.h"
19 #include "llvm/TableGen/Error.h"
20 #include "llvm/TableGen/Record.h"
21 
22 using namespace mlir;
23 using namespace mlir::tblgen;
24 
25 static llvm::cl::OptionCategory passGenCat("Options for -gen-pass-decls");
26 static llvm::cl::opt<std::string>
27     groupName("name", llvm::cl::desc("The name of this group of passes"),
28               llvm::cl::cat(passGenCat));
29 
30 //===----------------------------------------------------------------------===//
31 // GEN: Pass base class generation
32 //===----------------------------------------------------------------------===//
33 
34 /// The code snippet used to generate the start of a pass base class.
35 ///
36 /// {0}: The def name of the pass record.
37 /// {1}: The base class for the pass.
38 /// {2): The command line argument for the pass.
39 /// {3}: The dependent dialects registration.
40 const char *const passDeclBegin = R"(
41 //===----------------------------------------------------------------------===//
42 // {0}
43 //===----------------------------------------------------------------------===//
44 
45 template <typename DerivedT>
46 class {0}Base : public {1} {
47 public:
48   {0}Base() : {1}(::mlir::TypeID::get<DerivedT>()) {{}
49   {0}Base(const {0}Base &) : {1}(::mlir::TypeID::get<DerivedT>()) {{}
50 
51   /// Returns the command-line argument attached to this pass.
52   ::llvm::StringRef getArgument() const override { return "{2}"; }
53 
54   /// Returns the derived pass name.
55   ::llvm::StringRef getName() const override { return "{0}"; }
56 
57   /// Support isa/dyn_cast functionality for the derived pass class.
58   static bool classof(const ::mlir::Pass *pass) {{
59     return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>();
60   }
61 
62   /// A clone method to create a copy of this pass.
63   std::unique_ptr<::mlir::Pass> clonePass() const override {{
64     return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this));
65   }
66 
67   /// Return the dialect that must be loaded in the context before this pass.
68   void getDependentDialects(::mlir::DialectRegistry &registry) const override {
69     {3}
70   }
71 
72 protected:
73 )";
74 
75 /// Registration for a single dependent dialect, to be inserted for each
76 /// dependent dialect in the `getDependentDialects` above.
77 const char *const dialectRegistrationTemplate = R"(
78   registry.insert<{0}>();
79 )";
80 
81 /// Emit the declarations for each of the pass options.
emitPassOptionDecls(const Pass & pass,raw_ostream & os)82 static void emitPassOptionDecls(const Pass &pass, raw_ostream &os) {
83   for (const PassOption &opt : pass.getOptions()) {
84     os.indent(2) << "::mlir::Pass::"
85                  << (opt.isListOption() ? "ListOption" : "Option");
86 
87     os << llvm::formatv("<{0}> {1}{{*this, \"{2}\", ::llvm::cl::desc(\"{3}\")",
88                         opt.getType(), opt.getCppVariableName(),
89                         opt.getArgument(), opt.getDescription());
90     if (Optional<StringRef> defaultVal = opt.getDefaultValue())
91       os << ", ::llvm::cl::init(" << defaultVal << ")";
92     if (Optional<StringRef> additionalFlags = opt.getAdditionalFlags())
93       os << ", " << *additionalFlags;
94     os << "};\n";
95   }
96 }
97 
98 /// Emit the declarations for each of the pass statistics.
emitPassStatisticDecls(const Pass & pass,raw_ostream & os)99 static void emitPassStatisticDecls(const Pass &pass, raw_ostream &os) {
100   for (const PassStatistic &stat : pass.getStatistics()) {
101     os << llvm::formatv(
102         "  ::mlir::Pass::Statistic {0}{{this, \"{1}\", \"{2}\"};\n",
103         stat.getCppVariableName(), stat.getName(), stat.getDescription());
104   }
105 }
106 
emitPassDecl(const Pass & pass,raw_ostream & os)107 static void emitPassDecl(const Pass &pass, raw_ostream &os) {
108   StringRef defName = pass.getDef()->getName();
109   std::string dependentDialectRegistrations;
110   {
111     llvm::raw_string_ostream dialectsOs(dependentDialectRegistrations);
112     for (StringRef dependentDialect : pass.getDependentDialects())
113       dialectsOs << llvm::formatv(dialectRegistrationTemplate,
114                                   dependentDialect);
115   }
116   os << llvm::formatv(passDeclBegin, defName, pass.getBaseClass(),
117                       pass.getArgument(), dependentDialectRegistrations);
118   emitPassOptionDecls(pass, os);
119   emitPassStatisticDecls(pass, os);
120   os << "};\n";
121 }
122 
123 /// Emit the code for registering each of the given passes with the global
124 /// PassRegistry.
emitPassDecls(ArrayRef<Pass> passes,raw_ostream & os)125 static void emitPassDecls(ArrayRef<Pass> passes, raw_ostream &os) {
126   os << "#ifdef GEN_PASS_CLASSES\n";
127   for (const Pass &pass : passes)
128     emitPassDecl(pass, os);
129   os << "#undef GEN_PASS_CLASSES\n";
130   os << "#endif // GEN_PASS_CLASSES\n";
131 }
132 
133 //===----------------------------------------------------------------------===//
134 // GEN: Pass registration generation
135 //===----------------------------------------------------------------------===//
136 
137 /// The code snippet used to generate the start of a pass base class.
138 ///
139 /// {0}: The def name of the pass record.
140 /// {1}: The argument of the pass.
141 /// {2): The summary of the pass.
142 /// {3}: The code for constructing the pass.
143 const char *const passRegistrationCode = R"(
144 //===----------------------------------------------------------------------===//
145 // {0} Registration
146 //===----------------------------------------------------------------------===//
147 
148 inline void register{0}Pass() {{
149   ::mlir::registerPass("{1}", "{2}", []() -> std::unique_ptr<::mlir::Pass> {{
150     return {3};
151   });
152 }
153 )";
154 
155 /// {0}: The name of the pass group.
156 const char *const passGroupRegistrationCode = R"(
157 //===----------------------------------------------------------------------===//
158 // {0} Registration
159 //===----------------------------------------------------------------------===//
160 
161 inline void register{0}Passes() {{
162 )";
163 
164 /// Emit the code for registering each of the given passes with the global
165 /// PassRegistry.
emitRegistration(ArrayRef<Pass> passes,raw_ostream & os)166 static void emitRegistration(ArrayRef<Pass> passes, raw_ostream &os) {
167   os << "#ifdef GEN_PASS_REGISTRATION\n";
168   for (const Pass &pass : passes) {
169     os << llvm::formatv(passRegistrationCode, pass.getDef()->getName(),
170                         pass.getArgument(), pass.getSummary(),
171                         pass.getConstructor());
172   }
173 
174   os << llvm::formatv(passGroupRegistrationCode, groupName);
175   for (const Pass &pass : passes)
176     os << "  register" << pass.getDef()->getName() << "Pass();\n";
177   os << "}\n";
178   os << "#undef GEN_PASS_REGISTRATION\n";
179   os << "#endif // GEN_PASS_REGISTRATION\n";
180 }
181 
182 //===----------------------------------------------------------------------===//
183 // GEN: Registration hooks
184 //===----------------------------------------------------------------------===//
185 
emitDecls(const llvm::RecordKeeper & recordKeeper,raw_ostream & os)186 static void emitDecls(const llvm::RecordKeeper &recordKeeper, raw_ostream &os) {
187   os << "/* Autogenerated by mlir-tblgen; don't manually edit */\n";
188   std::vector<Pass> passes;
189   for (const auto *def : recordKeeper.getAllDerivedDefinitions("PassBase"))
190     passes.push_back(Pass(def));
191 
192   emitPassDecls(passes, os);
193   emitRegistration(passes, os);
194 }
195 
196 static mlir::GenRegistration
197     genRegister("gen-pass-decls", "Generate operation documentation",
__anon10694ef70102(const llvm::RecordKeeper &records, raw_ostream &os) 198                 [](const llvm::RecordKeeper &records, raw_ostream &os) {
199                   emitDecls(records, os);
200                   return false;
201                 });
202