1 //===- CallGraph.cpp - CallGraph analysis for MLIR ------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains interfaces and analyses for defining a nested callgraph.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "mlir/Analysis/CallGraph.h"
14 #include "mlir/IR/Operation.h"
15 #include "mlir/IR/SymbolTable.h"
16 #include "mlir/Interfaces/CallInterfaces.h"
17 #include "llvm/ADT/PointerUnion.h"
18 #include "llvm/ADT/SCCIterator.h"
19 #include "llvm/Support/raw_ostream.h"
20
21 using namespace mlir;
22
23 //===----------------------------------------------------------------------===//
24 // CallGraphNode
25 //===----------------------------------------------------------------------===//
26
27 /// Returns true if this node refers to the indirect/external node.
isExternal() const28 bool CallGraphNode::isExternal() const { return !callableRegion; }
29
30 /// Return the callable region this node represents. This can only be called
31 /// on non-external nodes.
getCallableRegion() const32 Region *CallGraphNode::getCallableRegion() const {
33 assert(!isExternal() && "the external node has no callable region");
34 return callableRegion;
35 }
36
37 /// Adds an reference edge to the given node. This is only valid on the
38 /// external node.
addAbstractEdge(CallGraphNode * node)39 void CallGraphNode::addAbstractEdge(CallGraphNode *node) {
40 assert(isExternal() && "abstract edges are only valid on external nodes");
41 addEdge(node, Edge::Kind::Abstract);
42 }
43
44 /// Add an outgoing call edge from this node.
addCallEdge(CallGraphNode * node)45 void CallGraphNode::addCallEdge(CallGraphNode *node) {
46 addEdge(node, Edge::Kind::Call);
47 }
48
49 /// Adds a reference edge to the given child node.
addChildEdge(CallGraphNode * child)50 void CallGraphNode::addChildEdge(CallGraphNode *child) {
51 addEdge(child, Edge::Kind::Child);
52 }
53
54 /// Returns true if this node has any child edges.
hasChildren() const55 bool CallGraphNode::hasChildren() const {
56 return llvm::any_of(edges, [](const Edge &edge) { return edge.isChild(); });
57 }
58
59 /// Add an edge to 'node' with the given kind.
addEdge(CallGraphNode * node,Edge::Kind kind)60 void CallGraphNode::addEdge(CallGraphNode *node, Edge::Kind kind) {
61 edges.insert({node, kind});
62 }
63
64 //===----------------------------------------------------------------------===//
65 // CallGraph
66 //===----------------------------------------------------------------------===//
67
68 /// Recursively compute the callgraph edges for the given operation. Computed
69 /// edges are placed into the given callgraph object.
computeCallGraph(Operation * op,CallGraph & cg,SymbolTableCollection & symbolTable,CallGraphNode * parentNode,bool resolveCalls)70 static void computeCallGraph(Operation *op, CallGraph &cg,
71 SymbolTableCollection &symbolTable,
72 CallGraphNode *parentNode, bool resolveCalls) {
73 if (CallOpInterface call = dyn_cast<CallOpInterface>(op)) {
74 // If there is no parent node, we ignore this operation. Even if this
75 // operation was a call, there would be no callgraph node to attribute it
76 // to.
77 if (resolveCalls && parentNode)
78 parentNode->addCallEdge(cg.resolveCallable(call, symbolTable));
79 return;
80 }
81
82 // Compute the callgraph nodes and edges for each of the nested operations.
83 if (CallableOpInterface callable = dyn_cast<CallableOpInterface>(op)) {
84 if (auto *callableRegion = callable.getCallableRegion())
85 parentNode = cg.getOrAddNode(callableRegion, parentNode);
86 else
87 return;
88 }
89
90 for (Region ®ion : op->getRegions())
91 for (Operation &nested : region.getOps())
92 computeCallGraph(&nested, cg, symbolTable, parentNode, resolveCalls);
93 }
94
CallGraph(Operation * op)95 CallGraph::CallGraph(Operation *op) : externalNode(/*callableRegion=*/nullptr) {
96 // Make two passes over the graph, one to compute the callables and one to
97 // resolve the calls. We split these up as we may have nested callable objects
98 // that need to be reserved before the calls.
99 SymbolTableCollection symbolTable;
100 computeCallGraph(op, *this, symbolTable, /*parentNode=*/nullptr,
101 /*resolveCalls=*/false);
102 computeCallGraph(op, *this, symbolTable, /*parentNode=*/nullptr,
103 /*resolveCalls=*/true);
104 }
105
106 /// Get or add a call graph node for the given region.
getOrAddNode(Region * region,CallGraphNode * parentNode)107 CallGraphNode *CallGraph::getOrAddNode(Region *region,
108 CallGraphNode *parentNode) {
109 assert(region && isa<CallableOpInterface>(region->getParentOp()) &&
110 "expected parent operation to be callable");
111 std::unique_ptr<CallGraphNode> &node = nodes[region];
112 if (!node) {
113 node.reset(new CallGraphNode(region));
114
115 // Add this node to the given parent node if necessary.
116 if (parentNode) {
117 parentNode->addChildEdge(node.get());
118 } else {
119 // Otherwise, connect all callable nodes to the external node, this allows
120 // for conservatively including all callable nodes within the graph.
121 // FIXME This isn't correct, this is only necessary for callable nodes
122 // that *could* be called from external sources. This requires extending
123 // the interface for callables to check if they may be referenced
124 // externally.
125 externalNode.addAbstractEdge(node.get());
126 }
127 }
128 return node.get();
129 }
130
131 /// Lookup a call graph node for the given region, or nullptr if none is
132 /// registered.
lookupNode(Region * region) const133 CallGraphNode *CallGraph::lookupNode(Region *region) const {
134 auto it = nodes.find(region);
135 return it == nodes.end() ? nullptr : it->second.get();
136 }
137
138 /// Resolve the callable for given callee to a node in the callgraph, or the
139 /// external node if a valid node was not resolved.
140 CallGraphNode *
resolveCallable(CallOpInterface call,SymbolTableCollection & symbolTable) const141 CallGraph::resolveCallable(CallOpInterface call,
142 SymbolTableCollection &symbolTable) const {
143 Operation *callable = call.resolveCallable(&symbolTable);
144 if (auto callableOp = dyn_cast_or_null<CallableOpInterface>(callable))
145 if (auto *node = lookupNode(callableOp.getCallableRegion()))
146 return node;
147
148 // If we don't have a valid direct region, this is an external call.
149 return getExternalNode();
150 }
151
152 /// Erase the given node from the callgraph.
eraseNode(CallGraphNode * node)153 void CallGraph::eraseNode(CallGraphNode *node) {
154 // Erase any children of this node first.
155 if (node->hasChildren()) {
156 for (const CallGraphNode::Edge &edge : llvm::make_early_inc_range(*node))
157 if (edge.isChild())
158 eraseNode(edge.getTarget());
159 }
160 // Erase any edges to this node from any other nodes.
161 for (auto &it : nodes) {
162 it.second->edges.remove_if([node](const CallGraphNode::Edge &edge) {
163 return edge.getTarget() == node;
164 });
165 }
166 nodes.erase(node->getCallableRegion());
167 }
168
169 //===----------------------------------------------------------------------===//
170 // Printing
171
172 /// Dump the graph in a human readable format.
dump() const173 void CallGraph::dump() const { print(llvm::errs()); }
print(raw_ostream & os) const174 void CallGraph::print(raw_ostream &os) const {
175 os << "// ---- CallGraph ----\n";
176
177 // Functor used to output the name for the given node.
178 auto emitNodeName = [&](const CallGraphNode *node) {
179 if (node->isExternal()) {
180 os << "<External-Node>";
181 return;
182 }
183
184 auto *callableRegion = node->getCallableRegion();
185 auto *parentOp = callableRegion->getParentOp();
186 os << "'" << callableRegion->getParentOp()->getName() << "' - Region #"
187 << callableRegion->getRegionNumber();
188 auto attrs = parentOp->getAttrDictionary();
189 if (!attrs.empty())
190 os << " : " << attrs;
191 };
192
193 for (auto &nodeIt : nodes) {
194 const CallGraphNode *node = nodeIt.second.get();
195
196 // Dump the header for this node.
197 os << "// - Node : ";
198 emitNodeName(node);
199 os << "\n";
200
201 // Emit each of the edges.
202 for (auto &edge : *node) {
203 os << "// -- ";
204 if (edge.isCall())
205 os << "Call";
206 else if (edge.isChild())
207 os << "Child";
208
209 os << "-Edge : ";
210 emitNodeName(edge.getTarget());
211 os << "\n";
212 }
213 os << "//\n";
214 }
215
216 os << "// -- SCCs --\n";
217
218 for (auto &scc : make_range(llvm::scc_begin(this), llvm::scc_end(this))) {
219 os << "// - SCC : \n";
220 for (auto &node : scc) {
221 os << "// -- Node :";
222 emitNodeName(node);
223 os << "\n";
224 }
225 os << "\n";
226 }
227
228 os << "// -------------------\n";
229 }
230