1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/call_graph.h"
17
18 #include <queue>
19
20 #include "absl/container/flat_hash_set.h"
21 #include "absl/memory/memory.h"
22 #include "absl/strings/str_cat.h"
23 #include "absl/strings/str_format.h"
24 #include "absl/strings/str_join.h"
25 #include "tensorflow/compiler/xla/map_util.h"
26 #include "tensorflow/compiler/xla/status_macros.h"
27 #include "tensorflow/compiler/xla/util.h"
28 #include "tensorflow/core/lib/core/errors.h"
29 #include "tensorflow/core/lib/core/status.h"
30 #include "tensorflow/core/platform/types.h"
31
32 namespace xla {
33
34 using absl::StrAppendFormat;
35 using absl::StrCat;
36
CallContextToString(CallContext context)37 string CallContextToString(CallContext context) {
38 switch (context) {
39 case CallContext::kNone:
40 return "kNone";
41 case CallContext::kSequential:
42 return "kSequential";
43 case CallContext::kParallel:
44 return "kParallel";
45 case CallContext::kBoth:
46 return "kBoth";
47 }
48 }
49
operator <<(std::ostream & out,const CallContext & context)50 std::ostream& operator<<(std::ostream& out, const CallContext& context) {
51 out << CallContextToString(context);
52 return out;
53 }
54
GetInstructionCallContext(HloOpcode opcode)55 CallContext GetInstructionCallContext(HloOpcode opcode) {
56 switch (opcode) {
57 case HloOpcode::kCall:
58 case HloOpcode::kConditional:
59 case HloOpcode::kWhile:
60 return CallContext::kSequential;
61 case HloOpcode::kAllReduce:
62 case HloOpcode::kMap:
63 case HloOpcode::kReduce:
64 case HloOpcode::kReduceWindow:
65 case HloOpcode::kScatter:
66 case HloOpcode::kSelectAndScatter:
67 case HloOpcode::kSort:
68 case HloOpcode::kFusion:
69 case HloOpcode::kCustomCall:
70 return CallContext::kParallel;
71 default:
72 return CallContext::kNone;
73 }
74 }
75
ToString() const76 string CallSite::ToString() const {
77 return StrCat(
78 instruction()->name(), " calls in context ",
79 CallContextToString(context()), ": ",
80 absl::StrJoin(called_computations(), ", ",
81 [](string* out, const HloComputation* computation) {
82 out->append(computation->name());
83 }));
84 }
85
CallGraphNode(HloComputation * computation)86 CallGraphNode::CallGraphNode(HloComputation* computation)
87 : computation_(computation) {}
88
GetCallSite(const HloInstruction * instruction) const89 const CallSite* CallGraphNode::GetCallSite(
90 const HloInstruction* instruction) const {
91 auto it = callsite_instructions_.find(instruction);
92 if (it == callsite_instructions_.end()) {
93 return nullptr;
94 }
95 return &callsites_[it->second];
96 }
97
ToString() const98 std::string CallGraphNode::ToString() const { return computation_->name(); }
99
AddCallerCallSite(const CallSite & caller_callsite)100 void CallGraphNode::AddCallerCallSite(const CallSite& caller_callsite) {
101 caller_callsites_.push_back(caller_callsite);
102 HloComputation* caller = caller_callsite.instruction()->parent();
103 if (!ContainsKey(caller_set_, caller)) {
104 callers_.push_back(caller);
105 caller_set_.insert(caller);
106 }
107 }
108
AddCallSiteForInstruction(HloInstruction * instruction)109 void CallGraphNode::AddCallSiteForInstruction(HloInstruction* instruction) {
110 CHECK_EQ(instruction->parent(), computation());
111 const CallContext context = GetInstructionCallContext(instruction->opcode());
112 if (!instruction->called_computations().empty()) {
113 CHECK(context == CallContext::kSequential ||
114 context == CallContext::kParallel);
115 callsite_instructions_.insert({instruction, callsites_.size()});
116 callsites_.push_back(
117 CallSite(instruction, instruction->called_computations(), context));
118 // Update callee computations to include any new computations called by this
119 // instruction.
120 for (auto* callee : callsites_.back().called_computations()) {
121 if (!ContainsKey(callee_set_, callee)) {
122 callees_.push_back(callee);
123 callee_set_.insert(callee);
124 }
125 }
126 }
127 }
128
CallGraph(const HloModule * module)129 CallGraph::CallGraph(const HloModule* module) : module_(module) {}
130
GetNode(const HloComputation * computation) const131 const CallGraphNode& CallGraph::GetNode(
132 const HloComputation* computation) const {
133 auto it = node_indices_.find(computation);
134 CHECK(it != node_indices_.end());
135 return nodes_[it->second];
136 }
137
GetNode(const HloComputation * computation)138 CallGraphNode& CallGraph::GetNode(const HloComputation* computation) {
139 auto it = node_indices_.find(computation);
140 CHECK(it != node_indices_.end());
141 return nodes_[it->second];
142 }
143
DominatesHelper(const HloComputation * a,const HloComputation * b,absl::flat_hash_set<const HloComputation * > * visited) const144 bool CallGraph::DominatesHelper(
145 const HloComputation* a, const HloComputation* b,
146 absl::flat_hash_set<const HloComputation*>* visited) const {
147 if (a == b || ContainsKey(*visited, b)) {
148 // The call graph is guaranteed to be acyclic so any previously visited node
149 // we encounter was already determined to be dominated.
150 return true;
151 }
152
153 const CallGraphNode& b_node = GetNode(b);
154 if (b_node.callers().empty()) {
155 // We reached a root node without hitting 'a'. 'a' does not dominate 'b'.
156 return false;
157 }
158
159 // Walk up the callers of 'b' until we hit 'a' or a root node (no callers).
160 visited->insert(b);
161 for (const HloComputation* b_caller : b_node.callers()) {
162 if (!DominatesHelper(a, b_caller, visited)) {
163 return false;
164 }
165 }
166 return true;
167 }
168
Dominates(const HloComputation * a,const HloComputation * b) const169 bool CallGraph::Dominates(const HloComputation* a,
170 const HloComputation* b) const {
171 absl::flat_hash_set<const HloComputation*> visited;
172 return DominatesHelper(a, b, &visited);
173 }
174
175 namespace {
176
177 // Returns the call context of a computation which is called from contexts 'a'
178 // and 'b'.
UnionContexts(CallContext a,CallContext b)179 CallContext UnionContexts(CallContext a, CallContext b) {
180 if (a == CallContext::kNone) {
181 return b;
182 } else if (b == CallContext::kNone) {
183 return a;
184 } else if (a == b) {
185 return a;
186 } else {
187 // Contexts are different and neither is kNone, ie one is kSequential and
188 // the other is kParallel.
189 return CallContext::kBoth;
190 }
191 }
192
193 } // namespace
194
SetCallContexts()195 void CallGraph::SetCallContexts() {
196 std::queue<CallGraphNode*> worklist;
197
198 // Initialize worklist with all roots of the call graph (computations without
199 // callers).
200 for (const HloComputation* computation : module_->computations()) {
201 CallGraphNode& node = GetNode(computation);
202 if (node.callers().empty()) {
203 node.set_context(CallContext::kSequential);
204 worklist.push(&node);
205 }
206 }
207
208 while (!worklist.empty()) {
209 CallGraphNode* node = worklist.front();
210 worklist.pop();
211
212 for (const CallSite& callsite : node->callsites()) {
213 for (const HloComputation* callee : callsite.called_computations()) {
214 CallGraphNode& callee_node = GetNode(callee);
215
216 // Update context of callee computation based on the callsite and its
217 // current context.
218 CallContext context_to_add;
219 if (callsite.context() == CallContext::kParallel) {
220 context_to_add = CallContext::kParallel;
221 } else {
222 CHECK_EQ(callsite.context(), CallContext::kSequential);
223 context_to_add = node->context();
224 }
225 CallContext new_context =
226 UnionContexts(context_to_add, callee_node.context());
227
228 if (new_context != callee_node.context()) {
229 // Context of computation has been changed so add node to worklist.
230 callee_node.set_context(new_context);
231 worklist.push(&callee_node);
232 }
233 }
234 }
235 }
236
237 // No node should have a kNone calling context.
238 for (const HloComputation* computation : module_->computations()) {
239 CHECK_NE(GetNode(computation).context(), CallContext::kNone);
240 }
241 }
242
SetNodeDepths()243 void CallGraph::SetNodeDepths() {
244 std::queue<CallGraphNode*> worklist;
245
246 // Initialize node depths to -1.
247 for (CallGraphNode& node : nodes_) {
248 node.set_depth(-1);
249 }
250
251 // Initialize worklist with all roots of the call graph (computations without
252 // callers).
253 for (const HloComputation* computation : module_->computations()) {
254 CallGraphNode& node = GetNode(computation);
255 if (node.callers().empty()) {
256 node.set_depth(0);
257 worklist.push(&node);
258 }
259 }
260
261 while (!worklist.empty()) {
262 CallGraphNode* node = worklist.front();
263 worklist.pop();
264 for (const HloComputation* callee : node->callees()) {
265 CallGraphNode& callee_node = GetNode(callee);
266 if (callee_node.depth() < node->depth() + 1) {
267 callee_node.set_depth(node->depth() + 1);
268 worklist.push(&callee_node);
269 }
270 }
271 }
272
273 for (CallGraphNode& node : nodes_) {
274 CHECK_NE(node.depth(), -1);
275 }
276 }
277
278 /* static */
Build(const HloModule * module)279 std::unique_ptr<CallGraph> CallGraph::Build(const HloModule* module) {
280 // Constructor for CallGraph is private so absl::make_unique can't be used.
281 auto call_graph = absl::WrapUnique<CallGraph>(new CallGraph(module));
282
283 VLOG(3) << "Building call graph for:";
284 XLA_VLOG_LINES(3, module->ToString());
285
286 // Construct nodes of the call graph and populate the callsites.
287 for (HloComputation* computation : module->computations()) {
288 auto it_added = call_graph->node_indices_.insert(
289 {computation, call_graph->nodes_.size()});
290 // All computations should be unique, so the computation should not already
291 // exist in the map.
292 CHECK(it_added.second);
293 call_graph->nodes_.emplace_back(computation);
294
295 // Add all callsites in this computation.
296 for (HloInstruction* instruction : computation->instructions()) {
297 call_graph->nodes_.back().AddCallSiteForInstruction(instruction);
298 }
299 }
300
301 // Add caller callsites to each node.
302 for (const HloComputation* computation : module->computations()) {
303 for (const CallSite& callsite :
304 call_graph->GetNode(computation).callsites()) {
305 for (auto* callee : callsite.called_computations()) {
306 // Add caller callsites.
307 call_graph->GetNode(callee).AddCallerCallSite(callsite);
308 }
309 }
310 }
311
312 call_graph->SetCallContexts();
313 call_graph->SetNodeDepths();
314
315 XLA_VLOG_LINES(2, call_graph->ToString());
316
317 return call_graph;
318 }
319
VisitNodesInternal(const VisitorFunction & visitor_func,const CallGraphNode & node,absl::flat_hash_set<const CallGraphNode * > * visited) const320 Status CallGraph::VisitNodesInternal(
321 const VisitorFunction& visitor_func, const CallGraphNode& node,
322 absl::flat_hash_set<const CallGraphNode*>* visited) const {
323 auto pair = visited->insert(&node);
324 if (!pair.second) {
325 // Node was not inserted. Node has already been visited.
326 return Status::OK();
327 }
328
329 for (const HloComputation* computation : node.callees()) {
330 TF_RETURN_IF_ERROR(
331 VisitNodesInternal(visitor_func, GetNode(computation), visited));
332 }
333
334 return visitor_func(node);
335 }
336
VisitNodes(const VisitorFunction & visitor_func,bool visit_unreachable_nodes) const337 Status CallGraph::VisitNodes(const VisitorFunction& visitor_func,
338 bool visit_unreachable_nodes) const {
339 absl::flat_hash_set<const CallGraphNode*> visited;
340 if (visit_unreachable_nodes) {
341 // Traverse from all roots in the call graph.
342 for (const CallGraphNode& node : nodes()) {
343 if (node.callers().empty()) {
344 TF_RETURN_IF_ERROR(VisitNodesInternal(visitor_func, node, &visited));
345 }
346 }
347 } else {
348 // Traverse only from the entry computation.
349 TF_RETURN_IF_ERROR(VisitNodesInternal(
350 visitor_func, GetNode(module_->entry_computation()), &visited));
351 }
352
353 return Status::OK();
354 }
355
IsFlattened() const356 bool CallGraph::IsFlattened() const {
357 for (const CallGraphNode& node : nodes_) {
358 if (node.context() == CallContext::kBoth) {
359 return false;
360 }
361 if (node.context() == CallContext::kSequential &&
362 node.caller_callsites().size() > 1) {
363 return false;
364 }
365 }
366 return true;
367 }
368
GetComputationCallers(HloComputation * c)369 std::vector<HloInstruction*> CallGraph::GetComputationCallers(
370 HloComputation* c) {
371 std::vector<HloInstruction*> callers;
372 for (const auto& callsite : GetNode(c).caller_callsites()) {
373 callers.push_back(callsite.instruction());
374 }
375 return callers;
376 }
377
378 std::pair<HloInstruction*, HloInstruction*>
NearestAncestorsInSameComputation(HloInstruction * a,HloInstruction * b) const379 CallGraph::NearestAncestorsInSameComputation(HloInstruction* a,
380 HloInstruction* b) const {
381 // Lambda which returns the next instruction in the callee->caller chain in
382 // the call graph. This is the unique instruction which calls the computation
383 // containing 'instruction'. If more than one instruction calls the
384 // computation containing 'instruction' or no instructions call the
385 // computation then nullptr is returned.
386 auto next_caller = [this](HloInstruction* instruction) -> HloInstruction* {
387 const CallGraphNode& node = GetNode(instruction->parent());
388 if (node.caller_callsites().size() != 1) {
389 return nullptr;
390 }
391 return node.caller_callsites()[0].instruction();
392 };
393
394 // Iterate through the callee->caller chains and find the earliest common
395 // element.
396 HloInstruction* a_ancestor = a;
397 HloInstruction* b_ancestor = b;
398 int a_depth = GetNode(a->parent()).depth();
399 int b_depth = GetNode(b->parent()).depth();
400
401 // Advance a_ancestor (b_ancestor) up the call chain until the call depth of
402 // a_ancestor or b_ancestor are the same. Necessarily each call to next_caller
403 // reduces the depth by exactly one.
404 if (a_depth > b_depth) {
405 for (int i = 0; i < a_depth - b_depth; ++i) {
406 a_ancestor = next_caller(a_ancestor);
407 if (a_ancestor == nullptr) {
408 return {nullptr, nullptr};
409 }
410 }
411 } else if (b_depth > a_depth) {
412 for (int i = 0; i < b_depth - a_depth; ++i) {
413 b_ancestor = next_caller(b_ancestor);
414 if (b_ancestor == nullptr) {
415 return {nullptr, nullptr};
416 }
417 }
418 }
419
420 while ((a_ancestor != nullptr) && (b_ancestor != nullptr)) {
421 if (a_ancestor->parent() == b_ancestor->parent()) {
422 return {a_ancestor, b_ancestor};
423 }
424
425 a_ancestor = next_caller(a_ancestor);
426 b_ancestor = next_caller(b_ancestor);
427 }
428 return {nullptr, nullptr};
429 }
430
ToString() const431 string CallGraph::ToString() const {
432 string out;
433 StrAppendFormat(&out, "Call graph for module %s:\n", module_->name());
434 for (const CallGraphNode& node : nodes()) {
435 StrAppendFormat(&out, "Computation %s:\n", node.computation()->name());
436 StrAppendFormat(&out, " calls:\n");
437 for (const HloComputation* callee : node.callees()) {
438 StrAppendFormat(&out, " %s\n", callee->name());
439 }
440 StrAppendFormat(&out, " called by:\n");
441 for (const HloComputation* caller : node.callers()) {
442 StrAppendFormat(&out, " %s\n", caller->name());
443 }
444 StrAppendFormat(&out, " callsites:\n");
445 for (const CallSite& callsite : node.callsites()) {
446 StrAppendFormat(&out, " %s\n", callsite.ToString());
447 }
448 }
449 return out;
450 }
451
452 } // namespace xla
453