1 //===- PatternMatch.cpp - Base classes for pattern match ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/IR/PatternMatch.h"
10 #include "mlir/IR/BlockAndValueMapping.h"
11
12 using namespace mlir;
13
14 //===----------------------------------------------------------------------===//
15 // PatternBenefit
16 //===----------------------------------------------------------------------===//
17
PatternBenefit(unsigned benefit)18 PatternBenefit::PatternBenefit(unsigned benefit) : representation(benefit) {
19 assert(representation == benefit && benefit != ImpossibleToMatchSentinel &&
20 "This pattern match benefit is too large to represent");
21 }
22
getBenefit() const23 unsigned short PatternBenefit::getBenefit() const {
24 assert(!isImpossibleToMatch() && "Pattern doesn't match");
25 return representation;
26 }
27
28 //===----------------------------------------------------------------------===//
29 // Pattern
30 //===----------------------------------------------------------------------===//
31
Pattern(StringRef rootName,PatternBenefit benefit,MLIRContext * context)32 Pattern::Pattern(StringRef rootName, PatternBenefit benefit,
33 MLIRContext *context)
34 : rootKind(OperationName(rootName, context)), benefit(benefit) {}
Pattern(PatternBenefit benefit,MatchAnyOpTypeTag tag)35 Pattern::Pattern(PatternBenefit benefit, MatchAnyOpTypeTag tag)
36 : benefit(benefit) {}
Pattern(StringRef rootName,ArrayRef<StringRef> generatedNames,PatternBenefit benefit,MLIRContext * context)37 Pattern::Pattern(StringRef rootName, ArrayRef<StringRef> generatedNames,
38 PatternBenefit benefit, MLIRContext *context)
39 : Pattern(rootName, benefit, context) {
40 generatedOps.reserve(generatedNames.size());
41 std::transform(generatedNames.begin(), generatedNames.end(),
42 std::back_inserter(generatedOps), [context](StringRef name) {
43 return OperationName(name, context);
44 });
45 }
Pattern(ArrayRef<StringRef> generatedNames,PatternBenefit benefit,MLIRContext * context,MatchAnyOpTypeTag tag)46 Pattern::Pattern(ArrayRef<StringRef> generatedNames, PatternBenefit benefit,
47 MLIRContext *context, MatchAnyOpTypeTag tag)
48 : Pattern(benefit, tag) {
49 generatedOps.reserve(generatedNames.size());
50 std::transform(generatedNames.begin(), generatedNames.end(),
51 std::back_inserter(generatedOps), [context](StringRef name) {
52 return OperationName(name, context);
53 });
54 }
55
56 //===----------------------------------------------------------------------===//
57 // RewritePattern
58 //===----------------------------------------------------------------------===//
59
rewrite(Operation * op,PatternRewriter & rewriter) const60 void RewritePattern::rewrite(Operation *op, PatternRewriter &rewriter) const {
61 llvm_unreachable("need to implement either matchAndRewrite or one of the "
62 "rewrite functions!");
63 }
64
match(Operation * op) const65 LogicalResult RewritePattern::match(Operation *op) const {
66 llvm_unreachable("need to implement either match or matchAndRewrite!");
67 }
68
69 /// Out-of-line vtable anchor.
anchor()70 void RewritePattern::anchor() {}
71
72 //===----------------------------------------------------------------------===//
73 // PDLValue
74 //===----------------------------------------------------------------------===//
75
print(raw_ostream & os)76 void PDLValue::print(raw_ostream &os) {
77 if (!impl) {
78 os << "<Null-PDLValue>";
79 return;
80 }
81 if (Value val = impl.dyn_cast<Value>()) {
82 os << val;
83 return;
84 }
85 AttrOpTypeImplT aotImpl = impl.get<AttrOpTypeImplT>();
86 if (Attribute attr = aotImpl.dyn_cast<Attribute>())
87 os << attr;
88 else if (Operation *op = aotImpl.dyn_cast<Operation *>())
89 os << *op;
90 else
91 os << aotImpl.get<Type>();
92 }
93
94 //===----------------------------------------------------------------------===//
95 // PDLPatternModule
96 //===----------------------------------------------------------------------===//
97
mergeIn(PDLPatternModule && other)98 void PDLPatternModule::mergeIn(PDLPatternModule &&other) {
99 // Ignore the other module if it has no patterns.
100 if (!other.pdlModule)
101 return;
102 // Steal the other state if we have no patterns.
103 if (!pdlModule) {
104 constraintFunctions = std::move(other.constraintFunctions);
105 createFunctions = std::move(other.createFunctions);
106 rewriteFunctions = std::move(other.rewriteFunctions);
107 pdlModule = std::move(other.pdlModule);
108 return;
109 }
110 // Steal the functions of the other module.
111 for (auto &it : constraintFunctions)
112 registerConstraintFunction(it.first(), std::move(it.second));
113 for (auto &it : createFunctions)
114 registerCreateFunction(it.first(), std::move(it.second));
115 for (auto &it : rewriteFunctions)
116 registerRewriteFunction(it.first(), std::move(it.second));
117
118 // Merge the pattern operations from the other module into this one.
119 Block *block = pdlModule->getBody();
120 block->getTerminator()->erase();
121 block->getOperations().splice(block->end(),
122 other.pdlModule->getBody()->getOperations());
123 }
124
125 //===----------------------------------------------------------------------===//
126 // Function Registry
127
registerConstraintFunction(StringRef name,PDLConstraintFunction constraintFn)128 void PDLPatternModule::registerConstraintFunction(
129 StringRef name, PDLConstraintFunction constraintFn) {
130 auto it = constraintFunctions.try_emplace(name, std::move(constraintFn));
131 (void)it;
132 assert(it.second &&
133 "constraint with the given name has already been registered");
134 }
registerCreateFunction(StringRef name,PDLCreateFunction createFn)135 void PDLPatternModule::registerCreateFunction(StringRef name,
136 PDLCreateFunction createFn) {
137 auto it = createFunctions.try_emplace(name, std::move(createFn));
138 (void)it;
139 assert(it.second && "native create function with the given name has "
140 "already been registered");
141 }
registerRewriteFunction(StringRef name,PDLRewriteFunction rewriteFn)142 void PDLPatternModule::registerRewriteFunction(StringRef name,
143 PDLRewriteFunction rewriteFn) {
144 auto it = rewriteFunctions.try_emplace(name, std::move(rewriteFn));
145 (void)it;
146 assert(it.second && "native rewrite function with the given name has "
147 "already been registered");
148 }
149
150 //===----------------------------------------------------------------------===//
151 // PatternRewriter
152 //===----------------------------------------------------------------------===//
153
~PatternRewriter()154 PatternRewriter::~PatternRewriter() {
155 // Out of line to provide a vtable anchor for the class.
156 }
157
158 /// This method performs the final replacement for a pattern, where the
159 /// results of the operation are updated to use the specified list of SSA
160 /// values.
replaceOp(Operation * op,ValueRange newValues)161 void PatternRewriter::replaceOp(Operation *op, ValueRange newValues) {
162 // Notify the rewriter subclass that we're about to replace this root.
163 notifyRootReplaced(op);
164
165 assert(op->getNumResults() == newValues.size() &&
166 "incorrect # of replacement values");
167 op->replaceAllUsesWith(newValues);
168
169 notifyOperationRemoved(op);
170 op->erase();
171 }
172
173 /// This method erases an operation that is known to have no uses. The uses of
174 /// the given operation *must* be known to be dead.
eraseOp(Operation * op)175 void PatternRewriter::eraseOp(Operation *op) {
176 assert(op->use_empty() && "expected 'op' to have no uses");
177 notifyOperationRemoved(op);
178 op->erase();
179 }
180
eraseBlock(Block * block)181 void PatternRewriter::eraseBlock(Block *block) {
182 for (auto &op : llvm::make_early_inc_range(llvm::reverse(*block))) {
183 assert(op.use_empty() && "expected 'op' to have no uses");
184 eraseOp(&op);
185 }
186 block->erase();
187 }
188
189 /// Merge the operations of block 'source' into the end of block 'dest'.
190 /// 'source's predecessors must be empty or only contain 'dest`.
191 /// 'argValues' is used to replace the block arguments of 'source' after
192 /// merging.
mergeBlocks(Block * source,Block * dest,ValueRange argValues)193 void PatternRewriter::mergeBlocks(Block *source, Block *dest,
194 ValueRange argValues) {
195 assert(llvm::all_of(source->getPredecessors(),
196 [dest](Block *succ) { return succ == dest; }) &&
197 "expected 'source' to have no predecessors or only 'dest'");
198 assert(argValues.size() == source->getNumArguments() &&
199 "incorrect # of argument replacement values");
200
201 // Replace all of the successor arguments with the provided values.
202 for (auto it : llvm::zip(source->getArguments(), argValues))
203 std::get<0>(it).replaceAllUsesWith(std::get<1>(it));
204
205 // Splice the operations of the 'source' block into the 'dest' block and erase
206 // it.
207 dest->getOperations().splice(dest->end(), source->getOperations());
208 source->dropAllUses();
209 source->erase();
210 }
211
212 // Merge the operations of block 'source' before the operation 'op'. Source
213 // block should not have existing predecessors or successors.
mergeBlockBefore(Block * source,Operation * op,ValueRange argValues)214 void PatternRewriter::mergeBlockBefore(Block *source, Operation *op,
215 ValueRange argValues) {
216 assert(source->hasNoPredecessors() &&
217 "expected 'source' to have no predecessors");
218 assert(source->hasNoSuccessors() &&
219 "expected 'source' to have no successors");
220
221 // Split the block containing 'op' into two, one containing all operations
222 // before 'op' (prologue) and another (epilogue) containing 'op' and all
223 // operations after it.
224 Block *prologue = op->getBlock();
225 Block *epilogue = splitBlock(prologue, op->getIterator());
226
227 // Merge the source block at the end of the prologue.
228 mergeBlocks(source, prologue, argValues);
229
230 // Merge the epilogue at the end the prologue.
231 mergeBlocks(epilogue, prologue);
232 }
233
234 /// Split the operations starting at "before" (inclusive) out of the given
235 /// block into a new block, and return it.
splitBlock(Block * block,Block::iterator before)236 Block *PatternRewriter::splitBlock(Block *block, Block::iterator before) {
237 return block->splitBlock(before);
238 }
239
240 /// 'op' and 'newOp' are known to have the same number of results, replace the
241 /// uses of op with uses of newOp
replaceOpWithResultsOfAnotherOp(Operation * op,Operation * newOp)242 void PatternRewriter::replaceOpWithResultsOfAnotherOp(Operation *op,
243 Operation *newOp) {
244 assert(op->getNumResults() == newOp->getNumResults() &&
245 "replacement op doesn't match results of original op");
246 if (op->getNumResults() == 1)
247 return replaceOp(op, newOp->getResult(0));
248 return replaceOp(op, newOp->getResults());
249 }
250
251 /// Move the blocks that belong to "region" before the given position in
252 /// another region. The two regions must be different. The caller is in
253 /// charge to update create the operation transferring the control flow to the
254 /// region and pass it the correct block arguments.
inlineRegionBefore(Region & region,Region & parent,Region::iterator before)255 void PatternRewriter::inlineRegionBefore(Region ®ion, Region &parent,
256 Region::iterator before) {
257 parent.getBlocks().splice(before, region.getBlocks());
258 }
inlineRegionBefore(Region & region,Block * before)259 void PatternRewriter::inlineRegionBefore(Region ®ion, Block *before) {
260 inlineRegionBefore(region, *before->getParent(), before->getIterator());
261 }
262
263 /// Clone the blocks that belong to "region" before the given position in
264 /// another region "parent". The two regions must be different. The caller is
265 /// responsible for creating or updating the operation transferring flow of
266 /// control to the region and passing it the correct block arguments.
cloneRegionBefore(Region & region,Region & parent,Region::iterator before,BlockAndValueMapping & mapping)267 void PatternRewriter::cloneRegionBefore(Region ®ion, Region &parent,
268 Region::iterator before,
269 BlockAndValueMapping &mapping) {
270 region.cloneInto(&parent, before, mapping);
271 }
cloneRegionBefore(Region & region,Region & parent,Region::iterator before)272 void PatternRewriter::cloneRegionBefore(Region ®ion, Region &parent,
273 Region::iterator before) {
274 BlockAndValueMapping mapping;
275 cloneRegionBefore(region, parent, before, mapping);
276 }
cloneRegionBefore(Region & region,Block * before)277 void PatternRewriter::cloneRegionBefore(Region ®ion, Block *before) {
278 cloneRegionBefore(region, *before->getParent(), before->getIterator());
279 }
280
281