1 //===- RegionUtils.cpp - Region-related transformation utilities ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Transforms/RegionUtils.h"
10 #include "mlir/IR/Block.h"
11 #include "mlir/IR/Operation.h"
12 #include "mlir/IR/RegionGraphTraits.h"
13 #include "mlir/IR/Value.h"
14 #include "mlir/Interfaces/ControlFlowInterfaces.h"
15 #include "mlir/Interfaces/SideEffectInterfaces.h"
16
17 #include "llvm/ADT/DepthFirstIterator.h"
18 #include "llvm/ADT/PostOrderIterator.h"
19 #include "llvm/ADT/SmallSet.h"
20
21 using namespace mlir;
22
replaceAllUsesInRegionWith(Value orig,Value replacement,Region & region)23 void mlir::replaceAllUsesInRegionWith(Value orig, Value replacement,
24 Region ®ion) {
25 for (auto &use : llvm::make_early_inc_range(orig.getUses())) {
26 if (region.isAncestor(use.getOwner()->getParentRegion()))
27 use.set(replacement);
28 }
29 }
30
visitUsedValuesDefinedAbove(Region & region,Region & limit,function_ref<void (OpOperand *)> callback)31 void mlir::visitUsedValuesDefinedAbove(
32 Region ®ion, Region &limit, function_ref<void(OpOperand *)> callback) {
33 assert(limit.isAncestor(®ion) &&
34 "expected isolation limit to be an ancestor of the given region");
35
36 // Collect proper ancestors of `limit` upfront to avoid traversing the region
37 // tree for every value.
38 SmallPtrSet<Region *, 4> properAncestors;
39 for (auto *reg = limit.getParentRegion(); reg != nullptr;
40 reg = reg->getParentRegion()) {
41 properAncestors.insert(reg);
42 }
43
44 region.walk([callback, &properAncestors](Operation *op) {
45 for (OpOperand &operand : op->getOpOperands())
46 // Callback on values defined in a proper ancestor of region.
47 if (properAncestors.count(operand.get().getParentRegion()))
48 callback(&operand);
49 });
50 }
51
visitUsedValuesDefinedAbove(MutableArrayRef<Region> regions,function_ref<void (OpOperand *)> callback)52 void mlir::visitUsedValuesDefinedAbove(
53 MutableArrayRef<Region> regions, function_ref<void(OpOperand *)> callback) {
54 for (Region ®ion : regions)
55 visitUsedValuesDefinedAbove(region, region, callback);
56 }
57
getUsedValuesDefinedAbove(Region & region,Region & limit,llvm::SetVector<Value> & values)58 void mlir::getUsedValuesDefinedAbove(Region ®ion, Region &limit,
59 llvm::SetVector<Value> &values) {
60 visitUsedValuesDefinedAbove(region, limit, [&](OpOperand *operand) {
61 values.insert(operand->get());
62 });
63 }
64
getUsedValuesDefinedAbove(MutableArrayRef<Region> regions,llvm::SetVector<Value> & values)65 void mlir::getUsedValuesDefinedAbove(MutableArrayRef<Region> regions,
66 llvm::SetVector<Value> &values) {
67 for (Region ®ion : regions)
68 getUsedValuesDefinedAbove(region, region, values);
69 }
70
71 //===----------------------------------------------------------------------===//
72 // Unreachable Block Elimination
73 //===----------------------------------------------------------------------===//
74
75 /// Erase the unreachable blocks within the provided regions. Returns success
76 /// if any blocks were erased, failure otherwise.
77 // TODO: We could likely merge this with the DCE algorithm below.
eraseUnreachableBlocks(MutableArrayRef<Region> regions)78 static LogicalResult eraseUnreachableBlocks(MutableArrayRef<Region> regions) {
79 // Set of blocks found to be reachable within a given region.
80 llvm::df_iterator_default_set<Block *, 16> reachable;
81 // If any blocks were found to be dead.
82 bool erasedDeadBlocks = false;
83
84 SmallVector<Region *, 1> worklist;
85 worklist.reserve(regions.size());
86 for (Region ®ion : regions)
87 worklist.push_back(®ion);
88 while (!worklist.empty()) {
89 Region *region = worklist.pop_back_val();
90 if (region->empty())
91 continue;
92
93 // If this is a single block region, just collect the nested regions.
94 if (std::next(region->begin()) == region->end()) {
95 for (Operation &op : region->front())
96 for (Region ®ion : op.getRegions())
97 worklist.push_back(®ion);
98 continue;
99 }
100
101 // Mark all reachable blocks.
102 reachable.clear();
103 for (Block *block : depth_first_ext(®ion->front(), reachable))
104 (void)block /* Mark all reachable blocks */;
105
106 // Collect all of the dead blocks and push the live regions onto the
107 // worklist.
108 for (Block &block : llvm::make_early_inc_range(*region)) {
109 if (!reachable.count(&block)) {
110 block.dropAllDefinedValueUses();
111 block.erase();
112 erasedDeadBlocks = true;
113 continue;
114 }
115
116 // Walk any regions within this block.
117 for (Operation &op : block)
118 for (Region ®ion : op.getRegions())
119 worklist.push_back(®ion);
120 }
121 }
122
123 return success(erasedDeadBlocks);
124 }
125
126 //===----------------------------------------------------------------------===//
127 // Dead Code Elimination
128 //===----------------------------------------------------------------------===//
129
130 namespace {
131 /// Data structure used to track which values have already been proved live.
132 ///
133 /// Because Operation's can have multiple results, this data structure tracks
134 /// liveness for both Value's and Operation's to avoid having to look through
135 /// all Operation results when analyzing a use.
136 ///
137 /// This data structure essentially tracks the dataflow lattice.
138 /// The set of values/ops proved live increases monotonically to a fixed-point.
139 class LiveMap {
140 public:
141 /// Value methods.
wasProvenLive(Value value)142 bool wasProvenLive(Value value) { return liveValues.count(value); }
setProvedLive(Value value)143 void setProvedLive(Value value) {
144 changed |= liveValues.insert(value).second;
145 }
146
147 /// Operation methods.
wasProvenLive(Operation * op)148 bool wasProvenLive(Operation *op) { return liveOps.count(op); }
setProvedLive(Operation * op)149 void setProvedLive(Operation *op) { changed |= liveOps.insert(op).second; }
150
151 /// Methods for tracking if we have reached a fixed-point.
resetChanged()152 void resetChanged() { changed = false; }
hasChanged()153 bool hasChanged() { return changed; }
154
155 private:
156 bool changed = false;
157 DenseSet<Value> liveValues;
158 DenseSet<Operation *> liveOps;
159 };
160 } // namespace
161
isUseSpeciallyKnownDead(OpOperand & use,LiveMap & liveMap)162 static bool isUseSpeciallyKnownDead(OpOperand &use, LiveMap &liveMap) {
163 Operation *owner = use.getOwner();
164 unsigned operandIndex = use.getOperandNumber();
165 // This pass generally treats all uses of an op as live if the op itself is
166 // considered live. However, for successor operands to terminators we need a
167 // finer-grained notion where we deduce liveness for operands individually.
168 // The reason for this is easiest to think about in terms of a classical phi
169 // node based SSA IR, where each successor operand is really an operand to a
170 // *separate* phi node, rather than all operands to the branch itself as with
171 // the block argument representation that MLIR uses.
172 //
173 // And similarly, because each successor operand is really an operand to a phi
174 // node, rather than to the terminator op itself, a terminator op can't e.g.
175 // "print" the value of a successor operand.
176 if (owner->isKnownTerminator()) {
177 if (BranchOpInterface branchInterface = dyn_cast<BranchOpInterface>(owner))
178 if (auto arg = branchInterface.getSuccessorBlockArgument(operandIndex))
179 return !liveMap.wasProvenLive(*arg);
180 return false;
181 }
182 return false;
183 }
184
processValue(Value value,LiveMap & liveMap)185 static void processValue(Value value, LiveMap &liveMap) {
186 bool provedLive = llvm::any_of(value.getUses(), [&](OpOperand &use) {
187 if (isUseSpeciallyKnownDead(use, liveMap))
188 return false;
189 return liveMap.wasProvenLive(use.getOwner());
190 });
191 if (provedLive)
192 liveMap.setProvedLive(value);
193 }
194
isOpIntrinsicallyLive(Operation * op)195 static bool isOpIntrinsicallyLive(Operation *op) {
196 // This pass doesn't modify the CFG, so terminators are never deleted.
197 if (!op->isKnownNonTerminator())
198 return true;
199 // If the op has a side effect, we treat it as live.
200 // TODO: Properly handle region side effects.
201 return !MemoryEffectOpInterface::hasNoEffect(op) || op->getNumRegions() != 0;
202 }
203
204 static void propagateLiveness(Region ®ion, LiveMap &liveMap);
205
propagateTerminatorLiveness(Operation * op,LiveMap & liveMap)206 static void propagateTerminatorLiveness(Operation *op, LiveMap &liveMap) {
207 // Terminators are always live.
208 liveMap.setProvedLive(op);
209
210 // Check to see if we can reason about the successor operands and mutate them.
211 BranchOpInterface branchInterface = dyn_cast<BranchOpInterface>(op);
212 if (!branchInterface) {
213 for (Block *successor : op->getSuccessors())
214 for (BlockArgument arg : successor->getArguments())
215 liveMap.setProvedLive(arg);
216 return;
217 }
218
219 // If we can't reason about the operands to a successor, conservatively mark
220 // all arguments as live.
221 for (unsigned i = 0, e = op->getNumSuccessors(); i != e; ++i) {
222 if (!branchInterface.getMutableSuccessorOperands(i))
223 for (BlockArgument arg : op->getSuccessor(i)->getArguments())
224 liveMap.setProvedLive(arg);
225 }
226 }
227
propagateLiveness(Operation * op,LiveMap & liveMap)228 static void propagateLiveness(Operation *op, LiveMap &liveMap) {
229 // All Value's are either a block argument or an op result.
230 // We call processValue on those cases.
231
232 // Recurse on any regions the op has.
233 for (Region ®ion : op->getRegions())
234 propagateLiveness(region, liveMap);
235
236 // Process terminator operations.
237 if (op->isKnownTerminator())
238 return propagateTerminatorLiveness(op, liveMap);
239
240 // Process the op itself.
241 if (isOpIntrinsicallyLive(op)) {
242 liveMap.setProvedLive(op);
243 return;
244 }
245 for (Value value : op->getResults())
246 processValue(value, liveMap);
247 bool provedLive = llvm::any_of(op->getResults(), [&](Value value) {
248 return liveMap.wasProvenLive(value);
249 });
250 if (provedLive)
251 liveMap.setProvedLive(op);
252 }
253
propagateLiveness(Region & region,LiveMap & liveMap)254 static void propagateLiveness(Region ®ion, LiveMap &liveMap) {
255 if (region.empty())
256 return;
257
258 for (Block *block : llvm::post_order(®ion.front())) {
259 // We process block arguments after the ops in the block, to promote
260 // faster convergence to a fixed point (we try to visit uses before defs).
261 for (Operation &op : llvm::reverse(block->getOperations()))
262 propagateLiveness(&op, liveMap);
263 for (Value value : block->getArguments())
264 processValue(value, liveMap);
265 }
266 }
267
eraseTerminatorSuccessorOperands(Operation * terminator,LiveMap & liveMap)268 static void eraseTerminatorSuccessorOperands(Operation *terminator,
269 LiveMap &liveMap) {
270 BranchOpInterface branchOp = dyn_cast<BranchOpInterface>(terminator);
271 if (!branchOp)
272 return;
273
274 for (unsigned succI = 0, succE = terminator->getNumSuccessors();
275 succI < succE; succI++) {
276 // Iterating successors in reverse is not strictly needed, since we
277 // aren't erasing any successors. But it is slightly more efficient
278 // since it will promote later operands of the terminator being erased
279 // first, reducing the quadratic-ness.
280 unsigned succ = succE - succI - 1;
281 Optional<MutableOperandRange> succOperands =
282 branchOp.getMutableSuccessorOperands(succ);
283 if (!succOperands)
284 continue;
285 Block *successor = terminator->getSuccessor(succ);
286
287 for (unsigned argI = 0, argE = succOperands->size(); argI < argE; ++argI) {
288 // Iterating args in reverse is needed for correctness, to avoid
289 // shifting later args when earlier args are erased.
290 unsigned arg = argE - argI - 1;
291 if (!liveMap.wasProvenLive(successor->getArgument(arg)))
292 succOperands->erase(arg);
293 }
294 }
295 }
296
deleteDeadness(MutableArrayRef<Region> regions,LiveMap & liveMap)297 static LogicalResult deleteDeadness(MutableArrayRef<Region> regions,
298 LiveMap &liveMap) {
299 bool erasedAnything = false;
300 for (Region ®ion : regions) {
301 if (region.empty())
302 continue;
303
304 // We do the deletion in an order that deletes all uses before deleting
305 // defs.
306 // MLIR's SSA structural invariants guarantee that except for block
307 // arguments, the use-def graph is acyclic, so this is possible with a
308 // single walk of ops and then a final pass to clean up block arguments.
309 //
310 // To do this, we visit ops in an order that visits domtree children
311 // before domtree parents. A CFG post-order (with reverse iteration with a
312 // block) satisfies that without needing an explicit domtree calculation.
313 for (Block *block : llvm::post_order(®ion.front())) {
314 eraseTerminatorSuccessorOperands(block->getTerminator(), liveMap);
315 for (Operation &childOp :
316 llvm::make_early_inc_range(llvm::reverse(block->getOperations()))) {
317 erasedAnything |=
318 succeeded(deleteDeadness(childOp.getRegions(), liveMap));
319 if (!liveMap.wasProvenLive(&childOp)) {
320 erasedAnything = true;
321 childOp.erase();
322 }
323 }
324 }
325 // Delete block arguments.
326 // The entry block has an unknown contract with their enclosing block, so
327 // skip it.
328 for (Block &block : llvm::drop_begin(region.getBlocks(), 1)) {
329 // Iterate in reverse to avoid shifting later arguments when deleting
330 // earlier arguments.
331 for (unsigned i = 0, e = block.getNumArguments(); i < e; i++)
332 if (!liveMap.wasProvenLive(block.getArgument(e - i - 1))) {
333 block.eraseArgument(e - i - 1);
334 erasedAnything = true;
335 }
336 }
337 }
338 return success(erasedAnything);
339 }
340
341 // This function performs a simple dead code elimination algorithm over the
342 // given regions.
343 //
344 // The overall goal is to prove that Values are dead, which allows deleting ops
345 // and block arguments.
346 //
347 // This uses an optimistic algorithm that assumes everything is dead until
348 // proved otherwise, allowing it to delete recursively dead cycles.
349 //
350 // This is a simple fixed-point dataflow analysis algorithm on a lattice
351 // {Dead,Alive}. Because liveness flows backward, we generally try to
352 // iterate everything backward to speed up convergence to the fixed-point. This
353 // allows for being able to delete recursively dead cycles of the use-def graph,
354 // including block arguments.
355 //
356 // This function returns success if any operations or arguments were deleted,
357 // failure otherwise.
runRegionDCE(MutableArrayRef<Region> regions)358 static LogicalResult runRegionDCE(MutableArrayRef<Region> regions) {
359 LiveMap liveMap;
360 do {
361 liveMap.resetChanged();
362
363 for (Region ®ion : regions)
364 propagateLiveness(region, liveMap);
365 } while (liveMap.hasChanged());
366
367 return deleteDeadness(regions, liveMap);
368 }
369
370 //===----------------------------------------------------------------------===//
371 // Block Merging
372 //===----------------------------------------------------------------------===//
373
374 //===----------------------------------------------------------------------===//
375 // BlockEquivalenceData
376
377 namespace {
378 /// This class contains the information for comparing the equivalencies of two
379 /// blocks. Blocks are considered equivalent if they contain the same operations
380 /// in the same order. The only allowed divergence is for operands that come
381 /// from sources outside of the parent block, i.e. the uses of values produced
382 /// within the block must be equivalent.
383 /// e.g.,
384 /// Equivalent:
385 /// ^bb1(%arg0: i32)
386 /// return %arg0, %foo : i32, i32
387 /// ^bb2(%arg1: i32)
388 /// return %arg1, %bar : i32, i32
389 /// Not Equivalent:
390 /// ^bb1(%arg0: i32)
391 /// return %foo, %arg0 : i32, i32
392 /// ^bb2(%arg1: i32)
393 /// return %arg1, %bar : i32, i32
394 struct BlockEquivalenceData {
395 BlockEquivalenceData(Block *block);
396
397 /// Return the order index for the given value that is within the block of
398 /// this data.
399 unsigned getOrderOf(Value value) const;
400
401 /// The block this data refers to.
402 Block *block;
403 /// A hash value for this block.
404 llvm::hash_code hash;
405 /// A map of result producing operations to their relative orders within this
406 /// block. The order of an operation is the number of defined values that are
407 /// produced within the block before this operation.
408 DenseMap<Operation *, unsigned> opOrderIndex;
409 };
410 } // end anonymous namespace
411
BlockEquivalenceData(Block * block)412 BlockEquivalenceData::BlockEquivalenceData(Block *block)
413 : block(block), hash(0) {
414 unsigned orderIt = block->getNumArguments();
415 for (Operation &op : *block) {
416 if (unsigned numResults = op.getNumResults()) {
417 opOrderIndex.try_emplace(&op, orderIt);
418 orderIt += numResults;
419 }
420 auto opHash = OperationEquivalence::computeHash(
421 &op, OperationEquivalence::Flags::IgnoreOperands);
422 hash = llvm::hash_combine(hash, opHash);
423 }
424 }
425
getOrderOf(Value value) const426 unsigned BlockEquivalenceData::getOrderOf(Value value) const {
427 assert(value.getParentBlock() == block && "expected value of this block");
428
429 // Arguments use the argument number as the order index.
430 if (BlockArgument arg = value.dyn_cast<BlockArgument>())
431 return arg.getArgNumber();
432
433 // Otherwise, the result order is offset from the parent op's order.
434 OpResult result = value.cast<OpResult>();
435 auto opOrderIt = opOrderIndex.find(result.getDefiningOp());
436 assert(opOrderIt != opOrderIndex.end() && "expected op to have an order");
437 return opOrderIt->second + result.getResultNumber();
438 }
439
440 //===----------------------------------------------------------------------===//
441 // BlockMergeCluster
442
443 namespace {
444 /// This class represents a cluster of blocks to be merged together.
445 class BlockMergeCluster {
446 public:
BlockMergeCluster(BlockEquivalenceData && leaderData)447 BlockMergeCluster(BlockEquivalenceData &&leaderData)
448 : leaderData(std::move(leaderData)) {}
449
450 /// Attempt to add the given block to this cluster. Returns success if the
451 /// block was merged, failure otherwise.
452 LogicalResult addToCluster(BlockEquivalenceData &blockData);
453
454 /// Try to merge all of the blocks within this cluster into the leader block.
455 LogicalResult merge();
456
457 private:
458 /// The equivalence data for the leader of the cluster.
459 BlockEquivalenceData leaderData;
460
461 /// The set of blocks that can be merged into the leader.
462 llvm::SmallSetVector<Block *, 1> blocksToMerge;
463
464 /// A set of operand+index pairs that correspond to operands that need to be
465 /// replaced by arguments when the cluster gets merged.
466 std::set<std::pair<int, int>> operandsToMerge;
467 };
468 } // end anonymous namespace
469
addToCluster(BlockEquivalenceData & blockData)470 LogicalResult BlockMergeCluster::addToCluster(BlockEquivalenceData &blockData) {
471 if (leaderData.hash != blockData.hash)
472 return failure();
473 Block *leaderBlock = leaderData.block, *mergeBlock = blockData.block;
474 if (leaderBlock->getArgumentTypes() != mergeBlock->getArgumentTypes())
475 return failure();
476
477 // A set of operands that mismatch between the leader and the new block.
478 SmallVector<std::pair<int, int>, 8> mismatchedOperands;
479 auto lhsIt = leaderBlock->begin(), lhsE = leaderBlock->end();
480 auto rhsIt = blockData.block->begin(), rhsE = blockData.block->end();
481 for (int opI = 0; lhsIt != lhsE && rhsIt != rhsE; ++lhsIt, ++rhsIt, ++opI) {
482 // Check that the operations are equivalent.
483 if (!OperationEquivalence::isEquivalentTo(
484 &*lhsIt, &*rhsIt, OperationEquivalence::Flags::IgnoreOperands))
485 return failure();
486
487 // Compare the operands of the two operations. If the operand is within
488 // the block, it must refer to the same operation.
489 auto lhsOperands = lhsIt->getOperands(), rhsOperands = rhsIt->getOperands();
490 for (int operand : llvm::seq<int>(0, lhsIt->getNumOperands())) {
491 Value lhsOperand = lhsOperands[operand];
492 Value rhsOperand = rhsOperands[operand];
493 if (lhsOperand == rhsOperand)
494 continue;
495 // Check that the types of the operands match.
496 if (lhsOperand.getType() != rhsOperand.getType())
497 return failure();
498
499 // Check that these uses are both external, or both internal.
500 bool lhsIsInBlock = lhsOperand.getParentBlock() == leaderBlock;
501 bool rhsIsInBlock = rhsOperand.getParentBlock() == mergeBlock;
502 if (lhsIsInBlock != rhsIsInBlock)
503 return failure();
504 // Let the operands differ if they are defined in a different block. These
505 // will become new arguments if the blocks get merged.
506 if (!lhsIsInBlock) {
507 mismatchedOperands.emplace_back(opI, operand);
508 continue;
509 }
510
511 // Otherwise, these operands must have the same logical order within the
512 // parent block.
513 if (leaderData.getOrderOf(lhsOperand) != blockData.getOrderOf(rhsOperand))
514 return failure();
515 }
516
517 // If the lhs or rhs has external uses, the blocks cannot be merged as the
518 // merged version of this operation will not be either the lhs or rhs
519 // alone (thus semantically incorrect), but some mix dependening on which
520 // block preceeded this.
521 // TODO allow merging of operations when one block does not dominate the
522 // other
523 if (rhsIt->isUsedOutsideOfBlock(mergeBlock) ||
524 lhsIt->isUsedOutsideOfBlock(leaderBlock)) {
525 return failure();
526 }
527 }
528 // Make sure that the block sizes are equivalent.
529 if (lhsIt != lhsE || rhsIt != rhsE)
530 return failure();
531
532 // If we get here, the blocks are equivalent and can be merged.
533 operandsToMerge.insert(mismatchedOperands.begin(), mismatchedOperands.end());
534 blocksToMerge.insert(blockData.block);
535 return success();
536 }
537
538 /// Returns true if the predecessor terminators of the given block can not have
539 /// their operands updated.
ableToUpdatePredOperands(Block * block)540 static bool ableToUpdatePredOperands(Block *block) {
541 for (auto it = block->pred_begin(), e = block->pred_end(); it != e; ++it) {
542 auto branch = dyn_cast<BranchOpInterface>((*it)->getTerminator());
543 if (!branch || !branch.getMutableSuccessorOperands(it.getSuccessorIndex()))
544 return false;
545 }
546 return true;
547 }
548
merge()549 LogicalResult BlockMergeCluster::merge() {
550 // Don't consider clusters that don't have blocks to merge.
551 if (blocksToMerge.empty())
552 return failure();
553
554 Block *leaderBlock = leaderData.block;
555 if (!operandsToMerge.empty()) {
556 // If the cluster has operands to merge, verify that the predecessor
557 // terminators of each of the blocks can have their successor operands
558 // updated.
559 // TODO: We could try and sub-partition this cluster if only some blocks
560 // cause the mismatch.
561 if (!ableToUpdatePredOperands(leaderBlock) ||
562 !llvm::all_of(blocksToMerge, ableToUpdatePredOperands))
563 return failure();
564
565 // Collect the iterators for each of the blocks to merge. We will walk all
566 // of the iterators at once to avoid operand index invalidation.
567 SmallVector<Block::iterator, 2> blockIterators;
568 blockIterators.reserve(blocksToMerge.size() + 1);
569 blockIterators.push_back(leaderBlock->begin());
570 for (Block *mergeBlock : blocksToMerge)
571 blockIterators.push_back(mergeBlock->begin());
572
573 // Update each of the predecessor terminators with the new arguments.
574 SmallVector<SmallVector<Value, 8>, 2> newArguments(
575 1 + blocksToMerge.size(),
576 SmallVector<Value, 8>(operandsToMerge.size()));
577 unsigned curOpIndex = 0;
578 for (auto it : llvm::enumerate(operandsToMerge)) {
579 unsigned nextOpOffset = it.value().first - curOpIndex;
580 curOpIndex = it.value().first;
581
582 // Process the operand for each of the block iterators.
583 for (unsigned i = 0, e = blockIterators.size(); i != e; ++i) {
584 Block::iterator &blockIter = blockIterators[i];
585 std::advance(blockIter, nextOpOffset);
586 auto &operand = blockIter->getOpOperand(it.value().second);
587 newArguments[i][it.index()] = operand.get();
588
589 // Update the operand and insert an argument if this is the leader.
590 if (i == 0)
591 operand.set(leaderBlock->addArgument(operand.get().getType()));
592 }
593 }
594 // Update the predecessors for each of the blocks.
595 auto updatePredecessors = [&](Block *block, unsigned clusterIndex) {
596 for (auto predIt = block->pred_begin(), predE = block->pred_end();
597 predIt != predE; ++predIt) {
598 auto branch = cast<BranchOpInterface>((*predIt)->getTerminator());
599 unsigned succIndex = predIt.getSuccessorIndex();
600 branch.getMutableSuccessorOperands(succIndex)->append(
601 newArguments[clusterIndex]);
602 }
603 };
604 updatePredecessors(leaderBlock, /*clusterIndex=*/0);
605 for (unsigned i = 0, e = blocksToMerge.size(); i != e; ++i)
606 updatePredecessors(blocksToMerge[i], /*clusterIndex=*/i + 1);
607 }
608
609 // Replace all uses of the merged blocks with the leader and erase them.
610 for (Block *block : blocksToMerge) {
611 block->replaceAllUsesWith(leaderBlock);
612 block->erase();
613 }
614 return success();
615 }
616
617 /// Identify identical blocks within the given region and merge them, inserting
618 /// new block arguments as necessary. Returns success if any blocks were merged,
619 /// failure otherwise.
mergeIdenticalBlocks(Region & region)620 static LogicalResult mergeIdenticalBlocks(Region ®ion) {
621 if (region.empty() || llvm::hasSingleElement(region))
622 return failure();
623
624 // Identify sets of blocks, other than the entry block, that branch to the
625 // same successors. We will use these groups to create clusters of equivalent
626 // blocks.
627 DenseMap<SuccessorRange, SmallVector<Block *, 1>> matchingSuccessors;
628 for (Block &block : llvm::drop_begin(region, 1))
629 matchingSuccessors[block.getSuccessors()].push_back(&block);
630
631 bool mergedAnyBlocks = false;
632 for (ArrayRef<Block *> blocks : llvm::make_second_range(matchingSuccessors)) {
633 if (blocks.size() == 1)
634 continue;
635
636 SmallVector<BlockMergeCluster, 1> clusters;
637 for (Block *block : blocks) {
638 BlockEquivalenceData data(block);
639
640 // Don't allow merging if this block has any regions.
641 // TODO: Add support for regions if necessary.
642 bool hasNonEmptyRegion = llvm::any_of(*block, [](Operation &op) {
643 return llvm::any_of(op.getRegions(),
644 [](Region ®ion) { return !region.empty(); });
645 });
646 if (hasNonEmptyRegion)
647 continue;
648
649 // Try to add this block to an existing cluster.
650 bool addedToCluster = false;
651 for (auto &cluster : clusters)
652 if ((addedToCluster = succeeded(cluster.addToCluster(data))))
653 break;
654 if (!addedToCluster)
655 clusters.emplace_back(std::move(data));
656 }
657 for (auto &cluster : clusters)
658 mergedAnyBlocks |= succeeded(cluster.merge());
659 }
660
661 return success(mergedAnyBlocks);
662 }
663
664 /// Identify identical blocks within the given regions and merge them, inserting
665 /// new block arguments as necessary.
mergeIdenticalBlocks(MutableArrayRef<Region> regions)666 static LogicalResult mergeIdenticalBlocks(MutableArrayRef<Region> regions) {
667 llvm::SmallSetVector<Region *, 1> worklist;
668 for (auto ®ion : regions)
669 worklist.insert(®ion);
670 bool anyChanged = false;
671 while (!worklist.empty()) {
672 Region *region = worklist.pop_back_val();
673 if (succeeded(mergeIdenticalBlocks(*region))) {
674 worklist.insert(region);
675 anyChanged = true;
676 }
677
678 // Add any nested regions to the worklist.
679 for (Block &block : *region)
680 for (auto &op : block)
681 for (auto &nestedRegion : op.getRegions())
682 worklist.insert(&nestedRegion);
683 }
684
685 return success(anyChanged);
686 }
687
688 //===----------------------------------------------------------------------===//
689 // Region Simplification
690 //===----------------------------------------------------------------------===//
691
692 /// Run a set of structural simplifications over the given regions. This
693 /// includes transformations like unreachable block elimination, dead argument
694 /// elimination, as well as some other DCE. This function returns success if any
695 /// of the regions were simplified, failure otherwise.
simplifyRegions(MutableArrayRef<Region> regions)696 LogicalResult mlir::simplifyRegions(MutableArrayRef<Region> regions) {
697 bool eliminatedBlocks = succeeded(eraseUnreachableBlocks(regions));
698 bool eliminatedOpsOrArgs = succeeded(runRegionDCE(regions));
699 bool mergedIdenticalBlocks = succeeded(mergeIdenticalBlocks(regions));
700 return success(eliminatedBlocks || eliminatedOpsOrArgs ||
701 mergedIdenticalBlocks);
702 }
703