1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/call_graph.h"
17
18 #include <memory>
19 #include <queue>
20
21 #include "absl/container/flat_hash_set.h"
22 #include "absl/strings/str_cat.h"
23 #include "absl/strings/str_format.h"
24 #include "absl/strings/str_join.h"
25 #include "tensorflow/compiler/xla/map_util.h"
26 #include "tensorflow/compiler/xla/status_macros.h"
27 #include "tensorflow/compiler/xla/util.h"
28 #include "tensorflow/core/lib/core/errors.h"
29 #include "tensorflow/core/lib/core/status.h"
30
31 namespace xla {
32
33 using absl::StrAppendFormat;
34 using absl::StrCat;
35
CallContextToString(CallContext context)36 std::string CallContextToString(CallContext context) {
37 switch (context) {
38 case CallContext::kNone:
39 return "kNone";
40 case CallContext::kControlFlow:
41 return "kControlFlow";
42 case CallContext::kEmbedded:
43 return "kEmbedded";
44 case CallContext::kBoth:
45 return "kBoth";
46 }
47 }
48
operator <<(std::ostream & out,const CallContext & context)49 std::ostream& operator<<(std::ostream& out, const CallContext& context) {
50 out << CallContextToString(context);
51 return out;
52 }
53
GetInstructionCallContext(HloOpcode opcode)54 CallContext GetInstructionCallContext(HloOpcode opcode) {
55 switch (opcode) {
56 case HloOpcode::kCall:
57 case HloOpcode::kConditional:
58 case HloOpcode::kWhile:
59 case HloOpcode::kAsyncStart:
60 case HloOpcode::kAsyncUpdate:
61 case HloOpcode::kAsyncDone:
62 return CallContext::kControlFlow;
63 case HloOpcode::kAllReduce:
64 case HloOpcode::kReduceScatter:
65 case HloOpcode::kAllReduceStart:
66 case HloOpcode::kMap:
67 case HloOpcode::kReduce:
68 case HloOpcode::kReduceWindow:
69 case HloOpcode::kScatter:
70 case HloOpcode::kSelectAndScatter:
71 case HloOpcode::kSort:
72 case HloOpcode::kFusion:
73 case HloOpcode::kCustomCall:
74 return CallContext::kEmbedded;
75 default:
76 return CallContext::kNone;
77 }
78 }
79
ToString() const80 std::string CallSite::ToString() const {
81 return StrCat(
82 instruction()->name(), " calls in context ",
83 CallContextToString(context()), ": ",
84 absl::StrJoin(called_computations(), ", ",
85 [](std::string* out, const HloComputation* computation) {
86 out->append(computation->name());
87 }));
88 }
89
CallGraphNode(HloComputation * computation)90 CallGraphNode::CallGraphNode(HloComputation* computation)
91 : computation_(computation) {}
92
GetCallSite(const HloInstruction * instruction) const93 const CallSite* CallGraphNode::GetCallSite(
94 const HloInstruction* instruction) const {
95 auto it = callsite_instructions_.find(instruction);
96 if (it == callsite_instructions_.end()) {
97 return nullptr;
98 }
99 return &callsites_[it->second];
100 }
101
ToString() const102 std::string CallGraphNode::ToString() const { return computation_->name(); }
103
AddCallerCallSite(const CallSite & caller_callsite)104 void CallGraphNode::AddCallerCallSite(const CallSite& caller_callsite) {
105 caller_callsites_.push_back(caller_callsite);
106 HloComputation* caller = caller_callsite.instruction()->parent();
107 if (!ContainsKey(caller_set_, caller)) {
108 callers_.push_back(caller);
109 caller_set_.insert(caller);
110 }
111 }
112
AddCallSiteForInstruction(HloInstruction * instruction)113 void CallGraphNode::AddCallSiteForInstruction(HloInstruction* instruction) {
114 CHECK_EQ(instruction->parent(), computation());
115 const CallContext context = GetInstructionCallContext(instruction->opcode());
116 if (!instruction->called_computations().empty()) {
117 CHECK(context == CallContext::kControlFlow ||
118 context == CallContext::kEmbedded);
119 callsite_instructions_.insert({instruction, callsites_.size()});
120 callsites_.push_back(
121 CallSite(instruction, instruction->called_computations(), context));
122 // Update callee computations to include any new computations called by this
123 // instruction.
124 for (auto* callee : callsites_.back().called_computations()) {
125 if (!ContainsKey(callee_set_, callee)) {
126 callees_.push_back(callee);
127 callee_set_.insert(callee);
128 }
129 }
130 }
131 }
132
CallGraph(const HloModule * module)133 CallGraph::CallGraph(const HloModule* module) : module_(module) {}
134
GetNode(const HloComputation * computation) const135 const CallGraphNode& CallGraph::GetNode(
136 const HloComputation* computation) const {
137 auto it = node_indices_.find(computation);
138 CHECK(it != node_indices_.end());
139 return nodes_[it->second];
140 }
141
GetNode(const HloComputation * computation)142 CallGraphNode& CallGraph::GetNode(const HloComputation* computation) {
143 auto it = node_indices_.find(computation);
144 CHECK(it != node_indices_.end());
145 return nodes_[it->second];
146 }
147
DominatesHelper(const HloComputation * a,const HloComputation * b,absl::flat_hash_set<const HloComputation * > * visited) const148 bool CallGraph::DominatesHelper(
149 const HloComputation* a, const HloComputation* b,
150 absl::flat_hash_set<const HloComputation*>* visited) const {
151 if (a == b || ContainsKey(*visited, b)) {
152 // The call graph is guaranteed to be acyclic so any previously visited node
153 // we encounter was already determined to be dominated.
154 return true;
155 }
156
157 const CallGraphNode& b_node = GetNode(b);
158 if (b_node.callers().empty()) {
159 // We reached a root node without hitting 'a'. 'a' does not dominate 'b'.
160 return false;
161 }
162
163 // Walk up the callers of 'b' until we hit 'a' or a root node (no callers).
164 visited->insert(b);
165 for (const HloComputation* b_caller : b_node.callers()) {
166 if (!DominatesHelper(a, b_caller, visited)) {
167 return false;
168 }
169 }
170 return true;
171 }
172
Dominates(const HloComputation * a,const HloComputation * b) const173 bool CallGraph::Dominates(const HloComputation* a,
174 const HloComputation* b) const {
175 absl::flat_hash_set<const HloComputation*> visited;
176 return DominatesHelper(a, b, &visited);
177 }
178
179 namespace {
180
181 // Returns the call context of a computation which is called from contexts 'a'
182 // and 'b'.
UnionContexts(CallContext a,CallContext b)183 CallContext UnionContexts(CallContext a, CallContext b) {
184 if (a == CallContext::kNone) {
185 return b;
186 } else if (b == CallContext::kNone) {
187 return a;
188 } else if (a == b) {
189 return a;
190 } else {
191 // Contexts are different and neither is kNone, ie one is kSequential and
192 // the other is kParallel.
193 return CallContext::kBoth;
194 }
195 }
196
197 } // namespace
198
SetCallContexts()199 void CallGraph::SetCallContexts() {
200 std::queue<CallGraphNode*> worklist;
201
202 // Initialize worklist with all roots of the call graph (computations without
203 // callers).
204 for (const HloComputation* computation : module_->computations()) {
205 CallGraphNode& node = GetNode(computation);
206 if (node.callers().empty()) {
207 node.set_context(CallContext::kControlFlow);
208 worklist.push(&node);
209 }
210 }
211
212 while (!worklist.empty()) {
213 CallGraphNode* node = worklist.front();
214 worklist.pop();
215
216 for (const CallSite& callsite : node->callsites()) {
217 for (const HloComputation* callee : callsite.called_computations()) {
218 CallGraphNode& callee_node = GetNode(callee);
219
220 // Update context of callee computation based on the callsite and its
221 // current context.
222 CallContext context_to_add;
223 if (callsite.context() == CallContext::kEmbedded) {
224 context_to_add = CallContext::kEmbedded;
225 } else {
226 CHECK_EQ(callsite.context(), CallContext::kControlFlow);
227 context_to_add = node->context();
228 }
229 CallContext new_context =
230 UnionContexts(context_to_add, callee_node.context());
231
232 if (new_context != callee_node.context()) {
233 // Context of computation has been changed so add node to worklist.
234 callee_node.set_context(new_context);
235 worklist.push(&callee_node);
236 }
237 }
238 }
239 }
240
241 // No node should have a kNone calling context.
242 for (const HloComputation* computation : module_->computations()) {
243 CHECK_NE(GetNode(computation).context(), CallContext::kNone);
244 }
245 }
246
SetNodeDepths()247 void CallGraph::SetNodeDepths() {
248 std::queue<CallGraphNode*> worklist;
249
250 // Initialize node depths to -1.
251 for (CallGraphNode& node : nodes_) {
252 node.set_depth(-1);
253 }
254
255 // Initialize worklist with all roots of the call graph (computations without
256 // callers).
257 for (const HloComputation* computation : module_->computations()) {
258 CallGraphNode& node = GetNode(computation);
259 if (node.callers().empty()) {
260 node.set_depth(0);
261 worklist.push(&node);
262 }
263 }
264
265 while (!worklist.empty()) {
266 CallGraphNode* node = worklist.front();
267 worklist.pop();
268 for (const HloComputation* callee : node->callees()) {
269 CallGraphNode& callee_node = GetNode(callee);
270 if (callee_node.depth() < node->depth() + 1) {
271 callee_node.set_depth(node->depth() + 1);
272 worklist.push(&callee_node);
273 }
274 }
275 }
276
277 for (CallGraphNode& node : nodes_) {
278 CHECK_NE(node.depth(), -1);
279 }
280 }
281
282 /* static */
Build(const HloModule * module)283 std::unique_ptr<CallGraph> CallGraph::Build(const HloModule* module) {
284 // Constructor for CallGraph is private so std::make_unique can't be used.
285 auto call_graph = absl::WrapUnique<CallGraph>(new CallGraph(module));
286
287 VLOG(3) << "Building call graph for:";
288 XLA_VLOG_LINES(3, module->ToString());
289
290 // Construct nodes of the call graph and populate the callsites.
291 for (HloComputation* computation : module->computations()) {
292 auto it_added = call_graph->node_indices_.insert(
293 {computation, call_graph->nodes_.size()});
294 // All computations should be unique, so the computation should not already
295 // exist in the map.
296 CHECK(it_added.second);
297 call_graph->nodes_.emplace_back(computation);
298
299 // Add all callsites in this computation.
300 for (HloInstruction* instruction : computation->instructions()) {
301 call_graph->nodes_.back().AddCallSiteForInstruction(instruction);
302 }
303 }
304
305 // Add caller callsites to each node.
306 for (const HloComputation* computation : module->computations()) {
307 for (const CallSite& callsite :
308 call_graph->GetNode(computation).callsites()) {
309 for (auto* callee : callsite.called_computations()) {
310 // Add caller callsites.
311 call_graph->GetNode(callee).AddCallerCallSite(callsite);
312 }
313 }
314 }
315
316 call_graph->SetCallContexts();
317 call_graph->SetNodeDepths();
318
319 XLA_VLOG_LINES(2, call_graph->ToString());
320
321 return call_graph;
322 }
323
VisitNodesInternal(const VisitorFunction & visitor_func,const CallGraphNode & node,absl::flat_hash_set<const CallGraphNode * > * visited) const324 Status CallGraph::VisitNodesInternal(
325 const VisitorFunction& visitor_func, const CallGraphNode& node,
326 absl::flat_hash_set<const CallGraphNode*>* visited) const {
327 auto pair = visited->insert(&node);
328 if (!pair.second) {
329 // Node was not inserted. Node has already been visited.
330 return OkStatus();
331 }
332
333 for (const HloComputation* computation : node.callees()) {
334 TF_RETURN_IF_ERROR(
335 VisitNodesInternal(visitor_func, GetNode(computation), visited));
336 }
337
338 return visitor_func(node);
339 }
340
VisitNodes(const VisitorFunction & visitor_func,bool visit_unreachable_nodes) const341 Status CallGraph::VisitNodes(const VisitorFunction& visitor_func,
342 bool visit_unreachable_nodes) const {
343 absl::flat_hash_set<const CallGraphNode*> visited;
344 if (visit_unreachable_nodes) {
345 // Traverse from all roots in the call graph.
346 for (const CallGraphNode& node : nodes()) {
347 if (node.callers().empty()) {
348 TF_RETURN_IF_ERROR(VisitNodesInternal(visitor_func, node, &visited));
349 }
350 }
351 } else {
352 // Traverse only from the entry computation.
353 TF_RETURN_IF_ERROR(VisitNodesInternal(
354 visitor_func, GetNode(module_->entry_computation()), &visited));
355 }
356
357 return OkStatus();
358 }
359
IsFlattened() const360 bool CallGraph::IsFlattened() const {
361 for (const CallGraphNode& node : nodes_) {
362 if (node.context() == CallContext::kBoth) {
363 return false;
364 }
365 if (node.context() == CallContext::kControlFlow &&
366 !node.computation()->IsAsyncComputation() &&
367 node.caller_callsites().size() > 1) {
368 return false;
369 }
370 }
371 return true;
372 }
373
GetComputationCallers(HloComputation * c) const374 std::vector<HloInstruction*> CallGraph::GetComputationCallers(
375 HloComputation* c) const {
376 std::vector<HloInstruction*> callers;
377 for (const auto& callsite : GetNode(c).caller_callsites()) {
378 callers.push_back(callsite.instruction());
379 }
380 return callers;
381 }
382
383 std::pair<HloInstruction*, HloInstruction*>
NearestAncestorsInSameComputation(HloInstruction * a,HloInstruction * b) const384 CallGraph::NearestAncestorsInSameComputation(HloInstruction* a,
385 HloInstruction* b) const {
386 // Lambda which returns the next instruction in the callee->caller chain in
387 // the call graph. This is the unique instruction which calls the computation
388 // containing 'instruction'. If more than one instruction calls the
389 // computation containing 'instruction' or no instructions call the
390 // computation then nullptr is returned.
391 auto next_caller = [this](HloInstruction* instruction) -> HloInstruction* {
392 const CallGraphNode& node = GetNode(instruction->parent());
393 if (node.caller_callsites().size() != 1) {
394 return nullptr;
395 }
396 return node.caller_callsites()[0].instruction();
397 };
398
399 // Iterate through the callee->caller chains and find the earliest common
400 // element.
401 HloInstruction* a_ancestor = a;
402 HloInstruction* b_ancestor = b;
403 int a_depth = GetNode(a->parent()).depth();
404 int b_depth = GetNode(b->parent()).depth();
405
406 // Advance a_ancestor (b_ancestor) up the call chain until the call depth of
407 // a_ancestor or b_ancestor are the same. Necessarily each call to next_caller
408 // reduces the depth by exactly one.
409 if (a_depth > b_depth) {
410 for (int i = 0; i < a_depth - b_depth; ++i) {
411 a_ancestor = next_caller(a_ancestor);
412 if (a_ancestor == nullptr) {
413 return {nullptr, nullptr};
414 }
415 }
416 } else if (b_depth > a_depth) {
417 for (int i = 0; i < b_depth - a_depth; ++i) {
418 b_ancestor = next_caller(b_ancestor);
419 if (b_ancestor == nullptr) {
420 return {nullptr, nullptr};
421 }
422 }
423 }
424
425 while ((a_ancestor != nullptr) && (b_ancestor != nullptr)) {
426 if (a_ancestor->parent() == b_ancestor->parent()) {
427 return {a_ancestor, b_ancestor};
428 }
429
430 a_ancestor = next_caller(a_ancestor);
431 b_ancestor = next_caller(b_ancestor);
432 }
433 return {nullptr, nullptr};
434 }
435
ToString() const436 std::string CallGraph::ToString() const {
437 std::string out;
438 StrAppendFormat(&out, "Call graph for module %s:\n", module_->name());
439 for (const CallGraphNode& node : nodes()) {
440 StrAppendFormat(&out, "Computation %s:\n", node.computation()->name());
441 StrAppendFormat(&out, " calls:\n");
442 for (const HloComputation* callee : node.callees()) {
443 StrAppendFormat(&out, " %s\n", callee->name());
444 }
445 StrAppendFormat(&out, " called by:\n");
446 for (const HloComputation* caller : node.callers()) {
447 StrAppendFormat(&out, " %s\n", caller->name());
448 }
449 StrAppendFormat(&out, " callsites:\n");
450 for (const CallSite& callsite : node.callsites()) {
451 StrAppendFormat(&out, " %s\n", callsite.ToString());
452 }
453 }
454 return out;
455 }
456
457 } // namespace xla
458