• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/call_graph.h"
17 
18 #include <memory>
19 #include <queue>
20 
21 #include "absl/container/flat_hash_set.h"
22 #include "absl/strings/str_cat.h"
23 #include "absl/strings/str_format.h"
24 #include "absl/strings/str_join.h"
25 #include "tensorflow/compiler/xla/map_util.h"
26 #include "tensorflow/compiler/xla/status_macros.h"
27 #include "tensorflow/compiler/xla/util.h"
28 #include "tensorflow/core/lib/core/errors.h"
29 #include "tensorflow/core/lib/core/status.h"
30 
31 namespace xla {
32 
33 using absl::StrAppendFormat;
34 using absl::StrCat;
35 
CallContextToString(CallContext context)36 std::string CallContextToString(CallContext context) {
37   switch (context) {
38     case CallContext::kNone:
39       return "kNone";
40     case CallContext::kControlFlow:
41       return "kControlFlow";
42     case CallContext::kEmbedded:
43       return "kEmbedded";
44     case CallContext::kBoth:
45       return "kBoth";
46   }
47 }
48 
operator <<(std::ostream & out,const CallContext & context)49 std::ostream& operator<<(std::ostream& out, const CallContext& context) {
50   out << CallContextToString(context);
51   return out;
52 }
53 
GetInstructionCallContext(HloOpcode opcode)54 CallContext GetInstructionCallContext(HloOpcode opcode) {
55   switch (opcode) {
56     case HloOpcode::kCall:
57     case HloOpcode::kConditional:
58     case HloOpcode::kWhile:
59     case HloOpcode::kAsyncStart:
60     case HloOpcode::kAsyncUpdate:
61     case HloOpcode::kAsyncDone:
62       return CallContext::kControlFlow;
63     case HloOpcode::kAllReduce:
64     case HloOpcode::kReduceScatter:
65     case HloOpcode::kAllReduceStart:
66     case HloOpcode::kMap:
67     case HloOpcode::kReduce:
68     case HloOpcode::kReduceWindow:
69     case HloOpcode::kScatter:
70     case HloOpcode::kSelectAndScatter:
71     case HloOpcode::kSort:
72     case HloOpcode::kFusion:
73     case HloOpcode::kCustomCall:
74       return CallContext::kEmbedded;
75     default:
76       return CallContext::kNone;
77   }
78 }
79 
ToString() const80 std::string CallSite::ToString() const {
81   return StrCat(
82       instruction()->name(), " calls in context ",
83       CallContextToString(context()), ": ",
84       absl::StrJoin(called_computations(), ", ",
85                     [](std::string* out, const HloComputation* computation) {
86                       out->append(computation->name());
87                     }));
88 }
89 
CallGraphNode(HloComputation * computation)90 CallGraphNode::CallGraphNode(HloComputation* computation)
91     : computation_(computation) {}
92 
GetCallSite(const HloInstruction * instruction) const93 const CallSite* CallGraphNode::GetCallSite(
94     const HloInstruction* instruction) const {
95   auto it = callsite_instructions_.find(instruction);
96   if (it == callsite_instructions_.end()) {
97     return nullptr;
98   }
99   return &callsites_[it->second];
100 }
101 
ToString() const102 std::string CallGraphNode::ToString() const { return computation_->name(); }
103 
AddCallerCallSite(const CallSite & caller_callsite)104 void CallGraphNode::AddCallerCallSite(const CallSite& caller_callsite) {
105   caller_callsites_.push_back(caller_callsite);
106   HloComputation* caller = caller_callsite.instruction()->parent();
107   if (!ContainsKey(caller_set_, caller)) {
108     callers_.push_back(caller);
109     caller_set_.insert(caller);
110   }
111 }
112 
AddCallSiteForInstruction(HloInstruction * instruction)113 void CallGraphNode::AddCallSiteForInstruction(HloInstruction* instruction) {
114   CHECK_EQ(instruction->parent(), computation());
115   const CallContext context = GetInstructionCallContext(instruction->opcode());
116   if (!instruction->called_computations().empty()) {
117     CHECK(context == CallContext::kControlFlow ||
118           context == CallContext::kEmbedded);
119     callsite_instructions_.insert({instruction, callsites_.size()});
120     callsites_.push_back(
121         CallSite(instruction, instruction->called_computations(), context));
122     // Update callee computations to include any new computations called by this
123     // instruction.
124     for (auto* callee : callsites_.back().called_computations()) {
125       if (!ContainsKey(callee_set_, callee)) {
126         callees_.push_back(callee);
127         callee_set_.insert(callee);
128       }
129     }
130   }
131 }
132 
CallGraph(const HloModule * module)133 CallGraph::CallGraph(const HloModule* module) : module_(module) {}
134 
GetNode(const HloComputation * computation) const135 const CallGraphNode& CallGraph::GetNode(
136     const HloComputation* computation) const {
137   auto it = node_indices_.find(computation);
138   CHECK(it != node_indices_.end());
139   return nodes_[it->second];
140 }
141 
GetNode(const HloComputation * computation)142 CallGraphNode& CallGraph::GetNode(const HloComputation* computation) {
143   auto it = node_indices_.find(computation);
144   CHECK(it != node_indices_.end());
145   return nodes_[it->second];
146 }
147 
DominatesHelper(const HloComputation * a,const HloComputation * b,absl::flat_hash_set<const HloComputation * > * visited) const148 bool CallGraph::DominatesHelper(
149     const HloComputation* a, const HloComputation* b,
150     absl::flat_hash_set<const HloComputation*>* visited) const {
151   if (a == b || ContainsKey(*visited, b)) {
152     // The call graph is guaranteed to be acyclic so any previously visited node
153     // we encounter was already determined to be dominated.
154     return true;
155   }
156 
157   const CallGraphNode& b_node = GetNode(b);
158   if (b_node.callers().empty()) {
159     // We reached a root node without hitting 'a'. 'a' does not dominate 'b'.
160     return false;
161   }
162 
163   // Walk up the callers of 'b' until we hit 'a' or a root node (no callers).
164   visited->insert(b);
165   for (const HloComputation* b_caller : b_node.callers()) {
166     if (!DominatesHelper(a, b_caller, visited)) {
167       return false;
168     }
169   }
170   return true;
171 }
172 
Dominates(const HloComputation * a,const HloComputation * b) const173 bool CallGraph::Dominates(const HloComputation* a,
174                           const HloComputation* b) const {
175   absl::flat_hash_set<const HloComputation*> visited;
176   return DominatesHelper(a, b, &visited);
177 }
178 
179 namespace {
180 
181 // Returns the call context of a computation which is called from contexts 'a'
182 // and 'b'.
UnionContexts(CallContext a,CallContext b)183 CallContext UnionContexts(CallContext a, CallContext b) {
184   if (a == CallContext::kNone) {
185     return b;
186   } else if (b == CallContext::kNone) {
187     return a;
188   } else if (a == b) {
189     return a;
190   } else {
191     // Contexts are different and neither is kNone, ie one is kSequential and
192     // the other is kParallel.
193     return CallContext::kBoth;
194   }
195 }
196 
197 }  // namespace
198 
SetCallContexts()199 void CallGraph::SetCallContexts() {
200   std::queue<CallGraphNode*> worklist;
201 
202   // Initialize worklist with all roots of the call graph (computations without
203   // callers).
204   for (const HloComputation* computation : module_->computations()) {
205     CallGraphNode& node = GetNode(computation);
206     if (node.callers().empty()) {
207       node.set_context(CallContext::kControlFlow);
208       worklist.push(&node);
209     }
210   }
211 
212   while (!worklist.empty()) {
213     CallGraphNode* node = worklist.front();
214     worklist.pop();
215 
216     for (const CallSite& callsite : node->callsites()) {
217       for (const HloComputation* callee : callsite.called_computations()) {
218         CallGraphNode& callee_node = GetNode(callee);
219 
220         // Update context of callee computation based on the callsite and its
221         // current context.
222         CallContext context_to_add;
223         if (callsite.context() == CallContext::kEmbedded) {
224           context_to_add = CallContext::kEmbedded;
225         } else {
226           CHECK_EQ(callsite.context(), CallContext::kControlFlow);
227           context_to_add = node->context();
228         }
229         CallContext new_context =
230             UnionContexts(context_to_add, callee_node.context());
231 
232         if (new_context != callee_node.context()) {
233           // Context of computation has been changed so add node to worklist.
234           callee_node.set_context(new_context);
235           worklist.push(&callee_node);
236         }
237       }
238     }
239   }
240 
241   // No node should have a kNone calling context.
242   for (const HloComputation* computation : module_->computations()) {
243     CHECK_NE(GetNode(computation).context(), CallContext::kNone);
244   }
245 }
246 
SetNodeDepths()247 void CallGraph::SetNodeDepths() {
248   std::queue<CallGraphNode*> worklist;
249 
250   // Initialize node depths to -1.
251   for (CallGraphNode& node : nodes_) {
252     node.set_depth(-1);
253   }
254 
255   // Initialize worklist with all roots of the call graph (computations without
256   // callers).
257   for (const HloComputation* computation : module_->computations()) {
258     CallGraphNode& node = GetNode(computation);
259     if (node.callers().empty()) {
260       node.set_depth(0);
261       worklist.push(&node);
262     }
263   }
264 
265   while (!worklist.empty()) {
266     CallGraphNode* node = worklist.front();
267     worklist.pop();
268     for (const HloComputation* callee : node->callees()) {
269       CallGraphNode& callee_node = GetNode(callee);
270       if (callee_node.depth() < node->depth() + 1) {
271         callee_node.set_depth(node->depth() + 1);
272         worklist.push(&callee_node);
273       }
274     }
275   }
276 
277   for (CallGraphNode& node : nodes_) {
278     CHECK_NE(node.depth(), -1);
279   }
280 }
281 
282 /* static */
Build(const HloModule * module)283 std::unique_ptr<CallGraph> CallGraph::Build(const HloModule* module) {
284   // Constructor for CallGraph is private so std::make_unique can't be used.
285   auto call_graph = absl::WrapUnique<CallGraph>(new CallGraph(module));
286 
287   VLOG(3) << "Building call graph for:";
288   XLA_VLOG_LINES(3, module->ToString());
289 
290   // Construct nodes of the call graph and populate the callsites.
291   for (HloComputation* computation : module->computations()) {
292     auto it_added = call_graph->node_indices_.insert(
293         {computation, call_graph->nodes_.size()});
294     // All computations should be unique, so the computation should not already
295     // exist in the map.
296     CHECK(it_added.second);
297     call_graph->nodes_.emplace_back(computation);
298 
299     // Add all callsites in this computation.
300     for (HloInstruction* instruction : computation->instructions()) {
301       call_graph->nodes_.back().AddCallSiteForInstruction(instruction);
302     }
303   }
304 
305   // Add caller callsites to each node.
306   for (const HloComputation* computation : module->computations()) {
307     for (const CallSite& callsite :
308          call_graph->GetNode(computation).callsites()) {
309       for (auto* callee : callsite.called_computations()) {
310         // Add caller callsites.
311         call_graph->GetNode(callee).AddCallerCallSite(callsite);
312       }
313     }
314   }
315 
316   call_graph->SetCallContexts();
317   call_graph->SetNodeDepths();
318 
319   XLA_VLOG_LINES(2, call_graph->ToString());
320 
321   return call_graph;
322 }
323 
VisitNodesInternal(const VisitorFunction & visitor_func,const CallGraphNode & node,absl::flat_hash_set<const CallGraphNode * > * visited) const324 Status CallGraph::VisitNodesInternal(
325     const VisitorFunction& visitor_func, const CallGraphNode& node,
326     absl::flat_hash_set<const CallGraphNode*>* visited) const {
327   auto pair = visited->insert(&node);
328   if (!pair.second) {
329     // Node was not inserted. Node has already been visited.
330     return OkStatus();
331   }
332 
333   for (const HloComputation* computation : node.callees()) {
334     TF_RETURN_IF_ERROR(
335         VisitNodesInternal(visitor_func, GetNode(computation), visited));
336   }
337 
338   return visitor_func(node);
339 }
340 
VisitNodes(const VisitorFunction & visitor_func,bool visit_unreachable_nodes) const341 Status CallGraph::VisitNodes(const VisitorFunction& visitor_func,
342                              bool visit_unreachable_nodes) const {
343   absl::flat_hash_set<const CallGraphNode*> visited;
344   if (visit_unreachable_nodes) {
345     // Traverse from all roots in the call graph.
346     for (const CallGraphNode& node : nodes()) {
347       if (node.callers().empty()) {
348         TF_RETURN_IF_ERROR(VisitNodesInternal(visitor_func, node, &visited));
349       }
350     }
351   } else {
352     // Traverse only from the entry computation.
353     TF_RETURN_IF_ERROR(VisitNodesInternal(
354         visitor_func, GetNode(module_->entry_computation()), &visited));
355   }
356 
357   return OkStatus();
358 }
359 
IsFlattened() const360 bool CallGraph::IsFlattened() const {
361   for (const CallGraphNode& node : nodes_) {
362     if (node.context() == CallContext::kBoth) {
363       return false;
364     }
365     if (node.context() == CallContext::kControlFlow &&
366         !node.computation()->IsAsyncComputation() &&
367         node.caller_callsites().size() > 1) {
368       return false;
369     }
370   }
371   return true;
372 }
373 
GetComputationCallers(HloComputation * c) const374 std::vector<HloInstruction*> CallGraph::GetComputationCallers(
375     HloComputation* c) const {
376   std::vector<HloInstruction*> callers;
377   for (const auto& callsite : GetNode(c).caller_callsites()) {
378     callers.push_back(callsite.instruction());
379   }
380   return callers;
381 }
382 
383 std::pair<HloInstruction*, HloInstruction*>
NearestAncestorsInSameComputation(HloInstruction * a,HloInstruction * b) const384 CallGraph::NearestAncestorsInSameComputation(HloInstruction* a,
385                                              HloInstruction* b) const {
386   // Lambda which returns the next instruction in the callee->caller chain in
387   // the call graph. This is the unique instruction which calls the computation
388   // containing 'instruction'. If more than one instruction calls the
389   // computation containing 'instruction' or no instructions call the
390   // computation then nullptr is returned.
391   auto next_caller = [this](HloInstruction* instruction) -> HloInstruction* {
392     const CallGraphNode& node = GetNode(instruction->parent());
393     if (node.caller_callsites().size() != 1) {
394       return nullptr;
395     }
396     return node.caller_callsites()[0].instruction();
397   };
398 
399   // Iterate through the callee->caller chains and find the earliest common
400   // element.
401   HloInstruction* a_ancestor = a;
402   HloInstruction* b_ancestor = b;
403   int a_depth = GetNode(a->parent()).depth();
404   int b_depth = GetNode(b->parent()).depth();
405 
406   // Advance a_ancestor (b_ancestor) up the call chain until the call depth of
407   // a_ancestor or b_ancestor are the same. Necessarily each call to next_caller
408   // reduces the depth by exactly one.
409   if (a_depth > b_depth) {
410     for (int i = 0; i < a_depth - b_depth; ++i) {
411       a_ancestor = next_caller(a_ancestor);
412       if (a_ancestor == nullptr) {
413         return {nullptr, nullptr};
414       }
415     }
416   } else if (b_depth > a_depth) {
417     for (int i = 0; i < b_depth - a_depth; ++i) {
418       b_ancestor = next_caller(b_ancestor);
419       if (b_ancestor == nullptr) {
420         return {nullptr, nullptr};
421       }
422     }
423   }
424 
425   while ((a_ancestor != nullptr) && (b_ancestor != nullptr)) {
426     if (a_ancestor->parent() == b_ancestor->parent()) {
427       return {a_ancestor, b_ancestor};
428     }
429 
430     a_ancestor = next_caller(a_ancestor);
431     b_ancestor = next_caller(b_ancestor);
432   }
433   return {nullptr, nullptr};
434 }
435 
ToString() const436 std::string CallGraph::ToString() const {
437   std::string out;
438   StrAppendFormat(&out, "Call graph for module %s:\n", module_->name());
439   for (const CallGraphNode& node : nodes()) {
440     StrAppendFormat(&out, "Computation %s:\n", node.computation()->name());
441     StrAppendFormat(&out, "  calls:\n");
442     for (const HloComputation* callee : node.callees()) {
443       StrAppendFormat(&out, "    %s\n", callee->name());
444     }
445     StrAppendFormat(&out, "  called by:\n");
446     for (const HloComputation* caller : node.callers()) {
447       StrAppendFormat(&out, "    %s\n", caller->name());
448     }
449     StrAppendFormat(&out, "  callsites:\n");
450     for (const CallSite& callsite : node.callsites()) {
451       StrAppendFormat(&out, "    %s\n", callsite.ToString());
452     }
453   }
454   return out;
455 }
456 
457 }  // namespace xla
458