1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/call_graph.h"
17
18 #include <queue>
19
20 #include "absl/container/flat_hash_set.h"
21 #include "absl/memory/memory.h"
22 #include "absl/strings/str_cat.h"
23 #include "absl/strings/str_format.h"
24 #include "absl/strings/str_join.h"
25 #include "tensorflow/compiler/xla/map_util.h"
26 #include "tensorflow/compiler/xla/status_macros.h"
27 #include "tensorflow/compiler/xla/util.h"
28 #include "tensorflow/core/lib/core/errors.h"
29 #include "tensorflow/core/lib/core/status.h"
30 #include "tensorflow/core/platform/types.h"
31
32 namespace xla {
33
34 using absl::StrAppendFormat;
35 using absl::StrCat;
36
CallContextToString(CallContext context)37 string CallContextToString(CallContext context) {
38 switch (context) {
39 case CallContext::kNone:
40 return "kNone";
41 case CallContext::kSequential:
42 return "kSequential";
43 case CallContext::kParallel:
44 return "kParallel";
45 case CallContext::kBoth:
46 return "kBoth";
47 }
48 }
49
operator <<(std::ostream & out,const CallContext & context)50 std::ostream& operator<<(std::ostream& out, const CallContext& context) {
51 out << CallContextToString(context);
52 return out;
53 }
54
GetInstructionCallContext(HloOpcode opcode)55 CallContext GetInstructionCallContext(HloOpcode opcode) {
56 switch (opcode) {
57 case HloOpcode::kCall:
58 case HloOpcode::kConditional:
59 case HloOpcode::kWhile:
60 return CallContext::kSequential;
61 case HloOpcode::kAllReduce:
62 case HloOpcode::kMap:
63 case HloOpcode::kReduce:
64 case HloOpcode::kReduceWindow:
65 case HloOpcode::kScatter:
66 case HloOpcode::kSelectAndScatter:
67 case HloOpcode::kSort:
68 case HloOpcode::kFusion:
69 return CallContext::kParallel;
70 default:
71 return CallContext::kNone;
72 }
73 }
74
ToString() const75 string CallSite::ToString() const {
76 return StrCat(
77 instruction()->name(), " calls in context ",
78 CallContextToString(context()), ": ",
79 absl::StrJoin(called_computations(), ", ",
80 [](string* out, const HloComputation* computation) {
81 out->append(computation->name());
82 }));
83 }
84
CallGraphNode(HloComputation * computation)85 CallGraphNode::CallGraphNode(HloComputation* computation)
86 : computation_(computation) {}
87
GetCallSite(const HloInstruction * instruction) const88 const CallSite* CallGraphNode::GetCallSite(
89 const HloInstruction* instruction) const {
90 auto it = callsite_instructions_.find(instruction);
91 if (it == callsite_instructions_.end()) {
92 return nullptr;
93 }
94 return &callsites_[it->second];
95 }
96
AddCallerCallSite(const CallSite & caller_callsite)97 void CallGraphNode::AddCallerCallSite(const CallSite& caller_callsite) {
98 caller_callsites_.push_back(caller_callsite);
99 HloComputation* caller = caller_callsite.instruction()->parent();
100 if (!ContainsKey(caller_set_, caller)) {
101 callers_.push_back(caller);
102 caller_set_.insert(caller);
103 }
104 }
105
AddCallSiteForInstruction(HloInstruction * instruction)106 void CallGraphNode::AddCallSiteForInstruction(HloInstruction* instruction) {
107 CHECK_EQ(instruction->parent(), computation());
108 const CallContext context = GetInstructionCallContext(instruction->opcode());
109 if (!instruction->called_computations().empty()) {
110 CHECK(context == CallContext::kSequential ||
111 context == CallContext::kParallel);
112 callsite_instructions_.insert({instruction, callsites_.size()});
113 callsites_.push_back(
114 CallSite(instruction, instruction->called_computations(), context));
115 // Update callee computations to include any new computations called by this
116 // instruction.
117 for (auto* callee : callsites_.back().called_computations()) {
118 if (!ContainsKey(callee_set_, callee)) {
119 callees_.push_back(callee);
120 callee_set_.insert(callee);
121 }
122 }
123 }
124 }
125
CallGraph(const HloModule * module)126 CallGraph::CallGraph(const HloModule* module) : module_(module) {}
127
GetNode(const HloComputation * computation) const128 const CallGraphNode& CallGraph::GetNode(
129 const HloComputation* computation) const {
130 auto it = node_indices_.find(computation);
131 CHECK(it != node_indices_.end());
132 return nodes_[it->second];
133 }
134
GetNode(const HloComputation * computation)135 CallGraphNode& CallGraph::GetNode(const HloComputation* computation) {
136 auto it = node_indices_.find(computation);
137 CHECK(it != node_indices_.end());
138 return nodes_[it->second];
139 }
140
DominatesHelper(const HloComputation * a,const HloComputation * b,absl::flat_hash_set<const HloComputation * > * visited) const141 bool CallGraph::DominatesHelper(
142 const HloComputation* a, const HloComputation* b,
143 absl::flat_hash_set<const HloComputation*>* visited) const {
144 if (a == b || ContainsKey(*visited, b)) {
145 // The call graph is guaranteed to be acyclic so any previously visited node
146 // we encounter was already determined to be dominated.
147 return true;
148 }
149
150 const CallGraphNode& b_node = GetNode(b);
151 if (b_node.callers().empty()) {
152 // We reached a root node without hitting 'a'. 'a' does not dominate 'b'.
153 return false;
154 }
155
156 // Walk up the callers of 'b' until we hit 'a' or a root node (no callers).
157 visited->insert(b);
158 for (const HloComputation* b_caller : b_node.callers()) {
159 if (!DominatesHelper(a, b_caller, visited)) {
160 return false;
161 }
162 }
163 return true;
164 }
165
Dominates(const HloComputation * a,const HloComputation * b) const166 bool CallGraph::Dominates(const HloComputation* a,
167 const HloComputation* b) const {
168 absl::flat_hash_set<const HloComputation*> visited;
169 return DominatesHelper(a, b, &visited);
170 }
171
172 namespace {
173
174 // Returns the call context of a computation which is called from contexts 'a'
175 // and 'b'.
UnionContexts(CallContext a,CallContext b)176 CallContext UnionContexts(CallContext a, CallContext b) {
177 if (a == CallContext::kNone) {
178 return b;
179 } else if (b == CallContext::kNone) {
180 return a;
181 } else if (a == b) {
182 return a;
183 } else {
184 // Contexts are different and neither is kNone, ie one is kSequential and
185 // the other is kParallel.
186 return CallContext::kBoth;
187 }
188 }
189
190 } // namespace
191
SetCallContexts()192 void CallGraph::SetCallContexts() {
193 std::queue<CallGraphNode*> worklist;
194
195 // Initialize worklist with all roots of the call graph (computations without
196 // callers).
197 for (const HloComputation* computation : module_->computations()) {
198 CallGraphNode& node = GetNode(computation);
199 if (node.callers().empty()) {
200 node.set_context(CallContext::kSequential);
201 worklist.push(&node);
202 }
203 }
204
205 while (!worklist.empty()) {
206 CallGraphNode* node = worklist.front();
207 worklist.pop();
208
209 for (const CallSite& callsite : node->callsites()) {
210 for (const HloComputation* callee : callsite.called_computations()) {
211 CallGraphNode& callee_node = GetNode(callee);
212
213 // Update context of callee computation based on the callsite and its
214 // current context.
215 CallContext context_to_add;
216 if (callsite.context() == CallContext::kParallel) {
217 context_to_add = CallContext::kParallel;
218 } else {
219 CHECK_EQ(callsite.context(), CallContext::kSequential);
220 context_to_add = node->context();
221 }
222 CallContext new_context =
223 UnionContexts(context_to_add, callee_node.context());
224
225 if (new_context != callee_node.context()) {
226 // Context of computation has been changed so add node to worklist.
227 callee_node.set_context(new_context);
228 worklist.push(&callee_node);
229 }
230 }
231 }
232 }
233
234 // No node should have a kNone calling context.
235 for (const HloComputation* computation : module_->computations()) {
236 CHECK_NE(GetNode(computation).context(), CallContext::kNone);
237 }
238 }
239
SetNodeDepths()240 void CallGraph::SetNodeDepths() {
241 std::queue<CallGraphNode*> worklist;
242
243 // Initialize node depths to -1.
244 for (CallGraphNode& node : nodes_) {
245 node.set_depth(-1);
246 }
247
248 // Initialize worklist with all roots of the call graph (computations without
249 // callers).
250 for (const HloComputation* computation : module_->computations()) {
251 CallGraphNode& node = GetNode(computation);
252 if (node.callers().empty()) {
253 node.set_depth(0);
254 worklist.push(&node);
255 }
256 }
257
258 while (!worklist.empty()) {
259 CallGraphNode* node = worklist.front();
260 worklist.pop();
261 for (const HloComputation* callee : node->callees()) {
262 CallGraphNode& callee_node = GetNode(callee);
263 if (callee_node.depth() < node->depth() + 1) {
264 callee_node.set_depth(node->depth() + 1);
265 worklist.push(&callee_node);
266 }
267 }
268 }
269
270 for (CallGraphNode& node : nodes_) {
271 CHECK_NE(node.depth(), -1);
272 }
273 }
274
275 /* static */
Build(const HloModule * module)276 std::unique_ptr<CallGraph> CallGraph::Build(const HloModule* module) {
277 // Constructor for CallGraph is private so absl::make_unique can't be used.
278 auto call_graph = absl::WrapUnique<CallGraph>(new CallGraph(module));
279
280 VLOG(2) << "Building call graph for:";
281 XLA_VLOG_LINES(2, module->ToString());
282
283 // Construct nodes of the call graph and populate the callsites.
284 for (HloComputation* computation : module->computations()) {
285 auto it_added = call_graph->node_indices_.insert(
286 {computation, call_graph->nodes_.size()});
287 // All computations should be unique, so the computation should not already
288 // exist in the map.
289 CHECK(it_added.second);
290 call_graph->nodes_.emplace_back(computation);
291
292 // Add all callsites in this computation.
293 for (HloInstruction* instruction : computation->instructions()) {
294 call_graph->nodes_.back().AddCallSiteForInstruction(instruction);
295 }
296 }
297
298 // Add caller callsites to each node.
299 for (const HloComputation* computation : module->computations()) {
300 for (const CallSite& callsite :
301 call_graph->GetNode(computation).callsites()) {
302 for (auto* callee : callsite.called_computations()) {
303 // Add caller callsites.
304 call_graph->GetNode(callee).AddCallerCallSite(callsite);
305 }
306 }
307 }
308
309 call_graph->SetCallContexts();
310 call_graph->SetNodeDepths();
311
312 XLA_VLOG_LINES(1, call_graph->ToString());
313
314 return call_graph;
315 }
316
VisitNodesInternal(const VisitorFunction & visitor_func,const CallGraphNode & node,absl::flat_hash_set<const CallGraphNode * > * visited) const317 Status CallGraph::VisitNodesInternal(
318 const VisitorFunction& visitor_func, const CallGraphNode& node,
319 absl::flat_hash_set<const CallGraphNode*>* visited) const {
320 auto pair = visited->insert(&node);
321 if (!pair.second) {
322 // Node was not inserted. Node has already been visited.
323 return Status::OK();
324 }
325
326 for (const HloComputation* computation : node.callees()) {
327 TF_RETURN_IF_ERROR(
328 VisitNodesInternal(visitor_func, GetNode(computation), visited));
329 }
330
331 return visitor_func(node);
332 }
333
VisitNodes(const VisitorFunction & visitor_func,bool visit_unreachable_nodes) const334 Status CallGraph::VisitNodes(const VisitorFunction& visitor_func,
335 bool visit_unreachable_nodes) const {
336 absl::flat_hash_set<const CallGraphNode*> visited;
337 if (visit_unreachable_nodes) {
338 // Traverse from all roots in the call graph.
339 for (const CallGraphNode& node : nodes()) {
340 if (node.callers().empty()) {
341 TF_RETURN_IF_ERROR(VisitNodesInternal(visitor_func, node, &visited));
342 }
343 }
344 } else {
345 // Traverse only from the entry computation.
346 TF_RETURN_IF_ERROR(VisitNodesInternal(
347 visitor_func, GetNode(module_->entry_computation()), &visited));
348 }
349
350 return Status::OK();
351 }
352
IsFlattened() const353 bool CallGraph::IsFlattened() const {
354 for (const CallGraphNode& node : nodes_) {
355 if (node.context() == CallContext::kBoth) {
356 return false;
357 }
358 if (node.context() == CallContext::kSequential &&
359 node.caller_callsites().size() > 1) {
360 return false;
361 }
362 }
363 return true;
364 }
365
GetComputationCallers(HloComputation * c)366 std::vector<HloInstruction*> CallGraph::GetComputationCallers(
367 HloComputation* c) {
368 std::vector<HloInstruction*> callers;
369 for (auto callsite : GetNode(c).caller_callsites()) {
370 callers.push_back(callsite.instruction());
371 }
372 return callers;
373 }
374
375 std::pair<HloInstruction*, HloInstruction*>
NearestAncestorsInSameComputation(HloInstruction * a,HloInstruction * b) const376 CallGraph::NearestAncestorsInSameComputation(HloInstruction* a,
377 HloInstruction* b) const {
378 // Lambda which returns the next instruction in the callee->caller chain in
379 // the call graph. This is the unique instruction which calls the computation
380 // containing 'instruction'. If more than one instruction calls the
381 // computation containing 'instruction' or no instructions call the
382 // computation then nullptr is returned.
383 auto next_caller = [this](HloInstruction* instruction) -> HloInstruction* {
384 const CallGraphNode& node = GetNode(instruction->parent());
385 if (node.caller_callsites().size() != 1) {
386 return nullptr;
387 }
388 return node.caller_callsites()[0].instruction();
389 };
390
391 // Iterate through the callee->caller chains and find the earliest common
392 // element.
393 HloInstruction* a_ancestor = a;
394 HloInstruction* b_ancestor = b;
395 int a_depth = GetNode(a->parent()).depth();
396 int b_depth = GetNode(b->parent()).depth();
397
398 // Advance a_ancestor (b_ancestor) up the call chain until the call depth of
399 // a_ancestor or b_ancestor are the same. Necessarily each call to next_caller
400 // reduces the depth by exactly one.
401 if (a_depth > b_depth) {
402 for (int i = 0; i < a_depth - b_depth; ++i) {
403 a_ancestor = next_caller(a_ancestor);
404 if (a_ancestor == nullptr) {
405 return {nullptr, nullptr};
406 }
407 }
408 } else if (b_depth > a_depth) {
409 for (int i = 0; i < b_depth - a_depth; ++i) {
410 b_ancestor = next_caller(b_ancestor);
411 if (b_ancestor == nullptr) {
412 return {nullptr, nullptr};
413 }
414 }
415 }
416
417 while ((a_ancestor != nullptr) && (b_ancestor != nullptr)) {
418 if (a_ancestor->parent() == b_ancestor->parent()) {
419 return {a_ancestor, b_ancestor};
420 }
421
422 a_ancestor = next_caller(a_ancestor);
423 b_ancestor = next_caller(b_ancestor);
424 }
425 return {nullptr, nullptr};
426 }
427
ToString() const428 string CallGraph::ToString() const {
429 string out;
430 StrAppendFormat(&out, "Call graph for module %s:\n", module_->name());
431 for (const CallGraphNode& node : nodes()) {
432 StrAppendFormat(&out, "Computation %s:\n", node.computation()->name());
433 StrAppendFormat(&out, " calls:\n");
434 for (const HloComputation* callee : node.callees()) {
435 StrAppendFormat(&out, " %s\n", callee->name());
436 }
437 StrAppendFormat(&out, " called by:\n");
438 for (const HloComputation* caller : node.callers()) {
439 StrAppendFormat(&out, " %s\n", caller->name());
440 }
441 StrAppendFormat(&out, " callsites:\n");
442 for (const CallSite& callsite : node.callsites()) {
443 StrAppendFormat(&out, " %s\n", callsite.ToString());
444 }
445 }
446 return out;
447 }
448
449 } // namespace xla
450