• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/call_graph.h"
17 
18 #include <queue>
19 
20 #include "absl/container/flat_hash_set.h"
21 #include "absl/memory/memory.h"
22 #include "absl/strings/str_cat.h"
23 #include "absl/strings/str_format.h"
24 #include "absl/strings/str_join.h"
25 #include "tensorflow/compiler/xla/map_util.h"
26 #include "tensorflow/compiler/xla/status_macros.h"
27 #include "tensorflow/compiler/xla/util.h"
28 #include "tensorflow/core/lib/core/errors.h"
29 #include "tensorflow/core/lib/core/status.h"
30 #include "tensorflow/core/platform/types.h"
31 
32 namespace xla {
33 
34 using absl::StrAppendFormat;
35 using absl::StrCat;
36 
CallContextToString(CallContext context)37 string CallContextToString(CallContext context) {
38   switch (context) {
39     case CallContext::kNone:
40       return "kNone";
41     case CallContext::kSequential:
42       return "kSequential";
43     case CallContext::kParallel:
44       return "kParallel";
45     case CallContext::kBoth:
46       return "kBoth";
47   }
48 }
49 
operator <<(std::ostream & out,const CallContext & context)50 std::ostream& operator<<(std::ostream& out, const CallContext& context) {
51   out << CallContextToString(context);
52   return out;
53 }
54 
GetInstructionCallContext(HloOpcode opcode)55 CallContext GetInstructionCallContext(HloOpcode opcode) {
56   switch (opcode) {
57     case HloOpcode::kCall:
58     case HloOpcode::kConditional:
59     case HloOpcode::kWhile:
60       return CallContext::kSequential;
61     case HloOpcode::kAllReduce:
62     case HloOpcode::kReduceScatter:
63     case HloOpcode::kAllReduceStart:
64     case HloOpcode::kMap:
65     case HloOpcode::kReduce:
66     case HloOpcode::kReduceWindow:
67     case HloOpcode::kScatter:
68     case HloOpcode::kSelectAndScatter:
69     case HloOpcode::kSort:
70     case HloOpcode::kFusion:
71     case HloOpcode::kCustomCall:
72       return CallContext::kParallel;
73     default:
74       return CallContext::kNone;
75   }
76 }
77 
ToString() const78 string CallSite::ToString() const {
79   return StrCat(
80       instruction()->name(), " calls in context ",
81       CallContextToString(context()), ": ",
82       absl::StrJoin(called_computations(), ", ",
83                     [](string* out, const HloComputation* computation) {
84                       out->append(computation->name());
85                     }));
86 }
87 
CallGraphNode(HloComputation * computation)88 CallGraphNode::CallGraphNode(HloComputation* computation)
89     : computation_(computation) {}
90 
GetCallSite(const HloInstruction * instruction) const91 const CallSite* CallGraphNode::GetCallSite(
92     const HloInstruction* instruction) const {
93   auto it = callsite_instructions_.find(instruction);
94   if (it == callsite_instructions_.end()) {
95     return nullptr;
96   }
97   return &callsites_[it->second];
98 }
99 
ToString() const100 std::string CallGraphNode::ToString() const { return computation_->name(); }
101 
AddCallerCallSite(const CallSite & caller_callsite)102 void CallGraphNode::AddCallerCallSite(const CallSite& caller_callsite) {
103   caller_callsites_.push_back(caller_callsite);
104   HloComputation* caller = caller_callsite.instruction()->parent();
105   if (!ContainsKey(caller_set_, caller)) {
106     callers_.push_back(caller);
107     caller_set_.insert(caller);
108   }
109 }
110 
AddCallSiteForInstruction(HloInstruction * instruction)111 void CallGraphNode::AddCallSiteForInstruction(HloInstruction* instruction) {
112   CHECK_EQ(instruction->parent(), computation());
113   const CallContext context = GetInstructionCallContext(instruction->opcode());
114   if (!instruction->called_computations().empty()) {
115     CHECK(context == CallContext::kSequential ||
116           context == CallContext::kParallel);
117     callsite_instructions_.insert({instruction, callsites_.size()});
118     callsites_.push_back(
119         CallSite(instruction, instruction->called_computations(), context));
120     // Update callee computations to include any new computations called by this
121     // instruction.
122     for (auto* callee : callsites_.back().called_computations()) {
123       if (!ContainsKey(callee_set_, callee)) {
124         callees_.push_back(callee);
125         callee_set_.insert(callee);
126       }
127     }
128   }
129 }
130 
CallGraph(const HloModule * module)131 CallGraph::CallGraph(const HloModule* module) : module_(module) {}
132 
GetNode(const HloComputation * computation) const133 const CallGraphNode& CallGraph::GetNode(
134     const HloComputation* computation) const {
135   auto it = node_indices_.find(computation);
136   CHECK(it != node_indices_.end());
137   return nodes_[it->second];
138 }
139 
GetNode(const HloComputation * computation)140 CallGraphNode& CallGraph::GetNode(const HloComputation* computation) {
141   auto it = node_indices_.find(computation);
142   CHECK(it != node_indices_.end());
143   return nodes_[it->second];
144 }
145 
DominatesHelper(const HloComputation * a,const HloComputation * b,absl::flat_hash_set<const HloComputation * > * visited) const146 bool CallGraph::DominatesHelper(
147     const HloComputation* a, const HloComputation* b,
148     absl::flat_hash_set<const HloComputation*>* visited) const {
149   if (a == b || ContainsKey(*visited, b)) {
150     // The call graph is guaranteed to be acyclic so any previously visited node
151     // we encounter was already determined to be dominated.
152     return true;
153   }
154 
155   const CallGraphNode& b_node = GetNode(b);
156   if (b_node.callers().empty()) {
157     // We reached a root node without hitting 'a'. 'a' does not dominate 'b'.
158     return false;
159   }
160 
161   // Walk up the callers of 'b' until we hit 'a' or a root node (no callers).
162   visited->insert(b);
163   for (const HloComputation* b_caller : b_node.callers()) {
164     if (!DominatesHelper(a, b_caller, visited)) {
165       return false;
166     }
167   }
168   return true;
169 }
170 
Dominates(const HloComputation * a,const HloComputation * b) const171 bool CallGraph::Dominates(const HloComputation* a,
172                           const HloComputation* b) const {
173   absl::flat_hash_set<const HloComputation*> visited;
174   return DominatesHelper(a, b, &visited);
175 }
176 
177 namespace {
178 
179 // Returns the call context of a computation which is called from contexts 'a'
180 // and 'b'.
UnionContexts(CallContext a,CallContext b)181 CallContext UnionContexts(CallContext a, CallContext b) {
182   if (a == CallContext::kNone) {
183     return b;
184   } else if (b == CallContext::kNone) {
185     return a;
186   } else if (a == b) {
187     return a;
188   } else {
189     // Contexts are different and neither is kNone, ie one is kSequential and
190     // the other is kParallel.
191     return CallContext::kBoth;
192   }
193 }
194 
195 }  // namespace
196 
SetCallContexts()197 void CallGraph::SetCallContexts() {
198   std::queue<CallGraphNode*> worklist;
199 
200   // Initialize worklist with all roots of the call graph (computations without
201   // callers).
202   for (const HloComputation* computation : module_->computations()) {
203     CallGraphNode& node = GetNode(computation);
204     if (node.callers().empty()) {
205       node.set_context(CallContext::kSequential);
206       worklist.push(&node);
207     }
208   }
209 
210   while (!worklist.empty()) {
211     CallGraphNode* node = worklist.front();
212     worklist.pop();
213 
214     for (const CallSite& callsite : node->callsites()) {
215       for (const HloComputation* callee : callsite.called_computations()) {
216         CallGraphNode& callee_node = GetNode(callee);
217 
218         // Update context of callee computation based on the callsite and its
219         // current context.
220         CallContext context_to_add;
221         if (callsite.context() == CallContext::kParallel) {
222           context_to_add = CallContext::kParallel;
223         } else {
224           CHECK_EQ(callsite.context(), CallContext::kSequential);
225           context_to_add = node->context();
226         }
227         CallContext new_context =
228             UnionContexts(context_to_add, callee_node.context());
229 
230         if (new_context != callee_node.context()) {
231           // Context of computation has been changed so add node to worklist.
232           callee_node.set_context(new_context);
233           worklist.push(&callee_node);
234         }
235       }
236     }
237   }
238 
239   // No node should have a kNone calling context.
240   for (const HloComputation* computation : module_->computations()) {
241     CHECK_NE(GetNode(computation).context(), CallContext::kNone);
242   }
243 }
244 
SetNodeDepths()245 void CallGraph::SetNodeDepths() {
246   std::queue<CallGraphNode*> worklist;
247 
248   // Initialize node depths to -1.
249   for (CallGraphNode& node : nodes_) {
250     node.set_depth(-1);
251   }
252 
253   // Initialize worklist with all roots of the call graph (computations without
254   // callers).
255   for (const HloComputation* computation : module_->computations()) {
256     CallGraphNode& node = GetNode(computation);
257     if (node.callers().empty()) {
258       node.set_depth(0);
259       worklist.push(&node);
260     }
261   }
262 
263   while (!worklist.empty()) {
264     CallGraphNode* node = worklist.front();
265     worklist.pop();
266     for (const HloComputation* callee : node->callees()) {
267       CallGraphNode& callee_node = GetNode(callee);
268       if (callee_node.depth() < node->depth() + 1) {
269         callee_node.set_depth(node->depth() + 1);
270         worklist.push(&callee_node);
271       }
272     }
273   }
274 
275   for (CallGraphNode& node : nodes_) {
276     CHECK_NE(node.depth(), -1);
277   }
278 }
279 
280 /* static */
Build(const HloModule * module)281 std::unique_ptr<CallGraph> CallGraph::Build(const HloModule* module) {
282   // Constructor for CallGraph is private so absl::make_unique can't be used.
283   auto call_graph = absl::WrapUnique<CallGraph>(new CallGraph(module));
284 
285   VLOG(3) << "Building call graph for:";
286   XLA_VLOG_LINES(3, module->ToString());
287 
288   // Construct nodes of the call graph and populate the callsites.
289   for (HloComputation* computation : module->computations()) {
290     auto it_added = call_graph->node_indices_.insert(
291         {computation, call_graph->nodes_.size()});
292     // All computations should be unique, so the computation should not already
293     // exist in the map.
294     CHECK(it_added.second);
295     call_graph->nodes_.emplace_back(computation);
296 
297     // Add all callsites in this computation.
298     for (HloInstruction* instruction : computation->instructions()) {
299       call_graph->nodes_.back().AddCallSiteForInstruction(instruction);
300     }
301   }
302 
303   // Add caller callsites to each node.
304   for (const HloComputation* computation : module->computations()) {
305     for (const CallSite& callsite :
306          call_graph->GetNode(computation).callsites()) {
307       for (auto* callee : callsite.called_computations()) {
308         // Add caller callsites.
309         call_graph->GetNode(callee).AddCallerCallSite(callsite);
310       }
311     }
312   }
313 
314   call_graph->SetCallContexts();
315   call_graph->SetNodeDepths();
316 
317   XLA_VLOG_LINES(2, call_graph->ToString());
318 
319   return call_graph;
320 }
321 
VisitNodesInternal(const VisitorFunction & visitor_func,const CallGraphNode & node,absl::flat_hash_set<const CallGraphNode * > * visited) const322 Status CallGraph::VisitNodesInternal(
323     const VisitorFunction& visitor_func, const CallGraphNode& node,
324     absl::flat_hash_set<const CallGraphNode*>* visited) const {
325   auto pair = visited->insert(&node);
326   if (!pair.second) {
327     // Node was not inserted. Node has already been visited.
328     return Status::OK();
329   }
330 
331   for (const HloComputation* computation : node.callees()) {
332     TF_RETURN_IF_ERROR(
333         VisitNodesInternal(visitor_func, GetNode(computation), visited));
334   }
335 
336   return visitor_func(node);
337 }
338 
VisitNodes(const VisitorFunction & visitor_func,bool visit_unreachable_nodes) const339 Status CallGraph::VisitNodes(const VisitorFunction& visitor_func,
340                              bool visit_unreachable_nodes) const {
341   absl::flat_hash_set<const CallGraphNode*> visited;
342   if (visit_unreachable_nodes) {
343     // Traverse from all roots in the call graph.
344     for (const CallGraphNode& node : nodes()) {
345       if (node.callers().empty()) {
346         TF_RETURN_IF_ERROR(VisitNodesInternal(visitor_func, node, &visited));
347       }
348     }
349   } else {
350     // Traverse only from the entry computation.
351     TF_RETURN_IF_ERROR(VisitNodesInternal(
352         visitor_func, GetNode(module_->entry_computation()), &visited));
353   }
354 
355   return Status::OK();
356 }
357 
IsFlattened() const358 bool CallGraph::IsFlattened() const {
359   for (const CallGraphNode& node : nodes_) {
360     if (node.context() == CallContext::kBoth) {
361       return false;
362     }
363     if (node.context() == CallContext::kSequential &&
364         node.caller_callsites().size() > 1) {
365       return false;
366     }
367   }
368   return true;
369 }
370 
GetComputationCallers(HloComputation * c)371 std::vector<HloInstruction*> CallGraph::GetComputationCallers(
372     HloComputation* c) {
373   std::vector<HloInstruction*> callers;
374   for (const auto& callsite : GetNode(c).caller_callsites()) {
375     callers.push_back(callsite.instruction());
376   }
377   return callers;
378 }
379 
380 std::pair<HloInstruction*, HloInstruction*>
NearestAncestorsInSameComputation(HloInstruction * a,HloInstruction * b) const381 CallGraph::NearestAncestorsInSameComputation(HloInstruction* a,
382                                              HloInstruction* b) const {
383   // Lambda which returns the next instruction in the callee->caller chain in
384   // the call graph. This is the unique instruction which calls the computation
385   // containing 'instruction'. If more than one instruction calls the
386   // computation containing 'instruction' or no instructions call the
387   // computation then nullptr is returned.
388   auto next_caller = [this](HloInstruction* instruction) -> HloInstruction* {
389     const CallGraphNode& node = GetNode(instruction->parent());
390     if (node.caller_callsites().size() != 1) {
391       return nullptr;
392     }
393     return node.caller_callsites()[0].instruction();
394   };
395 
396   // Iterate through the callee->caller chains and find the earliest common
397   // element.
398   HloInstruction* a_ancestor = a;
399   HloInstruction* b_ancestor = b;
400   int a_depth = GetNode(a->parent()).depth();
401   int b_depth = GetNode(b->parent()).depth();
402 
403   // Advance a_ancestor (b_ancestor) up the call chain until the call depth of
404   // a_ancestor or b_ancestor are the same. Necessarily each call to next_caller
405   // reduces the depth by exactly one.
406   if (a_depth > b_depth) {
407     for (int i = 0; i < a_depth - b_depth; ++i) {
408       a_ancestor = next_caller(a_ancestor);
409       if (a_ancestor == nullptr) {
410         return {nullptr, nullptr};
411       }
412     }
413   } else if (b_depth > a_depth) {
414     for (int i = 0; i < b_depth - a_depth; ++i) {
415       b_ancestor = next_caller(b_ancestor);
416       if (b_ancestor == nullptr) {
417         return {nullptr, nullptr};
418       }
419     }
420   }
421 
422   while ((a_ancestor != nullptr) && (b_ancestor != nullptr)) {
423     if (a_ancestor->parent() == b_ancestor->parent()) {
424       return {a_ancestor, b_ancestor};
425     }
426 
427     a_ancestor = next_caller(a_ancestor);
428     b_ancestor = next_caller(b_ancestor);
429   }
430   return {nullptr, nullptr};
431 }
432 
ToString() const433 string CallGraph::ToString() const {
434   string out;
435   StrAppendFormat(&out, "Call graph for module %s:\n", module_->name());
436   for (const CallGraphNode& node : nodes()) {
437     StrAppendFormat(&out, "Computation %s:\n", node.computation()->name());
438     StrAppendFormat(&out, "  calls:\n");
439     for (const HloComputation* callee : node.callees()) {
440       StrAppendFormat(&out, "    %s\n", callee->name());
441     }
442     StrAppendFormat(&out, "  called by:\n");
443     for (const HloComputation* caller : node.callers()) {
444       StrAppendFormat(&out, "    %s\n", caller->name());
445     }
446     StrAppendFormat(&out, "  callsites:\n");
447     for (const CallSite& callsite : node.callsites()) {
448       StrAppendFormat(&out, "    %s\n", callsite.ToString());
449     }
450   }
451   return out;
452 }
453 
454 }  // namespace xla
455