1 //===- Pattern.h - Pattern wrapper class ------------------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Pattern wrapper class to simplify using TableGen Record defining a MLIR 10 // Pattern. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #ifndef MLIR_TABLEGEN_PATTERN_H_ 15 #define MLIR_TABLEGEN_PATTERN_H_ 16 17 #include "mlir/Support/LLVM.h" 18 #include "mlir/TableGen/Argument.h" 19 #include "mlir/TableGen/Operator.h" 20 #include "llvm/ADT/DenseMap.h" 21 #include "llvm/ADT/StringMap.h" 22 #include "llvm/ADT/StringSet.h" 23 24 #include <unordered_map> 25 26 namespace llvm { 27 class DagInit; 28 class Init; 29 class Record; 30 } // end namespace llvm 31 32 namespace mlir { 33 namespace tblgen { 34 35 // Mapping from TableGen Record to Operator wrapper object. 36 // 37 // We allocate each wrapper object in heap to make sure the pointer to it is 38 // valid throughout the lifetime of this map. This is important because this map 39 // is shared among multiple patterns to avoid creating the wrapper object for 40 // the same op again and again. But this map will continuously grow. 41 using RecordOperatorMap = 42 DenseMap<const llvm::Record *, std::unique_ptr<Operator>>; 43 44 class Pattern; 45 46 // Wrapper class providing helper methods for accessing TableGen DAG leaves 47 // used inside Patterns. This class is lightweight and designed to be used like 48 // values. 49 // 50 // A TableGen DAG construct is of the syntax 51 // `(operator, arg0, arg1, ...)`. 52 // 53 // This class provides getters to retrieve `arg*` as tblgen:: wrapper objects 54 // for handy helper methods. It only works on `arg*`s that are not nested DAG 55 // constructs. 56 class DagLeaf { 57 public: DagLeaf(const llvm::Init * def)58 explicit DagLeaf(const llvm::Init *def) : def(def) {} 59 60 // Returns true if this DAG leaf is not specified in the pattern. That is, it 61 // places no further constraints/transforms and just carries over the original 62 // value. 63 bool isUnspecified() const; 64 65 // Returns true if this DAG leaf is matching an operand. That is, it specifies 66 // a type constraint. 67 bool isOperandMatcher() const; 68 69 // Returns true if this DAG leaf is matching an attribute. That is, it 70 // specifies an attribute constraint. 71 bool isAttrMatcher() const; 72 73 // Returns true if this DAG leaf is wrapping native code call. 74 bool isNativeCodeCall() const; 75 76 // Returns true if this DAG leaf is specifying a constant attribute. 77 bool isConstantAttr() const; 78 79 // Returns true if this DAG leaf is specifying an enum attribute case. 80 bool isEnumAttrCase() const; 81 82 // Returns true if this DAG leaf is specifying a string attribute. 83 bool isStringAttr() const; 84 85 // Returns this DAG leaf as a constraint. Asserts if fails. 86 Constraint getAsConstraint() const; 87 88 // Returns this DAG leaf as an constant attribute. Asserts if fails. 89 ConstantAttr getAsConstantAttr() const; 90 91 // Returns this DAG leaf as an enum attribute case. 92 // Precondition: isEnumAttrCase() 93 EnumAttrCase getAsEnumAttrCase() const; 94 95 // Returns the matching condition template inside this DAG leaf. Assumes the 96 // leaf is an operand/attribute matcher and asserts otherwise. 97 std::string getConditionTemplate() const; 98 99 // Returns the native code call template inside this DAG leaf. 100 // Precondition: isNativeCodeCall() 101 StringRef getNativeCodeTemplate() const; 102 103 // Returns the string associated with the leaf. 104 // Precondition: isStringAttr() 105 std::string getStringAttr() const; 106 107 void print(raw_ostream &os) const; 108 109 private: 110 // Returns true if the TableGen Init `def` in this DagLeaf is a DefInit and 111 // also a subclass of the given `superclass`. 112 bool isSubClassOf(StringRef superclass) const; 113 114 const llvm::Init *def; 115 }; 116 117 // Wrapper class providing helper methods for accessing TableGen DAG constructs 118 // used inside Patterns. This class is lightweight and designed to be used like 119 // values. 120 // 121 // A TableGen DAG construct is of the syntax 122 // `(operator, arg0, arg1, ...)`. 123 // 124 // When used inside Patterns, `operator` corresponds to some dialect op, or 125 // a known list of verbs that defines special transformation actions. This 126 // `arg*` can be a nested DAG construct. This class provides getters to 127 // retrieve `operator` and `arg*` as tblgen:: wrapper objects for handy helper 128 // methods. 129 // 130 // A null DagNode contains a nullptr and converts to false implicitly. 131 class DagNode { 132 public: DagNode(const llvm::DagInit * node)133 explicit DagNode(const llvm::DagInit *node) : node(node) {} 134 135 // Implicit bool converter that returns true if this DagNode is not a null 136 // DagNode. 137 operator bool() const { return node != nullptr; } 138 139 // Returns the symbol bound to this DAG node. 140 StringRef getSymbol() const; 141 142 // Returns the operator wrapper object corresponding to the dialect op matched 143 // by this DAG. The operator wrapper will be queried from the given `mapper` 144 // and created in it if not existing. 145 Operator &getDialectOp(RecordOperatorMap *mapper) const; 146 147 // Returns the number of operations recursively involved in the DAG tree 148 // rooted from this node. 149 int getNumOps() const; 150 151 // Returns the number of immediate arguments to this DAG node. 152 int getNumArgs() const; 153 154 // Returns true if the `index`-th argument is a nested DAG construct. 155 bool isNestedDagArg(unsigned index) const; 156 157 // Gets the `index`-th argument as a nested DAG construct if possible. Returns 158 // null DagNode otherwise. 159 DagNode getArgAsNestedDag(unsigned index) const; 160 161 // Gets the `index`-th argument as a DAG leaf. 162 DagLeaf getArgAsLeaf(unsigned index) const; 163 164 // Returns the specified name of the `index`-th argument. 165 StringRef getArgName(unsigned index) const; 166 167 // Returns true if this DAG construct means to replace with an existing SSA 168 // value. 169 bool isReplaceWithValue() const; 170 171 // Returns whether this DAG represents the location of an op creation. 172 bool isLocationDirective() const; 173 174 // Returns true if this DAG node is wrapping native code call. 175 bool isNativeCodeCall() const; 176 177 // Returns true if this DAG node is an operation. 178 bool isOperation() const; 179 180 // Returns the native code call template inside this DAG node. 181 // Precondition: isNativeCodeCall() 182 StringRef getNativeCodeTemplate() const; 183 184 void print(raw_ostream &os) const; 185 186 private: 187 const llvm::DagInit *node; // nullptr means null DagNode 188 }; 189 190 // A class for maintaining information for symbols bound in patterns and 191 // provides methods for resolving them according to specific use cases. 192 // 193 // Symbols can be bound to 194 // 195 // * Op arguments and op results in the source pattern and 196 // * Op results in result patterns. 197 // 198 // Symbols can be referenced in result patterns and additional constraints to 199 // the pattern. 200 // 201 // For example, in 202 // 203 // ``` 204 // def : Pattern< 205 // (SrcOp:$results1 $arg0, %arg1), 206 // [(ResOp1:$results2), (ResOp2 $results2 (ResOp3 $arg0, $arg1))]>; 207 // ``` 208 // 209 // `$argN` is bound to the `SrcOp`'s N-th argument. `$results1` is bound to 210 // `SrcOp`. `$results2` is bound to `ResOp1`. $result2 is referenced to build 211 // `ResOp2`. `$arg0` and `$arg1` are referenced to build `ResOp3`. 212 // 213 // If a symbol binds to a multi-result op and it does not have the `__N` 214 // suffix, the symbol is expanded to represent all results generated by the 215 // multi-result op. If the symbol has a `__N` suffix, then it will expand to 216 // only the N-th *static* result as declared in ODS, and that can still 217 // corresponds to multiple *dynamic* values if the N-th *static* result is 218 // variadic. 219 // 220 // This class keeps track of such symbols and resolves them into their bound 221 // values in a suitable way. 222 class SymbolInfoMap { 223 public: SymbolInfoMap(ArrayRef<llvm::SMLoc> loc)224 explicit SymbolInfoMap(ArrayRef<llvm::SMLoc> loc) : loc(loc) {} 225 226 // Class for information regarding a symbol. 227 class SymbolInfo { 228 public: 229 // Returns a string for defining a variable named as `name` to store the 230 // value bound by this symbol. 231 std::string getVarDecl(StringRef name) const; 232 233 // Returns a variable name for the symbol named as `name`. 234 std::string getVarName(StringRef name) const; 235 236 private: 237 // Allow SymbolInfoMap to access private methods. 238 friend class SymbolInfoMap; 239 240 // What kind of entity this symbol represents: 241 // * Attr: op attribute 242 // * Operand: op operand 243 // * Result: op result 244 // * Value: a value not attached to an op (e.g., from NativeCodeCall) 245 enum class Kind : uint8_t { Attr, Operand, Result, Value }; 246 247 // Creates a SymbolInfo instance. `index` is only used for `Attr` and 248 // `Operand` so should be negative for `Result` and `Value` kind. 249 SymbolInfo(const Operator *op, Kind kind, Optional<int> index); 250 251 // Static methods for creating SymbolInfo. getAttr(const Operator * op,int index)252 static SymbolInfo getAttr(const Operator *op, int index) { 253 return SymbolInfo(op, Kind::Attr, index); 254 } getAttr()255 static SymbolInfo getAttr() { 256 return SymbolInfo(nullptr, Kind::Attr, llvm::None); 257 } getOperand(const Operator * op,int index)258 static SymbolInfo getOperand(const Operator *op, int index) { 259 return SymbolInfo(op, Kind::Operand, index); 260 } getResult(const Operator * op)261 static SymbolInfo getResult(const Operator *op) { 262 return SymbolInfo(op, Kind::Result, llvm::None); 263 } getValue()264 static SymbolInfo getValue() { 265 return SymbolInfo(nullptr, Kind::Value, llvm::None); 266 } 267 268 // Returns the number of static values this symbol corresponds to. 269 // A static value is an operand/result declared in ODS. Normally a symbol 270 // only represents one static value, but symbols bound to op results can 271 // represent more than one if the op is a multi-result op. 272 int getStaticValueCount() const; 273 274 // Returns a string containing the C++ expression for referencing this 275 // symbol as a value (if this symbol represents one static value) or a value 276 // range (if this symbol represents multiple static values). `name` is the 277 // name of the C++ variable that this symbol bounds to. `index` should only 278 // be used for indexing results. `fmt` is used to format each value. 279 // `separator` is used to separate values if this is a value range. 280 std::string getValueAndRangeUse(StringRef name, int index, const char *fmt, 281 const char *separator) const; 282 283 // Returns a string containing the C++ expression for referencing this 284 // symbol as a value range regardless of how many static values this symbol 285 // represents. `name` is the name of the C++ variable that this symbol 286 // bounds to. `index` should only be used for indexing results. `fmt` is 287 // used to format each value. `separator` is used to separate values in the 288 // range. 289 std::string getAllRangeUse(StringRef name, int index, const char *fmt, 290 const char *separator) const; 291 292 const Operator *op; // The op where the bound entity belongs 293 Kind kind; // The kind of the bound entity 294 // The argument index (for `Attr` and `Operand` only) 295 Optional<int> argIndex; 296 // Alternative name for the symbol. It is used in case the name 297 // is not unique. Applicable for `Operand` only. 298 Optional<std::string> alternativeName; 299 }; 300 301 using BaseT = std::unordered_multimap<std::string, SymbolInfo>; 302 303 // Iterators for accessing all symbols. 304 using iterator = BaseT::iterator; begin()305 iterator begin() { return symbolInfoMap.begin(); } end()306 iterator end() { return symbolInfoMap.end(); } 307 308 // Const iterators for accessing all symbols. 309 using const_iterator = BaseT::const_iterator; begin()310 const_iterator begin() const { return symbolInfoMap.begin(); } end()311 const_iterator end() const { return symbolInfoMap.end(); } 312 313 // Binds the given `symbol` to the `argIndex`-th argument to the given `op`. 314 // Returns false if `symbol` is already bound and symbols are not operands. 315 bool bindOpArgument(StringRef symbol, const Operator &op, int argIndex); 316 317 // Binds the given `symbol` to the results the given `op`. Returns false if 318 // `symbol` is already bound. 319 bool bindOpResult(StringRef symbol, const Operator &op); 320 321 // Registers the given `symbol` as bound to a value. Returns false if `symbol` 322 // is already bound. 323 bool bindValue(StringRef symbol); 324 325 // Registers the given `symbol` as bound to an attr. Returns false if `symbol` 326 // is already bound. 327 bool bindAttr(StringRef symbol); 328 329 // Returns true if the given `symbol` is bound. 330 bool contains(StringRef symbol) const; 331 332 // Returns an iterator to the information of the given symbol named as `key`. 333 const_iterator find(StringRef key) const; 334 335 // Returns an iterator to the information of the given symbol named as `key`, 336 // with index `argIndex` for operator `op`. 337 const_iterator findBoundSymbol(StringRef key, const Operator &op, 338 int argIndex) const; 339 340 // Returns the bounds of a range that includes all the elements which 341 // bind to the `key`. 342 std::pair<iterator, iterator> getRangeOfEqualElements(StringRef key); 343 344 // Returns number of times symbol named as `key` was used. 345 int count(StringRef key) const; 346 347 // Returns the number of static values of the given `symbol` corresponds to. 348 // A static value is an operand/result declared in ODS. Normally a symbol only 349 // represents one static value, but symbols bound to op results can represent 350 // more than one if the op is a multi-result op. 351 int getStaticValueCount(StringRef symbol) const; 352 353 // Returns a string containing the C++ expression for referencing this 354 // symbol as a value (if this symbol represents one static value) or a value 355 // range (if this symbol represents multiple static values). `fmt` is used to 356 // format each value. `separator` is used to separate values if `symbol` 357 // represents a value range. 358 std::string getValueAndRangeUse(StringRef symbol, const char *fmt = "{0}", 359 const char *separator = ", ") const; 360 361 // Returns a string containing the C++ expression for referencing this 362 // symbol as a value range regardless of how many static values this symbol 363 // represents. `fmt` is used to format each value. `separator` is used to 364 // separate values in the range. 365 std::string getAllRangeUse(StringRef symbol, const char *fmt = "{0}", 366 const char *separator = ", ") const; 367 368 // Assign alternative unique names to Operands that have equal names. 369 void assignUniqueAlternativeNames(); 370 371 // Splits the given `symbol` into a value pack name and an index. Returns the 372 // value pack name and writes the index to `index` on success. Returns 373 // `symbol` itself if it does not contain an index. 374 // 375 // We can use `name__N` to access the `N`-th value in the value pack bound to 376 // `name`. `name` is typically the results of an multi-result op. 377 static StringRef getValuePackName(StringRef symbol, int *index = nullptr); 378 379 private: 380 BaseT symbolInfoMap; 381 382 // Pattern instantiation location. This is intended to be used as parameter 383 // to PrintFatalError() to report errors. 384 ArrayRef<llvm::SMLoc> loc; 385 }; 386 387 // Wrapper class providing helper methods for accessing MLIR Pattern defined 388 // in TableGen. This class should closely reflect what is defined as class 389 // `Pattern` in TableGen. This class contains maps so it is not intended to be 390 // used as values. 391 class Pattern { 392 public: 393 explicit Pattern(const llvm::Record *def, RecordOperatorMap *mapper); 394 395 // Returns the source pattern to match. 396 DagNode getSourcePattern() const; 397 398 // Returns the number of result patterns generated by applying this rewrite 399 // rule. 400 int getNumResultPatterns() const; 401 402 // Returns the DAG tree root node of the `index`-th result pattern. 403 DagNode getResultPattern(unsigned index) const; 404 405 // Collects all symbols bound in the source pattern into `infoMap`. 406 void collectSourcePatternBoundSymbols(SymbolInfoMap &infoMap); 407 408 // Collects all symbols bound in result patterns into `infoMap`. 409 void collectResultPatternBoundSymbols(SymbolInfoMap &infoMap); 410 411 // Returns the op that the root node of the source pattern matches. 412 const Operator &getSourceRootOp(); 413 414 // Returns the operator wrapper object corresponding to the given `node`'s DAG 415 // operator. 416 Operator &getDialectOp(DagNode node); 417 418 // Returns the constraints. 419 std::vector<AppliedConstraint> getConstraints() const; 420 421 // Returns the benefit score of the pattern. 422 int getBenefit() const; 423 424 using IdentifierLine = std::pair<StringRef, unsigned>; 425 426 // Returns the file location of the pattern (buffer identifier + line number 427 // pair). 428 std::vector<IdentifierLine> getLocation() const; 429 430 private: 431 // Helper function to verify variabld binding. 432 void verifyBind(bool result, StringRef symbolName); 433 434 // Recursively collects all bound symbols inside the DAG tree rooted 435 // at `tree` and updates the given `infoMap`. 436 void collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap, 437 bool isSrcPattern); 438 439 // The TableGen definition of this pattern. 440 const llvm::Record &def; 441 442 // All operators. 443 // TODO: we need a proper context manager, like MLIRContext, for managing the 444 // lifetime of shared entities. 445 RecordOperatorMap *recordOpMap; 446 }; 447 448 } // end namespace tblgen 449 } // end namespace mlir 450 451 #endif // MLIR_TABLEGEN_PATTERN_H_ 452