• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Call graph for an HLO module.
17 
18 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CALL_GRAPH_H_
19 #define TENSORFLOW_COMPILER_XLA_SERVICE_CALL_GRAPH_H_
20 
21 #include <ostream>
22 
23 #include "absl/container/flat_hash_map.h"
24 #include "absl/container/flat_hash_set.h"
25 #include "tensorflow/compiler/xla/service/hlo_computation.h"
26 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
27 #include "tensorflow/compiler/xla/service/hlo_module.h"
28 
29 namespace xla {
30 
31 // The context in which a computation is called by another computation.
32 enum class CallContext {
33   // In an embedded call context, the body of the function cannot allocate
34   // buffers.
35   kEmbedded,
36 
37   // A control flow call context can allocate buffers.
38   kControlFlow,
39 
40   // A computation is called from both an embedded and control flow context.
41   kBoth,
42 
43   // During call graph construction kNone is used to indicate that the context
44   // has not been determined. This is the top value for the context
45   // lattice. After construction, no call sites or call graph nodes should have
46   // this value.
47   kNone
48 };
49 
50 std::string CallContextToString(CallContext context);
51 std::ostream& operator<<(std::ostream& out, const CallContext& context);
52 
53 CallContext GetInstructionCallContext(HloOpcode opcode);
54 
55 // Represents an HLO instruction which calls one or more computations.
56 class CallSite {
57  public:
CallSite(HloInstruction * instruction,absl::Span<HloComputation * const> called_computations,CallContext context)58   CallSite(HloInstruction* instruction,
59            absl::Span<HloComputation* const> called_computations,
60            CallContext context)
61       : instruction_(CHECK_NOTNULL(instruction)),
62         called_computations_(called_computations.begin(),
63                              called_computations.end()),
64         context_(context) {}
65 
66   // Returns the instruction associated with this call site.
instruction()67   HloInstruction* instruction() const { return instruction_; }
68 
69   // Returns the computations called at this call site.
called_computations()70   absl::Span<HloComputation* const> called_computations() const {
71     return called_computations_;
72   }
73 
74   // Returns the context in which computations are called at this call site.
context()75   CallContext context() const { return context_; }
76 
77   std::string ToString() const;
78 
79  private:
80   // The calling instruction.
81   HloInstruction* instruction_;
82 
83   // The computations called by this callsite.
84   const absl::InlinedVector<HloComputation*, 2> called_computations_;
85 
86   // The context in which the computations are called.
87   const CallContext context_;
88 };
89 
90 // A node in the call graph representing an HLO computation.
91 class CallGraphNode {
92  public:
93   CallGraphNode(HloComputation* computation);
94 
95   // Returns the computation represented by this call graph node.
computation()96   HloComputation* computation() const { return computation_; }
97 
98   // Returns the call sites in this computation. These are the instructions in
99   // this computation which call other computations.
callsites()100   absl::Span<const CallSite> callsites() const { return callsites_; }
101 
102   // Returns the callsite associated with the given instruction. If this
103   // instruction calls no computations nullptr is returned.
104   // Prerequisite: instruction is in the computation associated with this call
105   // graph node.
106   const CallSite* GetCallSite(const HloInstruction* instruction) const;
107 
108   // Returns the computations called by this computation.
callees()109   absl::Span<HloComputation* const> callees() const { return callees_; }
110 
111   // Returns the call sites in other computations which call this computation.
caller_callsites()112   absl::Span<const CallSite> caller_callsites() const {
113     return caller_callsites_;
114   }
115 
116   // Returns the computations which call this computation.
callers()117   absl::Span<HloComputation* const> callers() const { return callers_; }
118 
119   // Returns the context in which this computation is called.
context()120   CallContext context() const { return context_; }
121 
122   // Returns the depth of this node in the call graph. The depth is defined as
123   // the length of the longest call chain from a computation with no callers
124   // (usually the entry computation node) to this node.
depth()125   int depth() const { return depth_; }
126 
127   std::string ToString() const;
128 
129   CallGraphNode(const CallGraphNode&) = delete;
130   CallGraphNode& operator=(const CallGraphNode&) = delete;
131   CallGraphNode(CallGraphNode&&) = default;
132   CallGraphNode& operator=(CallGraphNode&&) = default;
133 
134  private:
135   // Only CallGraph can modify CallGraphNode.
136   friend class CallGraph;
137 
138   // Sets the context in which this computation is called.
set_context(CallContext value)139   void set_context(CallContext value) { context_ = value; }
140 
141   // Sets the depth of this node in the graph.
set_depth(int value)142   void set_depth(int value) { depth_ = value; }
143 
144   // Adds a callsite which calls this computation. Updates callers to include
145   // the calling computation.
146   void AddCallerCallSite(const CallSite& caller_callsite);
147 
148   // If instruction calls any computations adds a call site for this instruction
149   // to the call graph node. If the instruction calls no computations then no
150   // call site is added.
151   void AddCallSiteForInstruction(HloInstruction* instruction);
152 
153   // Computation represented by this call graph node.
154   HloComputation* computation_;
155 
156   // The computations called by this computation. The vector is used for a
157   // stable ordering and the set enables fast membership testing.
158   absl::InlinedVector<HloComputation*, 1> callees_;
159   absl::flat_hash_set<HloComputation*> callee_set_;
160 
161   // The computations which call this computation. The vector is used for a
162   // stable ordering and the set enables fast membership testing.
163   absl::InlinedVector<HloComputation*, 1> callers_;
164   absl::flat_hash_set<HloComputation*> caller_set_;
165 
166   // The call sites in this computation
167   absl::InlinedVector<CallSite, 1> callsites_;
168 
169   // The map from instruction to index in callsites_ for looking up the callsite
170   // (if any) associated with a particular instruction in this computation.
171   absl::flat_hash_map<const HloInstruction*, int64_t> callsite_instructions_;
172 
173   // The call sites in other computations which call this computation.
174   absl::InlinedVector<CallSite, 1> caller_callsites_;
175 
176   // The context in which this computation is called.
177   CallContext context_ = CallContext::kNone;
178 
179   // The depth of this node in the call graph.
180   int depth_ = 0;
181 };
182 
183 // The call graph for an HLO module. The graph includes a node for each
184 // computation in the module.
185 class CallGraph {
186  public:
187   using VisitorFunction = std::function<Status(const CallGraphNode&)>;
188 
189   // Builds and returns a call graph for the given HLO module.
190   static std::unique_ptr<CallGraph> Build(const HloModule* module);
191 
192   // Returns the node associated with the given computation.
193   const CallGraphNode& GetNode(const HloComputation* computation) const;
194   CallGraphNode& GetNode(const HloComputation* computation);
195 
196   // Returns the vector of all nodes in the call graph.
nodes()197   const std::vector<CallGraphNode>& nodes() const { return nodes_; }
198 
199   // Calls the given function on each node in the call graph. Nodes are visited
200   // in post order (callees before callers). If visit_unreachable_nodes is true
201   // then all nodes in the call graph are visited. Otherwise only those nodes
202   // reachable from the entry computation are visited.
203   Status VisitNodes(const VisitorFunction& visitor_func,
204                     bool visit_unreachable_nodes = true) const;
205 
206   // Returns true if 'a' dominates 'b' in the call graph. Computation 'a'
207   // dominates computation 'b' iff all callgraph paths in the caller-to-callee
208   // direction from a root computation to 'b' pass through computation
209   // 'a'. Trivially, a computation dominates itself.
210   bool Dominates(const HloComputation* a, const HloComputation* b) const;
211 
212   // Returns whether 'instruction' is contained in 'computation' either directly
213   // ('instruction->parent' is 'computation') or indirectly ('computation'
214   // dominates 'instruction->parent' in the call graph).
InstructionIsNestedIn(const HloInstruction * instruction,const HloComputation * computation)215   bool InstructionIsNestedIn(const HloInstruction* instruction,
216                              const HloComputation* computation) const {
217     return Dominates(computation, instruction->parent());
218   }
219 
220   // Returns the nearest call graph ancestors of instructions 'a' and 'b' for
221   // which the ancestors are in the same computation. An instruction is an call
222   // graph ancestor of 'a' if the instruction calls the computation containing
223   // 'a' either directly or transitively. Degeneratively an instruction is an
224   // ancestor of itself. nullptr is returned if there is no common ancestor or
225   // if the caller chain of 'a' or 'b' diverges (has multiple callers) before
226   // the nearest common ancestor.
227   //
228   // Example:
229   //
230   // Entry computation:
231   //   %x = Call(A, {Constant(42.0)})
232   //   %y = Call(B, {%x})
233   //
234   // Computation A:
235   //   %a = Negate(Param())
236   //
237   // Computation B:
238   //   %b = Exp(Param());
239   //
240   // If called with %a and %b, this function would return (%x, %y). %x is an
241   // ancestor of %a, and %y is an ancestor of %b, and %x and %y are in the same
242   // computation.
243   std::pair<HloInstruction*, HloInstruction*> NearestAncestorsInSameComputation(
244       HloInstruction* a, HloInstruction* b) const;
245 
246   // Returns whether the call graph is flattened. A call graph is flattened if
247   // every computation called in a sequential context (eg, kWhile or kCall) has
248   // zero or one callsite, and no computation is called from both a parallel and
249   // sequential context. The call graph of a module can be flattened with
250   // FlattenCallGraph.
251   bool IsFlattened() const;
252 
253   // Returns a vector of instructions calling the passed computation.
254   // (Often a vector of size 1.)
255   std::vector<HloInstruction*> GetComputationCallers(HloComputation* c) const;
256 
257   std::string ToString() const;
258 
259  private:
260   CallGraph(const HloModule* module);
261 
262   // Not copyable.
263   CallGraph(const CallGraph&) = delete;
264   CallGraph& operator=(const CallGraph&) = delete;
265 
266   // Sets the call contexts for every node in the graph.
267   void SetCallContexts();
268 
269   // Sets the call node depths for every node in the graph.
270   void SetNodeDepths();
271 
272   // Helper method for VisitNodes(). Traverses the call graph from 'node' in DFS
273   // post order (callee before caller) calling visitor_func on each node. Adds
274   // nodes to 'visited' as each node is visited. Skips nodes already in
275   // 'visited'.
276   Status VisitNodesInternal(
277       const VisitorFunction& visitor_func, const CallGraphNode& node,
278       absl::flat_hash_set<const CallGraphNode*>* visited) const;
279 
280   // Recursive helper for computing whether 'a' dominates 'b' in the call
281   // graph. 'b_ancestor' is the currently visited node (which starts at 'b'),
282   // and 'visited' is the set of computations which have been visited.
283   bool DominatesHelper(
284       const HloComputation* a, const HloComputation* b,
285       absl::flat_hash_set<const HloComputation*>* visited) const;
286 
287   // The HLO module represented by this call graph.
288   const HloModule* module_ = nullptr;
289 
290   // Vector of all nodes in the call graph.
291   std::vector<CallGraphNode> nodes_;
292 
293   // Map from HLO computation to the index of the corresponding call graph node
294   // in nodes_.
295   absl::flat_hash_map<const HloComputation*, int64_t> node_indices_;
296 };
297 
298 }  // namespace xla
299 
300 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_CALL_GRAPH_H_
301