1 //===- PatternApplicator.cpp - Pattern Application Engine -------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements an applicator that applies pattern rewrites based upon a
10 // user defined cost model.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "mlir/Rewrite/PatternApplicator.h"
15 #include "ByteCode.h"
16 #include "llvm/Support/Debug.h"
17
18 using namespace mlir;
19 using namespace mlir::detail;
20
PatternApplicator(const FrozenRewritePatternList & frozenPatternList)21 PatternApplicator::PatternApplicator(
22 const FrozenRewritePatternList &frozenPatternList)
23 : frozenPatternList(frozenPatternList) {
24 if (const PDLByteCode *bytecode = frozenPatternList.getPDLByteCode()) {
25 mutableByteCodeState = std::make_unique<PDLByteCodeMutableState>();
26 bytecode->initializeMutableState(*mutableByteCodeState);
27 }
28 }
~PatternApplicator()29 PatternApplicator::~PatternApplicator() {}
30
31 #define DEBUG_TYPE "pattern-match"
32
applyCostModel(CostModel model)33 void PatternApplicator::applyCostModel(CostModel model) {
34 // Apply the cost model to the bytecode patterns first, and then the native
35 // patterns.
36 if (const PDLByteCode *bytecode = frozenPatternList.getPDLByteCode()) {
37 for (auto it : llvm::enumerate(bytecode->getPatterns()))
38 mutableByteCodeState->updatePatternBenefit(it.index(), model(it.value()));
39 }
40
41 // Separate patterns by root kind to simplify lookup later on.
42 patterns.clear();
43 anyOpPatterns.clear();
44 for (const auto &pat : frozenPatternList.getNativePatterns()) {
45 // If the pattern is always impossible to match, just ignore it.
46 if (pat.getBenefit().isImpossibleToMatch()) {
47 LLVM_DEBUG({
48 llvm::dbgs()
49 << "Ignoring pattern '" << pat.getRootKind()
50 << "' because it is impossible to match (by pattern benefit)\n";
51 });
52 continue;
53 }
54 if (Optional<OperationName> opName = pat.getRootKind())
55 patterns[*opName].push_back(&pat);
56 else
57 anyOpPatterns.push_back(&pat);
58 }
59
60 // Sort the patterns using the provided cost model.
61 llvm::SmallDenseMap<const Pattern *, PatternBenefit> benefits;
62 auto cmp = [&benefits](const Pattern *lhs, const Pattern *rhs) {
63 return benefits[lhs] > benefits[rhs];
64 };
65 auto processPatternList = [&](SmallVectorImpl<const RewritePattern *> &list) {
66 // Special case for one pattern in the list, which is the most common case.
67 if (list.size() == 1) {
68 if (model(*list.front()).isImpossibleToMatch()) {
69 LLVM_DEBUG({
70 llvm::dbgs() << "Ignoring pattern '" << list.front()->getRootKind()
71 << "' because it is impossible to match or cannot lead "
72 "to legal IR (by cost model)\n";
73 });
74 list.clear();
75 }
76 return;
77 }
78
79 // Collect the dynamic benefits for the current pattern list.
80 benefits.clear();
81 for (const Pattern *pat : list)
82 benefits.try_emplace(pat, model(*pat));
83
84 // Sort patterns with highest benefit first, and remove those that are
85 // impossible to match.
86 std::stable_sort(list.begin(), list.end(), cmp);
87 while (!list.empty() && benefits[list.back()].isImpossibleToMatch()) {
88 LLVM_DEBUG({
89 llvm::dbgs() << "Ignoring pattern '" << list.back()->getRootKind()
90 << "' because it is impossible to match or cannot lead to "
91 "legal IR (by cost model)\n";
92 });
93 list.pop_back();
94 }
95 };
96 for (auto &it : patterns)
97 processPatternList(it.second);
98 processPatternList(anyOpPatterns);
99 }
100
walkAllPatterns(function_ref<void (const Pattern &)> walk)101 void PatternApplicator::walkAllPatterns(
102 function_ref<void(const Pattern &)> walk) {
103 for (const Pattern &it : frozenPatternList.getNativePatterns())
104 walk(it);
105 if (const PDLByteCode *bytecode = frozenPatternList.getPDLByteCode()) {
106 for (const Pattern &it : bytecode->getPatterns())
107 walk(it);
108 }
109 }
110
matchAndRewrite(Operation * op,PatternRewriter & rewriter,function_ref<bool (const Pattern &)> canApply,function_ref<void (const Pattern &)> onFailure,function_ref<LogicalResult (const Pattern &)> onSuccess)111 LogicalResult PatternApplicator::matchAndRewrite(
112 Operation *op, PatternRewriter &rewriter,
113 function_ref<bool(const Pattern &)> canApply,
114 function_ref<void(const Pattern &)> onFailure,
115 function_ref<LogicalResult(const Pattern &)> onSuccess) {
116 // Before checking native patterns, first match against the bytecode. This
117 // won't automatically perform any rewrites so there is no need to worry about
118 // conflicts.
119 SmallVector<PDLByteCode::MatchResult, 4> pdlMatches;
120 const PDLByteCode *bytecode = frozenPatternList.getPDLByteCode();
121 if (bytecode)
122 bytecode->match(op, rewriter, pdlMatches, *mutableByteCodeState);
123
124 // Check to see if there are patterns matching this specific operation type.
125 MutableArrayRef<const RewritePattern *> opPatterns;
126 auto patternIt = patterns.find(op->getName());
127 if (patternIt != patterns.end())
128 opPatterns = patternIt->second;
129
130 // Process the patterns for that match the specific operation type, and any
131 // operation type in an interleaved fashion.
132 auto opIt = opPatterns.begin(), opE = opPatterns.end();
133 auto anyIt = anyOpPatterns.begin(), anyE = anyOpPatterns.end();
134 auto pdlIt = pdlMatches.begin(), pdlE = pdlMatches.end();
135 while (true) {
136 // Find the next pattern with the highest benefit.
137 const Pattern *bestPattern = nullptr;
138 const PDLByteCode::MatchResult *pdlMatch = nullptr;
139 /// Operation specific patterns.
140 if (opIt != opE)
141 bestPattern = *(opIt++);
142 /// Operation agnostic patterns.
143 if (anyIt != anyE &&
144 (!bestPattern || bestPattern->getBenefit() < (*anyIt)->getBenefit()))
145 bestPattern = *(anyIt++);
146 /// PDL patterns.
147 if (pdlIt != pdlE &&
148 (!bestPattern || bestPattern->getBenefit() < pdlIt->benefit)) {
149 pdlMatch = pdlIt;
150 bestPattern = (pdlIt++)->pattern;
151 }
152 if (!bestPattern)
153 break;
154
155 // Check that the pattern can be applied.
156 if (canApply && !canApply(*bestPattern))
157 continue;
158
159 // Try to match and rewrite this pattern. The patterns are sorted by
160 // benefit, so if we match we can immediately rewrite. For PDL patterns, the
161 // match has already been performed, we just need to rewrite.
162 rewriter.setInsertionPoint(op);
163 LogicalResult result = success();
164 if (pdlMatch) {
165 bytecode->rewrite(rewriter, *pdlMatch, *mutableByteCodeState);
166 } else {
167 result = static_cast<const RewritePattern *>(bestPattern)
168 ->matchAndRewrite(op, rewriter);
169 }
170 if (succeeded(result) && (!onSuccess || succeeded(onSuccess(*bestPattern))))
171 return success();
172
173 // Perform any necessary cleanups.
174 if (onFailure)
175 onFailure(*bestPattern);
176 }
177 return failure();
178 }
179