1 //===- SCCP.cpp - Sparse Conditional Constant Propagation -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This transformation pass performs a sparse conditional constant propagation
10 // in MLIR. It identifies values known to be constant, propagates that
11 // information throughout the IR, and replaces them. This is done with an
12 // optimistic dataflow analysis that assumes that all values are constant until
13 // proven otherwise.
14 //
15 //===----------------------------------------------------------------------===//
16
17 #include "PassDetail.h"
18 #include "mlir/IR/Builders.h"
19 #include "mlir/IR/Dialect.h"
20 #include "mlir/Interfaces/ControlFlowInterfaces.h"
21 #include "mlir/Interfaces/SideEffectInterfaces.h"
22 #include "mlir/Pass/Pass.h"
23 #include "mlir/Transforms/FoldUtils.h"
24 #include "mlir/Transforms/Passes.h"
25
26 using namespace mlir;
27
28 namespace {
29 /// This class represents a single lattice value. A lattive value corresponds to
30 /// the various different states that a value in the SCCP dataflow analysis can
31 /// take. See 'Kind' below for more details on the different states a value can
32 /// take.
33 class LatticeValue {
34 enum Kind {
35 /// A value with a yet to be determined value. This state may be changed to
36 /// anything.
37 Unknown,
38
39 /// A value that is known to be a constant. This state may be changed to
40 /// overdefined.
41 Constant,
42
43 /// A value that cannot statically be determined to be a constant. This
44 /// state cannot be changed.
45 Overdefined
46 };
47
48 public:
49 /// Initialize a lattice value with "Unknown".
LatticeValue()50 LatticeValue()
51 : constantAndTag(nullptr, Kind::Unknown), constantDialect(nullptr) {}
52 /// Initialize a lattice value with a constant.
LatticeValue(Attribute attr,Dialect * dialect)53 LatticeValue(Attribute attr, Dialect *dialect)
54 : constantAndTag(attr, Kind::Constant), constantDialect(dialect) {}
55
56 /// Returns true if this lattice value is unknown.
isUnknown() const57 bool isUnknown() const { return constantAndTag.getInt() == Kind::Unknown; }
58
59 /// Mark the lattice value as overdefined.
markOverdefined()60 void markOverdefined() {
61 constantAndTag.setPointerAndInt(nullptr, Kind::Overdefined);
62 constantDialect = nullptr;
63 }
64
65 /// Returns true if the lattice is overdefined.
isOverdefined() const66 bool isOverdefined() const {
67 return constantAndTag.getInt() == Kind::Overdefined;
68 }
69
70 /// Mark the lattice value as constant.
markConstant(Attribute value,Dialect * dialect)71 void markConstant(Attribute value, Dialect *dialect) {
72 constantAndTag.setPointerAndInt(value, Kind::Constant);
73 constantDialect = dialect;
74 }
75
76 /// If this lattice is constant, return the constant. Returns nullptr
77 /// otherwise.
getConstant() const78 Attribute getConstant() const { return constantAndTag.getPointer(); }
79
80 /// If this lattice is constant, return the dialect to use when materializing
81 /// the constant.
getConstantDialect() const82 Dialect *getConstantDialect() const {
83 assert(getConstant() && "expected valid constant");
84 return constantDialect;
85 }
86
87 /// Merge in the value of the 'rhs' lattice into this one. Returns true if the
88 /// lattice value changed.
meet(const LatticeValue & rhs)89 bool meet(const LatticeValue &rhs) {
90 // If we are already overdefined, or rhs is unknown, there is nothing to do.
91 if (isOverdefined() || rhs.isUnknown())
92 return false;
93 // If we are unknown, just take the value of rhs.
94 if (isUnknown()) {
95 constantAndTag = rhs.constantAndTag;
96 constantDialect = rhs.constantDialect;
97 return true;
98 }
99
100 // Otherwise, if this value doesn't match rhs go straight to overdefined.
101 if (constantAndTag != rhs.constantAndTag) {
102 markOverdefined();
103 return true;
104 }
105 return false;
106 }
107
108 private:
109 /// The attribute value if this is a constant and the tag for the element
110 /// kind.
111 llvm::PointerIntPair<Attribute, 2, Kind> constantAndTag;
112
113 /// The dialect the constant originated from. This is only valid if the
114 /// lattice is a constant. This is not used as part of the key, and is only
115 /// needed to materialize the held constant if necessary.
116 Dialect *constantDialect;
117 };
118
119 /// This class contains various state used when computing the lattice of a
120 /// callable operation.
121 class CallableLatticeState {
122 public:
123 /// Build a lattice state with a given callable region, and a specified number
124 /// of results to be initialized to the default lattice value (Unknown).
CallableLatticeState(Region * callableRegion,unsigned numResults)125 CallableLatticeState(Region *callableRegion, unsigned numResults)
126 : callableArguments(callableRegion->getArguments()),
127 resultLatticeValues(numResults) {}
128
129 /// Returns the arguments to the callable region.
getCallableArguments() const130 Block::BlockArgListType getCallableArguments() const {
131 return callableArguments;
132 }
133
134 /// Returns the lattice value for the results of the callable region.
getResultLatticeValues()135 MutableArrayRef<LatticeValue> getResultLatticeValues() {
136 return resultLatticeValues;
137 }
138
139 /// Add a call to this callable. This is only used if the callable defines a
140 /// symbol.
addSymbolCall(Operation * op)141 void addSymbolCall(Operation *op) { symbolCalls.push_back(op); }
142
143 /// Return the calls that reference this callable. This is only used
144 /// if the callable defines a symbol.
getSymbolCalls() const145 ArrayRef<Operation *> getSymbolCalls() const { return symbolCalls; }
146
147 private:
148 /// The arguments of the callable region.
149 Block::BlockArgListType callableArguments;
150
151 /// The lattice state for each of the results of this region. The return
152 /// values of the callable aren't SSA values, so we need to track them
153 /// separately.
154 SmallVector<LatticeValue, 4> resultLatticeValues;
155
156 /// The calls referencing this callable if this callable defines a symbol.
157 /// This removes the need to recompute symbol references during propagation.
158 /// Value based references are trivial to resolve, so they can be done
159 /// in-place.
160 SmallVector<Operation *, 4> symbolCalls;
161 };
162
163 /// This class represents the solver for the SCCP analysis. This class acts as
164 /// the propagation engine for computing which values form constants.
165 class SCCPSolver {
166 public:
167 /// Initialize the solver with the given top-level operation.
168 SCCPSolver(Operation *op);
169
170 /// Run the solver until it converges.
171 void solve();
172
173 /// Rewrite the given regions using the computing analysis. This replaces the
174 /// uses of all values that have been computed to be constant, and erases as
175 /// many newly dead operations.
176 void rewrite(MLIRContext *context, MutableArrayRef<Region> regions);
177
178 private:
179 /// Initialize the set of symbol defining callables that can have their
180 /// arguments and results tracked. 'op' is the top-level operation that SCCP
181 /// is operating on.
182 void initializeSymbolCallables(Operation *op);
183
184 /// Replace the given value with a constant if the corresponding lattice
185 /// represents a constant. Returns success if the value was replaced, failure
186 /// otherwise.
187 LogicalResult replaceWithConstant(OpBuilder &builder, OperationFolder &folder,
188 Value value);
189
190 /// Visit the users of the given IR that reside within executable blocks.
191 template <typename T>
visitUsers(T & value)192 void visitUsers(T &value) {
193 for (Operation *user : value.getUsers())
194 if (isBlockExecutable(user->getBlock()))
195 visitOperation(user);
196 }
197
198 /// Visit the given operation and compute any necessary lattice state.
199 void visitOperation(Operation *op);
200
201 /// Visit the given call operation and compute any necessary lattice state.
202 void visitCallOperation(CallOpInterface op);
203
204 /// Visit the given callable operation and compute any necessary lattice
205 /// state.
206 void visitCallableOperation(Operation *op);
207
208 /// Visit the given operation, which defines regions, and compute any
209 /// necessary lattice state. This also resolves the lattice state of both the
210 /// operation results and any nested regions.
211 void visitRegionOperation(Operation *op,
212 ArrayRef<Attribute> constantOperands);
213
214 /// Visit the given set of region successors, computing any necessary lattice
215 /// state. The provided function returns the input operands to the region at
216 /// the given index. If the index is 'None', the input operands correspond to
217 /// the parent operation results.
218 void visitRegionSuccessors(
219 Operation *parentOp, ArrayRef<RegionSuccessor> regionSuccessors,
220 function_ref<OperandRange(Optional<unsigned>)> getInputsForRegion);
221
222 /// Visit the given terminator operation and compute any necessary lattice
223 /// state.
224 void visitTerminatorOperation(Operation *op,
225 ArrayRef<Attribute> constantOperands);
226
227 /// Visit the given terminator operation that exits a callable region. These
228 /// are terminators with no CFG successors.
229 void visitCallableTerminatorOperation(Operation *callable,
230 Operation *terminator);
231
232 /// Visit the given block and compute any necessary lattice state.
233 void visitBlock(Block *block);
234
235 /// Visit argument #'i' of the given block and compute any necessary lattice
236 /// state.
237 void visitBlockArgument(Block *block, int i);
238
239 /// Mark the given block as executable. Returns false if the block was already
240 /// marked executable.
241 bool markBlockExecutable(Block *block);
242
243 /// Returns true if the given block is executable.
244 bool isBlockExecutable(Block *block) const;
245
246 /// Mark the edge between 'from' and 'to' as executable.
247 void markEdgeExecutable(Block *from, Block *to);
248
249 /// Return true if the edge between 'from' and 'to' is executable.
250 bool isEdgeExecutable(Block *from, Block *to) const;
251
252 /// Mark the given value as overdefined. This means that we cannot refine a
253 /// specific constant for this value.
254 void markOverdefined(Value value);
255
256 /// Mark all of the given values as overdefined.
257 template <typename ValuesT>
markAllOverdefined(ValuesT values)258 void markAllOverdefined(ValuesT values) {
259 for (auto value : values)
260 markOverdefined(value);
261 }
262 template <typename ValuesT>
markAllOverdefined(Operation * op,ValuesT values)263 void markAllOverdefined(Operation *op, ValuesT values) {
264 markAllOverdefined(values);
265 opWorklist.push_back(op);
266 }
267 template <typename ValuesT>
markAllOverdefinedAndVisitUsers(ValuesT values)268 void markAllOverdefinedAndVisitUsers(ValuesT values) {
269 for (auto value : values) {
270 auto &lattice = latticeValues[value];
271 if (!lattice.isOverdefined()) {
272 lattice.markOverdefined();
273 visitUsers(value);
274 }
275 }
276 }
277
278 /// Returns true if the given value was marked as overdefined.
279 bool isOverdefined(Value value) const;
280
281 /// Merge in the given lattice 'from' into the lattice 'to'. 'owner'
282 /// corresponds to the parent operation of 'to'.
283 void meet(Operation *owner, LatticeValue &to, const LatticeValue &from);
284
285 /// The lattice for each SSA value.
286 DenseMap<Value, LatticeValue> latticeValues;
287
288 /// The set of blocks that are known to execute, or are intrinsically live.
289 SmallPtrSet<Block *, 16> executableBlocks;
290
291 /// The set of control flow edges that are known to execute.
292 DenseSet<std::pair<Block *, Block *>> executableEdges;
293
294 /// A worklist containing blocks that need to be processed.
295 SmallVector<Block *, 64> blockWorklist;
296
297 /// A worklist of operations that need to be processed.
298 SmallVector<Operation *, 64> opWorklist;
299
300 /// The callable operations that have their argument/result state tracked.
301 DenseMap<Operation *, CallableLatticeState> callableLatticeState;
302
303 /// A map between a call operation and the resolved symbol callable. This
304 /// avoids re-resolving symbol references during propagation. Value based
305 /// callables are trivial to resolve, so they can be done in-place.
306 DenseMap<Operation *, Operation *> callToSymbolCallable;
307
308 /// A symbol table used for O(1) symbol lookups during simplification.
309 SymbolTableCollection symbolTable;
310 };
311 } // end anonymous namespace
312
SCCPSolver(Operation * op)313 SCCPSolver::SCCPSolver(Operation *op) {
314 /// Initialize the solver with the regions within this operation.
315 for (Region ®ion : op->getRegions()) {
316 if (region.empty())
317 continue;
318 Block *entryBlock = ®ion.front();
319
320 // Mark the entry block as executable.
321 markBlockExecutable(entryBlock);
322
323 // The values passed to these regions are invisible, so mark any arguments
324 // as overdefined.
325 markAllOverdefined(entryBlock->getArguments());
326 }
327 initializeSymbolCallables(op);
328 }
329
solve()330 void SCCPSolver::solve() {
331 while (!blockWorklist.empty() || !opWorklist.empty()) {
332 // Process any operations in the op worklist.
333 while (!opWorklist.empty())
334 visitUsers(*opWorklist.pop_back_val());
335
336 // Process any blocks in the block worklist.
337 while (!blockWorklist.empty())
338 visitBlock(blockWorklist.pop_back_val());
339 }
340 }
341
rewrite(MLIRContext * context,MutableArrayRef<Region> initialRegions)342 void SCCPSolver::rewrite(MLIRContext *context,
343 MutableArrayRef<Region> initialRegions) {
344 SmallVector<Block *, 8> worklist;
345 auto addToWorklist = [&](MutableArrayRef<Region> regions) {
346 for (Region ®ion : regions)
347 for (Block &block : region)
348 if (isBlockExecutable(&block))
349 worklist.push_back(&block);
350 };
351
352 // An operation folder used to create and unique constants.
353 OperationFolder folder(context);
354 OpBuilder builder(context);
355
356 addToWorklist(initialRegions);
357 while (!worklist.empty()) {
358 Block *block = worklist.pop_back_val();
359
360 // Replace any block arguments with constants.
361 builder.setInsertionPointToStart(block);
362 for (BlockArgument arg : block->getArguments())
363 replaceWithConstant(builder, folder, arg);
364
365 for (Operation &op : llvm::make_early_inc_range(*block)) {
366 builder.setInsertionPoint(&op);
367
368 // Replace any result with constants.
369 bool replacedAll = op.getNumResults() != 0;
370 for (Value res : op.getResults())
371 replacedAll &= succeeded(replaceWithConstant(builder, folder, res));
372
373 // If all of the results of the operation were replaced, try to erase
374 // the operation completely.
375 if (replacedAll && wouldOpBeTriviallyDead(&op)) {
376 assert(op.use_empty() && "expected all uses to be replaced");
377 op.erase();
378 continue;
379 }
380
381 // Add any the regions of this operation to the worklist.
382 addToWorklist(op.getRegions());
383 }
384 }
385 }
386
initializeSymbolCallables(Operation * op)387 void SCCPSolver::initializeSymbolCallables(Operation *op) {
388 // Initialize the set of symbol callables that can have their state tracked.
389 // This tracks which symbol callable operations we can propagate within and
390 // out of.
391 auto walkFn = [&](Operation *symTable, bool allUsesVisible) {
392 Region &symbolTableRegion = symTable->getRegion(0);
393 Block *symbolTableBlock = &symbolTableRegion.front();
394 for (auto callable : symbolTableBlock->getOps<CallableOpInterface>()) {
395 // We won't be able to track external callables.
396 Region *callableRegion = callable.getCallableRegion();
397 if (!callableRegion)
398 continue;
399 // We only care about symbol defining callables here.
400 auto symbol = dyn_cast<SymbolOpInterface>(callable.getOperation());
401 if (!symbol)
402 continue;
403 callableLatticeState.try_emplace(callable, callableRegion,
404 callable.getCallableResults().size());
405
406 // If not all of the uses of this symbol are visible, we can't track the
407 // state of the arguments.
408 if (symbol.isPublic() || (!allUsesVisible && symbol.isNested()))
409 markAllOverdefined(callableRegion->getArguments());
410 }
411 if (callableLatticeState.empty())
412 return;
413
414 // After computing the valid callables, walk any symbol uses to check
415 // for non-call references. We won't be able to track the lattice state
416 // for arguments to these callables, as we can't guarantee that we can see
417 // all of its calls.
418 Optional<SymbolTable::UseRange> uses =
419 SymbolTable::getSymbolUses(&symbolTableRegion);
420 if (!uses) {
421 // If we couldn't gather the symbol uses, conservatively assume that
422 // we can't track information for any nested symbols.
423 op->walk([&](CallableOpInterface op) { callableLatticeState.erase(op); });
424 return;
425 }
426
427 for (const SymbolTable::SymbolUse &use : *uses) {
428 // If the use is a call, track it to avoid the need to recompute the
429 // reference later.
430 if (auto callOp = dyn_cast<CallOpInterface>(use.getUser())) {
431 Operation *symCallable = callOp.resolveCallable(&symbolTable);
432 auto callableLatticeIt = callableLatticeState.find(symCallable);
433 if (callableLatticeIt != callableLatticeState.end()) {
434 callToSymbolCallable.try_emplace(callOp, symCallable);
435
436 // We only need to record the call in the lattice if it produces any
437 // values.
438 if (callOp->getNumResults())
439 callableLatticeIt->second.addSymbolCall(callOp);
440 }
441 continue;
442 }
443 // This use isn't a call, so don't we know all of the callers.
444 auto *symbol = symbolTable.lookupSymbolIn(op, use.getSymbolRef());
445 auto it = callableLatticeState.find(symbol);
446 if (it != callableLatticeState.end())
447 markAllOverdefined(it->second.getCallableArguments());
448 }
449 };
450 SymbolTable::walkSymbolTables(op, /*allSymUsesVisible=*/!op->getBlock(),
451 walkFn);
452 }
453
replaceWithConstant(OpBuilder & builder,OperationFolder & folder,Value value)454 LogicalResult SCCPSolver::replaceWithConstant(OpBuilder &builder,
455 OperationFolder &folder,
456 Value value) {
457 auto it = latticeValues.find(value);
458 auto attr = it == latticeValues.end() ? nullptr : it->second.getConstant();
459 if (!attr)
460 return failure();
461
462 // Attempt to materialize a constant for the given value.
463 Dialect *dialect = it->second.getConstantDialect();
464 Value constant = folder.getOrCreateConstant(builder, dialect, attr,
465 value.getType(), value.getLoc());
466 if (!constant)
467 return failure();
468
469 value.replaceAllUsesWith(constant);
470 latticeValues.erase(it);
471 return success();
472 }
473
visitOperation(Operation * op)474 void SCCPSolver::visitOperation(Operation *op) {
475 // Collect all of the constant operands feeding into this operation. If any
476 // are not ready to be resolved, bail out and wait for them to resolve.
477 SmallVector<Attribute, 8> operandConstants;
478 operandConstants.reserve(op->getNumOperands());
479 for (Value operand : op->getOperands()) {
480 // Make sure all of the operands are resolved first.
481 auto &operandLattice = latticeValues[operand];
482 if (operandLattice.isUnknown())
483 return;
484 operandConstants.push_back(operandLattice.getConstant());
485 }
486
487 // If this is a terminator operation, process any control flow lattice state.
488 if (op->isKnownTerminator())
489 visitTerminatorOperation(op, operandConstants);
490
491 // Process call operations. The call visitor processes result values, so we
492 // can exit afterwards.
493 if (CallOpInterface call = dyn_cast<CallOpInterface>(op))
494 return visitCallOperation(call);
495
496 // Process callable operations. These are specially handled region operations
497 // that track dataflow via calls.
498 if (isa<CallableOpInterface>(op))
499 return visitCallableOperation(op);
500
501 // Process region holding operations. The region visitor processes result
502 // values, so we can exit afterwards.
503 if (op->getNumRegions())
504 return visitRegionOperation(op, operandConstants);
505
506 // If this op produces no results, it can't produce any constants.
507 if (op->getNumResults() == 0)
508 return;
509
510 // If all of the results of this operation are already overdefined, bail out
511 // early.
512 auto isOverdefinedFn = [&](Value value) { return isOverdefined(value); };
513 if (llvm::all_of(op->getResults(), isOverdefinedFn))
514 return;
515
516 // Save the original operands and attributes just in case the operation folds
517 // in-place. The constant passed in may not correspond to the real runtime
518 // value, so in-place updates are not allowed.
519 SmallVector<Value, 8> originalOperands(op->getOperands());
520 MutableDictionaryAttr originalAttrs = op->getMutableAttrDict();
521
522 // Simulate the result of folding this operation to a constant. If folding
523 // fails or was an in-place fold, mark the results as overdefined.
524 SmallVector<OpFoldResult, 8> foldResults;
525 foldResults.reserve(op->getNumResults());
526 if (failed(op->fold(operandConstants, foldResults)))
527 return markAllOverdefined(op, op->getResults());
528
529 // If the folding was in-place, mark the results as overdefined and reset the
530 // operation. We don't allow in-place folds as the desire here is for
531 // simulated execution, and not general folding.
532 if (foldResults.empty()) {
533 op->setOperands(originalOperands);
534 op->setAttrs(originalAttrs);
535 return markAllOverdefined(op, op->getResults());
536 }
537
538 // Merge the fold results into the lattice for this operation.
539 assert(foldResults.size() == op->getNumResults() && "invalid result size");
540 Dialect *opDialect = op->getDialect();
541 for (unsigned i = 0, e = foldResults.size(); i != e; ++i) {
542 LatticeValue &resultLattice = latticeValues[op->getResult(i)];
543
544 // Merge in the result of the fold, either a constant or a value.
545 OpFoldResult foldResult = foldResults[i];
546 if (Attribute foldAttr = foldResult.dyn_cast<Attribute>())
547 meet(op, resultLattice, LatticeValue(foldAttr, opDialect));
548 else
549 meet(op, resultLattice, latticeValues[foldResult.get<Value>()]);
550 }
551 }
552
visitCallableOperation(Operation * op)553 void SCCPSolver::visitCallableOperation(Operation *op) {
554 // Mark the regions as executable.
555 bool isTrackingLatticeState = callableLatticeState.count(op);
556 for (Region ®ion : op->getRegions()) {
557 if (region.empty())
558 continue;
559 Block *entryBlock = ®ion.front();
560 markBlockExecutable(entryBlock);
561
562 // If we aren't tracking lattice state for this callable, mark all of the
563 // region arguments as overdefined.
564 if (!isTrackingLatticeState)
565 markAllOverdefined(entryBlock->getArguments());
566 }
567
568 // TODO: Add support for non-symbol callables when necessary. If the callable
569 // has non-call uses we would mark overdefined, otherwise allow for
570 // propagating the return values out.
571 markAllOverdefined(op, op->getResults());
572 }
573
visitCallOperation(CallOpInterface op)574 void SCCPSolver::visitCallOperation(CallOpInterface op) {
575 ResultRange callResults = op->getResults();
576
577 // Resolve the callable operation for this call.
578 Operation *callableOp = nullptr;
579 if (Value callableValue = op.getCallableForCallee().dyn_cast<Value>())
580 callableOp = callableValue.getDefiningOp();
581 else
582 callableOp = callToSymbolCallable.lookup(op);
583
584 // The callable of this call can't be resolved, mark any results overdefined.
585 if (!callableOp)
586 return markAllOverdefined(op, callResults);
587
588 // If this callable is tracking state, merge the argument operands with the
589 // arguments of the callable.
590 auto callableLatticeIt = callableLatticeState.find(callableOp);
591 if (callableLatticeIt == callableLatticeState.end())
592 return markAllOverdefined(op, callResults);
593
594 OperandRange callOperands = op.getArgOperands();
595 auto callableArgs = callableLatticeIt->second.getCallableArguments();
596 for (auto it : llvm::zip(callOperands, callableArgs)) {
597 BlockArgument callableArg = std::get<1>(it);
598 if (latticeValues[callableArg].meet(latticeValues[std::get<0>(it)]))
599 visitUsers(callableArg);
600 }
601
602 // Merge in the lattice state for the callable results as well.
603 auto callableResults = callableLatticeIt->second.getResultLatticeValues();
604 for (auto it : llvm::zip(callResults, callableResults))
605 meet(/*owner=*/op, /*to=*/latticeValues[std::get<0>(it)],
606 /*from=*/std::get<1>(it));
607 }
608
visitRegionOperation(Operation * op,ArrayRef<Attribute> constantOperands)609 void SCCPSolver::visitRegionOperation(Operation *op,
610 ArrayRef<Attribute> constantOperands) {
611 // Check to see if we can reason about the internal control flow of this
612 // region operation.
613 auto regionInterface = dyn_cast<RegionBranchOpInterface>(op);
614 if (!regionInterface) {
615 // If we can't, conservatively mark all regions as executable.
616 for (Region ®ion : op->getRegions()) {
617 if (region.empty())
618 continue;
619 Block *entryBlock = ®ion.front();
620 markBlockExecutable(entryBlock);
621 markAllOverdefined(entryBlock->getArguments());
622 }
623
624 // Don't try to simulate the results of a region operation as we can't
625 // guarantee that folding will be out-of-place. We don't allow in-place
626 // folds as the desire here is for simulated execution, and not general
627 // folding.
628 return markAllOverdefined(op, op->getResults());
629 }
630
631 // Check to see which regions are executable.
632 SmallVector<RegionSuccessor, 1> successors;
633 regionInterface.getSuccessorRegions(/*index=*/llvm::None, constantOperands,
634 successors);
635
636 // If the interface identified that no region will be executed. Mark
637 // any results of this operation as overdefined, as we can't reason about
638 // them.
639 // TODO: If we had an interface to detect pass through operands, we could
640 // resolve some results based on the lattice state of the operands. We could
641 // also allow for the parent operation to have itself as a region successor.
642 if (successors.empty())
643 return markAllOverdefined(op, op->getResults());
644 return visitRegionSuccessors(op, successors, [&](Optional<unsigned> index) {
645 assert(index && "expected valid region index");
646 return regionInterface.getSuccessorEntryOperands(*index);
647 });
648 }
649
visitRegionSuccessors(Operation * parentOp,ArrayRef<RegionSuccessor> regionSuccessors,function_ref<OperandRange (Optional<unsigned>)> getInputsForRegion)650 void SCCPSolver::visitRegionSuccessors(
651 Operation *parentOp, ArrayRef<RegionSuccessor> regionSuccessors,
652 function_ref<OperandRange(Optional<unsigned>)> getInputsForRegion) {
653 for (const RegionSuccessor &it : regionSuccessors) {
654 Region *region = it.getSuccessor();
655 ValueRange succArgs = it.getSuccessorInputs();
656
657 // Check to see if this is the parent operation.
658 if (!region) {
659 ResultRange results = parentOp->getResults();
660 if (llvm::all_of(results, [&](Value res) { return isOverdefined(res); }))
661 continue;
662
663 // Mark the results outside of the input range as overdefined.
664 if (succArgs.size() != results.size()) {
665 opWorklist.push_back(parentOp);
666 if (succArgs.empty())
667 return markAllOverdefined(results);
668
669 unsigned firstResIdx = succArgs[0].cast<OpResult>().getResultNumber();
670 markAllOverdefined(results.take_front(firstResIdx));
671 markAllOverdefined(results.drop_front(firstResIdx + succArgs.size()));
672 }
673
674 // Update the lattice for any operation results.
675 OperandRange operands = getInputsForRegion(/*index=*/llvm::None);
676 for (auto it : llvm::zip(succArgs, operands))
677 meet(parentOp, latticeValues[std::get<0>(it)],
678 latticeValues[std::get<1>(it)]);
679 return;
680 }
681 assert(!region->empty() && "expected region to be non-empty");
682 Block *entryBlock = ®ion->front();
683 markBlockExecutable(entryBlock);
684
685 // If all of the arguments are already overdefined, the arguments have
686 // already been fully resolved.
687 auto arguments = entryBlock->getArguments();
688 if (llvm::all_of(arguments, [&](Value arg) { return isOverdefined(arg); }))
689 continue;
690
691 // Mark any arguments that do not receive inputs as overdefined, we won't be
692 // able to discern if they are constant.
693 if (succArgs.size() != arguments.size()) {
694 if (succArgs.empty()) {
695 markAllOverdefined(arguments);
696 continue;
697 }
698
699 unsigned firstArgIdx = succArgs[0].cast<BlockArgument>().getArgNumber();
700 markAllOverdefinedAndVisitUsers(arguments.take_front(firstArgIdx));
701 markAllOverdefinedAndVisitUsers(
702 arguments.drop_front(firstArgIdx + succArgs.size()));
703 }
704
705 // Update the lattice for arguments that have inputs from the predecessor.
706 OperandRange succOperands = getInputsForRegion(region->getRegionNumber());
707 for (auto it : llvm::zip(succArgs, succOperands)) {
708 LatticeValue &argLattice = latticeValues[std::get<0>(it)];
709 if (argLattice.meet(latticeValues[std::get<1>(it)]))
710 visitUsers(std::get<0>(it));
711 }
712 }
713 }
714
visitTerminatorOperation(Operation * op,ArrayRef<Attribute> constantOperands)715 void SCCPSolver::visitTerminatorOperation(
716 Operation *op, ArrayRef<Attribute> constantOperands) {
717 // If this operation has no successors, we treat it as an exiting terminator.
718 if (op->getNumSuccessors() == 0) {
719 Region *parentRegion = op->getParentRegion();
720 Operation *parentOp = parentRegion->getParentOp();
721
722 // Check to see if this is a terminator for a callable region.
723 if (isa<CallableOpInterface>(parentOp))
724 return visitCallableTerminatorOperation(parentOp, op);
725
726 // Otherwise, check to see if the parent tracks region control flow.
727 auto regionInterface = dyn_cast<RegionBranchOpInterface>(parentOp);
728 if (!regionInterface || !isBlockExecutable(parentOp->getBlock()))
729 return;
730
731 // Query the set of successors from the current region.
732 SmallVector<RegionSuccessor, 1> regionSuccessors;
733 regionInterface.getSuccessorRegions(parentRegion->getRegionNumber(),
734 constantOperands, regionSuccessors);
735 if (regionSuccessors.empty())
736 return;
737
738 // If this terminator is not "region-like", conservatively mark all of the
739 // successor values as overdefined.
740 if (!op->hasTrait<OpTrait::ReturnLike>()) {
741 for (auto &it : regionSuccessors)
742 markAllOverdefinedAndVisitUsers(it.getSuccessorInputs());
743 return;
744 }
745
746 // Otherwise, propagate the operand lattice states to each of the
747 // successors.
748 OperandRange operands = op->getOperands();
749 return visitRegionSuccessors(parentOp, regionSuccessors,
750 [&](Optional<unsigned>) { return operands; });
751 }
752
753 // Try to resolve to a specific successor with the constant operands.
754 if (auto branch = dyn_cast<BranchOpInterface>(op)) {
755 if (Block *singleSucc = branch.getSuccessorForOperands(constantOperands)) {
756 markEdgeExecutable(op->getBlock(), singleSucc);
757 return;
758 }
759 }
760
761 // Otherwise, conservatively treat all edges as executable.
762 Block *block = op->getBlock();
763 for (Block *succ : op->getSuccessors())
764 markEdgeExecutable(block, succ);
765 }
766
visitCallableTerminatorOperation(Operation * callable,Operation * terminator)767 void SCCPSolver::visitCallableTerminatorOperation(Operation *callable,
768 Operation *terminator) {
769 // If there are no exiting values, we have nothing to track.
770 if (terminator->getNumOperands() == 0)
771 return;
772
773 // If this callable isn't tracking any lattice state there is nothing to do.
774 auto latticeIt = callableLatticeState.find(callable);
775 if (latticeIt == callableLatticeState.end())
776 return;
777 assert(callable->getNumResults() == 0 && "expected symbol callable");
778
779 // If this terminator is not "return-like", conservatively mark all of the
780 // call-site results as overdefined.
781 auto callableResultLattices = latticeIt->second.getResultLatticeValues();
782 if (!terminator->hasTrait<OpTrait::ReturnLike>()) {
783 for (auto &it : callableResultLattices)
784 it.markOverdefined();
785 for (Operation *call : latticeIt->second.getSymbolCalls())
786 markAllOverdefined(call, call->getResults());
787 return;
788 }
789
790 // Merge the terminator operands into the results.
791 bool anyChanged = false;
792 for (auto it : llvm::zip(terminator->getOperands(), callableResultLattices))
793 anyChanged |= std::get<1>(it).meet(latticeValues[std::get<0>(it)]);
794 if (!anyChanged)
795 return;
796
797 // If any of the result lattices changed, update the callers.
798 for (Operation *call : latticeIt->second.getSymbolCalls())
799 for (auto it : llvm::zip(call->getResults(), callableResultLattices))
800 meet(call, latticeValues[std::get<0>(it)], std::get<1>(it));
801 }
802
visitBlock(Block * block)803 void SCCPSolver::visitBlock(Block *block) {
804 // If the block is not the entry block we need to compute the lattice state
805 // for the block arguments. Entry block argument lattices are computed
806 // elsewhere, such as when visiting the parent operation.
807 if (!block->isEntryBlock()) {
808 for (int i : llvm::seq<int>(0, block->getNumArguments()))
809 visitBlockArgument(block, i);
810 }
811
812 // Visit all of the operations within the block.
813 for (Operation &op : *block)
814 visitOperation(&op);
815 }
816
visitBlockArgument(Block * block,int i)817 void SCCPSolver::visitBlockArgument(Block *block, int i) {
818 BlockArgument arg = block->getArgument(i);
819 LatticeValue &argLattice = latticeValues[arg];
820 if (argLattice.isOverdefined())
821 return;
822
823 bool updatedLattice = false;
824 for (auto it = block->pred_begin(), e = block->pred_end(); it != e; ++it) {
825 Block *pred = *it;
826
827 // We only care about this predecessor if it is going to execute.
828 if (!isEdgeExecutable(pred, block))
829 continue;
830
831 // Try to get the operand forwarded by the predecessor. If we can't reason
832 // about the terminator of the predecessor, mark overdefined.
833 Optional<OperandRange> branchOperands;
834 if (auto branch = dyn_cast<BranchOpInterface>(pred->getTerminator()))
835 branchOperands = branch.getSuccessorOperands(it.getSuccessorIndex());
836 if (!branchOperands) {
837 updatedLattice = true;
838 argLattice.markOverdefined();
839 break;
840 }
841
842 // If the operand hasn't been resolved, it is unknown which can merge with
843 // anything.
844 auto operandLattice = latticeValues.find((*branchOperands)[i]);
845 if (operandLattice == latticeValues.end())
846 continue;
847
848 // Otherwise, meet the two lattice values.
849 updatedLattice |= argLattice.meet(operandLattice->second);
850 if (argLattice.isOverdefined())
851 break;
852 }
853
854 // If the lattice was updated, visit any executable users of the argument.
855 if (updatedLattice)
856 visitUsers(arg);
857 }
858
markBlockExecutable(Block * block)859 bool SCCPSolver::markBlockExecutable(Block *block) {
860 bool marked = executableBlocks.insert(block).second;
861 if (marked)
862 blockWorklist.push_back(block);
863 return marked;
864 }
865
isBlockExecutable(Block * block) const866 bool SCCPSolver::isBlockExecutable(Block *block) const {
867 return executableBlocks.count(block);
868 }
869
markEdgeExecutable(Block * from,Block * to)870 void SCCPSolver::markEdgeExecutable(Block *from, Block *to) {
871 if (!executableEdges.insert(std::make_pair(from, to)).second)
872 return;
873 // Mark the destination as executable, and reprocess its arguments if it was
874 // already executable.
875 if (!markBlockExecutable(to)) {
876 for (int i : llvm::seq<int>(0, to->getNumArguments()))
877 visitBlockArgument(to, i);
878 }
879 }
880
isEdgeExecutable(Block * from,Block * to) const881 bool SCCPSolver::isEdgeExecutable(Block *from, Block *to) const {
882 return executableEdges.count(std::make_pair(from, to));
883 }
884
markOverdefined(Value value)885 void SCCPSolver::markOverdefined(Value value) {
886 latticeValues[value].markOverdefined();
887 }
888
isOverdefined(Value value) const889 bool SCCPSolver::isOverdefined(Value value) const {
890 auto it = latticeValues.find(value);
891 return it != latticeValues.end() && it->second.isOverdefined();
892 }
893
meet(Operation * owner,LatticeValue & to,const LatticeValue & from)894 void SCCPSolver::meet(Operation *owner, LatticeValue &to,
895 const LatticeValue &from) {
896 if (to.meet(from))
897 opWorklist.push_back(owner);
898 }
899
900 //===----------------------------------------------------------------------===//
901 // SCCP Pass
902 //===----------------------------------------------------------------------===//
903
904 namespace {
905 struct SCCP : public SCCPBase<SCCP> {
906 void runOnOperation() override;
907 };
908 } // end anonymous namespace
909
runOnOperation()910 void SCCP::runOnOperation() {
911 Operation *op = getOperation();
912
913 // Solve for SCCP constraints within nested regions.
914 SCCPSolver solver(op);
915 solver.solve();
916
917 // Cleanup any operations using the solver analysis.
918 solver.rewrite(&getContext(), op->getRegions());
919 }
920
createSCCPPass()921 std::unique_ptr<Pass> mlir::createSCCPPass() {
922 return std::make_unique<SCCP>();
923 }
924