1 //===- OpClass.h - Helper classes for Op C++ code emission ------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file defines several classes for Op C++ code emission. They are only 10 // expected to be used by MLIR TableGen backends. 11 // 12 // We emit the op declaration and definition into separate files: *Ops.h.inc 13 // and *Ops.cpp.inc. The former is to be included in the dialect *Ops.h and 14 // the latter for dialect *Ops.cpp. This way provides a cleaner interface. 15 // 16 // In order to do this split, we need to track method signature and 17 // implementation logic separately. Signature information is used for both 18 // declaration and definition, while implementation logic is only for 19 // definition. So we have the following classes for C++ code emission. 20 // 21 //===----------------------------------------------------------------------===// 22 23 #ifndef MLIR_TABLEGEN_OPCLASS_H_ 24 #define MLIR_TABLEGEN_OPCLASS_H_ 25 26 #include "mlir/Support/LLVM.h" 27 #include "llvm/ADT/SetVector.h" 28 #include "llvm/ADT/SmallVector.h" 29 #include "llvm/ADT/StringRef.h" 30 #include "llvm/ADT/StringSet.h" 31 #include "llvm/Support/raw_ostream.h" 32 33 #include <set> 34 #include <string> 35 36 namespace mlir { 37 namespace tblgen { 38 class FmtObjectBase; 39 40 // Class for holding a single parameter of an op's method for C++ code emission. 41 class OpMethodParameter { 42 public: 43 // Properties (qualifiers) for the parameter. 44 enum Property { 45 PP_None = 0x0, 46 PP_Optional = 0x1, 47 }; 48 49 OpMethodParameter(StringRef type, StringRef name, StringRef defaultValue = "", 50 Property properties = PP_None) type(type)51 : type(type), name(name), defaultValue(defaultValue), 52 properties(properties) {} 53 OpMethodParameter(StringRef type,StringRef name,Property property)54 OpMethodParameter(StringRef type, StringRef name, Property property) 55 : OpMethodParameter(type, name, "", property) {} 56 57 // Writes the parameter as a part of a method declaration to `os`. writeDeclTo(raw_ostream & os)58 void writeDeclTo(raw_ostream &os) const { writeTo(os, /*emitDefault=*/true); } 59 60 // Writes the parameter as a part of a method definition to `os` writeDefTo(raw_ostream & os)61 void writeDefTo(raw_ostream &os) const { writeTo(os, /*emitDefault=*/false); } 62 getType()63 const std::string &getType() const { return type; } hasDefaultValue()64 bool hasDefaultValue() const { return !defaultValue.empty(); } 65 66 private: 67 void writeTo(raw_ostream &os, bool emitDefault) const; 68 69 std::string type; 70 std::string name; 71 std::string defaultValue; 72 Property properties; 73 }; 74 75 // Base class for holding parameters of an op's method for C++ code emission. 76 class OpMethodParameters { 77 public: 78 // Discriminator for LLVM-style RTTI. 79 enum ParamsKind { 80 // Separate type and name for each parameter is not known. 81 PK_Unresolved, 82 // Each parameter is resolved to a type and name. 83 PK_Resolved, 84 }; 85 OpMethodParameters(ParamsKind kind)86 OpMethodParameters(ParamsKind kind) : kind(kind) {} ~OpMethodParameters()87 virtual ~OpMethodParameters() {} 88 89 // LLVM-style RTTI support. getKind()90 ParamsKind getKind() const { return kind; } 91 92 // Writes the parameters as a part of a method declaration to `os`. 93 virtual void writeDeclTo(raw_ostream &os) const = 0; 94 95 // Writes the parameters as a part of a method definition to `os` 96 virtual void writeDefTo(raw_ostream &os) const = 0; 97 98 // Factory methods to create the correct type of `OpMethodParameters` 99 // object based on the arguments. 100 static std::unique_ptr<OpMethodParameters> create(); 101 102 static std::unique_ptr<OpMethodParameters> create(StringRef params); 103 104 static std::unique_ptr<OpMethodParameters> 105 create(llvm::SmallVectorImpl<OpMethodParameter> &¶ms); 106 107 static std::unique_ptr<OpMethodParameters> 108 create(StringRef type, StringRef name, StringRef defaultValue = ""); 109 110 private: 111 const ParamsKind kind; 112 }; 113 114 // Class for holding unresolved parameters. 115 class OpMethodUnresolvedParameters : public OpMethodParameters { 116 public: OpMethodUnresolvedParameters(StringRef params)117 OpMethodUnresolvedParameters(StringRef params) 118 : OpMethodParameters(PK_Unresolved), parameters(params) {} 119 120 // write the parameters as a part of a method declaration to the given `os`. 121 void writeDeclTo(raw_ostream &os) const override; 122 123 // write the parameters as a part of a method definition to the given `os` 124 void writeDefTo(raw_ostream &os) const override; 125 126 // LLVM-style RTTI support. classof(const OpMethodParameters * params)127 static bool classof(const OpMethodParameters *params) { 128 return params->getKind() == PK_Unresolved; 129 } 130 131 private: 132 std::string parameters; 133 }; 134 135 // Class for holding resolved parameters. 136 class OpMethodResolvedParameters : public OpMethodParameters { 137 public: OpMethodResolvedParameters()138 OpMethodResolvedParameters() : OpMethodParameters(PK_Resolved) {} 139 OpMethodResolvedParameters(llvm::SmallVectorImpl<OpMethodParameter> && params)140 OpMethodResolvedParameters(llvm::SmallVectorImpl<OpMethodParameter> &¶ms) 141 : OpMethodParameters(PK_Resolved) { 142 for (OpMethodParameter ¶m : params) 143 parameters.emplace_back(std::move(param)); 144 } 145 OpMethodResolvedParameters(StringRef type,StringRef name,StringRef defaultValue)146 OpMethodResolvedParameters(StringRef type, StringRef name, 147 StringRef defaultValue) 148 : OpMethodParameters(PK_Resolved) { 149 parameters.emplace_back(type, name, defaultValue); 150 } 151 152 // Returns the number of parameters. getNumParameters()153 size_t getNumParameters() const { return parameters.size(); } 154 155 // Returns if this method makes the `other` method redundant. Note that this 156 // is more than just finding conflicting methods. This method determines if 157 // the 2 set of parameters are conflicting and if so, returns true if this 158 // method has a more general set of parameters that can replace all possible 159 // calls to the `other` method. 160 bool makesRedundant(const OpMethodResolvedParameters &other) const; 161 162 // write the parameters as a part of a method declaration to the given `os`. 163 void writeDeclTo(raw_ostream &os) const override; 164 165 // write the parameters as a part of a method definition to the given `os` 166 void writeDefTo(raw_ostream &os) const override; 167 168 // LLVM-style RTTI support. classof(const OpMethodParameters * params)169 static bool classof(const OpMethodParameters *params) { 170 return params->getKind() == PK_Resolved; 171 } 172 173 private: 174 llvm::SmallVector<OpMethodParameter, 4> parameters; 175 }; 176 177 // Class for holding the signature of an op's method for C++ code emission 178 class OpMethodSignature { 179 public: 180 template <typename... Args> OpMethodSignature(StringRef retType,StringRef name,Args &&...args)181 OpMethodSignature(StringRef retType, StringRef name, Args &&...args) 182 : returnType(retType), methodName(name), 183 parameters(OpMethodParameters::create(std::forward<Args>(args)...)) {} 184 OpMethodSignature(OpMethodSignature &&) = default; 185 186 // Returns if a method with this signature makes a method with `other` 187 // signature redundant. Only supports resolved parameters. 188 bool makesRedundant(const OpMethodSignature &other) const; 189 190 // Returns the number of parameters (for resolved parameters). getNumParameters()191 size_t getNumParameters() const { 192 return cast<OpMethodResolvedParameters>(parameters.get()) 193 ->getNumParameters(); 194 } 195 196 // Returns the name of the method. getName()197 StringRef getName() const { return methodName; } 198 199 // Writes the signature as a method declaration to the given `os`. 200 void writeDeclTo(raw_ostream &os) const; 201 202 // Writes the signature as the start of a method definition to the given `os`. 203 // `namePrefix` is the prefix to be prepended to the method name (typically 204 // namespaces for qualifying the method definition). 205 void writeDefTo(raw_ostream &os, StringRef namePrefix) const; 206 207 private: 208 std::string returnType; 209 std::string methodName; 210 std::unique_ptr<OpMethodParameters> parameters; 211 }; 212 213 // Class for holding the body of an op's method for C++ code emission 214 class OpMethodBody { 215 public: 216 explicit OpMethodBody(bool declOnly); 217 218 OpMethodBody &operator<<(Twine content); 219 OpMethodBody &operator<<(int content); 220 OpMethodBody &operator<<(const FmtObjectBase &content); 221 222 void writeTo(raw_ostream &os) const; 223 224 private: 225 // Whether this class should record method body. 226 bool isEffective; 227 std::string body; 228 }; 229 230 // Class for holding an op's method for C++ code emission 231 class OpMethod { 232 public: 233 // Properties (qualifiers) of class methods. Bitfield is used here to help 234 // querying properties. 235 enum Property { 236 MP_None = 0x0, 237 MP_Static = 0x1, 238 MP_Constructor = 0x2, 239 MP_Private = 0x4, 240 MP_Declaration = 0x8, 241 MP_StaticDeclaration = MP_Static | MP_Declaration, 242 }; 243 244 template <typename... Args> OpMethod(StringRef retType,StringRef name,Property property,unsigned id,Args &&...args)245 OpMethod(StringRef retType, StringRef name, Property property, unsigned id, 246 Args &&...args) 247 : properties(property), 248 methodSignature(retType, name, std::forward<Args>(args)...), 249 methodBody(properties & MP_Declaration), id(id) {} 250 251 OpMethod(OpMethod &&) = default; 252 253 virtual ~OpMethod() = default; 254 body()255 OpMethodBody &body() { return methodBody; } 256 257 // Returns true if this is a static method. isStatic()258 bool isStatic() const { return properties & MP_Static; } 259 260 // Returns true if this is a private method. isPrivate()261 bool isPrivate() const { return properties & MP_Private; } 262 263 // Returns the name of this method. getName()264 StringRef getName() const { return methodSignature.getName(); } 265 266 // Returns the ID for this method getID()267 unsigned getID() const { return id; } 268 269 // Returns if this method makes the `other` method redundant. makesRedundant(const OpMethod & other)270 bool makesRedundant(const OpMethod &other) const { 271 return methodSignature.makesRedundant(other.methodSignature); 272 } 273 274 // Writes the method as a declaration to the given `os`. 275 virtual void writeDeclTo(raw_ostream &os) const; 276 277 // Writes the method as a definition to the given `os`. `namePrefix` is the 278 // prefix to be prepended to the method name (typically namespaces for 279 // qualifying the method definition). 280 virtual void writeDefTo(raw_ostream &os, StringRef namePrefix) const; 281 282 protected: 283 Property properties; 284 OpMethodSignature methodSignature; 285 OpMethodBody methodBody; 286 const unsigned id; 287 }; 288 289 // Class for holding an op's constructor method for C++ code emission. 290 class OpConstructor : public OpMethod { 291 public: 292 template <typename... Args> OpConstructor(StringRef className,Property property,unsigned id,Args &&...args)293 OpConstructor(StringRef className, Property property, unsigned id, 294 Args &&...args) 295 : OpMethod("", className, property, id, std::forward<Args>(args)...) {} 296 297 // Add member initializer to constructor initializing `name` with `value`. 298 void addMemberInitializer(StringRef name, StringRef value); 299 300 // Writes the method as a definition to the given `os`. `namePrefix` is the 301 // prefix to be prepended to the method name (typically namespaces for 302 // qualifying the method definition). 303 void writeDefTo(raw_ostream &os, StringRef namePrefix) const override; 304 305 private: 306 // Member initializers. 307 std::string memberInitializers; 308 }; 309 310 // A class used to emit C++ classes from Tablegen. Contains a list of public 311 // methods and a list of private fields to be emitted. 312 class Class { 313 public: 314 explicit Class(StringRef name); 315 316 // Adds a new method to this class and prune redundant methods. Returns null 317 // if the method was not added (because an existing method would make it 318 // redundant), else returns a pointer to the added method. Note that this call 319 // may also delete existing methods that are made redundant by a method to the 320 // class. 321 template <typename... Args> addMethodAndPrune(StringRef retType,StringRef name,OpMethod::Property properties,Args &&...args)322 OpMethod *addMethodAndPrune(StringRef retType, StringRef name, 323 OpMethod::Property properties, Args &&...args) { 324 auto newMethod = std::make_unique<OpMethod>( 325 retType, name, properties, nextMethodID++, std::forward<Args>(args)...); 326 return addMethodAndPrune(methods, std::move(newMethod)); 327 } 328 329 template <typename... Args> addMethodAndPrune(StringRef retType,StringRef name,Args &&...args)330 OpMethod *addMethodAndPrune(StringRef retType, StringRef name, 331 Args &&...args) { 332 return addMethodAndPrune(retType, name, OpMethod::MP_None, 333 std::forward<Args>(args)...); 334 } 335 336 template <typename... Args> addConstructorAndPrune(Args &&...args)337 OpConstructor *addConstructorAndPrune(Args &&...args) { 338 auto newConstructor = std::make_unique<OpConstructor>( 339 getClassName(), OpMethod::MP_Constructor, nextMethodID++, 340 std::forward<Args>(args)...); 341 return addMethodAndPrune(constructors, std::move(newConstructor)); 342 } 343 344 // Creates a new field in this class. 345 void newField(StringRef type, StringRef name, StringRef defaultValue = ""); 346 347 // Writes this op's class as a declaration to the given `os`. 348 void writeDeclTo(raw_ostream &os) const; 349 // Writes the method definitions in this op's class to the given `os`. 350 void writeDefTo(raw_ostream &os) const; 351 352 // Returns the C++ class name of the op. getClassName()353 StringRef getClassName() const { return className; } 354 355 protected: 356 // Get a list of all the methods to emit, filtering out hidden ones. forAllMethods(llvm::function_ref<void (const OpMethod &)> func)357 void forAllMethods(llvm::function_ref<void(const OpMethod &)> func) const { 358 using ConsRef = const std::unique_ptr<OpConstructor> &; 359 using MethodRef = const std::unique_ptr<OpMethod> &; 360 llvm::for_each(constructors, [&](ConsRef ptr) { func(*ptr); }); 361 llvm::for_each(methods, [&](MethodRef ptr) { func(*ptr); }); 362 } 363 364 // For deterministic code generation, keep methods sorted in the order in 365 // which they were generated. 366 template <typename MethodTy> 367 struct MethodCompare { operatorMethodCompare368 bool operator()(const std::unique_ptr<MethodTy> &x, 369 const std::unique_ptr<MethodTy> &y) const { 370 return x->getID() < y->getID(); 371 } 372 }; 373 374 template <typename MethodTy> 375 using MethodSet = 376 std::set<std::unique_ptr<MethodTy>, MethodCompare<MethodTy>>; 377 378 template <typename MethodTy> addMethodAndPrune(MethodSet<MethodTy> & set,std::unique_ptr<MethodTy> && newMethod)379 MethodTy *addMethodAndPrune(MethodSet<MethodTy> &set, 380 std::unique_ptr<MethodTy> &&newMethod) { 381 // Check if the new method will be made redundant by existing methods. 382 for (auto &method : set) 383 if (method->makesRedundant(*newMethod)) 384 return nullptr; 385 386 // We can add this a method to the set. Prune any existing methods that will 387 // be made redundant by adding this new method. Note that the redundant 388 // check between two methods is more than a conflict check. makesRedundant() 389 // below will check if the new method conflicts with an existing method and 390 // if so, returns true if the new method makes the existing method redundant 391 // because all calls to the existing method can be subsumed by the new 392 // method. So makesRedundant() does a combined job of finding conflicts and 393 // deciding which of the 2 conflicting methods survive. 394 // 395 // Note: llvm::erase_if does not work with sets of std::unique_ptr, so doing 396 // it manually here. 397 for (auto it = set.begin(), end = set.end(); it != end;) { 398 if (newMethod->makesRedundant(*(it->get()))) 399 it = set.erase(it); 400 else 401 ++it; 402 } 403 404 MethodTy *ret = newMethod.get(); 405 set.insert(std::move(newMethod)); 406 return ret; 407 } 408 409 std::string className; 410 MethodSet<OpConstructor> constructors; 411 MethodSet<OpMethod> methods; 412 unsigned nextMethodID = 0; 413 SmallVector<std::string, 4> fields; 414 }; 415 416 // Class for holding an op for C++ code emission 417 class OpClass : public Class { 418 public: 419 explicit OpClass(StringRef name, StringRef extraClassDeclaration = ""); 420 421 // Adds an op trait. 422 void addTrait(Twine trait); 423 424 // Writes this op's class as a declaration to the given `os`. Redefines 425 // Class::writeDeclTo to also emit traits and extra class declarations. 426 void writeDeclTo(raw_ostream &os) const; 427 428 private: 429 StringRef extraClassDeclaration; 430 SmallVector<std::string, 4> traitsVec; 431 StringSet<> traitsSet; 432 }; 433 434 } // namespace tblgen 435 } // namespace mlir 436 437 #endif // MLIR_TABLEGEN_OPCLASS_H_ 438