1 //===- PatternMatch.h - PatternMatcher classes -------==---------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #ifndef MLIR_PATTERNMATCHER_H 10 #define MLIR_PATTERNMATCHER_H 11 12 #include "mlir/IR/Builders.h" 13 #include "mlir/IR/BuiltinOps.h" 14 15 namespace mlir { 16 17 class PatternRewriter; 18 19 //===----------------------------------------------------------------------===// 20 // PatternBenefit class 21 //===----------------------------------------------------------------------===// 22 23 /// This class represents the benefit of a pattern match in a unitless scheme 24 /// that ranges from 0 (very little benefit) to 65K. The most common unit to 25 /// use here is the "number of operations matched" by the pattern. 26 /// 27 /// This also has a sentinel representation that can be used for patterns that 28 /// fail to match. 29 /// 30 class PatternBenefit { 31 enum { ImpossibleToMatchSentinel = 65535 }; 32 33 public: PatternBenefit()34 PatternBenefit() : representation(ImpossibleToMatchSentinel) {} 35 PatternBenefit(unsigned benefit); 36 PatternBenefit(const PatternBenefit &) = default; 37 PatternBenefit &operator=(const PatternBenefit &) = default; 38 impossibleToMatch()39 static PatternBenefit impossibleToMatch() { return PatternBenefit(); } isImpossibleToMatch()40 bool isImpossibleToMatch() const { return *this == impossibleToMatch(); } 41 42 /// If the corresponding pattern can match, return its benefit. If the 43 // corresponding pattern isImpossibleToMatch() then this aborts. 44 unsigned short getBenefit() const; 45 46 bool operator==(const PatternBenefit &rhs) const { 47 return representation == rhs.representation; 48 } 49 bool operator!=(const PatternBenefit &rhs) const { return !(*this == rhs); } 50 bool operator<(const PatternBenefit &rhs) const { 51 return representation < rhs.representation; 52 } 53 bool operator>(const PatternBenefit &rhs) const { return rhs < *this; } 54 bool operator<=(const PatternBenefit &rhs) const { return !(*this > rhs); } 55 bool operator>=(const PatternBenefit &rhs) const { return !(*this < rhs); } 56 57 private: 58 unsigned short representation; 59 }; 60 61 //===----------------------------------------------------------------------===// 62 // Pattern 63 //===----------------------------------------------------------------------===// 64 65 /// This class contains all of the data related to a pattern, but does not 66 /// contain any methods or logic for the actual matching. This class is solely 67 /// used to interface with the metadata of a pattern, such as the benefit or 68 /// root operation. 69 class Pattern { 70 public: 71 /// Return a list of operations that may be generated when rewriting an 72 /// operation instance with this pattern. getGeneratedOps()73 ArrayRef<OperationName> getGeneratedOps() const { return generatedOps; } 74 75 /// Return the root node that this pattern matches. Patterns that can match 76 /// multiple root types return None. getRootKind()77 Optional<OperationName> getRootKind() const { return rootKind; } 78 79 /// Return the benefit (the inverse of "cost") of matching this pattern. The 80 /// benefit of a Pattern is always static - rewrites that may have dynamic 81 /// benefit can be instantiated multiple times (different Pattern instances) 82 /// for each benefit that they may return, and be guarded by different match 83 /// condition predicates. getBenefit()84 PatternBenefit getBenefit() const { return benefit; } 85 86 /// Returns true if this pattern is known to result in recursive application, 87 /// i.e. this pattern may generate IR that also matches this pattern, but is 88 /// known to bound the recursion. This signals to a rewrite driver that it is 89 /// safe to apply this pattern recursively to generated IR. hasBoundedRewriteRecursion()90 bool hasBoundedRewriteRecursion() const { return hasBoundedRecursion; } 91 92 protected: 93 /// This class acts as a special tag that makes the desire to match "any" 94 /// operation type explicit. This helps to avoid unnecessary usages of this 95 /// feature, and ensures that the user is making a conscious decision. 96 struct MatchAnyOpTypeTag {}; 97 98 /// Construct a pattern with a certain benefit that matches the operation 99 /// with the given root name. 100 Pattern(StringRef rootName, PatternBenefit benefit, MLIRContext *context); 101 /// Construct a pattern with a certain benefit that matches any operation 102 /// type. `MatchAnyOpTypeTag` is just a tag to ensure that the "match any" 103 /// behavior is what the user actually desired, `MatchAnyOpTypeTag()` should 104 /// always be supplied here. 105 Pattern(PatternBenefit benefit, MatchAnyOpTypeTag tag); 106 /// Construct a pattern with a certain benefit that matches the operation with 107 /// the given root name. `generatedNames` contains the names of operations 108 /// that may be generated during a successful rewrite. 109 Pattern(StringRef rootName, ArrayRef<StringRef> generatedNames, 110 PatternBenefit benefit, MLIRContext *context); 111 /// Construct a pattern that may match any operation type. `generatedNames` 112 /// contains the names of operations that may be generated during a successful 113 /// rewrite. `MatchAnyOpTypeTag` is just a tag to ensure that the "match any" 114 /// behavior is what the user actually desired, `MatchAnyOpTypeTag()` should 115 /// always be supplied here. 116 Pattern(ArrayRef<StringRef> generatedNames, PatternBenefit benefit, 117 MLIRContext *context, MatchAnyOpTypeTag tag); 118 119 /// Set the flag detailing if this pattern has bounded rewrite recursion or 120 /// not. 121 void setHasBoundedRewriteRecursion(bool hasBoundedRecursionArg = true) { 122 hasBoundedRecursion = hasBoundedRecursionArg; 123 } 124 125 private: 126 /// A list of the potential operations that may be generated when rewriting 127 /// an op with this pattern. 128 SmallVector<OperationName, 2> generatedOps; 129 130 /// The root operation of the pattern. If the pattern matches a specific 131 /// operation, this contains the name of that operation. Contains None 132 /// otherwise. 133 Optional<OperationName> rootKind; 134 135 /// The expected benefit of matching this pattern. 136 const PatternBenefit benefit; 137 138 /// A boolean flag of whether this pattern has bounded recursion or not. 139 bool hasBoundedRecursion = false; 140 }; 141 142 //===----------------------------------------------------------------------===// 143 // RewritePattern 144 //===----------------------------------------------------------------------===// 145 146 /// RewritePattern is the common base class for all DAG to DAG replacements. 147 /// There are two possible usages of this class: 148 /// * Multi-step RewritePattern with "match" and "rewrite" 149 /// - By overloading the "match" and "rewrite" functions, the user can 150 /// separate the concerns of matching and rewriting. 151 /// * Single-step RewritePattern with "matchAndRewrite" 152 /// - By overloading the "matchAndRewrite" function, the user can perform 153 /// the rewrite in the same call as the match. 154 /// 155 class RewritePattern : public Pattern { 156 public: ~RewritePattern()157 virtual ~RewritePattern() {} 158 159 /// Rewrite the IR rooted at the specified operation with the result of 160 /// this pattern, generating any new operations with the specified 161 /// builder. If an unexpected error is encountered (an internal 162 /// compiler error), it is emitted through the normal MLIR diagnostic 163 /// hooks and the IR is left in a valid state. 164 virtual void rewrite(Operation *op, PatternRewriter &rewriter) const; 165 166 /// Attempt to match against code rooted at the specified operation, 167 /// which is the same operation code as getRootKind(). 168 virtual LogicalResult match(Operation *op) const; 169 170 /// Attempt to match against code rooted at the specified operation, 171 /// which is the same operation code as getRootKind(). If successful, this 172 /// function will automatically perform the rewrite. matchAndRewrite(Operation * op,PatternRewriter & rewriter)173 virtual LogicalResult matchAndRewrite(Operation *op, 174 PatternRewriter &rewriter) const { 175 if (succeeded(match(op))) { 176 rewrite(op, rewriter); 177 return success(); 178 } 179 return failure(); 180 } 181 182 protected: 183 /// Inherit the base constructors from `Pattern`. 184 using Pattern::Pattern; 185 186 /// An anchor for the virtual table. 187 virtual void anchor(); 188 }; 189 190 /// OpRewritePattern is a wrapper around RewritePattern that allows for 191 /// matching and rewriting against an instance of a derived operation class as 192 /// opposed to a raw Operation. 193 template <typename SourceOp> 194 struct OpRewritePattern : public RewritePattern { 195 /// Patterns must specify the root operation name they match against, and can 196 /// also specify the benefit of the pattern matching. 197 OpRewritePattern(MLIRContext *context, PatternBenefit benefit = 1) RewritePatternOpRewritePattern198 : RewritePattern(SourceOp::getOperationName(), benefit, context) {} 199 200 /// Wrappers around the RewritePattern methods that pass the derived op type. rewriteOpRewritePattern201 void rewrite(Operation *op, PatternRewriter &rewriter) const final { 202 rewrite(cast<SourceOp>(op), rewriter); 203 } matchOpRewritePattern204 LogicalResult match(Operation *op) const final { 205 return match(cast<SourceOp>(op)); 206 } matchAndRewriteOpRewritePattern207 LogicalResult matchAndRewrite(Operation *op, 208 PatternRewriter &rewriter) const final { 209 return matchAndRewrite(cast<SourceOp>(op), rewriter); 210 } 211 212 /// Rewrite and Match methods that operate on the SourceOp type. These must be 213 /// overridden by the derived pattern class. rewriteOpRewritePattern214 virtual void rewrite(SourceOp op, PatternRewriter &rewriter) const { 215 llvm_unreachable("must override rewrite or matchAndRewrite"); 216 } matchOpRewritePattern217 virtual LogicalResult match(SourceOp op) const { 218 llvm_unreachable("must override match or matchAndRewrite"); 219 } matchAndRewriteOpRewritePattern220 virtual LogicalResult matchAndRewrite(SourceOp op, 221 PatternRewriter &rewriter) const { 222 if (succeeded(match(op))) { 223 rewrite(op, rewriter); 224 return success(); 225 } 226 return failure(); 227 } 228 }; 229 230 //===----------------------------------------------------------------------===// 231 // PDLPatternModule 232 //===----------------------------------------------------------------------===// 233 234 //===----------------------------------------------------------------------===// 235 // PDLValue 236 237 /// Storage type of byte-code interpreter values. These are passed to constraint 238 /// functions as arguments. 239 class PDLValue { 240 /// The internal implementation type when the value is an Attribute, 241 /// Operation*, or Type. See `impl` below for more details. 242 using AttrOpTypeImplT = llvm::PointerUnion<Attribute, Operation *, Type>; 243 244 public: PDLValue(const PDLValue & other)245 PDLValue(const PDLValue &other) : impl(other.impl) {} impl()246 PDLValue(std::nullptr_t = nullptr) : impl() {} PDLValue(Attribute value)247 PDLValue(Attribute value) : impl(value) {} PDLValue(Operation * value)248 PDLValue(Operation *value) : impl(value) {} PDLValue(Type value)249 PDLValue(Type value) : impl(value) {} PDLValue(Value value)250 PDLValue(Value value) : impl(value) {} 251 252 /// Returns true if the type of the held value is `T`. 253 template <typename T> isa()254 std::enable_if_t<std::is_same<T, Value>::value, bool> isa() const { 255 return impl.is<Value>(); 256 } 257 template <typename T> isa()258 std::enable_if_t<!std::is_same<T, Value>::value, bool> isa() const { 259 auto attrOpTypeImpl = impl.dyn_cast<AttrOpTypeImplT>(); 260 return attrOpTypeImpl && attrOpTypeImpl.is<T>(); 261 } 262 263 /// Attempt to dynamically cast this value to type `T`, returns null if this 264 /// value is not an instance of `T`. 265 template <typename T> dyn_cast()266 std::enable_if_t<std::is_same<T, Value>::value, T> dyn_cast() const { 267 return impl.dyn_cast<T>(); 268 } 269 template <typename T> dyn_cast()270 std::enable_if_t<!std::is_same<T, Value>::value, T> dyn_cast() const { 271 auto attrOpTypeImpl = impl.dyn_cast<AttrOpTypeImplT>(); 272 return attrOpTypeImpl && attrOpTypeImpl.dyn_cast<T>(); 273 } 274 275 /// Cast this value to type `T`, asserts if this value is not an instance of 276 /// `T`. 277 template <typename T> cast()278 std::enable_if_t<std::is_same<T, Value>::value, T> cast() const { 279 return impl.get<T>(); 280 } 281 template <typename T> cast()282 std::enable_if_t<!std::is_same<T, Value>::value, T> cast() const { 283 return impl.get<AttrOpTypeImplT>().get<T>(); 284 } 285 286 /// Get an opaque pointer to the value. getAsOpaquePointer()287 void *getAsOpaquePointer() { return impl.getOpaqueValue(); } 288 289 /// Print this value to the provided output stream. 290 void print(raw_ostream &os); 291 292 private: 293 /// The internal opaque representation of a PDLValue. We use a nested 294 /// PointerUnion structure here because `Value` only has 1 low bit 295 /// available, where as the remaining types all have 3. 296 llvm::PointerUnion<AttrOpTypeImplT, Value> impl; 297 }; 298 299 inline raw_ostream &operator<<(raw_ostream &os, PDLValue value) { 300 value.print(os); 301 return os; 302 } 303 304 //===----------------------------------------------------------------------===// 305 // PDLPatternModule 306 307 /// A generic PDL pattern constraint function. This function applies a 308 /// constraint to a given set of opaque PDLValue entities. The second parameter 309 /// is a set of constant value parameters specified in Attribute form. Returns 310 /// success if the constraint successfully held, failure otherwise. 311 using PDLConstraintFunction = std::function<LogicalResult( 312 ArrayRef<PDLValue>, ArrayAttr, PatternRewriter &)>; 313 /// A native PDL creation function. This function creates a new PDLValue given 314 /// a set of existing PDL values, a set of constant parameters specified in 315 /// Attribute form, and a PatternRewriter. Returns the newly created PDLValue. 316 using PDLCreateFunction = 317 std::function<PDLValue(ArrayRef<PDLValue>, ArrayAttr, PatternRewriter &)>; 318 /// A native PDL rewrite function. This function rewrites the given root 319 /// operation using the provided PatternRewriter. This method is only invoked 320 /// when the corresponding match was successful. 321 using PDLRewriteFunction = std::function<void(Operation *, ArrayRef<PDLValue>, 322 ArrayAttr, PatternRewriter &)>; 323 /// A generic PDL pattern constraint function. This function applies a 324 /// constraint to a given opaque PDLValue entity. The second parameter is a set 325 /// of constant value parameters specified in Attribute form. Returns success if 326 /// the constraint successfully held, failure otherwise. 327 using PDLSingleEntityConstraintFunction = 328 std::function<LogicalResult(PDLValue, ArrayAttr, PatternRewriter &)>; 329 330 /// This class contains all of the necessary data for a set of PDL patterns, or 331 /// pattern rewrites specified in the form of the PDL dialect. This PDL module 332 /// contained by this pattern may contain any number of `pdl.pattern` 333 /// operations. 334 class PDLPatternModule { 335 public: 336 PDLPatternModule() = default; 337 338 /// Construct a PDL pattern with the given module. PDLPatternModule(OwningModuleRef pdlModule)339 PDLPatternModule(OwningModuleRef pdlModule) 340 : pdlModule(std::move(pdlModule)) {} 341 342 /// Merge the state in `other` into this pattern module. 343 void mergeIn(PDLPatternModule &&other); 344 345 /// Return the internal PDL module of this pattern. getModule()346 ModuleOp getModule() { return pdlModule.get(); } 347 348 //===--------------------------------------------------------------------===// 349 // Function Registry 350 351 /// Register a constraint function. 352 void registerConstraintFunction(StringRef name, 353 PDLConstraintFunction constraintFn); 354 /// Register a single entity constraint function. 355 template <typename SingleEntityFn> 356 std::enable_if_t<!llvm::is_invocable<SingleEntityFn, ArrayRef<PDLValue>, 357 ArrayAttr, PatternRewriter &>::value> registerConstraintFunction(StringRef name,SingleEntityFn && constraintFn)358 registerConstraintFunction(StringRef name, SingleEntityFn &&constraintFn) { 359 registerConstraintFunction( 360 name, [constraintFn = std::forward<SingleEntityFn>(constraintFn)]( 361 ArrayRef<PDLValue> values, ArrayAttr constantParams, 362 PatternRewriter &rewriter) { 363 assert(values.size() == 1 && 364 "expected values to have a single entity"); 365 return constraintFn(values[0], constantParams, rewriter); 366 }); 367 } 368 369 /// Register a creation function. 370 void registerCreateFunction(StringRef name, PDLCreateFunction createFn); 371 372 /// Register a rewrite function. 373 void registerRewriteFunction(StringRef name, PDLRewriteFunction rewriteFn); 374 375 /// Return the set of the registered constraint functions. getConstraintFunctions()376 const llvm::StringMap<PDLConstraintFunction> &getConstraintFunctions() const { 377 return constraintFunctions; 378 } takeConstraintFunctions()379 llvm::StringMap<PDLConstraintFunction> takeConstraintFunctions() { 380 return constraintFunctions; 381 } 382 /// Return the set of the registered create functions. getCreateFunctions()383 const llvm::StringMap<PDLCreateFunction> &getCreateFunctions() const { 384 return createFunctions; 385 } takeCreateFunctions()386 llvm::StringMap<PDLCreateFunction> takeCreateFunctions() { 387 return createFunctions; 388 } 389 /// Return the set of the registered rewrite functions. getRewriteFunctions()390 const llvm::StringMap<PDLRewriteFunction> &getRewriteFunctions() const { 391 return rewriteFunctions; 392 } takeRewriteFunctions()393 llvm::StringMap<PDLRewriteFunction> takeRewriteFunctions() { 394 return rewriteFunctions; 395 } 396 397 /// Clear out the patterns and functions within this module. clear()398 void clear() { 399 pdlModule = nullptr; 400 constraintFunctions.clear(); 401 createFunctions.clear(); 402 rewriteFunctions.clear(); 403 } 404 405 private: 406 /// The module containing the `pdl.pattern` operations. 407 OwningModuleRef pdlModule; 408 409 /// The external functions referenced from within the PDL module. 410 llvm::StringMap<PDLConstraintFunction> constraintFunctions; 411 llvm::StringMap<PDLCreateFunction> createFunctions; 412 llvm::StringMap<PDLRewriteFunction> rewriteFunctions; 413 }; 414 415 //===----------------------------------------------------------------------===// 416 // PatternRewriter 417 //===----------------------------------------------------------------------===// 418 419 /// This class coordinates the application of a pattern to the current function, 420 /// providing a way to create operations and keep track of what gets deleted. 421 /// 422 /// These class serves two purposes: 423 /// 1) it is the interface that patterns interact with to make mutations to the 424 /// IR they are being applied to. 425 /// 2) It is a base class that clients of the PatternMatcher use when they want 426 /// to apply patterns and observe their effects (e.g. to keep worklists or 427 /// other data structures up to date). 428 /// 429 class PatternRewriter : public OpBuilder, public OpBuilder::Listener { 430 public: 431 /// Move the blocks that belong to "region" before the given position in 432 /// another region "parent". The two regions must be different. The caller 433 /// is responsible for creating or updating the operation transferring flow 434 /// of control to the region and passing it the correct block arguments. 435 virtual void inlineRegionBefore(Region ®ion, Region &parent, 436 Region::iterator before); 437 void inlineRegionBefore(Region ®ion, Block *before); 438 439 /// Clone the blocks that belong to "region" before the given position in 440 /// another region "parent". The two regions must be different. The caller is 441 /// responsible for creating or updating the operation transferring flow of 442 /// control to the region and passing it the correct block arguments. 443 virtual void cloneRegionBefore(Region ®ion, Region &parent, 444 Region::iterator before, 445 BlockAndValueMapping &mapping); 446 void cloneRegionBefore(Region ®ion, Region &parent, 447 Region::iterator before); 448 void cloneRegionBefore(Region ®ion, Block *before); 449 450 /// This method performs the final replacement for a pattern, where the 451 /// results of the operation are updated to use the specified list of SSA 452 /// values. 453 virtual void replaceOp(Operation *op, ValueRange newValues); 454 455 /// Replaces the result op with a new op that is created without verification. 456 /// The result values of the two ops must be the same types. 457 template <typename OpTy, typename... Args> replaceOpWithNewOp(Operation * op,Args &&...args)458 void replaceOpWithNewOp(Operation *op, Args &&... args) { 459 auto newOp = create<OpTy>(op->getLoc(), std::forward<Args>(args)...); 460 replaceOpWithResultsOfAnotherOp(op, newOp.getOperation()); 461 } 462 463 /// This method erases an operation that is known to have no uses. 464 virtual void eraseOp(Operation *op); 465 466 /// This method erases all operations in a block. 467 virtual void eraseBlock(Block *block); 468 469 /// Merge the operations of block 'source' into the end of block 'dest'. 470 /// 'source's predecessors must either be empty or only contain 'dest`. 471 /// 'argValues' is used to replace the block arguments of 'source' after 472 /// merging. 473 virtual void mergeBlocks(Block *source, Block *dest, 474 ValueRange argValues = llvm::None); 475 476 // Merge the operations of block 'source' before the operation 'op'. Source 477 // block should not have existing predecessors or successors. 478 void mergeBlockBefore(Block *source, Operation *op, 479 ValueRange argValues = llvm::None); 480 481 /// Split the operations starting at "before" (inclusive) out of the given 482 /// block into a new block, and return it. 483 virtual Block *splitBlock(Block *block, Block::iterator before); 484 485 /// This method is used to notify the rewriter that an in-place operation 486 /// modification is about to happen. A call to this function *must* be 487 /// followed by a call to either `finalizeRootUpdate` or `cancelRootUpdate`. 488 /// This is a minor efficiency win (it avoids creating a new operation and 489 /// removing the old one) but also often allows simpler code in the client. startRootUpdate(Operation * op)490 virtual void startRootUpdate(Operation *op) {} 491 492 /// This method is used to signal the end of a root update on the given 493 /// operation. This can only be called on operations that were provided to a 494 /// call to `startRootUpdate`. finalizeRootUpdate(Operation * op)495 virtual void finalizeRootUpdate(Operation *op) {} 496 497 /// This method cancels a pending root update. This can only be called on 498 /// operations that were provided to a call to `startRootUpdate`. cancelRootUpdate(Operation * op)499 virtual void cancelRootUpdate(Operation *op) {} 500 501 /// This method is a utility wrapper around a root update of an operation. It 502 /// wraps calls to `startRootUpdate` and `finalizeRootUpdate` around the given 503 /// callable. 504 template <typename CallableT> updateRootInPlace(Operation * root,CallableT && callable)505 void updateRootInPlace(Operation *root, CallableT &&callable) { 506 startRootUpdate(root); 507 callable(); 508 finalizeRootUpdate(root); 509 } 510 511 /// Notify the pattern rewriter that the pattern is failing to match the given 512 /// operation, and provide a callback to populate a diagnostic with the reason 513 /// why the failure occurred. This method allows for derived rewriters to 514 /// optionally hook into the reason why a pattern failed, and display it to 515 /// users. 516 template <typename CallbackT> 517 std::enable_if_t<!std::is_convertible<CallbackT, Twine>::value, LogicalResult> notifyMatchFailure(Operation * op,CallbackT && reasonCallback)518 notifyMatchFailure(Operation *op, CallbackT &&reasonCallback) { 519 #ifndef NDEBUG 520 return notifyMatchFailure(op, 521 function_ref<void(Diagnostic &)>(reasonCallback)); 522 #else 523 return failure(); 524 #endif 525 } notifyMatchFailure(Operation * op,const Twine & msg)526 LogicalResult notifyMatchFailure(Operation *op, const Twine &msg) { 527 return notifyMatchFailure(op, [&](Diagnostic &diag) { diag << msg; }); 528 } notifyMatchFailure(Operation * op,const char * msg)529 LogicalResult notifyMatchFailure(Operation *op, const char *msg) { 530 return notifyMatchFailure(op, Twine(msg)); 531 } 532 533 protected: 534 /// Initialize the builder with this rewriter as the listener. PatternRewriter(MLIRContext * ctx)535 explicit PatternRewriter(MLIRContext *ctx) 536 : OpBuilder(ctx, /*listener=*/this) {} 537 ~PatternRewriter() override; 538 539 /// These are the callback methods that subclasses can choose to implement if 540 /// they would like to be notified about certain types of mutations. 541 542 /// Notify the pattern rewriter that the specified operation is about to be 543 /// replaced with another set of operations. This is called before the uses 544 /// of the operation have been changed. notifyRootReplaced(Operation * op)545 virtual void notifyRootReplaced(Operation *op) {} 546 547 /// This is called on an operation that a pattern match is removing, right 548 /// before the operation is deleted. At this point, the operation has zero 549 /// uses. notifyOperationRemoved(Operation * op)550 virtual void notifyOperationRemoved(Operation *op) {} 551 552 /// Notify the pattern rewriter that the pattern is failing to match the given 553 /// operation, and provide a callback to populate a diagnostic with the reason 554 /// why the failure occurred. This method allows for derived rewriters to 555 /// optionally hook into the reason why a pattern failed, and display it to 556 /// users. 557 virtual LogicalResult notifyMatchFailure(Operation * op,function_ref<void (Diagnostic &)> reasonCallback)558 notifyMatchFailure(Operation *op, 559 function_ref<void(Diagnostic &)> reasonCallback) { 560 return failure(); 561 } 562 563 private: 564 /// 'op' and 'newOp' are known to have the same number of results, replace the 565 /// uses of op with uses of newOp. 566 void replaceOpWithResultsOfAnotherOp(Operation *op, Operation *newOp); 567 }; 568 569 //===----------------------------------------------------------------------===// 570 // OwningRewritePatternList 571 //===----------------------------------------------------------------------===// 572 573 class OwningRewritePatternList { 574 using NativePatternListT = std::vector<std::unique_ptr<RewritePattern>>; 575 576 public: 577 OwningRewritePatternList() = default; 578 579 /// Construct a OwningRewritePatternList populated with the given pattern. OwningRewritePatternList(std::unique_ptr<RewritePattern> pattern)580 OwningRewritePatternList(std::unique_ptr<RewritePattern> pattern) { 581 nativePatterns.emplace_back(std::move(pattern)); 582 } OwningRewritePatternList(PDLPatternModule && pattern)583 OwningRewritePatternList(PDLPatternModule &&pattern) 584 : pdlPatterns(std::move(pattern)) {} 585 586 /// Return the native patterns held in this list. getNativePatterns()587 NativePatternListT &getNativePatterns() { return nativePatterns; } 588 589 /// Return the PDL patterns held in this list. getPDLPatterns()590 PDLPatternModule &getPDLPatterns() { return pdlPatterns; } 591 592 /// Clear out all of the held patterns in this list. clear()593 void clear() { 594 nativePatterns.clear(); 595 pdlPatterns.clear(); 596 } 597 598 //===--------------------------------------------------------------------===// 599 // Pattern Insertion 600 //===--------------------------------------------------------------------===// 601 602 /// Add an instance of each of the pattern types 'Ts' to the pattern list with 603 /// the given arguments. Return a reference to `this` for chaining insertions. 604 /// Note: ConstructorArg is necessary here to separate the two variadic lists. 605 template <typename... Ts, typename ConstructorArg, 606 typename... ConstructorArgs, 607 typename = std::enable_if_t<sizeof...(Ts) != 0>> insert(ConstructorArg && arg,ConstructorArgs &&...args)608 OwningRewritePatternList &insert(ConstructorArg &&arg, 609 ConstructorArgs &&...args) { 610 // The following expands a call to emplace_back for each of the pattern 611 // types 'Ts'. This magic is necessary due to a limitation in the places 612 // that a parameter pack can be expanded in c++11. 613 // FIXME: In c++17 this can be simplified by using 'fold expressions'. 614 (void)std::initializer_list<int>{0, (insertImpl<Ts>(arg, args...), 0)...}; 615 return *this; 616 } 617 618 /// Add an instance of each of the pattern types 'Ts'. Return a reference to 619 /// `this` for chaining insertions. insert()620 template <typename... Ts> OwningRewritePatternList &insert() { 621 (void)std::initializer_list<int>{0, (insertImpl<Ts>(), 0)...}; 622 return *this; 623 } 624 625 /// Add the given native pattern to the pattern list. Return a reference to 626 /// `this` for chaining insertions. insert(std::unique_ptr<RewritePattern> pattern)627 OwningRewritePatternList &insert(std::unique_ptr<RewritePattern> pattern) { 628 nativePatterns.emplace_back(std::move(pattern)); 629 return *this; 630 } 631 632 /// Add the given PDL pattern to the pattern list. Return a reference to 633 /// `this` for chaining insertions. insert(PDLPatternModule && pattern)634 OwningRewritePatternList &insert(PDLPatternModule &&pattern) { 635 pdlPatterns.mergeIn(std::move(pattern)); 636 return *this; 637 } 638 639 private: 640 /// Add an instance of the pattern type 'T'. Return a reference to `this` for 641 /// chaining insertions. 642 template <typename T, typename... Args> 643 std::enable_if_t<std::is_base_of<RewritePattern, T>::value> insertImpl(Args &&...args)644 insertImpl(Args &&...args) { 645 nativePatterns.emplace_back( 646 std::make_unique<T>(std::forward<Args>(args)...)); 647 } 648 template <typename T, typename... Args> 649 std::enable_if_t<std::is_base_of<PDLPatternModule, T>::value> insertImpl(Args &&...args)650 insertImpl(Args &&...args) { 651 pdlPatterns.mergeIn(T(std::forward<Args>(args)...)); 652 } 653 654 NativePatternListT nativePatterns; 655 PDLPatternModule pdlPatterns; 656 }; 657 658 } // end namespace mlir 659 660 #endif // MLIR_PATTERN_MATCH_H 661