1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/call_graph.h"
17
18 #include <queue>
19
20 #include "absl/container/flat_hash_set.h"
21 #include "absl/memory/memory.h"
22 #include "absl/strings/str_cat.h"
23 #include "absl/strings/str_format.h"
24 #include "absl/strings/str_join.h"
25 #include "tensorflow/compiler/xla/map_util.h"
26 #include "tensorflow/compiler/xla/status_macros.h"
27 #include "tensorflow/compiler/xla/util.h"
28 #include "tensorflow/core/lib/core/errors.h"
29 #include "tensorflow/core/lib/core/status.h"
30 #include "tensorflow/core/platform/types.h"
31
32 namespace xla {
33
34 using absl::StrAppendFormat;
35 using absl::StrCat;
36
CallContextToString(CallContext context)37 string CallContextToString(CallContext context) {
38 switch (context) {
39 case CallContext::kNone:
40 return "kNone";
41 case CallContext::kSequential:
42 return "kSequential";
43 case CallContext::kParallel:
44 return "kParallel";
45 case CallContext::kBoth:
46 return "kBoth";
47 }
48 }
49
operator <<(std::ostream & out,const CallContext & context)50 std::ostream& operator<<(std::ostream& out, const CallContext& context) {
51 out << CallContextToString(context);
52 return out;
53 }
54
GetInstructionCallContext(HloOpcode opcode)55 CallContext GetInstructionCallContext(HloOpcode opcode) {
56 switch (opcode) {
57 case HloOpcode::kCall:
58 case HloOpcode::kConditional:
59 case HloOpcode::kWhile:
60 return CallContext::kSequential;
61 case HloOpcode::kAllReduce:
62 case HloOpcode::kReduceScatter:
63 case HloOpcode::kAllReduceStart:
64 case HloOpcode::kMap:
65 case HloOpcode::kReduce:
66 case HloOpcode::kReduceWindow:
67 case HloOpcode::kScatter:
68 case HloOpcode::kSelectAndScatter:
69 case HloOpcode::kSort:
70 case HloOpcode::kFusion:
71 case HloOpcode::kCustomCall:
72 return CallContext::kParallel;
73 default:
74 return CallContext::kNone;
75 }
76 }
77
ToString() const78 string CallSite::ToString() const {
79 return StrCat(
80 instruction()->name(), " calls in context ",
81 CallContextToString(context()), ": ",
82 absl::StrJoin(called_computations(), ", ",
83 [](string* out, const HloComputation* computation) {
84 out->append(computation->name());
85 }));
86 }
87
CallGraphNode(HloComputation * computation)88 CallGraphNode::CallGraphNode(HloComputation* computation)
89 : computation_(computation) {}
90
GetCallSite(const HloInstruction * instruction) const91 const CallSite* CallGraphNode::GetCallSite(
92 const HloInstruction* instruction) const {
93 auto it = callsite_instructions_.find(instruction);
94 if (it == callsite_instructions_.end()) {
95 return nullptr;
96 }
97 return &callsites_[it->second];
98 }
99
ToString() const100 std::string CallGraphNode::ToString() const { return computation_->name(); }
101
AddCallerCallSite(const CallSite & caller_callsite)102 void CallGraphNode::AddCallerCallSite(const CallSite& caller_callsite) {
103 caller_callsites_.push_back(caller_callsite);
104 HloComputation* caller = caller_callsite.instruction()->parent();
105 if (!ContainsKey(caller_set_, caller)) {
106 callers_.push_back(caller);
107 caller_set_.insert(caller);
108 }
109 }
110
AddCallSiteForInstruction(HloInstruction * instruction)111 void CallGraphNode::AddCallSiteForInstruction(HloInstruction* instruction) {
112 CHECK_EQ(instruction->parent(), computation());
113 const CallContext context = GetInstructionCallContext(instruction->opcode());
114 if (!instruction->called_computations().empty()) {
115 CHECK(context == CallContext::kSequential ||
116 context == CallContext::kParallel);
117 callsite_instructions_.insert({instruction, callsites_.size()});
118 callsites_.push_back(
119 CallSite(instruction, instruction->called_computations(), context));
120 // Update callee computations to include any new computations called by this
121 // instruction.
122 for (auto* callee : callsites_.back().called_computations()) {
123 if (!ContainsKey(callee_set_, callee)) {
124 callees_.push_back(callee);
125 callee_set_.insert(callee);
126 }
127 }
128 }
129 }
130
CallGraph(const HloModule * module)131 CallGraph::CallGraph(const HloModule* module) : module_(module) {}
132
GetNode(const HloComputation * computation) const133 const CallGraphNode& CallGraph::GetNode(
134 const HloComputation* computation) const {
135 auto it = node_indices_.find(computation);
136 CHECK(it != node_indices_.end());
137 return nodes_[it->second];
138 }
139
GetNode(const HloComputation * computation)140 CallGraphNode& CallGraph::GetNode(const HloComputation* computation) {
141 auto it = node_indices_.find(computation);
142 CHECK(it != node_indices_.end());
143 return nodes_[it->second];
144 }
145
DominatesHelper(const HloComputation * a,const HloComputation * b,absl::flat_hash_set<const HloComputation * > * visited) const146 bool CallGraph::DominatesHelper(
147 const HloComputation* a, const HloComputation* b,
148 absl::flat_hash_set<const HloComputation*>* visited) const {
149 if (a == b || ContainsKey(*visited, b)) {
150 // The call graph is guaranteed to be acyclic so any previously visited node
151 // we encounter was already determined to be dominated.
152 return true;
153 }
154
155 const CallGraphNode& b_node = GetNode(b);
156 if (b_node.callers().empty()) {
157 // We reached a root node without hitting 'a'. 'a' does not dominate 'b'.
158 return false;
159 }
160
161 // Walk up the callers of 'b' until we hit 'a' or a root node (no callers).
162 visited->insert(b);
163 for (const HloComputation* b_caller : b_node.callers()) {
164 if (!DominatesHelper(a, b_caller, visited)) {
165 return false;
166 }
167 }
168 return true;
169 }
170
Dominates(const HloComputation * a,const HloComputation * b) const171 bool CallGraph::Dominates(const HloComputation* a,
172 const HloComputation* b) const {
173 absl::flat_hash_set<const HloComputation*> visited;
174 return DominatesHelper(a, b, &visited);
175 }
176
177 namespace {
178
179 // Returns the call context of a computation which is called from contexts 'a'
180 // and 'b'.
UnionContexts(CallContext a,CallContext b)181 CallContext UnionContexts(CallContext a, CallContext b) {
182 if (a == CallContext::kNone) {
183 return b;
184 } else if (b == CallContext::kNone) {
185 return a;
186 } else if (a == b) {
187 return a;
188 } else {
189 // Contexts are different and neither is kNone, ie one is kSequential and
190 // the other is kParallel.
191 return CallContext::kBoth;
192 }
193 }
194
195 } // namespace
196
SetCallContexts()197 void CallGraph::SetCallContexts() {
198 std::queue<CallGraphNode*> worklist;
199
200 // Initialize worklist with all roots of the call graph (computations without
201 // callers).
202 for (const HloComputation* computation : module_->computations()) {
203 CallGraphNode& node = GetNode(computation);
204 if (node.callers().empty()) {
205 node.set_context(CallContext::kSequential);
206 worklist.push(&node);
207 }
208 }
209
210 while (!worklist.empty()) {
211 CallGraphNode* node = worklist.front();
212 worklist.pop();
213
214 for (const CallSite& callsite : node->callsites()) {
215 for (const HloComputation* callee : callsite.called_computations()) {
216 CallGraphNode& callee_node = GetNode(callee);
217
218 // Update context of callee computation based on the callsite and its
219 // current context.
220 CallContext context_to_add;
221 if (callsite.context() == CallContext::kParallel) {
222 context_to_add = CallContext::kParallel;
223 } else {
224 CHECK_EQ(callsite.context(), CallContext::kSequential);
225 context_to_add = node->context();
226 }
227 CallContext new_context =
228 UnionContexts(context_to_add, callee_node.context());
229
230 if (new_context != callee_node.context()) {
231 // Context of computation has been changed so add node to worklist.
232 callee_node.set_context(new_context);
233 worklist.push(&callee_node);
234 }
235 }
236 }
237 }
238
239 // No node should have a kNone calling context.
240 for (const HloComputation* computation : module_->computations()) {
241 CHECK_NE(GetNode(computation).context(), CallContext::kNone);
242 }
243 }
244
SetNodeDepths()245 void CallGraph::SetNodeDepths() {
246 std::queue<CallGraphNode*> worklist;
247
248 // Initialize node depths to -1.
249 for (CallGraphNode& node : nodes_) {
250 node.set_depth(-1);
251 }
252
253 // Initialize worklist with all roots of the call graph (computations without
254 // callers).
255 for (const HloComputation* computation : module_->computations()) {
256 CallGraphNode& node = GetNode(computation);
257 if (node.callers().empty()) {
258 node.set_depth(0);
259 worklist.push(&node);
260 }
261 }
262
263 while (!worklist.empty()) {
264 CallGraphNode* node = worklist.front();
265 worklist.pop();
266 for (const HloComputation* callee : node->callees()) {
267 CallGraphNode& callee_node = GetNode(callee);
268 if (callee_node.depth() < node->depth() + 1) {
269 callee_node.set_depth(node->depth() + 1);
270 worklist.push(&callee_node);
271 }
272 }
273 }
274
275 for (CallGraphNode& node : nodes_) {
276 CHECK_NE(node.depth(), -1);
277 }
278 }
279
280 /* static */
Build(const HloModule * module)281 std::unique_ptr<CallGraph> CallGraph::Build(const HloModule* module) {
282 // Constructor for CallGraph is private so absl::make_unique can't be used.
283 auto call_graph = absl::WrapUnique<CallGraph>(new CallGraph(module));
284
285 VLOG(3) << "Building call graph for:";
286 XLA_VLOG_LINES(3, module->ToString());
287
288 // Construct nodes of the call graph and populate the callsites.
289 for (HloComputation* computation : module->computations()) {
290 auto it_added = call_graph->node_indices_.insert(
291 {computation, call_graph->nodes_.size()});
292 // All computations should be unique, so the computation should not already
293 // exist in the map.
294 CHECK(it_added.second);
295 call_graph->nodes_.emplace_back(computation);
296
297 // Add all callsites in this computation.
298 for (HloInstruction* instruction : computation->instructions()) {
299 call_graph->nodes_.back().AddCallSiteForInstruction(instruction);
300 }
301 }
302
303 // Add caller callsites to each node.
304 for (const HloComputation* computation : module->computations()) {
305 for (const CallSite& callsite :
306 call_graph->GetNode(computation).callsites()) {
307 for (auto* callee : callsite.called_computations()) {
308 // Add caller callsites.
309 call_graph->GetNode(callee).AddCallerCallSite(callsite);
310 }
311 }
312 }
313
314 call_graph->SetCallContexts();
315 call_graph->SetNodeDepths();
316
317 XLA_VLOG_LINES(2, call_graph->ToString());
318
319 return call_graph;
320 }
321
VisitNodesInternal(const VisitorFunction & visitor_func,const CallGraphNode & node,absl::flat_hash_set<const CallGraphNode * > * visited) const322 Status CallGraph::VisitNodesInternal(
323 const VisitorFunction& visitor_func, const CallGraphNode& node,
324 absl::flat_hash_set<const CallGraphNode*>* visited) const {
325 auto pair = visited->insert(&node);
326 if (!pair.second) {
327 // Node was not inserted. Node has already been visited.
328 return Status::OK();
329 }
330
331 for (const HloComputation* computation : node.callees()) {
332 TF_RETURN_IF_ERROR(
333 VisitNodesInternal(visitor_func, GetNode(computation), visited));
334 }
335
336 return visitor_func(node);
337 }
338
VisitNodes(const VisitorFunction & visitor_func,bool visit_unreachable_nodes) const339 Status CallGraph::VisitNodes(const VisitorFunction& visitor_func,
340 bool visit_unreachable_nodes) const {
341 absl::flat_hash_set<const CallGraphNode*> visited;
342 if (visit_unreachable_nodes) {
343 // Traverse from all roots in the call graph.
344 for (const CallGraphNode& node : nodes()) {
345 if (node.callers().empty()) {
346 TF_RETURN_IF_ERROR(VisitNodesInternal(visitor_func, node, &visited));
347 }
348 }
349 } else {
350 // Traverse only from the entry computation.
351 TF_RETURN_IF_ERROR(VisitNodesInternal(
352 visitor_func, GetNode(module_->entry_computation()), &visited));
353 }
354
355 return Status::OK();
356 }
357
IsFlattened() const358 bool CallGraph::IsFlattened() const {
359 for (const CallGraphNode& node : nodes_) {
360 if (node.context() == CallContext::kBoth) {
361 return false;
362 }
363 if (node.context() == CallContext::kSequential &&
364 node.caller_callsites().size() > 1) {
365 return false;
366 }
367 }
368 return true;
369 }
370
GetComputationCallers(HloComputation * c)371 std::vector<HloInstruction*> CallGraph::GetComputationCallers(
372 HloComputation* c) {
373 std::vector<HloInstruction*> callers;
374 for (const auto& callsite : GetNode(c).caller_callsites()) {
375 callers.push_back(callsite.instruction());
376 }
377 return callers;
378 }
379
380 std::pair<HloInstruction*, HloInstruction*>
NearestAncestorsInSameComputation(HloInstruction * a,HloInstruction * b) const381 CallGraph::NearestAncestorsInSameComputation(HloInstruction* a,
382 HloInstruction* b) const {
383 // Lambda which returns the next instruction in the callee->caller chain in
384 // the call graph. This is the unique instruction which calls the computation
385 // containing 'instruction'. If more than one instruction calls the
386 // computation containing 'instruction' or no instructions call the
387 // computation then nullptr is returned.
388 auto next_caller = [this](HloInstruction* instruction) -> HloInstruction* {
389 const CallGraphNode& node = GetNode(instruction->parent());
390 if (node.caller_callsites().size() != 1) {
391 return nullptr;
392 }
393 return node.caller_callsites()[0].instruction();
394 };
395
396 // Iterate through the callee->caller chains and find the earliest common
397 // element.
398 HloInstruction* a_ancestor = a;
399 HloInstruction* b_ancestor = b;
400 int a_depth = GetNode(a->parent()).depth();
401 int b_depth = GetNode(b->parent()).depth();
402
403 // Advance a_ancestor (b_ancestor) up the call chain until the call depth of
404 // a_ancestor or b_ancestor are the same. Necessarily each call to next_caller
405 // reduces the depth by exactly one.
406 if (a_depth > b_depth) {
407 for (int i = 0; i < a_depth - b_depth; ++i) {
408 a_ancestor = next_caller(a_ancestor);
409 if (a_ancestor == nullptr) {
410 return {nullptr, nullptr};
411 }
412 }
413 } else if (b_depth > a_depth) {
414 for (int i = 0; i < b_depth - a_depth; ++i) {
415 b_ancestor = next_caller(b_ancestor);
416 if (b_ancestor == nullptr) {
417 return {nullptr, nullptr};
418 }
419 }
420 }
421
422 while ((a_ancestor != nullptr) && (b_ancestor != nullptr)) {
423 if (a_ancestor->parent() == b_ancestor->parent()) {
424 return {a_ancestor, b_ancestor};
425 }
426
427 a_ancestor = next_caller(a_ancestor);
428 b_ancestor = next_caller(b_ancestor);
429 }
430 return {nullptr, nullptr};
431 }
432
ToString() const433 string CallGraph::ToString() const {
434 string out;
435 StrAppendFormat(&out, "Call graph for module %s:\n", module_->name());
436 for (const CallGraphNode& node : nodes()) {
437 StrAppendFormat(&out, "Computation %s:\n", node.computation()->name());
438 StrAppendFormat(&out, " calls:\n");
439 for (const HloComputation* callee : node.callees()) {
440 StrAppendFormat(&out, " %s\n", callee->name());
441 }
442 StrAppendFormat(&out, " called by:\n");
443 for (const HloComputation* caller : node.callers()) {
444 StrAppendFormat(&out, " %s\n", caller->name());
445 }
446 StrAppendFormat(&out, " callsites:\n");
447 for (const CallSite& callsite : node.callsites()) {
448 StrAppendFormat(&out, " %s\n", callsite.ToString());
449 }
450 }
451 return out;
452 }
453
454 } // namespace xla
455