1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
17
18 #include <unistd.h>
19
20 #include <algorithm>
21 #include <atomic>
22 #include <deque>
23 #include <map>
24 #include <memory>
25 #include <queue>
26 #include <string>
27 #include <tuple>
28 #include <vector>
29
30 #include "absl/container/flat_hash_map.h"
31 #include "absl/strings/match.h"
32 #include "absl/strings/str_cat.h"
33 #include "absl/strings/str_format.h"
34 #include "absl/strings/str_join.h"
35 #include "absl/strings/str_replace.h"
36 #include "absl/types/optional.h"
37 #include "tensorflow/compiler/xla/layout_util.h"
38 #include "tensorflow/compiler/xla/literal.h"
39 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
40 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
41 #include "tensorflow/compiler/xla/service/hlo_module.h"
42 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
43 #include "tensorflow/compiler/xla/shape_util.h"
44 #include "tensorflow/compiler/xla/types.h"
45 #include "tensorflow/compiler/xla/util.h"
46 #include "tensorflow/compiler/xla/window_util.h"
47 #include "tensorflow/core/lib/core/status.h"
48 #include "tensorflow/core/lib/gtl/map_util.h"
49 #include "tensorflow/core/lib/io/path.h"
50 #include "tensorflow/core/lib/strings/numbers.h"
51 #include "tensorflow/core/platform/env.h"
52 #include "tensorflow/core/platform/mutex.h"
53 #include "tensorflow/core/platform/protobuf.h"
54 #include "tensorflow/core/platform/regexp.h"
55
56 namespace xla {
57 namespace {
58
59 using absl::nullopt;
60 using absl::optional;
61 using absl::StrAppend;
62 using absl::StrCat;
63 using absl::StrFormat;
64 using absl::StrJoin;
65
66 // Used to indicate how we should treat a given HLOInstruction in the graph.
67 // should we treat it like normal, hide it, and so on?
68 enum NodeFilterResult {
69 kNormalNode,
70 kHideNode,
71 // Make the node easy to find in the final graph.
72 kHighlightNode,
73 // "Gray out" the node to indicate that some of its operands have been
74 // omitted.
75 kSomeOperandsOmitted,
76 // Style the node the same as kSomeOperandsOmitted, but also don't connect it
77 // to its operands, even if they're present in the graph.
78 kOmitNodeOperands,
79 // Same style as kSomeOperandsOmitted, but used to indicate that some of the
80 // node's *users* have been omitted.
81 kSomeUsersOmitted,
82 };
83
84 // NodeFilter is essentially a map from HloInstruction*s to NodeFilterResult.
85 // It lets callers tell the graph-drawing routines which nodes they want to be
86 // shown, hidden, or highlighted.
87 class NodeFilter {
88 public:
__anon627048d20202(const HloInstruction*) 89 NodeFilter() : filter_([](const HloInstruction*) { return kNormalNode; }) {}
90
NodeFilter(std::function<NodeFilterResult (const HloInstruction * instr)> filter)91 explicit NodeFilter(
92 std::function<NodeFilterResult(const HloInstruction* instr)> filter)
93 : filter_(std::move(filter)) {}
94
Show(const HloInstruction * instr) const95 bool Show(const HloInstruction* instr) const {
96 return filter_(instr) != kHideNode;
97 }
Highlight(const HloInstruction * instr) const98 bool Highlight(const HloInstruction* instr) const {
99 return filter_(instr) == kHighlightNode;
100 }
OmitOperands(const HloInstruction * instr) const101 bool OmitOperands(const HloInstruction* instr) const {
102 return filter_(instr) == kOmitNodeOperands;
103 }
SomeOrAllOperandsOmitted(const HloInstruction * instr) const104 bool SomeOrAllOperandsOmitted(const HloInstruction* instr) const {
105 auto result = filter_(instr);
106 return result == kOmitNodeOperands || result == kSomeOperandsOmitted;
107 }
Deemphasized(const HloInstruction * instr) const108 bool Deemphasized(const HloInstruction* instr) const {
109 auto result = filter_(instr);
110 return result == kOmitNodeOperands || result == kSomeOperandsOmitted ||
111 result == kSomeUsersOmitted;
112 }
113
114 private:
115 std::function<NodeFilterResult(const HloInstruction* instr)> filter_;
116 };
117
118 // We arbitrarily set this as the boundary between "large" and "small"
119 // instructions.
IsSmall(const HloInstruction * instr)120 bool IsSmall(const HloInstruction* instr) {
121 if (ShapeUtil::HasPrimitiveType(instr->shape(), OPAQUE_TYPE) ||
122 ShapeUtil::HasPrimitiveType(instr->shape(), TOKEN)) {
123 return true;
124 }
125 return ShapeUtil::ElementsInRecursive(instr->shape()) < 4096;
126 }
127
128 // Node color schemes, used by NodeColorAttributes.
129 enum ColorScheme {
130 kBlue,
131 kBrown,
132 kDarkBlue,
133 kDarkGreen,
134 kDarkOrange,
135 kDarkRed,
136 kGray,
137 kGreen,
138 kOrange,
139 kPurple,
140 kRed,
141 kWhite,
142 kYellow,
143
144 // Causes the node's border to be a dashed line, and its content to be gray
145 // text on a white background, suggesting that this is an "unimportant" node.
146 kDashedBorder,
147 };
148
149 // Graphviz attributes/colors that make up a color scheme.
150 struct NodeColors {
151 const char* style;
152 const char* fill_color;
153 const char* stroke_color;
154 const char* font_color;
155 };
156
NodeColorsForScheme(ColorScheme color)157 NodeColors NodeColorsForScheme(ColorScheme color) {
158 switch (color) {
159 case kBlue:
160 return NodeColors{"filled", "#bbdefb", "#8aacc8", "black"};
161 case kBrown:
162 return NodeColors{"filled", "#bcaaa4", "#8c7b75", "black"};
163 case kDarkBlue:
164 return NodeColors{"filled", "#1565c0", "#003c8f", "white"};
165 case kDarkGreen:
166 return NodeColors{"filled", "#2e7d32", "#005005", "white"};
167 case kDarkOrange:
168 // This is more of a "medium" orange, made to look close to kOrange;
169 // there's probably room for a darker weight if desired.
170 return NodeColors{"filled", "#ffb74d", "#c88719", "black"};
171 case kDarkRed:
172 return NodeColors{"filled", "#b71c1c", "#7f0000", "white"};
173 case kGray:
174 return NodeColors{"filled", "#cfd8dc", "#9ea7aa", "black"};
175 case kGreen:
176 return NodeColors{"filled", "#c8e6c9", "#97b498", "black"};
177 case kOrange:
178 return NodeColors{"filled", "#ffe0b2", "#cbae82", "black"};
179 case kPurple:
180 return NodeColors{"filled", "#e1bee7", "#af8eb5", "black"};
181 case kRed:
182 return NodeColors{"filled", "#ffcdd2", "#cb9ca1", "black"};
183 case kWhite:
184 return NodeColors{"filled", "white", "black", "black"};
185 case kYellow:
186 return NodeColors{"filled", "#fff9c4", "#cbc693", "black"};
187 case kDashedBorder:
188 // "filled,dashed" looks the same as "dashed", since we have a white
189 // background. But we use "filled,dashed" so that when you hover over
190 // any part of the node (not just the text inside the node), our css
191 // :hover rule is triggered.
192 return NodeColors{"filled,dashed", "white", "#757575", "#757575"};
193 }
194 }
195
196 // Given a ColorScheme, returns an attribute string for a node of that color.
197 // Sets the node's style and fill/stroke/text colors.
198 //
199 // Colors are from https://material.io/color.
NodeColorAttributes(ColorScheme color)200 string NodeColorAttributes(ColorScheme color) {
201 NodeColors node_colors = NodeColorsForScheme(color);
202
203 return StrFormat(R"(style="%s", fontcolor="%s", color="%s", fillcolor="%s")",
204 node_colors.style, node_colors.font_color,
205 node_colors.stroke_color, node_colors.fill_color);
206 }
207
208 // Replaces <> with <>, so that this string is safe(er) for use in a
209 // graphviz HTML-like string.
HtmlLikeStringSanitize(absl::string_view s)210 string HtmlLikeStringSanitize(absl::string_view s) {
211 return absl::StrReplaceAll(s, {{"<", "<"}, {">", ">"}});
212 }
213
IsFusedBroadcastOfConstantEffectiveScalar(const HloInstruction * instr)214 bool IsFusedBroadcastOfConstantEffectiveScalar(const HloInstruction* instr) {
215 namespace m = match;
216 return instr->parent()->IsFusionComputation() &&
217 Match(instr, m::Broadcast(m::ConstantEffectiveScalar()));
218 }
219
220 // Tries to generates a human-readable one-word description of the given
221 // computation.
222 //
223 // Currently we support:
224 //
225 // "return param0 + param1;" --> "add"
226 // "return param0 * param1;" --> "multiply"
227 // "return min(param0, param1);" --> "min"
228 // "return max(param0, param1);" --> "max"
229 // "return param0 <= param1;" --> "less-or-equal"
230 // "return param0 >= param1;" --> "greater-or-equal"
231 // "return param0 > param1;" --> "greater-than"
232 // "return param0 < param1;" --> "less-than"
233 // "return param0 == param1;" --> "equal-to"
234 // "return param0 != param1;" --> "not-equal-to"
235 //
236 // where param0 and param1 are effective scalars. For the ops that are
237 // commutative, we also support them with param0 and param1 swapped.
238 //
239 // This is useful primarily for reduce and map nodes. These take a
240 // subcomputation which is almost always one of the above, and pattern matching
241 // it to a short string lets us tell the user what the subcomputation is without
242 // drawing it as a graph.
MatchTrivialComputation(const HloComputation * computation)243 optional<string> MatchTrivialComputation(const HloComputation* computation) {
244 namespace m = match;
245
246 if (computation->instruction_count() != 3) {
247 return nullopt;
248 }
249 HloInstruction* root = computation->root_instruction();
250 const HloInstruction *param0, *param1;
251 if (!Match(root, m::Op()
252 .WithNumOperands(2)
253 .WithShape(m::Shape().IsEffectiveScalar())
254 .WithBinaryOperandsAnyOrder(
255 m::Parameter(¶m0, 0)
256 .WithShape(m::Shape().IsEffectiveScalar()),
257 m::Parameter(¶m1, 1)
258 .WithShape(m::Shape().IsEffectiveScalar())))) {
259 return nullopt;
260 }
261
262 // If the params are reversed (i.e. operand0 is param1 and operand1 is
263 // param0), check that the operation being performed is commutative.
264 if (root->operand(0) == param1) {
265 CHECK_EQ(root->operand(1), param0);
266 if (root->opcode() == HloOpcode()) {
267 switch (root->comparison_direction()) {
268 case ComparisonDirection::kLe:
269 case ComparisonDirection::kGe:
270 case ComparisonDirection::kGt:
271 case ComparisonDirection::kLt:
272 return nullopt;
273 default:
274 break;
275 }
276 }
277 }
278
279 // If we recognize the root's opcode, we've successfully pattern-matched!
280 switch (root->opcode()) {
281 case HloOpcode::kAdd:
282 return "add";
283 case HloOpcode::kMultiply:
284 return "multiply";
285 case HloOpcode::kMinimum:
286 return "min";
287 case HloOpcode::kMaximum:
288 return "max";
289 case HloOpcode::kCompare: {
290 switch (root->comparison_direction()) {
291 case ComparisonDirection::kLe:
292 return "less-or-equal";
293 case ComparisonDirection::kGe:
294 return "greater-or-equal";
295 case ComparisonDirection::kGt:
296 return "greater-than";
297 case ComparisonDirection::kLt:
298 return "less-than";
299 case ComparisonDirection::kEq:
300 return "equal-to";
301 case ComparisonDirection::kNe:
302 return "not-equal-to";
303 }
304 }
305 default:
306 return nullopt;
307 }
308 }
309
310 // Encapsulates logic for dumping an HLO module to DOT (i.e. graphviz syntax).
311 class HloDotDumper {
312 public:
HloDotDumper(const HloComputation * computation,absl::string_view label,const DebugOptions & debug_options,bool show_backend_config,const HloExecutionProfile * profile,NodeFilter filter)313 HloDotDumper(const HloComputation* computation, absl::string_view label,
314 const DebugOptions& debug_options, bool show_backend_config,
315 const HloExecutionProfile* profile, NodeFilter filter)
316 : computation_(computation),
317 label_(label),
318 debug_options_(debug_options),
319 show_backend_config_(show_backend_config),
320 profile_(profile),
321 filter_(std::move(filter)) {}
322
323 string Dump();
324
325 private:
326 // Returns the dot graph identifier for the given instruction.
InstructionId(const HloInstruction * instruction)327 string InstructionId(const HloInstruction* instruction) {
328 return StrCat(reinterpret_cast<uint64>(instruction));
329 }
330
331 // Returns the dot graph identifier for the given computation.
SubcomputationId(const HloComputation * computation)332 string SubcomputationId(const HloComputation* computation) {
333 return StrCat("cluster_", reinterpret_cast<uint64>(computation));
334 }
335
336 // Generates graph header/footer. These should be called *after* dumping all
337 // of the instructions and subcomputations for the graph, as they both use
338 // data generated while dumping the graph.
339 string Header();
340 string Footer();
341
342 bool ShouldShowSubcomputation(const HloComputation* subcomp);
343 bool ShouldShowFusionSubcomputation(const HloInstruction* instr);
344
345 // We omit some nodes from the graph, instead drawing them inlined into the
346 // nodes that use them.
347 bool ShouldMergeIntoUsers(const HloInstruction* instr) const;
348
349 string DumpSubcomputation(const HloComputation* subcomp,
350 const HloInstruction* parent_instr);
351 string DumpComputation(const HloComputation* comp);
352 string DumpRootTag();
353 string DumpInstruction(const HloInstruction* instr);
354 ColorScheme GetInstructionColor(const HloInstruction* instr);
355 string GetInstructionNodeShape(const HloInstruction* instr);
356 string GetInstructionNodeLabel(const HloInstruction* instr);
357 string GetInstructionNodeMetadata(const HloInstruction* instr);
358 string GetInstructionNodeBackendConfig(const HloInstruction* instr);
359 string GetInstructionNodeExtraInfo(const HloInstruction* instr);
360 string GetInstructionNodeInlinedOperands(const HloInstruction* instr);
361 void AddInstructionIncomingEdges(const HloInstruction* instr);
362
363 // For most instructions, GetNodeForEdge(instr) returns instr.
364 //
365 // The exception is fusion nodes. For these, we walk up the chain of nested
366 // fusion nodes starting at instr until we reach a node that either (a) isn't
367 // a fusion node, or (b) is a fusion node for which
368 // ShouldShowFusionSubcomputation is false.
369 //
370 // We do this because fusion nodes are expanded inline -- if
371 // ShouldShowFusionSubcomputation is true, the fusion node won't be present in
372 // the graph.
373 //
374 // In general when you want to draw an edge from A to B, you should actually
375 // draw an edge from GetNodeForEdge(A) to GetNodeForEdge(B).
376 const HloInstruction* GetNodeForEdge(const HloInstruction* instr);
377
378 // If instr has just one computation and it's trivial (e.g. "return param0 +
379 // param1"), returns a string you can put into the node's body that names the
380 // subcomputation, e.g. "Subcomputation: <b>add</b>".
381 string GetInstructionTrivialComputationStr(const HloInstruction* instr);
382
383 const HloComputation* computation_; // never null
384 const string label_; // overall name for the graph
385 const DebugOptions& debug_options_;
386 const bool show_backend_config_;
387 const HloExecutionProfile* profile_; // may be null
388 const NodeFilter filter_;
389
390 // Each HloInstruction dumped gets a monotonically-increasing node ID. This
391 // must start at 1, because that's where graphviz's accounting starts.
392 int64 next_node_id_ = 1;
393 absl::flat_hash_map<const HloInstruction*, int64> node_ids_;
394
395 // The "root" tag doesn't have an associated HloInstruction pointer, so we
396 // need to store it outside the map.
397 int64 root_node_id_;
398
399 // Each (from, to) edge gets a monotonically-increasing ID. This is a
400 // multimap because it's possible for the same edge to appear multiple times
401 // in the graph (e.g. x^2 may be represented as mul(x, x)).
402 int64 next_edge_id_ = 1;
403 std::unordered_multimap<
404 std::pair<const HloInstruction*, const HloInstruction*>, int64,
405 tensorflow::hash<std::pair<const HloInstruction*, const HloInstruction*>>>
406 edge_ids_;
407
408 // Each HloComputation that's emitted gets a monotonically-increasing ID.
409 int64 next_cluster_id_ = 1;
410 absl::flat_hash_map<const HloComputation*, int64> cluster_ids_;
411
412 // Edges to print from Footer(). Edges come at the end because graphviz is
413 // unhappy if an edge from a subcomputation to a node in the outer computation
414 // appears before both the inner computation and the destination node are
415 // defined.
416 std::vector<string> edges_;
417
418 // When coloring by sharding information, we track the sharding string
419 // representation to color association, by round-robin the color schemes.
420 absl::flat_hash_map<HloSharding, ColorScheme, HloSharding::Hasher>
421 sharding_colors_;
422 int64 next_shard_color_ = 0;
423 };
424
Dump()425 string HloDotDumper::Dump() {
426 string body;
427 StrAppend(&body, DumpComputation(computation_));
428 StrAppend(&body, DumpRootTag());
429
430 // By contract, Header() and Footer() have to be called after we've dumped all
431 // our instructions, because they use state generated during that process.
432 string g = Header();
433 StrAppend(&g, body);
434 StrAppend(&g, Footer());
435 return g;
436 }
437
Header()438 string HloDotDumper::Header() {
439 constexpr char fmt[] = R"(digraph G {
440 rankdir = TB;
441 compound = true;
442 label = <<b>%s</b>>;
443 labelloc = t;
444 // Disable the tooltip. Interestingly, "" doesn't work!
445 tooltip = " ";
446 // DOT graphs accept a stylesheet as a URI. So naturally, an inline
447 // stylesheet is a data URI!
448 stylesheet=<
449 data:text/css,
450 @import url(https://fonts.googleapis.com/css?family=Roboto:400,700);
451 svg text {
452 font-family: 'Roboto';
453 font-size: 12px;
454 }
455
456 %s
457 >
458
459 )";
460
461 VLOG(3) << "Generating Header";
462
463 string graph_label =
464 StrCat(label_, "<br/>Computation ", computation_->name());
465 if (computation_->IsFusionComputation()) {
466 StrAppend(&graph_label, " (in fusion instruction ",
467 computation_->FusionInstruction()->name(), ")");
468 }
469 if (profile_ != nullptr) {
470 auto cycles = profile_->total_cycles_executed(*computation_);
471 absl::StrAppendFormat(&graph_label, "<br/>total cycles = %d (%s)", cycles,
472 tensorflow::strings::HumanReadableNum(cycles));
473 }
474
475 // Create CSS rules that say, when you hover over the given node or cluster,
476 // turn the given edge the given color.
477 //
478 // We rely on a few properties of how graphviz generates SVGs:
479 //
480 // - Nodes are named "nodeN", where N corresponds to the 1-based index of
481 // the node in our DOT (i.e. the first node in the DOT is "node1", etc.).
482 // Edges are similarly named "edgeN", and clusters are named "clustN".
483 // - Nodes come before their in- and out-edges in the SVG. We need this
484 // because the "X ~ Y" CSS selector finds a sibling of X that *comes
485 // after X in the DOM* and matches Y.
486 std::vector<string> edge_css_rules;
487 const char* kBlue = "#1976d2";
488 const char* kRed = "#d32f2f";
489 for (const auto& kv : edge_ids_) {
490 const HloInstruction* from_node = kv.first.first;
491 const HloInstruction* to_node = kv.first.second;
492 int64 edge_id = kv.second;
493
494 auto add_hover_css_rule = [&](string elem_type, int64 elem_id,
495 const char* color) {
496 // One could imagine other ways of writing this CSS rule that involve
497 // less duplication, but this way seems to be relatively performant.
498 edge_css_rules.push_back(
499 StrFormat(" #%s%d:hover ~ #edge%d text { fill: %s; }\n"
500 " #%s%d:hover ~ #edge%d path { "
501 "stroke: %s; stroke-width: .2em; }\n"
502 " #%s%d:hover ~ #edge%d polygon { "
503 "fill: %s; stroke: %s; stroke-width: .2em; }\n",
504 elem_type, elem_id, edge_id, color, //
505 elem_type, elem_id, edge_id, color, //
506 elem_type, elem_id, edge_id, color, color));
507 };
508
509 // The "to_node" value may be a NULL, indicating that this points to the
510 // "root" tag rather than a normal node.
511 int64 from_node_id =
512 tensorflow::gtl::FindWithDefault(node_ids_, from_node, -1);
513 if (from_node_id == -1) {
514 LOG(FATAL) << from_node->name() << " was added to edges but not to nodes";
515 }
516 int64 to_node_id =
517 to_node ? tensorflow::gtl::FindWithDefault(node_ids_, to_node, -1)
518 : root_node_id_;
519 if (to_node != nullptr && to_node_id == -1) {
520 LOG(FATAL) << to_node->name() << " was added to edges but not to nodes";
521 }
522
523 add_hover_css_rule("node", from_node_id, kBlue);
524 add_hover_css_rule("node", to_node_id, kRed);
525
526 if (to_node) {
527 VLOG(3) << "Adding css for edge " << edge_id << " from node "
528 << from_node->name() << " to node " << to_node->name();
529 } else {
530 VLOG(3) << "Adding css for edge " << edge_id << " from node "
531 << from_node->name() << " to root tag";
532 }
533
534 // If this edge crosses a fusion cluster boundary, highlight it when the
535 // cluster is hovered over.
536 if (to_node) {
537 if (from_node->IsFused() &&
538 from_node->parent()->root_instruction() == from_node) {
539 int64 cluster_id = cluster_ids_.at(from_node->parent());
540 add_hover_css_rule("clust", cluster_id, kBlue);
541 }
542 if (to_node->IsFused() && to_node->opcode() == HloOpcode::kParameter) {
543 int64 cluster_id = cluster_ids_.at(to_node->parent());
544 add_hover_css_rule("clust", cluster_id, kRed);
545 }
546 }
547 }
548
549 // Browsers require that we URI-encode the contents of our data URI. (It
550 // seems this was a relatively recent change?) In practice, this means that we
551 // need to escape '#'.
552 return StrFormat(
553 fmt, graph_label,
554 absl::StrReplaceAll(StrJoin(edge_css_rules, "\n"), {{"#", "%23"}}));
555 }
556
Footer()557 string HloDotDumper::Footer() { return StrCat(StrJoin(edges_, "\n"), "\n}"); }
558
ShouldShowFusionSubcomputation(const HloInstruction * instr)559 bool HloDotDumper::ShouldShowFusionSubcomputation(const HloInstruction* instr) {
560 CHECK_EQ(instr->opcode(), HloOpcode::kFusion);
561 return ShouldShowSubcomputation(instr->fused_instructions_computation());
562 }
563
ShouldShowSubcomputation(const HloComputation * subcomp)564 bool HloDotDumper::ShouldShowSubcomputation(const HloComputation* subcomp) {
565 if (subcomp->IsFusionComputation()) {
566 const HloInstruction* fusion = subcomp->FusionInstruction();
567 if (!filter_.Show(fusion) || filter_.SomeOrAllOperandsOmitted(fusion)) {
568 return false;
569 }
570 }
571
572 // Don't show trivial subcomputations on non-fusion nodes -- these are inlined
573 // into the graph.
574 if (!subcomp->IsFusionComputation() && MatchTrivialComputation(subcomp)) {
575 return false;
576 }
577
578 // Show the subcomputation if we're showing any of its members.
579 return absl::c_any_of(
580 subcomp->instructions(),
581 [&](const HloInstruction* instr) { return filter_.Show(instr); });
582 }
583
DumpSubcomputation(const HloComputation * subcomp,const HloInstruction * parent_instr)584 string HloDotDumper::DumpSubcomputation(const HloComputation* subcomp,
585 const HloInstruction* parent_instr) {
586 VLOG(2) << "Dumping subcomputation " << subcomp->name();
587 // Add an edge from the subcomputation to its parent node. If subcomp
588 // belongs to a fusion node, it's drawn in place of the fusion instruction,
589 // so there's no need to link those.
590 if (parent_instr->opcode() != HloOpcode::kFusion) {
591 const HloInstruction* from = GetNodeForEdge(subcomp->root_instruction());
592 VLOG(2) << "Edge: from " << from->name() << " to " << parent_instr->name()
593 << " as " << next_edge_id_;
594 edge_ids_.insert({{from, parent_instr}, next_edge_id_++});
595 constexpr char edge_fmt[] =
596 R"(%s -> %s [ltail="%s", style="dashed" tooltip="%s -> %s"];)";
597 edges_.push_back(StrFormat(
598 edge_fmt, InstructionId(from), InstructionId(parent_instr),
599 SubcomputationId(subcomp), subcomp->name(), parent_instr->name()));
600 }
601
602 // Have we already dumped this subcomputation? If so, generating the edge
603 // linking it and parent_instr is all we want to do in this function.
604 if (cluster_ids_.find(subcomp) != cluster_ids_.end()) {
605 return "";
606 }
607
608 cluster_ids_[subcomp] = next_cluster_id_++;
609
610 string id = SubcomputationId(subcomp);
611
612 string subcomp_label, style;
613 if (parent_instr->opcode() == HloOpcode::kFusion) {
614 subcomp_label =
615 StrFormat("Fused expression for <b>%s</b><br/>%s",
616 HtmlLikeStringSanitize(parent_instr->name()),
617 HtmlLikeStringSanitize(parent_instr->ToCategory()));
618 string extra_info = GetInstructionNodeExtraInfo(parent_instr);
619 if (!extra_info.empty()) {
620 StrAppend(&subcomp_label, "<br/>", extra_info);
621 }
622 string node_backend_config = GetInstructionNodeBackendConfig(parent_instr);
623 if (!node_backend_config.empty()) {
624 StrAppend(&subcomp_label, "<br/>", node_backend_config);
625 }
626
627 bool highlight = filter_.Highlight(parent_instr);
628 const char* fillcolor;
629 const char* strokecolor;
630 if (debug_options_.xla_hlo_graph_sharding_color() && !highlight) {
631 // Use the sharding color, if the node isn't highlighted.
632 NodeColors node_colors =
633 NodeColorsForScheme(GetInstructionColor(parent_instr));
634 fillcolor = node_colors.fill_color;
635 strokecolor = node_colors.stroke_color;
636 } else {
637 // Subcomputation's fill/stroke color is light/dark red/gray, depending on
638 // whether or not the subcomputation's fusion node is highlighted.
639 fillcolor = highlight ? "#ffcdd2" : "#f5f5f5";
640 strokecolor = highlight ? "#b71c1c" : "#c2c2c2";
641 }
642 style =
643 StrFormat(R"(style="rounded,filled,bold"; fillcolor="%s"; color="%s;")",
644 fillcolor, strokecolor);
645 } else {
646 subcomp_label = StrFormat("Subcomputation for <b>%s</b><br/>%s",
647 HtmlLikeStringSanitize(parent_instr->name()),
648 HtmlLikeStringSanitize(subcomp->name()));
649 style = "style=rounded; color=black;";
650 }
651
652 string comp_body = DumpComputation(subcomp);
653
654 constexpr char computation_fmt[] = R"(subgraph %s {
655 %s
656 label = <%s>;
657 labelloc = t;
658 tooltip = " ";
659 %s
660 } // %s
661
662 )";
663 return StrFormat(computation_fmt, id, style, subcomp_label, comp_body, id);
664 }
665
DumpComputation(const HloComputation * comp)666 string HloDotDumper::DumpComputation(const HloComputation* comp) {
667 string g;
668 for (const auto* instr : comp->instructions()) {
669 if (!filter_.Show(instr)) {
670 continue;
671 }
672
673 // Dump subcomputations within instr.
674 for (const HloComputation* subcomp : instr->called_computations()) {
675 if (ShouldShowSubcomputation(subcomp)) {
676 StrAppend(&g, DumpSubcomputation(subcomp, instr));
677 }
678 }
679
680 StrAppend(&g, DumpInstruction(instr));
681 }
682 return g;
683 }
684
DumpRootTag()685 string HloDotDumper::DumpRootTag() {
686 const HloInstruction* from = GetNodeForEdge(computation_->root_instruction());
687
688 // We didn't display constants or broadcasts of effective scalars within
689 // fusions as separate nodes; so if the root is a constant/broadcast of
690 // scalar, we don't add root tag or edge for it.
691 if (!filter_.Show(from) || from->opcode() == HloOpcode::kConstant ||
692 IsFusedBroadcastOfConstantEffectiveScalar(from)) {
693 return "";
694 }
695
696 auto from_id = InstructionId(from);
697
698 // The ID of the root computation is otherwise unused, so it makes a good ID
699 // to use for the root-tag node. However, the edge_ids_ map requires a
700 // HloInstruction* pointer for the 'to' value, so we use a NULL value there
701 // (rather than a pointer type-cast) to make it obvious if it is erroneously
702 // dereferenced.
703 HloInstruction* to = nullptr;
704 auto to_id = SubcomputationId(computation_);
705
706 string node_body = "ROOT";
707 string node_shape = "circle";
708 ColorScheme color = kBrown;
709
710 VLOG(2) << "Adding root tag as node " << next_node_id_;
711 root_node_id_ = next_node_id_++;
712
713 VLOG(2) << "Adding edge from " << from->name() << " to root tag as "
714 << next_edge_id_;
715 edge_ids_.insert({{from, to}, next_edge_id_++});
716 edges_.push_back(StrFormat(R"(%s -> %s [tooltip=" "];)", from_id, to_id));
717
718 return StrFormat(R"(%s [label=<%s>, shape=%s, tooltip=" ", %s];)"
719 "\n",
720 to_id, node_body, node_shape, NodeColorAttributes(color));
721 }
722
TryGetFusionParameterConstant(const HloInstruction * instr)723 static const HloConstantInstruction* TryGetFusionParameterConstant(
724 const HloInstruction* instr) {
725 if (instr->opcode() != HloOpcode::kParameter || !instr->IsFused()) {
726 return nullptr;
727 }
728 const HloInstruction* fusion = instr->parent()->FusionInstruction();
729 const HloInstruction* operand = fusion->operand(instr->parameter_number());
730 return DynCast<HloConstantInstruction>(operand);
731 }
732
ShouldMergeIntoUsers(const HloInstruction * instr) const733 bool HloDotDumper::ShouldMergeIntoUsers(const HloInstruction* instr) const {
734 // If a node:
735 //
736 // - is a parameter of a fusion node which is bound to a constant,
737 //
738 // or
739 //
740 // - is a tuple-shaped parameter, and
741 // - is not a parameter to a fusion node, and
742 // - has at least kMinUsersToOmit users shown, and
743 // - all of the shown users are get-tuple-elements,
744 //
745 // then we omit it from the graph, merging it with its users.
746 //
747 // This helps us handle the common case where a while loop body has one big
748 // tuple-shaped parameter.
749 if (TryGetFusionParameterConstant(instr) != nullptr) {
750 return true;
751 }
752 const int kMinUsersToOmit = 3;
753 return instr->opcode() == HloOpcode::kParameter && instr->shape().IsTuple() &&
754 !instr->IsFused() &&
755 absl::c_count_if(instr->users(),
756 [&](const HloInstruction* user) {
757 return filter_.Show(user);
758 }) > kMinUsersToOmit &&
759 absl::c_all_of(instr->users(), [&](const HloInstruction* user) {
760 return !filter_.Show(user) ||
761 user->opcode() == HloOpcode::kGetTupleElement;
762 });
763 }
764
DumpInstruction(const HloInstruction * instr)765 string HloDotDumper::DumpInstruction(const HloInstruction* instr) {
766 // We don't display constants or broadcasts of effective scalar constants
767 // within fusions as separate nodes; they're merged into their users.
768 if (instr->opcode() == HloOpcode::kConstant ||
769 IsFusedBroadcastOfConstantEffectiveScalar(instr)) {
770 return "";
771 }
772 // Skip this node if it's merged into its users.
773 if (ShouldMergeIntoUsers(instr)) {
774 return "";
775 }
776 // Omit the fusion node if its subcomputation is drawn, since the
777 // subcomputation will be drawn inline.
778 if (instr->opcode() == HloOpcode::kFusion &&
779 ShouldShowFusionSubcomputation(instr)) {
780 return "";
781 }
782
783 VLOG(2) << "Adding node " << instr->name() << " as " << next_node_id_;
784 node_ids_[instr] = next_node_id_++;
785
786 ColorScheme color = GetInstructionColor(instr);
787 string node_shape = GetInstructionNodeShape(instr);
788 string node_label = GetInstructionNodeLabel(instr);
789 string node_metadata = GetInstructionNodeMetadata(instr);
790 string node_backend_config = GetInstructionNodeBackendConfig(instr);
791 string extra_info = GetInstructionNodeExtraInfo(instr);
792 string inlined_constants = GetInstructionNodeInlinedOperands(instr);
793 string trivial_subcomputation = GetInstructionTrivialComputationStr(instr);
794 AddInstructionIncomingEdges(instr);
795
796 if (!debug_options_.xla_hlo_graph_sharding_color()) {
797 // Override the node's styling if it should be (de-)emphasized.
798 if (filter_.Deemphasized(instr)) {
799 color = kDashedBorder;
800 }
801 if (filter_.Highlight(instr)) {
802 node_shape = "diamond";
803 color = kDarkRed;
804 }
805 }
806 // Build the text that will be displayed inside the node.
807 string node_body = node_label;
808 for (const string& s : {trivial_subcomputation, node_backend_config,
809 extra_info, inlined_constants}) {
810 if (!s.empty()) {
811 StrAppend(&node_body, "<br/>", s);
812 }
813 }
814
815 return StrFormat(R"(%s [label=<%s>, shape=%s, tooltip="%s", %s];)"
816 "\n",
817 InstructionId(instr), node_body, node_shape, node_metadata,
818 NodeColorAttributes(color));
819 }
820
GetInstructionNodeInlinedOperands(const HloInstruction * instr)821 string HloDotDumper::GetInstructionNodeInlinedOperands(
822 const HloInstruction* instr) {
823 // The constant's shape is a parameter because, in the case of a broadcasted
824 // scalar constant, we want to show the broadcasted shape, not the constant's
825 // scalar shape.
826 auto stringify_constant = [](const HloConstantInstruction* constant,
827 const Shape& shape) {
828 // If the shape has a dimension of size zero, print it as e.g.
829 // "{} (f32[42, 0, 10])". The alternative, calling Literal::ToString(),
830 // enumerates all of its empty dimensions (e.g. "{ { {}, {} }, ..."), which
831 // is just noise.
832 if (ShapeUtil::IsZeroElementArray(shape)) {
833 return StrFormat("{} (%s)", ShapeUtil::HumanString(constant->shape()));
834 }
835
836 // Print the literal value of constants with <= K elements. Note that we
837 // use `constant->shape()` rather than `shape`, because if `constant` is a
838 // scalar that's broadcasted into `shape`, we want to print the constant.
839 optional<int64> elem_count;
840 if (shape.IsArray()) {
841 elem_count = ShapeUtil::ElementsIn(constant->shape());
842 }
843 // Allow HloDotDumper to print HloInstruction reconstructed from HloProto
844 // collected from profiling tools. Those constants may not have a valid
845 // literal.
846 if (elem_count.has_value() && *elem_count <= 8 && constant->HasLiteral()) {
847 return StrFormat("%s %s", shape.ToString(),
848 constant->literal().ToStringWithoutShape());
849 }
850
851 // Otherwise, print e.g. "%constant.42 (s32[100])".
852 string constant_name;
853 if (absl::StartsWith(constant->name(), "constant")) {
854 constant_name = constant->name();
855 } else {
856 constant_name = StrCat("constant ", constant->name());
857 }
858 return StrFormat("%s %s", constant_name, ShapeUtil::HumanString(shape));
859 };
860
861 std::vector<string> lines;
862 for (int64 i = 0; i < instr->operand_count(); ++i) {
863 const HloInstruction* operand = instr->operand(i);
864 optional<string> operand_str;
865 if (const auto* constant_operand =
866 DynCast<HloConstantInstruction>(operand)) {
867 operand_str =
868 stringify_constant(constant_operand, constant_operand->shape());
869 } else if (IsFusedBroadcastOfConstantEffectiveScalar(operand)) {
870 operand_str = stringify_constant(
871 Cast<HloConstantInstruction>(operand->operand(0)), operand->shape());
872 } else if (ShouldMergeIntoUsers(operand)) {
873 // Special case: If the operand is a parameter to a fusion node and it
874 // always has a constant value, display it like a regular constant.
875 //
876 // For other parameters, use the parameter number rather than the proper
877 // name, because that's generally how people think of the node.
878 if (operand->opcode() == HloOpcode::kParameter) {
879 if (const HloConstantInstruction* constant =
880 TryGetFusionParameterConstant(operand)) {
881 operand_str = stringify_constant(constant, constant->shape());
882 } else {
883 operand_str = StrFormat("Parameter %d", operand->parameter_number());
884 }
885 } else {
886 operand_str = operand->name();
887 }
888 }
889
890 if (operand_str) {
891 if (instr->operand_count() > 1) {
892 lines.push_back(StrFormat("<b>operand %d</b> = %s", i, *operand_str));
893 } else {
894 lines.push_back(StrFormat("<b>operand</b> = %s", *operand_str));
895 }
896 }
897 }
898 return StrJoin(lines, "<br/>");
899 }
900
GetInstructionColor(const HloInstruction * instr)901 ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) {
902 if (debug_options_.xla_hlo_graph_sharding_color()) {
903 if (!instr->has_sharding()) {
904 return kDashedBorder;
905 }
906 auto it = sharding_colors_.find(instr->sharding());
907 if (it != sharding_colors_.end()) {
908 return it->second;
909 }
910 ColorScheme color = static_cast<ColorScheme>(
911 kBlue + (next_shard_color_++ % (kDashedBorder - kBlue)));
912 sharding_colors_.emplace(instr->sharding(), color);
913 return color;
914 }
915
916 // Choose different weights of orange for small vs large parameters. This
917 // distinction is often important, especially in fusion nodes.
918 auto parameter_color = IsSmall(instr) ? kOrange : kDarkOrange;
919
920 // Special case: If this instruction has a parameter merged into it, paint it
921 // the same color as a parameter. Unless the merged-in parameter is a
922 // parameter to a fusion node that is bound to a constant -- these aren't
923 // "real" parameters from the user's perspective.
924 if (absl::c_any_of(instr->operands(), [&](const HloInstruction* operand) {
925 return operand->opcode() == HloOpcode::kParameter &&
926 ShouldMergeIntoUsers(operand) &&
927 TryGetFusionParameterConstant(operand) == nullptr;
928 })) {
929 return parameter_color;
930 }
931
932 // Pick different colors or shapes for instructions which are particularly
933 // expensive (eg, dot) and those which are unusual in some way or unique
934 // (eg, parameter).
935 switch (instr->opcode()) {
936 case HloOpcode::kAbs:
937 case HloOpcode::kAdd:
938 case HloOpcode::kAnd:
939 case HloOpcode::kAtan2:
940 case HloOpcode::kBitcastConvert:
941 case HloOpcode::kCeil:
942 case HloOpcode::kClamp:
943 case HloOpcode::kClz:
944 case HloOpcode::kCompare:
945 case HloOpcode::kComplex:
946 case HloOpcode::kConvert:
947 case HloOpcode::kCos:
948 case HloOpcode::kDivide:
949 case HloOpcode::kExp:
950 case HloOpcode::kExpm1:
951 case HloOpcode::kFloor:
952 case HloOpcode::kImag:
953 case HloOpcode::kIota:
954 case HloOpcode::kIsFinite:
955 case HloOpcode::kLog:
956 case HloOpcode::kLog1p:
957 case HloOpcode::kMaximum:
958 case HloOpcode::kMinimum:
959 case HloOpcode::kMultiply:
960 case HloOpcode::kNegate:
961 case HloOpcode::kNot:
962 case HloOpcode::kPopulationCount:
963 case HloOpcode::kOr:
964 case HloOpcode::kXor:
965 case HloOpcode::kPower:
966 case HloOpcode::kReal:
967 case HloOpcode::kRemainder:
968 case HloOpcode::kRng:
969 case HloOpcode::kRngGetAndUpdateState:
970 case HloOpcode::kRoundNearestAfz:
971 case HloOpcode::kRsqrt:
972 case HloOpcode::kSelect:
973 case HloOpcode::kShiftLeft:
974 case HloOpcode::kShiftRightArithmetic:
975 case HloOpcode::kShiftRightLogical:
976 case HloOpcode::kSign:
977 case HloOpcode::kSin:
978 case HloOpcode::kSlice:
979 case HloOpcode::kSort:
980 case HloOpcode::kSqrt:
981 case HloOpcode::kSubtract:
982 case HloOpcode::kTanh:
983 // De-emphasize scalar-shaped elementwise ops -- they're generally
984 // uninteresting.
985 if (ShapeUtil::IsEffectiveScalar(instr->shape())) {
986 return kWhite;
987 }
988 return kYellow;
989 case HloOpcode::kBitcast:
990 case HloOpcode::kGetTupleElement:
991 case HloOpcode::kTrace:
992 case HloOpcode::kAfterAll:
993 case HloOpcode::kAddDependency:
994 case HloOpcode::kTuple:
995 return kWhite;
996 case HloOpcode::kBroadcast:
997 // De-emphasize nodes which broadcast a scalar within a fusion node --
998 // these are essentially free.
999 if (instr->IsFused() &&
1000 ShapeUtil::IsEffectiveScalar(instr->operand(0)->shape())) {
1001 return kWhite;
1002 }
1003 return kGreen;
1004 case HloOpcode::kConcatenate:
1005 case HloOpcode::kDynamicSlice:
1006 case HloOpcode::kGather:
1007 case HloOpcode::kPad:
1008 case HloOpcode::kReshape:
1009 case HloOpcode::kReverse:
1010 case HloOpcode::kTupleSelect:
1011 case HloOpcode::kTranspose:
1012 // De-emphasize scalar-shaped data movement ops and all data movement ops
1013 // inside fusion nodes, both of which are essentially free.
1014 if (ShapeUtil::IsEffectiveScalar(instr->shape()) || instr->IsFused()) {
1015 return kWhite;
1016 }
1017 return kGreen;
1018 case HloOpcode::kDynamicUpdateSlice:
1019 // Unlike the data-movement ops above, dynamic-update-slice is not ~free
1020 // inside of fusion nodes, so we de-emphasize it only if it's
1021 // scalar-shaped.
1022 if (ShapeUtil::IsEffectiveScalar(instr->shape())) {
1023 return kWhite;
1024 }
1025 return kGreen;
1026 case HloOpcode::kScatter:
1027 // Do not de-emphasize Scatter, since it involves significant work.
1028 case HloOpcode::kCopy:
1029 case HloOpcode::kCopyStart:
1030 case HloOpcode::kCopyDone:
1031 // Emphasize copy nodes, which are either physical transposes (and thus
1032 // significant), or copies of read-only buffers (and thus dead weight).
1033 return kGreen;
1034 case HloOpcode::kConvolution:
1035 case HloOpcode::kDot:
1036 case HloOpcode::kFft:
1037 case HloOpcode::kTriangularSolve:
1038 case HloOpcode::kCholesky:
1039 return kDarkBlue;
1040 case HloOpcode::kReducePrecision:
1041 return kRed;
1042 case HloOpcode::kParameter:
1043 return parameter_color;
1044 case HloOpcode::kBatchNormGrad:
1045 case HloOpcode::kBatchNormInference:
1046 case HloOpcode::kBatchNormTraining:
1047 case HloOpcode::kReduce:
1048 case HloOpcode::kReduceWindow:
1049 case HloOpcode::kSelectAndScatter:
1050 return kPurple;
1051 case HloOpcode::kDomain:
1052 case HloOpcode::kFusion:
1053 case HloOpcode::kMap:
1054 case HloOpcode::kGetDimensionSize:
1055 case HloOpcode::kSetDimensionSize:
1056 return kGray;
1057 case HloOpcode::kAllReduce:
1058 case HloOpcode::kAllToAll:
1059 case HloOpcode::kCollectivePermute:
1060 case HloOpcode::kInfeed:
1061 case HloOpcode::kOutfeed:
1062 case HloOpcode::kPartitionId:
1063 case HloOpcode::kRecv:
1064 case HloOpcode::kRecvDone:
1065 case HloOpcode::kSend:
1066 case HloOpcode::kSendDone:
1067 case HloOpcode::kReplicaId:
1068 return kBrown;
1069 case HloOpcode::kCall:
1070 case HloOpcode::kConditional:
1071 case HloOpcode::kCustomCall:
1072 case HloOpcode::kWhile:
1073 return kDarkGreen;
1074 case HloOpcode::kConstant:
1075 LOG(FATAL) << "Constants don't get their own nodes in the graph.";
1076 }
1077 }
1078
GetInstructionNodeShape(const HloInstruction * instr)1079 string HloDotDumper::GetInstructionNodeShape(const HloInstruction* instr) {
1080 // Give while loops a different shape so they're easier to pick out.
1081 switch (instr->opcode()) {
1082 case HloOpcode::kWhile:
1083 return "ellipse";
1084 default:
1085 return "rect";
1086 }
1087 }
1088
GetInstructionNodeLabel(const HloInstruction * instr)1089 string HloDotDumper::GetInstructionNodeLabel(const HloInstruction* instr) {
1090 // If we have a parameter, put the param number in the name.
1091 if (instr->opcode() == HloOpcode::kParameter) {
1092 return StrFormat("<b>Parameter %d</b>", instr->parameter_number());
1093 }
1094
1095 // The HLO instruction name contains usually the opcode, e.g. "%add.42" is
1096 // an add instruction. In this case we render just the name.
1097 if (absl::StartsWith(instr->name(), HloOpcodeString(instr->opcode()))) {
1098 return StrFormat("<b>%s</b>", HtmlLikeStringSanitize(instr->name()));
1099 }
1100 string extended_opcode =
1101 StrCat(HloOpcodeString(instr->opcode()),
1102 instr->opcode() != HloOpcode::kFusion
1103 ? ""
1104 : StrCat(":", xla::ToString(instr->fusion_kind())));
1105 // If the name does not contain the opcode, render both.
1106 return StrFormat("<b>%s</b><br/>%s", HtmlLikeStringSanitize(extended_opcode),
1107 HtmlLikeStringSanitize(instr->name()));
1108 }
1109
GetInstructionNodeMetadata(const HloInstruction * instr)1110 string HloDotDumper::GetInstructionNodeMetadata(const HloInstruction* instr) {
1111 std::vector<string> lines;
1112 if (!instr->metadata().op_name().empty()) {
1113 lines.push_back(HtmlLikeStringSanitize(instr->metadata().op_name()));
1114 }
1115 if (!instr->metadata().op_type().empty()) {
1116 lines.push_back(StrFormat(
1117 "op_type: %s", HtmlLikeStringSanitize(instr->metadata().op_type())));
1118 }
1119 if (!instr->metadata().source_file().empty() &&
1120 instr->metadata().source_line() != 0) {
1121 lines.push_back(StrFormat("op_type: %s:%d", instr->metadata().source_file(),
1122 instr->metadata().source_line()));
1123 }
1124
1125 return StrJoin(lines, "\n");
1126 }
1127
GetInstructionNodeBackendConfig(const HloInstruction * instr)1128 string HloDotDumper::GetInstructionNodeBackendConfig(
1129 const HloInstruction* instr) {
1130 if (!show_backend_config_ || instr->raw_backend_config_string().empty()) {
1131 return "";
1132 }
1133
1134 return StrCat("backend_config=\"", instr->raw_backend_config_string(), "\"");
1135 }
1136
GetInstructionNodeExtraInfo(const HloInstruction * instr)1137 string HloDotDumper::GetInstructionNodeExtraInfo(const HloInstruction* instr) {
1138 std::vector<string> lines;
1139
1140 // Get the instruction's extra attributes excluding the names of its
1141 // subcomputations, since those are drawn explicitly in the graph.
1142 for (const auto& line : instr->ExtraAttributesToString(
1143 HloPrintOptions().set_print_subcomputation_mode(
1144 HloPrintOptions::PrintSubcomputationMode::kOff))) {
1145 lines.push_back(HtmlLikeStringSanitize(line));
1146 }
1147
1148 // Show the shape and layout of the instruction, unless it's an inlined fusion
1149 // node -- there the shape and layout is present in the output node.
1150 if (instr->opcode() != HloOpcode::kFusion ||
1151 !ShouldShowFusionSubcomputation(instr)) {
1152 // Show layout of instructions with more than one dimension. Don't show
1153 // layout on tuples or tensors with just one dimension (which only have one
1154 // possible layout) to avoid visual noise.
1155 bool shape_is_multidim = false;
1156 ShapeUtil::ForEachSubshape(instr->shape(),
1157 [&](const Shape& s, const ShapeIndex&) {
1158 shape_is_multidim |= s.dimensions_size() > 1;
1159 });
1160 string instr_shape;
1161 if (instr->opcode() != HloOpcode::kTuple && shape_is_multidim) {
1162 instr_shape = ShapeUtil::HumanStringWithLayout(instr->shape());
1163 } else {
1164 instr_shape = ShapeUtil::HumanString(instr->shape());
1165 }
1166
1167 // Some instructions have giant tuples as their shapes, so truncate the
1168 // HLO's shape to kMaxShapeLen characters.
1169 constexpr int kMaxShapeLen = 64;
1170 if (instr_shape.length() > kMaxShapeLen) {
1171 instr_shape = StrCat(
1172 absl::string_view(instr_shape).substr(0, kMaxShapeLen - 3), "...");
1173 }
1174 lines.push_back(instr_shape);
1175 }
1176 if (debug_options_.xla_hlo_graph_addresses()) {
1177 lines.push_back(StrFormat("[%p]", instr));
1178 }
1179 if (profile_ != nullptr) {
1180 double hlo_cycles_executed = profile_->GetCyclesTakenBy(*instr);
1181 double total_cycles_executed =
1182 profile_->total_cycles_executed(*instr->parent());
1183 if (hlo_cycles_executed > 0 && total_cycles_executed > 0) {
1184 lines.push_back(
1185 StrFormat("%% of cycles executed=%.2f",
1186 100 * hlo_cycles_executed / total_cycles_executed));
1187 }
1188 }
1189 return StrJoin(lines, "<br/>");
1190 }
1191
AddInstructionIncomingEdges(const HloInstruction * instr)1192 void HloDotDumper::AddInstructionIncomingEdges(const HloInstruction* instr) {
1193 auto add_edge = [&](const HloInstruction* from, const HloInstruction* to,
1194 int64 operand_num, bool control_edge = false) {
1195 from = GetNodeForEdge(from);
1196
1197 if (!filter_.Show(from) || from->opcode() == HloOpcode::kConstant ||
1198 IsFusedBroadcastOfConstantEffectiveScalar(from) ||
1199 ShouldMergeIntoUsers(from)) {
1200 return;
1201 }
1202 VLOG(2) << "Adding edge from " << from->name() << " to " << to->name()
1203 << " as " << next_edge_id_;
1204 edge_ids_.insert({{from, to}, next_edge_id_++});
1205
1206 string edge_label;
1207 if (instr->operand_count() > 1 && !control_edge) {
1208 edge_label =
1209 StrFormat(R"( headlabel="%d", labeldistance=2)", operand_num);
1210 } else if (control_edge) {
1211 edge_label = "style=\"dotted\" color=\"gray\" label=\"ctrl\"";
1212 }
1213
1214 // We print "small" arrays using a hollow arrowhead and "large" arrays using
1215 // a filled arrowhead.
1216 constexpr char kEdgeFmt[] =
1217 R"(%s -> %s [arrowhead=%s tooltip="%s -> %s" %s];)";
1218 edges_.push_back(StrFormat(kEdgeFmt, InstructionId(from), InstructionId(to),
1219 (IsSmall(from) ? "empty" : "normal"),
1220 from->name(), to->name(), edge_label));
1221 };
1222
1223 // Add edges from instr's operands to instr. Parameters within fusion
1224 // expressions are handled specially -- we draw an edge from the corresponding
1225 // operand on the fusion node itself to the parameter.
1226 if (instr->opcode() == HloOpcode::kParameter && instr->IsFused()) {
1227 // Only add the edge if this is not the outermost computation; otherwise it
1228 // will lead from a node we're not drawing.
1229 if (instr->parent() != computation_) {
1230 const HloInstruction* fusion = instr->parent()->FusionInstruction();
1231 add_edge(fusion->operand(instr->parameter_number()), instr,
1232 /*operand_num=*/0);
1233 }
1234 } else {
1235 for (int64 i = 0; i < instr->operand_count(); ++i) {
1236 add_edge(instr->operand(i), instr, i);
1237 }
1238 for (const HloInstruction* pred : instr->control_predecessors()) {
1239 add_edge(pred, instr, /*operand_num=*/0, /*control_edge=*/true);
1240 }
1241 }
1242 }
1243
GetInstructionTrivialComputationStr(const HloInstruction * instr)1244 string HloDotDumper::GetInstructionTrivialComputationStr(
1245 const HloInstruction* instr) {
1246 // called_computations() on a fusion node "inherits" any called computations
1247 // of the fused root, which isn't what we want. Just ignore fusion nodes
1248 // here; they're handled separately.
1249 if (instr->opcode() == HloOpcode::kFusion) {
1250 return "";
1251 }
1252
1253 std::vector<string> lines;
1254 for (int64 i = 0; i < instr->called_computations().size(); ++i) {
1255 optional<string> computation_type =
1256 MatchTrivialComputation(instr->called_computations()[i]);
1257 if (!computation_type) {
1258 continue;
1259 }
1260 if (instr->called_computations().size() == 1) {
1261 lines.push_back(StrFormat("Subcomputation: <b>%s</b>",
1262 HtmlLikeStringSanitize(*computation_type)));
1263 } else {
1264 lines.push_back(StrFormat("Subcomputation %d: <b>%s</b>", i,
1265 HtmlLikeStringSanitize(*computation_type)));
1266 }
1267 }
1268 return StrJoin(lines, "<br/>");
1269 }
1270
GetNodeForEdge(const HloInstruction * instr)1271 const HloInstruction* HloDotDumper::GetNodeForEdge(
1272 const HloInstruction* instr) {
1273 while (instr->opcode() == HloOpcode::kFusion &&
1274 ShouldShowFusionSubcomputation(instr)) {
1275 instr = instr->fused_expression_root();
1276 }
1277 return instr;
1278 }
1279
1280 // Gets a NodeFilter that includes roughly all instructions whose distance from
1281 // root is <= radius.
MakeNodeRadiusAroundFilter(const HloInstruction * root,int64 radius,const absl::flat_hash_set<const HloInstruction * > & boundary)1282 NodeFilter MakeNodeRadiusAroundFilter(
1283 const HloInstruction* root, int64 radius,
1284 const absl::flat_hash_set<const HloInstruction*>& boundary) {
1285 // First, find the neighborhood of nodes with distance from root <= radius.
1286 // These nodes are our initial set of "normal" nodes.
1287 absl::flat_hash_map<const HloInstruction*, NodeFilterResult> nodes;
1288 std::deque<std::pair<const HloInstruction*, /*depth*/ int64>> worklist;
1289 worklist.push_back({root, 0});
1290 while (!worklist.empty()) {
1291 const HloInstruction* instr;
1292 int64 depth;
1293 std::tie(instr, depth) = worklist.front();
1294 worklist.pop_front();
1295
1296 nodes[instr] = kNormalNode;
1297 if (depth == radius) {
1298 continue;
1299 }
1300 if (boundary.contains(instr)) {
1301 continue;
1302 }
1303
1304 // Traverse into instr's operands.
1305 //
1306 // Don't traverse into tuples' operands unless the tuple is the root.
1307 // Usually a tuple is the bottommost node in the graph, and so its operands
1308 // are not interesting to the graph at hand.
1309 if (instr == root || instr->opcode() != HloOpcode::kTuple) {
1310 for (const HloInstruction* operand : instr->operands()) {
1311 if (!nodes.contains(operand)) {
1312 worklist.push_back({operand, depth + 1});
1313 }
1314 }
1315 }
1316
1317 // Traverse into instr's nested computations.
1318 for (const HloComputation* computation : instr->called_computations()) {
1319 worklist.push_back({computation->root_instruction(), depth + 1});
1320 }
1321
1322 // Traverse into instr's users, unless:
1323 //
1324 // - there are a ton of them, in which case they're probably not
1325 // interesting (and anyway, rendering them all would make the graph
1326 // unreadable), or
1327 // - instr is a constant, in which case its users are probably not
1328 // interesting.
1329 if (instr->opcode() == HloOpcode::kConstant) {
1330 continue;
1331 }
1332 constexpr int kMaxUsersToRender = 16;
1333 if (instr->user_count() > kMaxUsersToRender) {
1334 // If we're going to skip this node's users, style it as such.
1335 nodes[instr] = kSomeUsersOmitted;
1336 continue;
1337 }
1338 for (const HloInstruction* user : instr->users()) {
1339 if (!nodes.contains(user)) {
1340 worklist.push_back({user, depth + 1});
1341 }
1342 }
1343 }
1344
1345 auto is_displayed = [&](const HloInstruction* instr) {
1346 // Constants are displayed inline with their users; they're never omitted.
1347 // Nodes in subcomputations are always shown.
1348 return nodes.contains(instr) || instr->opcode() == HloOpcode::kConstant ||
1349 instr->parent() != root->parent();
1350 };
1351
1352 // Make a second pass over 'nodes' to fix up the NodeFilterResults now that we
1353 // know which nodes will be included in the graph.
1354 for (auto& kv : nodes) {
1355 const HloInstruction* instr = kv.first;
1356 NodeFilterResult& filter_result = kv.second;
1357 const auto& operands = instr->operands();
1358
1359 if (absl::c_any_of(operands, is_displayed) &&
1360 !absl::c_all_of(operands, is_displayed)) {
1361 // Mark nodes with some operands omitted appropriately.
1362 filter_result = kSomeOperandsOmitted;
1363 } else if (!operands.empty() && absl::c_none_of(operands, is_displayed)) {
1364 // Mark nodes with *all* operands omitted appropriately.
1365 filter_result = kOmitNodeOperands;
1366 }
1367
1368 // Promote nodes with type kSomeUsersOmitted to kNormalNode if all of their
1369 // users made it into the graph.
1370 if (filter_result == kSomeUsersOmitted &&
1371 absl::c_all_of(instr->users(), is_displayed)) {
1372 filter_result = kNormalNode;
1373 }
1374 }
1375
1376 // Highlight the root node.
1377 nodes[root] = kHighlightNode;
1378
1379 return NodeFilter([=](const HloInstruction* instr) {
1380 auto it = nodes.find(instr);
1381 if (it != nodes.end()) {
1382 return it->second;
1383 }
1384 // Show all nodes in subcomputations.
1385 if (instr->parent() != root->parent()) {
1386 return kNormalNode;
1387 }
1388 return kHideNode;
1389 });
1390 }
1391
1392 // Gets a node filter that includes nodes on all paths from `from` to `to`. If
1393 // the all-paths set contains more than max_nodes elements, includes the nodes
1394 // on the shortest paths and sets hit_limit to true.
MakeNodeFromToFilter(const HloInstruction * from,const HloInstruction * to,int64 max_nodes,bool * hit_limit)1395 NodeFilter MakeNodeFromToFilter(const HloInstruction* from,
1396 const HloInstruction* to, int64 max_nodes,
1397 bool* hit_limit) {
1398 *hit_limit = false;
1399
1400 // Elements in the queue are paths through the graph.
1401 std::deque<std::vector<const HloInstruction*>> queue;
1402 queue.push_front({from});
1403
1404 // Compute the set of nodes we want to show using a slightly-modified
1405 // Djikstra's algorithm. The only real difference is, rather than stopping
1406 // when we find a (shortest) path, we continue until we've found max_nodes
1407 // nodes on some path.
1408 std::unordered_set<const HloInstruction*> visited;
1409 std::unordered_set<const HloInstruction*> to_display = {from, to};
1410 while (!queue.empty() && to_display.size() < max_nodes) {
1411 std::vector<const HloInstruction*> path = std::move(queue.front());
1412 queue.pop_front();
1413 if (!visited.insert(path.back()).second) {
1414 continue;
1415 }
1416
1417 for (const auto* user : path.back()->users()) {
1418 if (user == to) {
1419 auto it = path.begin();
1420 for (; it != path.end() && to_display.size() < max_nodes; ++it) {
1421 to_display.insert(*it);
1422 }
1423 if (it != path.end()) {
1424 *hit_limit = true;
1425 }
1426 } else if (!visited.count(user)) {
1427 auto new_path = path;
1428 new_path.push_back(user);
1429 queue.push_back(std::move(new_path));
1430 }
1431 }
1432 }
1433
1434 return NodeFilter([=](const HloInstruction* instr) {
1435 if (instr == from || instr == to) {
1436 return kHighlightNode;
1437 }
1438 return to_display.count(instr) ? kNormalNode : kHideNode;
1439 });
1440 }
1441
WrapDotInHtml(absl::string_view dot)1442 string WrapDotInHtml(absl::string_view dot) {
1443 static const char html_prefix[] = R"html(
1444 <!DOCTYPE html>
1445 <html>
1446 <head>
1447 <meta charset="utf-8">
1448 <style type="text/css">
1449 body {
1450 height: 100vh;
1451 margin: 0;
1452 }
1453 </style>
1454 </head>
1455 <body>
1456 <!-- Integrity hash is generated by https://www.srihash.org/ -->
1457 <script src="https://cdn.jsdelivr.net/npm/viz.js@2.1.1/viz.js"
1458 integrity="sha384-aD1MJYb0WKIUT+CtwJp5LTuV3U4pLAS6B/nUxL7ECimC2pN9N8vjlMr/yQCAkzxE"
1459 crossorigin="anonymous"></script>
1460 <script src="https://cdn.jsdelivr.net/npm/viz.js@2.1.1/full.render.js"
1461 integrity="sha384-bAixY275aIpCj6Te19y0MILZ4V+VEC8CVFujFEH+Lf7W+4XYYeYLwW5IBI6yQmMT"
1462 crossorigin="anonymous"></script>
1463 <script src="https://cdn.jsdelivr.net/npm/svg-pan-zoom@3.6.0/dist/svg-pan-zoom.min.js"
1464 integrity="sha384-3008WpYB2pOBvE7lwkrKf+qTmbTPGGPYxA9C1YVhvbPukns4ZFj7E98QPLkNW9dS"
1465 crossorigin="anonymous"></script>
1466 <div id="container" style="height:95vh; border:1px solid black; "></div>
1467 <script>
1468 var data = `
1469 )html";
1470
1471 static const char html_suffix[] = R"html(
1472 `;
1473 var cssregex = new RegExp('stylesheet=<([^]*)\n>\n', 'gm');
1474 var results = cssregex.exec(data)
1475 // graphviz has problem dealing with large stylesheets.
1476 // https://github.com/tensorflow/tensorflow/issues/17220#issuecomment-369228492
1477 // In order to avoid the problem, remove the stylesheet from the dot and
1478 // insert it directly info the rendered SVG.
1479 var dot_data = data;
1480 var css_data = ''
1481 if (results !== null) {
1482 css_data = results[1].replace(/\s*data:.*\s*,/,''); // Strip content-type field.
1483 // CSS inside DOT is URL-escaped, so we must unescape it
1484 // before we can insert it into SVG.
1485 css_data = unescape(css_data);
1486 dot_data = data.replace(cssregex, ''); // Remove the stylesheet
1487 }
1488
1489 var render_start = performance.now()
1490 function add_controls(svg) {
1491 var htmlblob = new Blob([document.documentElement.innerHTML],
1492 {type: 'text/html'});
1493 var savehtml = document.createElement('a');
1494 savehtml.setAttribute('href', URL.createObjectURL(htmlblob));
1495 savehtml.setAttribute('download', 'graph.html');
1496 savehtml.innerHTML = " [Save HTML+SVG] ";
1497 document.body.append(savehtml);
1498 var svgblob = new Blob([svg.outerHTML], {type: 'image/svg'});
1499 var savesvg = document.createElement('a');
1500 savesvg.setAttribute('href', URL.createObjectURL(svgblob));
1501 savesvg.setAttribute('download', 'graph.svg');
1502 savesvg.innerHTML = " [Save SVG] ";
1503 document.body.append(savesvg);
1504 var dotblob = new Blob([data], {type: 'text/dot'});
1505 var savedot = document.createElement('a');
1506 savedot.setAttribute('href', URL.createObjectURL(dotblob));
1507 savedot.setAttribute('download', 'graph.dot');
1508 savedot.innerHTML = " [Save DOT] ";
1509 document.body.append(savedot);
1510 // Will get called after embed element was loaded
1511 var panzoom = svgPanZoom(svg, {
1512 zoomEnabled: true,
1513 controlIconsEnabled: true,
1514 });
1515 document.getElementsByTagName("BODY")[0].onresize = function() {
1516 panzoom.resize();
1517 panzoom.fit();
1518 panzoom.center();
1519 };
1520 var render_end = performance.now();
1521 var render_note = document.createElement('div')
1522 render_note.innerHTML = 'Rendering took '
1523 + (render_end - render_start).toFixed(2) + "ms."
1524 document.body.append(render_note);
1525 }
1526 var svg = document.getElementById('graph')
1527 if (svg == null) {
1528 // Need to render SVG first.
1529 var viz = new Viz();
1530 viz.renderSVGElement(dot_data)
1531 .then(function(svg){
1532 var container = document.getElementById('container')
1533 var style = document.createElementNS('http://www.w3.org/2000/svg', 'style');
1534 var node = document.createTextNode(css_data);
1535 style.appendChild(node);
1536 svg.setAttribute('width', '100%');
1537 svg.setAttribute('height', '100%');
1538 svg.setAttribute('id', 'graph');
1539 svg.appendChild(style);
1540 container.appendChild(svg);
1541 add_controls(svg);
1542 })
1543 } else {
1544 // HTML already has rendered SVG embedded, so we just need to add
1545 // controls.
1546 add_controls(svg);
1547 }
1548 </script>
1549 </body>
1550 </html>
1551 )html";
1552
1553 return absl::StrCat(html_prefix, dot, html_suffix);
1554 }
1555
1556 tensorflow::mutex url_renderer_mu(tensorflow::LINKER_INITIALIZED);
1557 std::function<StatusOr<string>(absl::string_view)>* url_renderer
1558 GUARDED_BY(url_renderer_mu) = nullptr;
1559
1560 // Precondition: url_renderer != nullptr.
1561 //
1562 // (We specify this as a precondition rather than checking it in here and
1563 // returning an error because we want to fail quickly when there's no URL
1564 // renderer available, and this function runs only after we've done all the work
1565 // of producing dot for the graph.)
1566 StatusOr<string> WrapDotInFormat(absl::string_view dot,
1567 RenderedGraphFormat format)
1568 EXCLUSIVE_LOCKS_REQUIRED(url_renderer_mu) {
1569 switch (format) {
1570 case RenderedGraphFormat::kUrl:
1571 CHECK(url_renderer != nullptr)
1572 << "Should have checked url_renderer != null before calling.";
1573 return (*url_renderer)(dot);
1574 case RenderedGraphFormat::kHtml:
1575 return WrapDotInHtml(dot);
1576 case RenderedGraphFormat::kDot:
1577 return string(dot);
1578 }
1579 }
1580
1581 } // namespace
1582
1583 void RegisterGraphToURLRenderer(
1584 std::function<StatusOr<string>(absl::string_view)> renderer) {
1585 tensorflow::mutex_lock lock(url_renderer_mu);
1586 if (url_renderer != nullptr) {
1587 LOG(WARNING) << "Multiple calls to RegisterGraphToURLRenderer. Last call "
1588 "wins, but because order of initialization in C++ is "
1589 "nondeterministic, this may not be what you want.";
1590 }
1591 delete url_renderer;
1592 url_renderer = new std::function<StatusOr<string>(absl::string_view)>(
1593 std::move(renderer));
1594 }
1595
1596 StatusOr<string> RenderGraph(const HloComputation& computation,
1597 absl::string_view label,
1598 const DebugOptions& debug_options,
1599 RenderedGraphFormat format,
1600 const HloExecutionProfile* hlo_execution_profile,
1601 bool show_backend_config) {
1602 tensorflow::mutex_lock lock(url_renderer_mu);
1603 if (format == RenderedGraphFormat::kUrl && url_renderer == nullptr) {
1604 return Unavailable("Can't render as URL; no URL renderer was registered.");
1605 }
1606
1607 string rendered_dot =
1608 HloDotDumper(&computation, label, debug_options, show_backend_config,
1609 hlo_execution_profile, NodeFilter())
1610 .Dump();
1611 return WrapDotInFormat(rendered_dot, format);
1612 }
1613
1614 StatusOr<string> RenderNeighborhoodAround(
1615 const HloInstruction& node, int radius, RenderedGraphFormat format,
1616 bool show_backend_config,
1617 const absl::flat_hash_set<const HloInstruction*>& boundary) {
1618 tensorflow::mutex_lock lock(url_renderer_mu);
1619 if (format == RenderedGraphFormat::kUrl && url_renderer == nullptr) {
1620 return FailedPrecondition(
1621 "Can't render as URL; no URL renderer was registered.");
1622 }
1623
1624 string label =
1625 StrCat("Neighborhood of ", radius, " nodes around ", node.name());
1626 string rendered_dot =
1627 HloDotDumper(node.parent(), label,
1628 node.GetModule()->config().debug_options(),
1629 show_backend_config, /*profile=*/nullptr,
1630 MakeNodeRadiusAroundFilter(&node, radius, boundary))
1631 .Dump();
1632 return WrapDotInFormat(rendered_dot, format);
1633 }
1634
1635 StatusOr<string> RenderAllPathsFromTo(const HloInstruction& from,
1636 const HloInstruction& to, int64 max_nodes,
1637 RenderedGraphFormat format,
1638 bool show_backend_config) {
1639 tensorflow::mutex_lock lock(url_renderer_mu);
1640 if (format == RenderedGraphFormat::kUrl && url_renderer == nullptr) {
1641 return FailedPrecondition(
1642 "Can't render as URL; no URL renderer was registered.");
1643 }
1644
1645 CHECK_EQ(from.parent(), to.parent()) << "Nodes must be in same computation!";
1646 auto debug_options = from.GetModule()->config().debug_options();
1647
1648 bool hit_limit = false;
1649 NodeFilter filter = MakeNodeFromToFilter(&from, &to, max_nodes, &hit_limit);
1650 string label;
1651 if (!hit_limit) {
1652 label = StrCat("All paths from ", from.name(), " to ", to.name());
1653 } else {
1654 label = StrCat(max_nodes, " nodes on the shortest paths from ", from.name(),
1655 " to ", to.name(),
1656 "<br/><br/>***SHOWING ONLY A SUBSET OF ALL PATHS BETWEEN "
1657 "NODES***<br/><br/>");
1658 }
1659 string rendered_dot =
1660 HloDotDumper(from.parent(), label, debug_options, show_backend_config,
1661 /*profile=*/nullptr, filter)
1662 .Dump();
1663 return WrapDotInFormat(rendered_dot, format);
1664 }
1665
1666 } // namespace xla
1667