• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //===- CallGraph.h - CallGraph analysis for MLIR ----------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains an analysis for computing the multi-level callgraph from a
10 // given top-level operation. This nodes within this callgraph are defined by
11 // the `CallOpInterface` and `CallableOpInterface` operation interfaces defined
12 // in CallInterface.td.
13 //
14 //===----------------------------------------------------------------------===//
15 
16 #ifndef MLIR_ANALYSIS_CALLGRAPH_H
17 #define MLIR_ANALYSIS_CALLGRAPH_H
18 
19 #include "mlir/Support/LLVM.h"
20 #include "llvm/ADT/GraphTraits.h"
21 #include "llvm/ADT/MapVector.h"
22 #include "llvm/ADT/PointerIntPair.h"
23 #include "llvm/ADT/SetVector.h"
24 
25 namespace mlir {
26 class CallOpInterface;
27 struct CallInterfaceCallable;
28 class Operation;
29 class Region;
30 class SymbolTableCollection;
31 
32 //===----------------------------------------------------------------------===//
33 // CallGraphNode
34 //===----------------------------------------------------------------------===//
35 
36 /// This class represents a single callable in the callgraph. Aside from the
37 /// external node, each node represents a callable node in the graph and
38 /// contains a valid corresponding Region. The external node is a virtual node
39 /// used to represent external edges into, and out of, the callgraph.
40 class CallGraphNode {
41 public:
42   /// This class represents a directed edge between two nodes in the callgraph.
43   class Edge {
44     enum class Kind {
45       // An 'Abstract' edge represents an opaque, non-operation, reference
46       // between this node and the target. Edges of this type are only valid
47       // from the external node, as there is no valid connection to an operation
48       // in the module.
49       Abstract,
50 
51       // A 'Call' edge represents a direct reference to the target node via a
52       // call-like operation within the callable region of this node.
53       Call,
54 
55       // A 'Child' edge is used when the region of target node is defined inside
56       // of the callable region of this node. This means that the region of this
57       // node is an ancestor of the region for the target node. As such, this
58       // edge cannot be used on the 'external' node.
59       Child,
60     };
61 
62   public:
63     /// Returns true if this edge represents an `Abstract` edge.
isAbstract()64     bool isAbstract() const { return targetAndKind.getInt() == Kind::Abstract; }
65 
66     /// Returns true if this edge represents a `Call` edge.
isCall()67     bool isCall() const { return targetAndKind.getInt() == Kind::Call; }
68 
69     /// Returns true if this edge represents a `Child` edge.
isChild()70     bool isChild() const { return targetAndKind.getInt() == Kind::Child; }
71 
72     /// Returns the target node for this edge.
getTarget()73     CallGraphNode *getTarget() const { return targetAndKind.getPointer(); }
74 
75     bool operator==(const Edge &edge) const {
76       return targetAndKind == edge.targetAndKind;
77     }
78 
79   private:
Edge(CallGraphNode * node,Kind kind)80     Edge(CallGraphNode *node, Kind kind) : targetAndKind(node, kind) {}
Edge(llvm::PointerIntPair<CallGraphNode *,2,Kind> targetAndKind)81     explicit Edge(llvm::PointerIntPair<CallGraphNode *, 2, Kind> targetAndKind)
82         : targetAndKind(targetAndKind) {}
83 
84     /// The target node of this edge, as well as the edge kind.
85     llvm::PointerIntPair<CallGraphNode *, 2, Kind> targetAndKind;
86 
87     // Provide access to the constructor and Kind.
88     friend class CallGraphNode;
89   };
90 
91   /// Returns true if this node is an external node.
92   bool isExternal() const;
93 
94   /// Returns the callable region this node represents. This can only be called
95   /// on non-external nodes.
96   Region *getCallableRegion() const;
97 
98   /// Adds an abstract reference edge to the given node. An abstract edge does
99   /// not come from any observable operations, so this is only valid on the
100   /// external node.
101   void addAbstractEdge(CallGraphNode *node);
102 
103   /// Add an outgoing call edge from this node.
104   void addCallEdge(CallGraphNode *node);
105 
106   /// Adds a reference edge to the given child node.
107   void addChildEdge(CallGraphNode *child);
108 
109   /// Iterator over the outgoing edges of this node.
110   using iterator = SmallVectorImpl<Edge>::const_iterator;
begin()111   iterator begin() const { return edges.begin(); }
end()112   iterator end() const { return edges.end(); }
113 
114   /// Returns true if this node has any child edges.
115   bool hasChildren() const;
116 
117 private:
118   /// DenseMap info for callgraph edges.
119   struct EdgeKeyInfo {
120     using BaseInfo =
121         DenseMapInfo<llvm::PointerIntPair<CallGraphNode *, 2, Edge::Kind>>;
122 
getEmptyKeyEdgeKeyInfo123     static Edge getEmptyKey() { return Edge(BaseInfo::getEmptyKey()); }
getTombstoneKeyEdgeKeyInfo124     static Edge getTombstoneKey() { return Edge(BaseInfo::getTombstoneKey()); }
getHashValueEdgeKeyInfo125     static unsigned getHashValue(const Edge &edge) {
126       return BaseInfo::getHashValue(edge.targetAndKind);
127     }
isEqualEdgeKeyInfo128     static bool isEqual(const Edge &lhs, const Edge &rhs) { return lhs == rhs; }
129   };
130 
CallGraphNode(Region * callableRegion)131   CallGraphNode(Region *callableRegion) : callableRegion(callableRegion) {}
132 
133   /// Add an edge to 'node' with the given kind.
134   void addEdge(CallGraphNode *node, Edge::Kind kind);
135 
136   /// The callable region defines the boundary of the call graph node. This is
137   /// the region referenced by 'call' operations. This is at a per-region
138   /// boundary as operations may define multiple callable regions.
139   Region *callableRegion;
140 
141   /// A set of out-going edges from this node to other nodes in the graph.
142   llvm::SetVector<Edge, SmallVector<Edge, 4>,
143                   llvm::SmallDenseSet<Edge, 4, EdgeKeyInfo>>
144       edges;
145 
146   // Provide access to private methods.
147   friend class CallGraph;
148 };
149 
150 //===----------------------------------------------------------------------===//
151 // CallGraph
152 //===----------------------------------------------------------------------===//
153 
154 class CallGraph {
155   using NodeMapT = llvm::MapVector<Region *, std::unique_ptr<CallGraphNode>>;
156 
157   /// This class represents an iterator over the internal call graph nodes. This
158   /// class unwraps the map iterator to access the raw node.
159   class NodeIterator final
160       : public llvm::mapped_iterator<
161             NodeMapT::const_iterator,
162             CallGraphNode *(*)(const NodeMapT::value_type &)> {
unwrap(const NodeMapT::value_type & value)163     static CallGraphNode *unwrap(const NodeMapT::value_type &value) {
164       return value.second.get();
165     }
166 
167   public:
168     /// Initializes the result type iterator to the specified result iterator.
NodeIterator(NodeMapT::const_iterator it)169     NodeIterator(NodeMapT::const_iterator it)
170         : llvm::mapped_iterator<
171               NodeMapT::const_iterator,
172               CallGraphNode *(*)(const NodeMapT::value_type &)>(it, &unwrap) {}
173   };
174 
175 public:
176   CallGraph(Operation *op);
177 
178   /// Get or add a call graph node for the given region. `parentNode`
179   /// corresponds to the direct node in the callgraph that contains the parent
180   /// operation of `region`, or nullptr if there is no parent node.
181   CallGraphNode *getOrAddNode(Region *region, CallGraphNode *parentNode);
182 
183   /// Lookup a call graph node for the given region, or nullptr if none is
184   /// registered.
185   CallGraphNode *lookupNode(Region *region) const;
186 
187   /// Return the callgraph node representing the indirect-external callee.
getExternalNode()188   CallGraphNode *getExternalNode() const {
189     return const_cast<CallGraphNode *>(&externalNode);
190   }
191 
192   /// Resolve the callable for given callee to a node in the callgraph, or the
193   /// external node if a valid node was not resolved. The provided symbol table
194   /// is used when resolving calls that reference callables via a symbol
195   /// reference.
196   CallGraphNode *resolveCallable(CallOpInterface call,
197                                  SymbolTableCollection &symbolTable) const;
198 
199   /// Erase the given node from the callgraph.
200   void eraseNode(CallGraphNode *node);
201 
202   /// An iterator over the nodes of the graph.
203   using iterator = NodeIterator;
begin()204   iterator begin() const { return nodes.begin(); }
end()205   iterator end() const { return nodes.end(); }
206 
207   /// Dump the graph in a human readable format.
208   void dump() const;
209   void print(raw_ostream &os) const;
210 
211 private:
212   /// The set of nodes within the callgraph.
213   NodeMapT nodes;
214 
215   /// A special node used to indicate an external edges.
216   CallGraphNode externalNode;
217 };
218 
219 } // end namespace mlir
220 
221 namespace llvm {
222 // Provide graph traits for traversing call graphs using standard graph
223 // traversals.
224 template <> struct GraphTraits<const mlir::CallGraphNode *> {
225   using NodeRef = mlir::CallGraphNode *;
226   static NodeRef getEntryNode(NodeRef node) { return node; }
227 
228   static NodeRef unwrap(const mlir::CallGraphNode::Edge &edge) {
229     return edge.getTarget();
230   }
231 
232   // ChildIteratorType/begin/end - Allow iteration over all nodes in the graph.
233   using ChildIteratorType =
234       mapped_iterator<mlir::CallGraphNode::iterator, decltype(&unwrap)>;
235   static ChildIteratorType child_begin(NodeRef node) {
236     return {node->begin(), &unwrap};
237   }
238   static ChildIteratorType child_end(NodeRef node) {
239     return {node->end(), &unwrap};
240   }
241 };
242 
243 template <>
244 struct GraphTraits<const mlir::CallGraph *>
245     : public GraphTraits<const mlir::CallGraphNode *> {
246   /// The entry node into the graph is the external node.
247   static NodeRef getEntryNode(const mlir::CallGraph *cg) {
248     return cg->getExternalNode();
249   }
250 
251   // nodes_iterator/begin/end - Allow iteration over all nodes in the graph
252   using nodes_iterator = mlir::CallGraph::iterator;
253   static nodes_iterator nodes_begin(mlir::CallGraph *cg) { return cg->begin(); }
254   static nodes_iterator nodes_end(mlir::CallGraph *cg) { return cg->end(); }
255 };
256 } // end namespace llvm
257 
258 #endif // MLIR_ANALYSIS_CALLGRAPH_H
259