• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //===- Inliner.cpp - Pass to inline function calls ------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a basic inlining algorithm that operates bottom up over
10 // the Strongly Connect Components(SCCs) of the CallGraph. This enables a more
11 // incremental propagation of inlining decisions from the leafs to the roots of
12 // the callgraph.
13 //
14 //===----------------------------------------------------------------------===//
15 
16 #include "PassDetail.h"
17 #include "mlir/Analysis/CallGraph.h"
18 #include "mlir/IR/PatternMatch.h"
19 #include "mlir/Interfaces/SideEffectInterfaces.h"
20 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
21 #include "mlir/Transforms/InliningUtils.h"
22 #include "mlir/Transforms/Passes.h"
23 #include "llvm/ADT/SCCIterator.h"
24 #include "llvm/Support/Debug.h"
25 #include "llvm/Support/Parallel.h"
26 
27 #define DEBUG_TYPE "inlining"
28 
29 using namespace mlir;
30 
31 //===----------------------------------------------------------------------===//
32 // Symbol Use Tracking
33 //===----------------------------------------------------------------------===//
34 
35 /// Walk all of the used symbol callgraph nodes referenced with the given op.
walkReferencedSymbolNodes(Operation * op,CallGraph & cg,SymbolTableCollection & symbolTable,DenseMap<Attribute,CallGraphNode * > & resolvedRefs,function_ref<void (CallGraphNode *,Operation *)> callback)36 static void walkReferencedSymbolNodes(
37     Operation *op, CallGraph &cg, SymbolTableCollection &symbolTable,
38     DenseMap<Attribute, CallGraphNode *> &resolvedRefs,
39     function_ref<void(CallGraphNode *, Operation *)> callback) {
40   auto symbolUses = SymbolTable::getSymbolUses(op);
41   assert(symbolUses && "expected uses to be valid");
42 
43   Operation *symbolTableOp = op->getParentOp();
44   for (const SymbolTable::SymbolUse &use : *symbolUses) {
45     auto refIt = resolvedRefs.insert({use.getSymbolRef(), nullptr});
46     CallGraphNode *&node = refIt.first->second;
47 
48     // If this is the first instance of this reference, try to resolve a
49     // callgraph node for it.
50     if (refIt.second) {
51       auto *symbolOp = symbolTable.lookupNearestSymbolFrom(symbolTableOp,
52                                                            use.getSymbolRef());
53       auto callableOp = dyn_cast_or_null<CallableOpInterface>(symbolOp);
54       if (!callableOp)
55         continue;
56       node = cg.lookupNode(callableOp.getCallableRegion());
57     }
58     if (node)
59       callback(node, use.getUser());
60   }
61 }
62 
63 //===----------------------------------------------------------------------===//
64 // CGUseList
65 
66 namespace {
67 /// This struct tracks the uses of callgraph nodes that can be dropped when
68 /// use_empty. It directly tracks and manages a use-list for all of the
69 /// call-graph nodes. This is necessary because many callgraph nodes are
70 /// referenced by SymbolRefAttr, which has no mechanism akin to the SSA `Use`
71 /// class.
72 struct CGUseList {
73   /// This struct tracks the uses of callgraph nodes within a specific
74   /// operation.
75   struct CGUser {
76     /// Any nodes referenced in the top-level attribute list of this user. We
77     /// use a set here because the number of references does not matter.
78     DenseSet<CallGraphNode *> topLevelUses;
79 
80     /// Uses of nodes referenced by nested operations.
81     DenseMap<CallGraphNode *, int> innerUses;
82   };
83 
84   CGUseList(Operation *op, CallGraph &cg, SymbolTableCollection &symbolTable);
85 
86   /// Drop uses of nodes referred to by the given call operation that resides
87   /// within 'userNode'.
88   void dropCallUses(CallGraphNode *userNode, Operation *callOp, CallGraph &cg);
89 
90   /// Remove the given node from the use list.
91   void eraseNode(CallGraphNode *node);
92 
93   /// Returns true if the given callgraph node has no uses and can be pruned.
94   bool isDead(CallGraphNode *node) const;
95 
96   /// Returns true if the given callgraph node has a single use and can be
97   /// discarded.
98   bool hasOneUseAndDiscardable(CallGraphNode *node) const;
99 
100   /// Recompute the uses held by the given callgraph node.
101   void recomputeUses(CallGraphNode *node, CallGraph &cg);
102 
103   /// Merge the uses of 'lhs' with the uses of the 'rhs' after inlining a copy
104   /// of 'lhs' into 'rhs'.
105   void mergeUsesAfterInlining(CallGraphNode *lhs, CallGraphNode *rhs);
106 
107 private:
108   /// Decrement the uses of discardable nodes referenced by the given user.
109   void decrementDiscardableUses(CGUser &uses);
110 
111   /// A mapping between a discardable callgraph node (that is a symbol) and the
112   /// number of uses for this node.
113   DenseMap<CallGraphNode *, int> discardableSymNodeUses;
114 
115   /// A mapping between a callgraph node and the symbol callgraph nodes that it
116   /// uses.
117   DenseMap<CallGraphNode *, CGUser> nodeUses;
118 
119   /// A symbol table to use when resolving call lookups.
120   SymbolTableCollection &symbolTable;
121 };
122 } // end anonymous namespace
123 
CGUseList(Operation * op,CallGraph & cg,SymbolTableCollection & symbolTable)124 CGUseList::CGUseList(Operation *op, CallGraph &cg,
125                      SymbolTableCollection &symbolTable)
126     : symbolTable(symbolTable) {
127   /// A set of callgraph nodes that are always known to be live during inlining.
128   DenseMap<Attribute, CallGraphNode *> alwaysLiveNodes;
129 
130   // Walk each of the symbol tables looking for discardable callgraph nodes.
131   auto walkFn = [&](Operation *symbolTableOp, bool allUsesVisible) {
132     for (Operation &op : symbolTableOp->getRegion(0).getOps()) {
133       // If this is a callgraph operation, check to see if it is discardable.
134       if (auto callable = dyn_cast<CallableOpInterface>(&op)) {
135         if (auto *node = cg.lookupNode(callable.getCallableRegion())) {
136           SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(&op);
137           if (symbol && (allUsesVisible || symbol.isPrivate()) &&
138               symbol.canDiscardOnUseEmpty()) {
139             discardableSymNodeUses.try_emplace(node, 0);
140           }
141           continue;
142         }
143       }
144       // Otherwise, check for any referenced nodes. These will be always-live.
145       walkReferencedSymbolNodes(&op, cg, symbolTable, alwaysLiveNodes,
146                                 [](CallGraphNode *, Operation *) {});
147     }
148   };
149   SymbolTable::walkSymbolTables(op, /*allSymUsesVisible=*/!op->getBlock(),
150                                 walkFn);
151 
152   // Drop the use information for any discardable nodes that are always live.
153   for (auto &it : alwaysLiveNodes)
154     discardableSymNodeUses.erase(it.second);
155 
156   // Compute the uses for each of the callable nodes in the graph.
157   for (CallGraphNode *node : cg)
158     recomputeUses(node, cg);
159 }
160 
dropCallUses(CallGraphNode * userNode,Operation * callOp,CallGraph & cg)161 void CGUseList::dropCallUses(CallGraphNode *userNode, Operation *callOp,
162                              CallGraph &cg) {
163   auto &userRefs = nodeUses[userNode].innerUses;
164   auto walkFn = [&](CallGraphNode *node, Operation *user) {
165     auto parentIt = userRefs.find(node);
166     if (parentIt == userRefs.end())
167       return;
168     --parentIt->second;
169     --discardableSymNodeUses[node];
170   };
171   DenseMap<Attribute, CallGraphNode *> resolvedRefs;
172   walkReferencedSymbolNodes(callOp, cg, symbolTable, resolvedRefs, walkFn);
173 }
174 
eraseNode(CallGraphNode * node)175 void CGUseList::eraseNode(CallGraphNode *node) {
176   // Drop all child nodes.
177   for (auto &edge : *node)
178     if (edge.isChild())
179       eraseNode(edge.getTarget());
180 
181   // Drop the uses held by this node and erase it.
182   auto useIt = nodeUses.find(node);
183   assert(useIt != nodeUses.end() && "expected node to be valid");
184   decrementDiscardableUses(useIt->getSecond());
185   nodeUses.erase(useIt);
186   discardableSymNodeUses.erase(node);
187 }
188 
isDead(CallGraphNode * node) const189 bool CGUseList::isDead(CallGraphNode *node) const {
190   // If the parent operation isn't a symbol, simply check normal SSA deadness.
191   Operation *nodeOp = node->getCallableRegion()->getParentOp();
192   if (!isa<SymbolOpInterface>(nodeOp))
193     return MemoryEffectOpInterface::hasNoEffect(nodeOp) && nodeOp->use_empty();
194 
195   // Otherwise, check the number of symbol uses.
196   auto symbolIt = discardableSymNodeUses.find(node);
197   return symbolIt != discardableSymNodeUses.end() && symbolIt->second == 0;
198 }
199 
hasOneUseAndDiscardable(CallGraphNode * node) const200 bool CGUseList::hasOneUseAndDiscardable(CallGraphNode *node) const {
201   // If this isn't a symbol node, check for side-effects and SSA use count.
202   Operation *nodeOp = node->getCallableRegion()->getParentOp();
203   if (!isa<SymbolOpInterface>(nodeOp))
204     return MemoryEffectOpInterface::hasNoEffect(nodeOp) && nodeOp->hasOneUse();
205 
206   // Otherwise, check the number of symbol uses.
207   auto symbolIt = discardableSymNodeUses.find(node);
208   return symbolIt != discardableSymNodeUses.end() && symbolIt->second == 1;
209 }
210 
recomputeUses(CallGraphNode * node,CallGraph & cg)211 void CGUseList::recomputeUses(CallGraphNode *node, CallGraph &cg) {
212   Operation *parentOp = node->getCallableRegion()->getParentOp();
213   CGUser &uses = nodeUses[node];
214   decrementDiscardableUses(uses);
215 
216   // Collect the new discardable uses within this node.
217   uses = CGUser();
218   DenseMap<Attribute, CallGraphNode *> resolvedRefs;
219   auto walkFn = [&](CallGraphNode *refNode, Operation *user) {
220     auto discardSymIt = discardableSymNodeUses.find(refNode);
221     if (discardSymIt == discardableSymNodeUses.end())
222       return;
223 
224     if (user != parentOp)
225       ++uses.innerUses[refNode];
226     else if (!uses.topLevelUses.insert(refNode).second)
227       return;
228     ++discardSymIt->second;
229   };
230   walkReferencedSymbolNodes(parentOp, cg, symbolTable, resolvedRefs, walkFn);
231 }
232 
mergeUsesAfterInlining(CallGraphNode * lhs,CallGraphNode * rhs)233 void CGUseList::mergeUsesAfterInlining(CallGraphNode *lhs, CallGraphNode *rhs) {
234   auto &lhsUses = nodeUses[lhs], &rhsUses = nodeUses[rhs];
235   for (auto &useIt : lhsUses.innerUses) {
236     rhsUses.innerUses[useIt.first] += useIt.second;
237     discardableSymNodeUses[useIt.first] += useIt.second;
238   }
239 }
240 
decrementDiscardableUses(CGUser & uses)241 void CGUseList::decrementDiscardableUses(CGUser &uses) {
242   for (CallGraphNode *node : uses.topLevelUses)
243     --discardableSymNodeUses[node];
244   for (auto &it : uses.innerUses)
245     discardableSymNodeUses[it.first] -= it.second;
246 }
247 
248 //===----------------------------------------------------------------------===//
249 // CallGraph traversal
250 //===----------------------------------------------------------------------===//
251 
252 namespace {
253 /// This class represents a specific callgraph SCC.
254 class CallGraphSCC {
255 public:
CallGraphSCC(llvm::scc_iterator<const CallGraph * > & parentIterator)256   CallGraphSCC(llvm::scc_iterator<const CallGraph *> &parentIterator)
257       : parentIterator(parentIterator) {}
258   /// Return a range over the nodes within this SCC.
begin()259   std::vector<CallGraphNode *>::iterator begin() { return nodes.begin(); }
end()260   std::vector<CallGraphNode *>::iterator end() { return nodes.end(); }
261 
262   /// Reset the nodes of this SCC with those provided.
reset(const std::vector<CallGraphNode * > & newNodes)263   void reset(const std::vector<CallGraphNode *> &newNodes) { nodes = newNodes; }
264 
265   /// Remove the given node from this SCC.
remove(CallGraphNode * node)266   void remove(CallGraphNode *node) {
267     auto it = llvm::find(nodes, node);
268     if (it != nodes.end()) {
269       nodes.erase(it);
270       parentIterator.ReplaceNode(node, nullptr);
271     }
272   }
273 
274 private:
275   std::vector<CallGraphNode *> nodes;
276   llvm::scc_iterator<const CallGraph *> &parentIterator;
277 };
278 } // end anonymous namespace
279 
280 /// Run a given transformation over the SCCs of the callgraph in a bottom up
281 /// traversal.
282 static void
runTransformOnCGSCCs(const CallGraph & cg,function_ref<void (CallGraphSCC &)> sccTransformer)283 runTransformOnCGSCCs(const CallGraph &cg,
284                      function_ref<void(CallGraphSCC &)> sccTransformer) {
285   llvm::scc_iterator<const CallGraph *> cgi = llvm::scc_begin(&cg);
286   CallGraphSCC currentSCC(cgi);
287   while (!cgi.isAtEnd()) {
288     // Copy the current SCC and increment so that the transformer can modify the
289     // SCC without invalidating our iterator.
290     currentSCC.reset(*cgi);
291     ++cgi;
292     sccTransformer(currentSCC);
293   }
294 }
295 
296 namespace {
297 /// This struct represents a resolved call to a given callgraph node. Given that
298 /// the call does not actually contain a direct reference to the
299 /// Region(CallGraphNode) that it is dispatching to, we need to resolve them
300 /// explicitly.
301 struct ResolvedCall {
ResolvedCall__anon92f6fd4f0711::ResolvedCall302   ResolvedCall(CallOpInterface call, CallGraphNode *sourceNode,
303                CallGraphNode *targetNode)
304       : call(call), sourceNode(sourceNode), targetNode(targetNode) {}
305   CallOpInterface call;
306   CallGraphNode *sourceNode, *targetNode;
307 };
308 } // end anonymous namespace
309 
310 /// Collect all of the callable operations within the given range of blocks. If
311 /// `traverseNestedCGNodes` is true, this will also collect call operations
312 /// inside of nested callgraph nodes.
collectCallOps(iterator_range<Region::iterator> blocks,CallGraphNode * sourceNode,CallGraph & cg,SymbolTableCollection & symbolTable,SmallVectorImpl<ResolvedCall> & calls,bool traverseNestedCGNodes)313 static void collectCallOps(iterator_range<Region::iterator> blocks,
314                            CallGraphNode *sourceNode, CallGraph &cg,
315                            SymbolTableCollection &symbolTable,
316                            SmallVectorImpl<ResolvedCall> &calls,
317                            bool traverseNestedCGNodes) {
318   SmallVector<std::pair<Block *, CallGraphNode *>, 8> worklist;
319   auto addToWorklist = [&](CallGraphNode *node,
320                            iterator_range<Region::iterator> blocks) {
321     for (Block &block : blocks)
322       worklist.emplace_back(&block, node);
323   };
324 
325   addToWorklist(sourceNode, blocks);
326   while (!worklist.empty()) {
327     Block *block;
328     std::tie(block, sourceNode) = worklist.pop_back_val();
329 
330     for (Operation &op : *block) {
331       if (auto call = dyn_cast<CallOpInterface>(op)) {
332         // TODO: Support inlining nested call references.
333         CallInterfaceCallable callable = call.getCallableForCallee();
334         if (SymbolRefAttr symRef = callable.dyn_cast<SymbolRefAttr>()) {
335           if (!symRef.isa<FlatSymbolRefAttr>())
336             continue;
337         }
338 
339         CallGraphNode *targetNode = cg.resolveCallable(call, symbolTable);
340         if (!targetNode->isExternal())
341           calls.emplace_back(call, sourceNode, targetNode);
342         continue;
343       }
344 
345       // If this is not a call, traverse the nested regions. If
346       // `traverseNestedCGNodes` is false, then don't traverse nested call graph
347       // regions.
348       for (auto &nestedRegion : op.getRegions()) {
349         CallGraphNode *nestedNode = cg.lookupNode(&nestedRegion);
350         if (traverseNestedCGNodes || !nestedNode)
351           addToWorklist(nestedNode ? nestedNode : sourceNode, nestedRegion);
352       }
353     }
354   }
355 }
356 
357 //===----------------------------------------------------------------------===//
358 // Inliner
359 //===----------------------------------------------------------------------===//
360 namespace {
361 /// This class provides a specialization of the main inlining interface.
362 struct Inliner : public InlinerInterface {
Inliner__anon92f6fd4f0911::Inliner363   Inliner(MLIRContext *context, CallGraph &cg,
364           SymbolTableCollection &symbolTable)
365       : InlinerInterface(context), cg(cg), symbolTable(symbolTable) {}
366 
367   /// Process a set of blocks that have been inlined. This callback is invoked
368   /// *before* inlined terminator operations have been processed.
369   void
processInlinedBlocks__anon92f6fd4f0911::Inliner370   processInlinedBlocks(iterator_range<Region::iterator> inlinedBlocks) final {
371     // Find the closest callgraph node from the first block.
372     CallGraphNode *node;
373     Region *region = inlinedBlocks.begin()->getParent();
374     while (!(node = cg.lookupNode(region))) {
375       region = region->getParentRegion();
376       assert(region && "expected valid parent node");
377     }
378 
379     collectCallOps(inlinedBlocks, node, cg, symbolTable, calls,
380                    /*traverseNestedCGNodes=*/true);
381   }
382 
383   /// Mark the given callgraph node for deletion.
markForDeletion__anon92f6fd4f0911::Inliner384   void markForDeletion(CallGraphNode *node) { deadNodes.insert(node); }
385 
386   /// This method properly disposes of callables that became dead during
387   /// inlining. This should not be called while iterating over the SCCs.
eraseDeadCallables__anon92f6fd4f0911::Inliner388   void eraseDeadCallables() {
389     for (CallGraphNode *node : deadNodes)
390       node->getCallableRegion()->getParentOp()->erase();
391   }
392 
393   /// The set of callables known to be dead.
394   SmallPtrSet<CallGraphNode *, 8> deadNodes;
395 
396   /// The current set of call instructions to consider for inlining.
397   SmallVector<ResolvedCall, 8> calls;
398 
399   /// The callgraph being operated on.
400   CallGraph &cg;
401 
402   /// A symbol table to use when resolving call lookups.
403   SymbolTableCollection &symbolTable;
404 };
405 } // namespace
406 
407 /// Returns true if the given call should be inlined.
shouldInline(ResolvedCall & resolvedCall)408 static bool shouldInline(ResolvedCall &resolvedCall) {
409   // Don't allow inlining terminator calls. We currently don't support this
410   // case.
411   if (resolvedCall.call->isKnownTerminator())
412     return false;
413 
414   // Don't allow inlining if the target is an ancestor of the call. This
415   // prevents inlining recursively.
416   if (resolvedCall.targetNode->getCallableRegion()->isAncestor(
417           resolvedCall.call->getParentRegion()))
418     return false;
419 
420   // Otherwise, inline.
421   return true;
422 }
423 
424 /// Attempt to inline calls within the given scc. This function returns
425 /// success if any calls were inlined, failure otherwise.
inlineCallsInSCC(Inliner & inliner,CGUseList & useList,CallGraphSCC & currentSCC)426 static LogicalResult inlineCallsInSCC(Inliner &inliner, CGUseList &useList,
427                                       CallGraphSCC &currentSCC) {
428   CallGraph &cg = inliner.cg;
429   auto &calls = inliner.calls;
430 
431   // A set of dead nodes to remove after inlining.
432   SmallVector<CallGraphNode *, 1> deadNodes;
433 
434   // Collect all of the direct calls within the nodes of the current SCC. We
435   // don't traverse nested callgraph nodes, because they are handled separately
436   // likely within a different SCC.
437   for (CallGraphNode *node : currentSCC) {
438     if (node->isExternal())
439       continue;
440 
441     // Don't collect calls if the node is already dead.
442     if (useList.isDead(node)) {
443       deadNodes.push_back(node);
444     } else {
445       collectCallOps(*node->getCallableRegion(), node, cg, inliner.symbolTable,
446                      calls, /*traverseNestedCGNodes=*/false);
447     }
448   }
449 
450   // Try to inline each of the call operations. Don't cache the end iterator
451   // here as more calls may be added during inlining.
452   bool inlinedAnyCalls = false;
453   for (unsigned i = 0; i != calls.size(); ++i) {
454     ResolvedCall it = calls[i];
455     bool doInline = shouldInline(it);
456     CallOpInterface call = it.call;
457     LLVM_DEBUG({
458       if (doInline)
459         llvm::dbgs() << "* Inlining call: " << call << "\n";
460       else
461         llvm::dbgs() << "* Not inlining call: " << call << "\n";
462     });
463     if (!doInline)
464       continue;
465     Region *targetRegion = it.targetNode->getCallableRegion();
466 
467     // If this is the last call to the target node and the node is discardable,
468     // then inline it in-place and delete the node if successful.
469     bool inlineInPlace = useList.hasOneUseAndDiscardable(it.targetNode);
470 
471     LogicalResult inlineResult = inlineCall(
472         inliner, call, cast<CallableOpInterface>(targetRegion->getParentOp()),
473         targetRegion, /*shouldCloneInlinedRegion=*/!inlineInPlace);
474     if (failed(inlineResult)) {
475       LLVM_DEBUG(llvm::dbgs() << "** Failed to inline\n");
476       continue;
477     }
478     inlinedAnyCalls = true;
479 
480     // If the inlining was successful, Merge the new uses into the source node.
481     useList.dropCallUses(it.sourceNode, call.getOperation(), cg);
482     useList.mergeUsesAfterInlining(it.targetNode, it.sourceNode);
483 
484     // then erase the call.
485     call.erase();
486 
487     // If we inlined in place, mark the node for deletion.
488     if (inlineInPlace) {
489       useList.eraseNode(it.targetNode);
490       deadNodes.push_back(it.targetNode);
491     }
492   }
493 
494   for (CallGraphNode *node : deadNodes) {
495     currentSCC.remove(node);
496     inliner.markForDeletion(node);
497   }
498   calls.clear();
499   return success(inlinedAnyCalls);
500 }
501 
502 /// Canonicalize the nodes within the given SCC with the given set of
503 /// canonicalization patterns.
canonicalizeSCC(CallGraph & cg,CGUseList & useList,CallGraphSCC & currentSCC,MLIRContext * context,const FrozenRewritePatternList & canonPatterns)504 static void canonicalizeSCC(CallGraph &cg, CGUseList &useList,
505                             CallGraphSCC &currentSCC, MLIRContext *context,
506                             const FrozenRewritePatternList &canonPatterns) {
507   // Collect the sets of nodes to canonicalize.
508   SmallVector<CallGraphNode *, 4> nodesToCanonicalize;
509   for (auto *node : currentSCC) {
510     // Don't canonicalize the external node, it has no valid callable region.
511     if (node->isExternal())
512       continue;
513 
514     // Don't canonicalize nodes with children. Nodes with children
515     // require special handling as we may remove the node during
516     // canonicalization. In the future, we should be able to handle this
517     // case with proper node deletion tracking.
518     if (node->hasChildren())
519       continue;
520 
521     // We also won't apply canonicalizations for nodes that are not
522     // isolated. This avoids potentially mutating the regions of nodes defined
523     // above, this is also a stipulation of the 'applyPatternsAndFoldGreedily'
524     // driver.
525     auto *region = node->getCallableRegion();
526     if (!region->getParentOp()->isKnownIsolatedFromAbove())
527       continue;
528     nodesToCanonicalize.push_back(node);
529   }
530   if (nodesToCanonicalize.empty())
531     return;
532 
533   // Canonicalize each of the nodes within the SCC in parallel.
534   // NOTE: This is simple now, because we don't enable canonicalizing nodes
535   // within children. When we remove this restriction, this logic will need to
536   // be reworked.
537   if (context->isMultithreadingEnabled()) {
538     ParallelDiagnosticHandler canonicalizationHandler(context);
539     llvm::parallelForEachN(
540         /*Begin=*/0, /*End=*/nodesToCanonicalize.size(), [&](size_t index) {
541           // Set the order for this thread so that diagnostics will be properly
542           // ordered.
543           canonicalizationHandler.setOrderIDForThread(index);
544 
545           // Apply the canonicalization patterns to this region.
546           auto *node = nodesToCanonicalize[index];
547           applyPatternsAndFoldGreedily(*node->getCallableRegion(),
548                                        canonPatterns);
549 
550           // Make sure to reset the order ID for the diagnostic handler, as this
551           // thread may be used in a different context.
552           canonicalizationHandler.eraseOrderIDForThread();
553         });
554   } else {
555     for (CallGraphNode *node : nodesToCanonicalize)
556       applyPatternsAndFoldGreedily(*node->getCallableRegion(), canonPatterns);
557   }
558 
559   // Recompute the uses held by each of the nodes.
560   for (CallGraphNode *node : nodesToCanonicalize)
561     useList.recomputeUses(node, cg);
562 }
563 
564 //===----------------------------------------------------------------------===//
565 // InlinerPass
566 //===----------------------------------------------------------------------===//
567 
568 namespace {
569 struct InlinerPass : public InlinerBase<InlinerPass> {
570   void runOnOperation() override;
571 
572   /// Attempt to inline calls within the given scc, and run canonicalizations
573   /// with the given patterns, until a fixed point is reached. This allows for
574   /// the inlining of newly devirtualized calls.
575   void inlineSCC(Inliner &inliner, CGUseList &useList, CallGraphSCC &currentSCC,
576                  MLIRContext *context,
577                  const FrozenRewritePatternList &canonPatterns);
578 };
579 } // end anonymous namespace
580 
runOnOperation()581 void InlinerPass::runOnOperation() {
582   CallGraph &cg = getAnalysis<CallGraph>();
583   auto *context = &getContext();
584 
585   // The inliner should only be run on operations that define a symbol table,
586   // as the callgraph will need to resolve references.
587   Operation *op = getOperation();
588   if (!op->hasTrait<OpTrait::SymbolTable>()) {
589     op->emitOpError() << " was scheduled to run under the inliner, but does "
590                          "not define a symbol table";
591     return signalPassFailure();
592   }
593 
594   // Collect a set of canonicalization patterns to use when simplifying
595   // callable regions within an SCC.
596   OwningRewritePatternList canonPatterns;
597   for (auto *op : context->getRegisteredOperations())
598     op->getCanonicalizationPatterns(canonPatterns, context);
599   FrozenRewritePatternList frozenCanonPatterns(std::move(canonPatterns));
600 
601   // Run the inline transform in post-order over the SCCs in the callgraph.
602   SymbolTableCollection symbolTable;
603   Inliner inliner(context, cg, symbolTable);
604   CGUseList useList(getOperation(), cg, symbolTable);
605   runTransformOnCGSCCs(cg, [&](CallGraphSCC &scc) {
606     inlineSCC(inliner, useList, scc, context, frozenCanonPatterns);
607   });
608 
609   // After inlining, make sure to erase any callables proven to be dead.
610   inliner.eraseDeadCallables();
611 }
612 
inlineSCC(Inliner & inliner,CGUseList & useList,CallGraphSCC & currentSCC,MLIRContext * context,const FrozenRewritePatternList & canonPatterns)613 void InlinerPass::inlineSCC(Inliner &inliner, CGUseList &useList,
614                             CallGraphSCC &currentSCC, MLIRContext *context,
615                             const FrozenRewritePatternList &canonPatterns) {
616   // If we successfully inlined any calls, run some simplifications on the
617   // nodes of the scc. Continue attempting to inline until we reach a fixed
618   // point, or a maximum iteration count. We canonicalize here as it may
619   // devirtualize new calls, as well as give us a better cost model.
620   unsigned iterationCount = 0;
621   while (succeeded(inlineCallsInSCC(inliner, useList, currentSCC))) {
622     // If we aren't allowing simplifications or the max iteration count was
623     // reached, then bail out early.
624     if (disableCanonicalization || ++iterationCount >= maxInliningIterations)
625       break;
626     canonicalizeSCC(inliner.cg, useList, currentSCC, context, canonPatterns);
627   }
628 }
629 
createInlinerPass()630 std::unique_ptr<Pass> mlir::createInlinerPass() {
631   return std::make_unique<InlinerPass>();
632 }
633