1 //===- CallGraph.h - CallGraph analysis for MLIR ----------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file contains an analysis for computing the multi-level callgraph from a 10 // given top-level operation. This nodes within this callgraph are defined by 11 // the `CallOpInterface` and `CallableOpInterface` operation interfaces defined 12 // in CallInterface.td. 13 // 14 //===----------------------------------------------------------------------===// 15 16 #ifndef MLIR_ANALYSIS_CALLGRAPH_H 17 #define MLIR_ANALYSIS_CALLGRAPH_H 18 19 #include "mlir/Support/LLVM.h" 20 #include "llvm/ADT/GraphTraits.h" 21 #include "llvm/ADT/MapVector.h" 22 #include "llvm/ADT/PointerIntPair.h" 23 #include "llvm/ADT/SetVector.h" 24 25 namespace mlir { 26 class CallOpInterface; 27 struct CallInterfaceCallable; 28 class Operation; 29 class Region; 30 class SymbolTableCollection; 31 32 //===----------------------------------------------------------------------===// 33 // CallGraphNode 34 //===----------------------------------------------------------------------===// 35 36 /// This class represents a single callable in the callgraph. Aside from the 37 /// external node, each node represents a callable node in the graph and 38 /// contains a valid corresponding Region. The external node is a virtual node 39 /// used to represent external edges into, and out of, the callgraph. 40 class CallGraphNode { 41 public: 42 /// This class represents a directed edge between two nodes in the callgraph. 43 class Edge { 44 enum class Kind { 45 // An 'Abstract' edge represents an opaque, non-operation, reference 46 // between this node and the target. Edges of this type are only valid 47 // from the external node, as there is no valid connection to an operation 48 // in the module. 49 Abstract, 50 51 // A 'Call' edge represents a direct reference to the target node via a 52 // call-like operation within the callable region of this node. 53 Call, 54 55 // A 'Child' edge is used when the region of target node is defined inside 56 // of the callable region of this node. This means that the region of this 57 // node is an ancestor of the region for the target node. As such, this 58 // edge cannot be used on the 'external' node. 59 Child, 60 }; 61 62 public: 63 /// Returns true if this edge represents an `Abstract` edge. isAbstract()64 bool isAbstract() const { return targetAndKind.getInt() == Kind::Abstract; } 65 66 /// Returns true if this edge represents a `Call` edge. isCall()67 bool isCall() const { return targetAndKind.getInt() == Kind::Call; } 68 69 /// Returns true if this edge represents a `Child` edge. isChild()70 bool isChild() const { return targetAndKind.getInt() == Kind::Child; } 71 72 /// Returns the target node for this edge. getTarget()73 CallGraphNode *getTarget() const { return targetAndKind.getPointer(); } 74 75 bool operator==(const Edge &edge) const { 76 return targetAndKind == edge.targetAndKind; 77 } 78 79 private: Edge(CallGraphNode * node,Kind kind)80 Edge(CallGraphNode *node, Kind kind) : targetAndKind(node, kind) {} Edge(llvm::PointerIntPair<CallGraphNode *,2,Kind> targetAndKind)81 explicit Edge(llvm::PointerIntPair<CallGraphNode *, 2, Kind> targetAndKind) 82 : targetAndKind(targetAndKind) {} 83 84 /// The target node of this edge, as well as the edge kind. 85 llvm::PointerIntPair<CallGraphNode *, 2, Kind> targetAndKind; 86 87 // Provide access to the constructor and Kind. 88 friend class CallGraphNode; 89 }; 90 91 /// Returns true if this node is an external node. 92 bool isExternal() const; 93 94 /// Returns the callable region this node represents. This can only be called 95 /// on non-external nodes. 96 Region *getCallableRegion() const; 97 98 /// Adds an abstract reference edge to the given node. An abstract edge does 99 /// not come from any observable operations, so this is only valid on the 100 /// external node. 101 void addAbstractEdge(CallGraphNode *node); 102 103 /// Add an outgoing call edge from this node. 104 void addCallEdge(CallGraphNode *node); 105 106 /// Adds a reference edge to the given child node. 107 void addChildEdge(CallGraphNode *child); 108 109 /// Iterator over the outgoing edges of this node. 110 using iterator = SmallVectorImpl<Edge>::const_iterator; begin()111 iterator begin() const { return edges.begin(); } end()112 iterator end() const { return edges.end(); } 113 114 /// Returns true if this node has any child edges. 115 bool hasChildren() const; 116 117 private: 118 /// DenseMap info for callgraph edges. 119 struct EdgeKeyInfo { 120 using BaseInfo = 121 DenseMapInfo<llvm::PointerIntPair<CallGraphNode *, 2, Edge::Kind>>; 122 getEmptyKeyEdgeKeyInfo123 static Edge getEmptyKey() { return Edge(BaseInfo::getEmptyKey()); } getTombstoneKeyEdgeKeyInfo124 static Edge getTombstoneKey() { return Edge(BaseInfo::getTombstoneKey()); } getHashValueEdgeKeyInfo125 static unsigned getHashValue(const Edge &edge) { 126 return BaseInfo::getHashValue(edge.targetAndKind); 127 } isEqualEdgeKeyInfo128 static bool isEqual(const Edge &lhs, const Edge &rhs) { return lhs == rhs; } 129 }; 130 CallGraphNode(Region * callableRegion)131 CallGraphNode(Region *callableRegion) : callableRegion(callableRegion) {} 132 133 /// Add an edge to 'node' with the given kind. 134 void addEdge(CallGraphNode *node, Edge::Kind kind); 135 136 /// The callable region defines the boundary of the call graph node. This is 137 /// the region referenced by 'call' operations. This is at a per-region 138 /// boundary as operations may define multiple callable regions. 139 Region *callableRegion; 140 141 /// A set of out-going edges from this node to other nodes in the graph. 142 llvm::SetVector<Edge, SmallVector<Edge, 4>, 143 llvm::SmallDenseSet<Edge, 4, EdgeKeyInfo>> 144 edges; 145 146 // Provide access to private methods. 147 friend class CallGraph; 148 }; 149 150 //===----------------------------------------------------------------------===// 151 // CallGraph 152 //===----------------------------------------------------------------------===// 153 154 class CallGraph { 155 using NodeMapT = llvm::MapVector<Region *, std::unique_ptr<CallGraphNode>>; 156 157 /// This class represents an iterator over the internal call graph nodes. This 158 /// class unwraps the map iterator to access the raw node. 159 class NodeIterator final 160 : public llvm::mapped_iterator< 161 NodeMapT::const_iterator, 162 CallGraphNode *(*)(const NodeMapT::value_type &)> { unwrap(const NodeMapT::value_type & value)163 static CallGraphNode *unwrap(const NodeMapT::value_type &value) { 164 return value.second.get(); 165 } 166 167 public: 168 /// Initializes the result type iterator to the specified result iterator. NodeIterator(NodeMapT::const_iterator it)169 NodeIterator(NodeMapT::const_iterator it) 170 : llvm::mapped_iterator< 171 NodeMapT::const_iterator, 172 CallGraphNode *(*)(const NodeMapT::value_type &)>(it, &unwrap) {} 173 }; 174 175 public: 176 CallGraph(Operation *op); 177 178 /// Get or add a call graph node for the given region. `parentNode` 179 /// corresponds to the direct node in the callgraph that contains the parent 180 /// operation of `region`, or nullptr if there is no parent node. 181 CallGraphNode *getOrAddNode(Region *region, CallGraphNode *parentNode); 182 183 /// Lookup a call graph node for the given region, or nullptr if none is 184 /// registered. 185 CallGraphNode *lookupNode(Region *region) const; 186 187 /// Return the callgraph node representing the indirect-external callee. getExternalNode()188 CallGraphNode *getExternalNode() const { 189 return const_cast<CallGraphNode *>(&externalNode); 190 } 191 192 /// Resolve the callable for given callee to a node in the callgraph, or the 193 /// external node if a valid node was not resolved. The provided symbol table 194 /// is used when resolving calls that reference callables via a symbol 195 /// reference. 196 CallGraphNode *resolveCallable(CallOpInterface call, 197 SymbolTableCollection &symbolTable) const; 198 199 /// Erase the given node from the callgraph. 200 void eraseNode(CallGraphNode *node); 201 202 /// An iterator over the nodes of the graph. 203 using iterator = NodeIterator; begin()204 iterator begin() const { return nodes.begin(); } end()205 iterator end() const { return nodes.end(); } 206 207 /// Dump the graph in a human readable format. 208 void dump() const; 209 void print(raw_ostream &os) const; 210 211 private: 212 /// The set of nodes within the callgraph. 213 NodeMapT nodes; 214 215 /// A special node used to indicate an external edges. 216 CallGraphNode externalNode; 217 }; 218 219 } // end namespace mlir 220 221 namespace llvm { 222 // Provide graph traits for traversing call graphs using standard graph 223 // traversals. 224 template <> struct GraphTraits<const mlir::CallGraphNode *> { 225 using NodeRef = mlir::CallGraphNode *; 226 static NodeRef getEntryNode(NodeRef node) { return node; } 227 228 static NodeRef unwrap(const mlir::CallGraphNode::Edge &edge) { 229 return edge.getTarget(); 230 } 231 232 // ChildIteratorType/begin/end - Allow iteration over all nodes in the graph. 233 using ChildIteratorType = 234 mapped_iterator<mlir::CallGraphNode::iterator, decltype(&unwrap)>; 235 static ChildIteratorType child_begin(NodeRef node) { 236 return {node->begin(), &unwrap}; 237 } 238 static ChildIteratorType child_end(NodeRef node) { 239 return {node->end(), &unwrap}; 240 } 241 }; 242 243 template <> 244 struct GraphTraits<const mlir::CallGraph *> 245 : public GraphTraits<const mlir::CallGraphNode *> { 246 /// The entry node into the graph is the external node. 247 static NodeRef getEntryNode(const mlir::CallGraph *cg) { 248 return cg->getExternalNode(); 249 } 250 251 // nodes_iterator/begin/end - Allow iteration over all nodes in the graph 252 using nodes_iterator = mlir::CallGraph::iterator; 253 static nodes_iterator nodes_begin(mlir::CallGraph *cg) { return cg->begin(); } 254 static nodes_iterator nodes_end(mlir::CallGraph *cg) { return cg->end(); } 255 }; 256 } // end namespace llvm 257 258 #endif // MLIR_ANALYSIS_CALLGRAPH_H 259