1 //===- OpClass.cpp - Helper classes for Op C++ code emission --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/TableGen/OpClass.h"
10
11 #include "mlir/TableGen/Format.h"
12 #include "llvm/ADT/Sequence.h"
13 #include "llvm/ADT/Twine.h"
14 #include "llvm/Support/Debug.h"
15 #include "llvm/Support/raw_ostream.h"
16 #include <unordered_set>
17
18 #define DEBUG_TYPE "mlir-tblgen-opclass"
19
20 using namespace mlir;
21 using namespace mlir::tblgen;
22
23 namespace {
24
25 // Returns space to be emitted after the given C++ `type`. return "" if the
26 // ends with '&' or '*', or is empty, else returns " ".
getSpaceAfterType(StringRef type)27 StringRef getSpaceAfterType(StringRef type) {
28 return (type.empty() || type.endswith("&") || type.endswith("*")) ? "" : " ";
29 }
30
31 } // namespace
32
33 //===----------------------------------------------------------------------===//
34 // OpMethodParameter definitions
35 //===----------------------------------------------------------------------===//
36
writeTo(raw_ostream & os,bool emitDefault) const37 void OpMethodParameter::writeTo(raw_ostream &os, bool emitDefault) const {
38 if (properties & PP_Optional)
39 os << "/*optional*/";
40 os << type << getSpaceAfterType(type) << name;
41 if (emitDefault && !defaultValue.empty())
42 os << " = " << defaultValue;
43 }
44
45 //===----------------------------------------------------------------------===//
46 // OpMethodParameters definitions
47 //===----------------------------------------------------------------------===//
48
49 // Factory methods to construct the correct type of `OpMethodParameters`
50 // object based on the arguments.
create()51 std::unique_ptr<OpMethodParameters> OpMethodParameters::create() {
52 return std::make_unique<OpMethodResolvedParameters>();
53 }
54
55 std::unique_ptr<OpMethodParameters>
create(StringRef params)56 OpMethodParameters::create(StringRef params) {
57 return std::make_unique<OpMethodUnresolvedParameters>(params);
58 }
59
60 std::unique_ptr<OpMethodParameters>
create(llvm::SmallVectorImpl<OpMethodParameter> && params)61 OpMethodParameters::create(llvm::SmallVectorImpl<OpMethodParameter> &¶ms) {
62 return std::make_unique<OpMethodResolvedParameters>(std::move(params));
63 }
64
65 std::unique_ptr<OpMethodParameters>
create(StringRef type,StringRef name,StringRef defaultValue)66 OpMethodParameters::create(StringRef type, StringRef name,
67 StringRef defaultValue) {
68 return std::make_unique<OpMethodResolvedParameters>(type, name, defaultValue);
69 }
70
71 //===----------------------------------------------------------------------===//
72 // OpMethodUnresolvedParameters definitions
73 //===----------------------------------------------------------------------===//
writeDeclTo(raw_ostream & os) const74 void OpMethodUnresolvedParameters::writeDeclTo(raw_ostream &os) const {
75 os << parameters;
76 }
77
writeDefTo(raw_ostream & os) const78 void OpMethodUnresolvedParameters::writeDefTo(raw_ostream &os) const {
79 // We need to remove the default values for parameters in method definition.
80 // TODO: We are using '=' and ',' as delimiters for parameter
81 // initializers. This is incorrect for initializer list with more than one
82 // element. Change to a more robust approach.
83 llvm::SmallVector<StringRef, 4> tokens;
84 StringRef params = parameters;
85 while (!params.empty()) {
86 std::pair<StringRef, StringRef> parts = params.split("=");
87 tokens.push_back(parts.first);
88 params = parts.second.split(',').second;
89 }
90 llvm::interleaveComma(tokens, os, [&](StringRef token) { os << token; });
91 }
92
93 //===----------------------------------------------------------------------===//
94 // OpMethodResolvedParameters definitions
95 //===----------------------------------------------------------------------===//
96
97 // Returns true if a method with these parameters makes a method with parameters
98 // `other` redundant. This should return true only if all possible calls to the
99 // other method can be replaced by calls to this method.
makesRedundant(const OpMethodResolvedParameters & other) const100 bool OpMethodResolvedParameters::makesRedundant(
101 const OpMethodResolvedParameters &other) const {
102 const size_t otherNumParams = other.getNumParameters();
103 const size_t thisNumParams = getNumParameters();
104
105 // All calls to the other method can be replaced this method only if this
106 // method has the same or more arguments number of arguments as the other, and
107 // the common arguments have the same type.
108 if (thisNumParams < otherNumParams)
109 return false;
110 for (int idx : llvm::seq<int>(0, otherNumParams))
111 if (parameters[idx].getType() != other.parameters[idx].getType())
112 return false;
113
114 // If all the common arguments have the same type, we can elide the other
115 // method if this method has the same number of arguments as other or the
116 // first argument after the common ones has a default value (and by C++
117 // requirement, all the later ones will also have a default value).
118 return thisNumParams == otherNumParams ||
119 parameters[otherNumParams].hasDefaultValue();
120 }
121
writeDeclTo(raw_ostream & os) const122 void OpMethodResolvedParameters::writeDeclTo(raw_ostream &os) const {
123 llvm::interleaveComma(parameters, os, [&](const OpMethodParameter ¶m) {
124 param.writeDeclTo(os);
125 });
126 }
127
writeDefTo(raw_ostream & os) const128 void OpMethodResolvedParameters::writeDefTo(raw_ostream &os) const {
129 llvm::interleaveComma(parameters, os, [&](const OpMethodParameter ¶m) {
130 param.writeDefTo(os);
131 });
132 }
133
134 //===----------------------------------------------------------------------===//
135 // OpMethodSignature definitions
136 //===----------------------------------------------------------------------===//
137
138 // Returns if a method with this signature makes a method with `other` signature
139 // redundant. Only supports resolved parameters.
makesRedundant(const OpMethodSignature & other) const140 bool OpMethodSignature::makesRedundant(const OpMethodSignature &other) const {
141 if (methodName != other.methodName)
142 return false;
143 auto *resolvedThis = dyn_cast<OpMethodResolvedParameters>(parameters.get());
144 auto *resolvedOther =
145 dyn_cast<OpMethodResolvedParameters>(other.parameters.get());
146 if (resolvedThis && resolvedOther)
147 return resolvedThis->makesRedundant(*resolvedOther);
148 return false;
149 }
150
writeDeclTo(raw_ostream & os) const151 void OpMethodSignature::writeDeclTo(raw_ostream &os) const {
152 os << returnType << getSpaceAfterType(returnType) << methodName << "(";
153 parameters->writeDeclTo(os);
154 os << ")";
155 }
156
writeDefTo(raw_ostream & os,StringRef namePrefix) const157 void OpMethodSignature::writeDefTo(raw_ostream &os,
158 StringRef namePrefix) const {
159 os << returnType << getSpaceAfterType(returnType) << namePrefix
160 << (namePrefix.empty() ? "" : "::") << methodName << "(";
161 parameters->writeDefTo(os);
162 os << ")";
163 }
164
165 //===----------------------------------------------------------------------===//
166 // OpMethodBody definitions
167 //===----------------------------------------------------------------------===//
168
OpMethodBody(bool declOnly)169 OpMethodBody::OpMethodBody(bool declOnly) : isEffective(!declOnly) {}
170
operator <<(Twine content)171 OpMethodBody &OpMethodBody::operator<<(Twine content) {
172 if (isEffective)
173 body.append(content.str());
174 return *this;
175 }
176
operator <<(int content)177 OpMethodBody &OpMethodBody::operator<<(int content) {
178 if (isEffective)
179 body.append(std::to_string(content));
180 return *this;
181 }
182
operator <<(const FmtObjectBase & content)183 OpMethodBody &OpMethodBody::operator<<(const FmtObjectBase &content) {
184 if (isEffective)
185 body.append(content.str());
186 return *this;
187 }
188
writeTo(raw_ostream & os) const189 void OpMethodBody::writeTo(raw_ostream &os) const {
190 auto bodyRef = StringRef(body).drop_while([](char c) { return c == '\n'; });
191 os << bodyRef;
192 if (bodyRef.empty() || bodyRef.back() != '\n')
193 os << "\n";
194 }
195
196 //===----------------------------------------------------------------------===//
197 // OpMethod definitions
198 //===----------------------------------------------------------------------===//
199
writeDeclTo(raw_ostream & os) const200 void OpMethod::writeDeclTo(raw_ostream &os) const {
201 os.indent(2);
202 if (isStatic())
203 os << "static ";
204 methodSignature.writeDeclTo(os);
205 os << ";";
206 }
207
writeDefTo(raw_ostream & os,StringRef namePrefix) const208 void OpMethod::writeDefTo(raw_ostream &os, StringRef namePrefix) const {
209 // Do not write definition if the method is decl only.
210 if (properties & MP_Declaration)
211 return;
212 methodSignature.writeDefTo(os, namePrefix);
213 os << " {\n";
214 methodBody.writeTo(os);
215 os << "}";
216 }
217
218 //===----------------------------------------------------------------------===//
219 // OpConstructor definitions
220 //===----------------------------------------------------------------------===//
221
addMemberInitializer(StringRef name,StringRef value)222 void OpConstructor::addMemberInitializer(StringRef name, StringRef value) {
223 memberInitializers.append(std::string(llvm::formatv(
224 "{0}{1}({2})", memberInitializers.empty() ? " : " : ", ", name, value)));
225 }
226
writeDefTo(raw_ostream & os,StringRef namePrefix) const227 void OpConstructor::writeDefTo(raw_ostream &os, StringRef namePrefix) const {
228 // Do not write definition if the method is decl only.
229 if (properties & MP_Declaration)
230 return;
231
232 methodSignature.writeDefTo(os, namePrefix);
233 os << " " << memberInitializers << " {\n";
234 methodBody.writeTo(os);
235 os << "}";
236 }
237
238 //===----------------------------------------------------------------------===//
239 // Class definitions
240 //===----------------------------------------------------------------------===//
241
Class(StringRef name)242 Class::Class(StringRef name) : className(name) {}
243
newField(StringRef type,StringRef name,StringRef defaultValue)244 void Class::newField(StringRef type, StringRef name, StringRef defaultValue) {
245 std::string varName = formatv("{0} {1}", type, name).str();
246 std::string field = defaultValue.empty()
247 ? varName
248 : formatv("{0} = {1}", varName, defaultValue).str();
249 fields.push_back(std::move(field));
250 }
writeDeclTo(raw_ostream & os) const251 void Class::writeDeclTo(raw_ostream &os) const {
252 bool hasPrivateMethod = false;
253 os << "class " << className << " {\n";
254 os << "public:\n";
255
256 forAllMethods([&](const OpMethod &method) {
257 if (!method.isPrivate()) {
258 method.writeDeclTo(os);
259 os << '\n';
260 } else {
261 hasPrivateMethod = true;
262 }
263 });
264
265 os << '\n';
266 os << "private:\n";
267 if (hasPrivateMethod) {
268 forAllMethods([&](const OpMethod &method) {
269 if (method.isPrivate()) {
270 method.writeDeclTo(os);
271 os << '\n';
272 }
273 });
274 os << '\n';
275 }
276
277 for (const auto &field : fields)
278 os.indent(2) << field << ";\n";
279 os << "};\n";
280 }
281
writeDefTo(raw_ostream & os) const282 void Class::writeDefTo(raw_ostream &os) const {
283 forAllMethods([&](const OpMethod &method) {
284 method.writeDefTo(os, className);
285 os << "\n\n";
286 });
287 }
288
289 //===----------------------------------------------------------------------===//
290 // OpClass definitions
291 //===----------------------------------------------------------------------===//
292
OpClass(StringRef name,StringRef extraClassDeclaration)293 OpClass::OpClass(StringRef name, StringRef extraClassDeclaration)
294 : Class(name), extraClassDeclaration(extraClassDeclaration) {}
295
addTrait(Twine trait)296 void OpClass::addTrait(Twine trait) {
297 auto traitStr = trait.str();
298 if (traitsSet.insert(traitStr).second)
299 traitsVec.push_back(std::move(traitStr));
300 }
301
writeDeclTo(raw_ostream & os) const302 void OpClass::writeDeclTo(raw_ostream &os) const {
303 os << "class " << className << " : public ::mlir::Op<" << className;
304 for (const auto &trait : traitsVec)
305 os << ", " << trait;
306 os << "> {\npublic:\n"
307 << " using Op::Op;\n"
308 << " using Op::print;\n"
309 << " using Adaptor = " << className << "Adaptor;\n";
310
311 bool hasPrivateMethod = false;
312 forAllMethods([&](const OpMethod &method) {
313 if (!method.isPrivate()) {
314 method.writeDeclTo(os);
315 os << "\n";
316 } else {
317 hasPrivateMethod = true;
318 }
319 });
320
321 // TODO: Add line control markers to make errors easier to debug.
322 if (!extraClassDeclaration.empty())
323 os << extraClassDeclaration << "\n";
324
325 if (hasPrivateMethod) {
326 os << "\nprivate:\n";
327 forAllMethods([&](const OpMethod &method) {
328 if (method.isPrivate()) {
329 method.writeDeclTo(os);
330 os << "\n";
331 }
332 });
333 }
334
335 os << "};\n";
336 }
337