• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Call graph for an HLO module.
17 
18 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CALL_GRAPH_H_
19 #define TENSORFLOW_COMPILER_XLA_SERVICE_CALL_GRAPH_H_
20 
21 #include <ostream>
22 
23 #include "absl/container/flat_hash_map.h"
24 #include "absl/container/flat_hash_set.h"
25 #include "tensorflow/compiler/xla/service/hlo_computation.h"
26 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
27 #include "tensorflow/compiler/xla/service/hlo_module.h"
28 
29 namespace xla {
30 
31 // The context in which a computation is called by another computation.
32 enum class CallContext {
33   // In a parallel context the computation is applied to each element of the
34   // array argument(s). kMap and kReduce instructions call computations in
35   // parallel context.
36   kParallel,
37 
38   // In a sequential context the computation is applied to the entire argument
39   // shape(s). kCall and kWhile (body and condition) call computations in
40   // sequential context.
41   kSequential,
42 
43   // A computation is called from both a parallel and sequential context.
44   kBoth,
45 
46   // During call graph construction kNone is used to indicate that the context
47   // has not been determined. This is the top value for the context
48   // lattice. After construction, no call sites or call graph nodes should have
49   // this value.
50   kNone
51 };
52 
53 string CallContextToString(CallContext context);
54 std::ostream& operator<<(std::ostream& out, const CallContext& context);
55 
56 CallContext GetInstructionCallContext(HloOpcode opcode);
57 
58 // Represents an HLO instruction which calls one or more computations.
59 class CallSite {
60  public:
CallSite(HloInstruction * instruction,const std::vector<HloComputation * > & called_computations,CallContext context)61   CallSite(HloInstruction* instruction,
62            const std::vector<HloComputation*>& called_computations,
63            CallContext context)
64       : instruction_(CHECK_NOTNULL(instruction)),
65         called_computations_(called_computations),
66         context_(context) {}
67 
68   // Returns the instruction associated with this call site.
instruction()69   HloInstruction* instruction() const { return instruction_; }
70 
71   // Returns the computations called at this call site.
called_computations()72   const std::vector<HloComputation*>& called_computations() const {
73     return called_computations_;
74   }
75 
76   // Returns the context in which computations are called at this call site.
context()77   CallContext context() const { return context_; }
78 
79   string ToString() const;
80 
81  private:
82   // The calling instruction.
83   HloInstruction* instruction_;
84 
85   // The computations called by this callsite.
86   const std::vector<HloComputation*> called_computations_;
87 
88   // The context in which the computations are called.
89   const CallContext context_;
90 };
91 
92 // A node in the call graph representing an HLO computation.
93 class CallGraphNode {
94  public:
95   CallGraphNode(HloComputation* computation);
96 
97   // Returns the computation represented by this call graph node.
computation()98   HloComputation* computation() const { return computation_; }
99 
100   // Returns the call sites in this computation. These are the instructions in
101   // this computation which call other computations.
callsites()102   const std::vector<CallSite>& callsites() const { return callsites_; }
103 
104   // Returns the callsite associated with the given instruction. If this
105   // instruction calls no computations nullptr is returned.
106   // Prerequisite: instruction is in the computation associated with this call
107   // graph node.
108   const CallSite* GetCallSite(const HloInstruction* instruction) const;
109 
110   // Returns the computations called by this computation.
callees()111   const std::vector<HloComputation*>& callees() const { return callees_; }
112 
113   // Returns the call sites in other computations which call this computation.
caller_callsites()114   const std::vector<CallSite>& caller_callsites() const {
115     return caller_callsites_;
116   }
117 
118   // Returns the computations which call this computation.
callers()119   const std::vector<HloComputation*>& callers() const { return callers_; }
120 
121   // Returns the context in which this computation is called.
context()122   CallContext context() const { return context_; }
123 
124   // Returns the depth of this node in the call graph. The depth is defined as
125   // the length of the longest call chain from a computation with no callers
126   // (usually the entry computation node) to this node.
depth()127   int depth() const { return depth_; }
128 
129   string ToString() const;
130 
131  private:
132   // Only CallGraph can modify CallGraphNode.
133   friend class CallGraph;
134 
135   // Sets the context in which this computation is called.
set_context(CallContext value)136   void set_context(CallContext value) { context_ = value; }
137 
138   // Sets the depth of this node in the graph.
set_depth(int value)139   void set_depth(int value) { depth_ = value; }
140 
141   // Adds a callsite which calls this computation. Updates callers to include
142   // the calling computation.
143   void AddCallerCallSite(const CallSite& caller_callsite);
144 
145   // If instruction calls any computations adds a call site for this instruction
146   // to the call graph node. If the instruction calls no computations then no
147   // call site is added.
148   void AddCallSiteForInstruction(HloInstruction* instruction);
149 
150   // Computation represented by this call graph node.
151   HloComputation* computation_;
152 
153   // The computations called by this computation. The vector is used for a
154   // stable ordering and the set enables fast membership testing.
155   std::vector<HloComputation*> callees_;
156   absl::flat_hash_set<HloComputation*> callee_set_;
157 
158   // The computations which call this computation. The vector is used for a
159   // stable ordering and the set enables fast membership testing.
160   std::vector<HloComputation*> callers_;
161   absl::flat_hash_set<HloComputation*> caller_set_;
162 
163   // The call sites in this computation
164   std::vector<CallSite> callsites_;
165 
166   // The map from instruction to index in callsites_ for looking up the callsite
167   // (if any) associated with a particular instruction in this computation.
168   absl::flat_hash_map<const HloInstruction*, int64> callsite_instructions_;
169 
170   // The call sites in other computations which call this computation.
171   std::vector<CallSite> caller_callsites_;
172 
173   // The context in which this computation is called.
174   CallContext context_ = CallContext::kNone;
175 
176   // The depth of this node in the call graph.
177   int depth_ = 0;
178 };
179 
180 // The call graph for an HLO module. The graph includes a node for each
181 // computation in the module.
182 class CallGraph {
183  public:
184   using VisitorFunction = std::function<Status(const CallGraphNode&)>;
185 
186   // Builds and returns a call graph for the given HLO module.
187   static std::unique_ptr<CallGraph> Build(const HloModule* module);
188 
189   // Returns the node associated with the given computation.
190   const CallGraphNode& GetNode(const HloComputation* computation) const;
191   CallGraphNode& GetNode(const HloComputation* computation);
192 
193   // Returns the vector of all nodes in the call graph.
nodes()194   const std::vector<CallGraphNode>& nodes() const { return nodes_; }
195 
196   // Calls the given function on each node in the call graph. Nodes are visited
197   // in post order (callees before callers). If visit_unreachable_nodes is true
198   // then all nodes in the call graph are visited. Otherwise only those nodes
199   // reachable from the entry computation are visited.
200   Status VisitNodes(const VisitorFunction& visitor_func,
201                     bool visit_unreachable_nodes = true) const;
202 
203   // Returns true if 'a' dominates 'b' in the call graph. Computation 'a'
204   // dominates computation 'b' iff all callgraph paths in the caller-to-callee
205   // direction from a root computation to 'b' pass through computation
206   // 'a'. Trivially, a computation dominates itself.
207   bool Dominates(const HloComputation* a, const HloComputation* b) const;
208 
209   // Returns whether 'instruction' is contained in 'computation' either directly
210   // ('instruction->parent' is 'computation') or indirectly ('computation'
211   // dominates 'instruction->parent' in the call graph).
InstructionIsNestedIn(const HloInstruction * instruction,const HloComputation * computation)212   bool InstructionIsNestedIn(const HloInstruction* instruction,
213                              const HloComputation* computation) const {
214     return Dominates(computation, instruction->parent());
215   }
216 
217   // Returns the nearest call graph ancestors of instructions 'a' and 'b' for
218   // which the ancestors are in the same computation. An instruction is an call
219   // graph ancestor of 'a' if the instruction calls the computation containing
220   // 'a' either directly or transitively. Degeneratively an instruction is an
221   // ancestor of itself. nullptr is returned if there is no common ancestor or
222   // if the caller chain of 'a' or 'b' diverges (has multiple callers) before
223   // the nearest common ancestor.
224   //
225   // Example:
226   //
227   // Entry computation:
228   //   %x = Call(A, {Constant(42.0)})
229   //   %y = Call(B, {%x})
230   //
231   // Computation A:
232   //   %a = Negate(Param())
233   //
234   // Computation B:
235   //   %b = Exp(Param());
236   //
237   // If called with %a and %b, this function would return (%x, %y). %x is an
238   // ancestor of %a, and %y is an ancestor of %b, and %x and %y are in the same
239   // computation.
240   std::pair<HloInstruction*, HloInstruction*> NearestAncestorsInSameComputation(
241       HloInstruction* a, HloInstruction* b) const;
242 
243   // Returns whether the call graph is flattened. A call graph is flattened if
244   // every computation called in a sequential context (eg, kWhile or kCall) has
245   // zero or one callsite, and no computation is called from both a parallel and
246   // sequential context. The call graph of a module can be flattened with
247   // FlattenCallGraph.
248   bool IsFlattened() const;
249 
250   // Returns a vector of instructions calling the passed computation.
251   // (Often a vector of size 1.)
252   std::vector<HloInstruction*> GetComputationCallers(HloComputation* c);
253 
254   string ToString() const;
255 
256  private:
257   CallGraph(const HloModule* module);
258 
259   // Not copyable.
260   CallGraph(const CallGraph&) = delete;
261   CallGraph& operator=(const CallGraph&) = delete;
262 
263   // Sets the call contexts for every node in the graph.
264   void SetCallContexts();
265 
266   // Sets the call node depths for every node in the graph.
267   void SetNodeDepths();
268 
269   // Helper method for VisitNodes(). Traverses the call graph from 'node' in DFS
270   // post order (callee before caller) calling visitor_func on each node. Adds
271   // nodes to 'visited' as each node is visited. Skips nodes already in
272   // 'visited'.
273   Status VisitNodesInternal(
274       const VisitorFunction& visitor_func, const CallGraphNode& node,
275       absl::flat_hash_set<const CallGraphNode*>* visited) const;
276 
277   // Recursive helper for computing whether 'a' dominates 'b' in the call
278   // graph. 'b_ancestor' is the currently visited node (which starts at 'b'),
279   // and 'visited' is the set of computations which have been visited.
280   bool DominatesHelper(
281       const HloComputation* a, const HloComputation* b,
282       absl::flat_hash_set<const HloComputation*>* visited) const;
283 
284   // The HLO module represented by this call graph.
285   const HloModule* module_ = nullptr;
286 
287   // Vector of all nodes in the call graph.
288   std::vector<CallGraphNode> nodes_;
289 
290   // Map from HLO computation to the index of the corresponding call graph node
291   // in nodes_.
292   absl::flat_hash_map<const HloComputation*, int64> node_indices_;
293 };
294 
295 }  // namespace xla
296 
297 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_CALL_GRAPH_H_
298