1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
17
18 #include <unistd.h>
19
20 #include <algorithm>
21 #include <atomic>
22 #include <deque>
23 #include <map>
24 #include <memory>
25 #include <queue>
26 #include <string>
27 #include <tuple>
28 #include <vector>
29
30 #include "absl/container/flat_hash_map.h"
31 #include "absl/strings/match.h"
32 #include "absl/strings/str_cat.h"
33 #include "absl/strings/str_format.h"
34 #include "absl/strings/str_join.h"
35 #include "absl/strings/str_replace.h"
36 #include "absl/types/optional.h"
37 #include "tensorflow/compiler/xla/layout_util.h"
38 #include "tensorflow/compiler/xla/literal.h"
39 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
40 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
41 #include "tensorflow/compiler/xla/service/hlo_module.h"
42 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
43 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
44 #include "tensorflow/compiler/xla/shape_util.h"
45 #include "tensorflow/compiler/xla/types.h"
46 #include "tensorflow/compiler/xla/util.h"
47 #include "tensorflow/compiler/xla/window_util.h"
48 #include "tensorflow/core/lib/core/status.h"
49 #include "tensorflow/core/lib/gtl/map_util.h"
50 #include "tensorflow/core/lib/io/path.h"
51 #include "tensorflow/core/lib/strings/numbers.h"
52 #include "tensorflow/core/platform/env.h"
53 #include "tensorflow/core/platform/mutex.h"
54 #include "tensorflow/core/platform/protobuf.h"
55 #include "tensorflow/core/platform/regexp.h"
56
57 namespace xla {
58 namespace {
59
60 using absl::nullopt;
61 using absl::optional;
62 using absl::StrAppend;
63 using absl::StrCat;
64 using absl::StrFormat;
65 using absl::StrJoin;
66
67 // Used to indicate how we should treat a given HLOInstruction in the graph.
68 // should we treat it like normal, hide it, and so on?
69 enum NodeFilterResult {
70 kNormalNode,
71 kHideNode,
72 // Make the node easy to find in the final graph.
73 kHighlightNode,
74 // "Gray out" the node to indicate that some of its operands have been
75 // omitted.
76 kSomeOperandsOmitted,
77 // Style the node the same as kSomeOperandsOmitted, but also don't connect it
78 // to its operands, even if they're present in the graph.
79 kOmitNodeOperands,
80 // Same style as kSomeOperandsOmitted, but used to indicate that some of the
81 // node's *users* have been omitted.
82 kSomeUsersOmitted,
83 };
84
85 // NodeFilter is essentially a map from HloInstruction*s to NodeFilterResult.
86 // It lets callers tell the graph-drawing routines which nodes they want to be
87 // shown, hidden, or highlighted.
88 class NodeFilter {
89 public:
__anonf346b5ae0202(const HloInstruction*) 90 NodeFilter() : filter_([](const HloInstruction*) { return kNormalNode; }) {}
91
NodeFilter(std::function<NodeFilterResult (const HloInstruction * instr)> filter)92 explicit NodeFilter(
93 std::function<NodeFilterResult(const HloInstruction* instr)> filter)
94 : filter_(std::move(filter)) {}
95
Show(const HloInstruction * instr) const96 bool Show(const HloInstruction* instr) const {
97 return filter_(instr) != kHideNode;
98 }
Highlight(const HloInstruction * instr) const99 bool Highlight(const HloInstruction* instr) const {
100 return filter_(instr) == kHighlightNode;
101 }
OmitOperands(const HloInstruction * instr) const102 bool OmitOperands(const HloInstruction* instr) const {
103 return filter_(instr) == kOmitNodeOperands;
104 }
SomeOrAllOperandsOmitted(const HloInstruction * instr) const105 bool SomeOrAllOperandsOmitted(const HloInstruction* instr) const {
106 auto result = filter_(instr);
107 return result == kOmitNodeOperands || result == kSomeOperandsOmitted;
108 }
Deemphasized(const HloInstruction * instr) const109 bool Deemphasized(const HloInstruction* instr) const {
110 auto result = filter_(instr);
111 return result == kOmitNodeOperands || result == kSomeOperandsOmitted ||
112 result == kSomeUsersOmitted;
113 }
114
115 private:
116 std::function<NodeFilterResult(const HloInstruction* instr)> filter_;
117 };
118
119 // We arbitrarily set this as the boundary between "large" and "small"
120 // instructions.
IsSmall(const HloInstruction * instr)121 bool IsSmall(const HloInstruction* instr) {
122 if (ShapeUtil::HasPrimitiveType(instr->shape(), OPAQUE_TYPE) ||
123 ShapeUtil::HasPrimitiveType(instr->shape(), TOKEN)) {
124 return true;
125 }
126 return ShapeUtil::ElementsInRecursive(instr->shape()) < 4096;
127 }
128
129 // Node color schemes, used by NodeColorAttributes.
130 enum ColorScheme {
131 kBlue,
132 kBrown,
133 kDarkBlue,
134 kDarkGreen,
135 kDarkOrange,
136 kDarkRed,
137 kGray,
138 kGreen,
139 kOrange,
140 kPurple,
141 kRed,
142 kWhite,
143 kYellow,
144
145 // Causes the node's border to be a dashed line, and its content to be gray
146 // text on a white background, suggesting that this is an "unimportant" node.
147 kDashedBorder,
148 };
149
150 // Graphviz attributes/colors that make up a color scheme.
151 struct NodeColors {
152 const char* style;
153 const char* fill_color;
154 const char* stroke_color;
155 const char* font_color;
156 };
157
NodeColorsForScheme(ColorScheme color)158 NodeColors NodeColorsForScheme(ColorScheme color) {
159 switch (color) {
160 case kBlue:
161 return NodeColors{"filled", "#bbdefb", "#8aacc8", "black"};
162 case kBrown:
163 return NodeColors{"filled", "#bcaaa4", "#8c7b75", "black"};
164 case kDarkBlue:
165 return NodeColors{"filled", "#1565c0", "#003c8f", "white"};
166 case kDarkGreen:
167 return NodeColors{"filled", "#2e7d32", "#005005", "white"};
168 case kDarkOrange:
169 // This is more of a "medium" orange, made to look close to kOrange;
170 // there's probably room for a darker weight if desired.
171 return NodeColors{"filled", "#ffb74d", "#c88719", "black"};
172 case kDarkRed:
173 return NodeColors{"filled", "#b71c1c", "#7f0000", "white"};
174 case kGray:
175 return NodeColors{"filled", "#cfd8dc", "#9ea7aa", "black"};
176 case kGreen:
177 return NodeColors{"filled", "#c8e6c9", "#97b498", "black"};
178 case kOrange:
179 return NodeColors{"filled", "#ffe0b2", "#cbae82", "black"};
180 case kPurple:
181 return NodeColors{"filled", "#e1bee7", "#af8eb5", "black"};
182 case kRed:
183 return NodeColors{"filled", "#ffcdd2", "#cb9ca1", "black"};
184 case kWhite:
185 return NodeColors{"filled", "white", "black", "black"};
186 case kYellow:
187 return NodeColors{"filled", "#fff9c4", "#cbc693", "black"};
188 case kDashedBorder:
189 // "filled,dashed" looks the same as "dashed", since we have a white
190 // background. But we use "filled,dashed" so that when you hover over
191 // any part of the node (not just the text inside the node), our css
192 // :hover rule is triggered.
193 return NodeColors{"filled,dashed", "white", "#757575", "#757575"};
194 }
195 }
196
197 // Given a ColorScheme, returns an attribute string for a node of that color.
198 // Sets the node's style and fill/stroke/text colors.
199 //
200 // Colors are from https://material.io/color.
NodeColorAttributes(ColorScheme color)201 string NodeColorAttributes(ColorScheme color) {
202 NodeColors node_colors = NodeColorsForScheme(color);
203
204 return StrFormat(R"(style="%s", fontcolor="%s", color="%s", fillcolor="%s")",
205 node_colors.style, node_colors.font_color,
206 node_colors.stroke_color, node_colors.fill_color);
207 }
208
209 // Replaces <> with <>, so that this string is safe(er) for use in a
210 // graphviz HTML-like string.
HtmlLikeStringSanitize(absl::string_view s)211 string HtmlLikeStringSanitize(absl::string_view s) {
212 return absl::StrReplaceAll(s, {{"<", "<"}, {">", ">"}});
213 }
214
IsFusedBroadcastOfConstantEffectiveScalar(const HloInstruction * instr)215 bool IsFusedBroadcastOfConstantEffectiveScalar(const HloInstruction* instr) {
216 namespace m = match;
217 return instr->parent()->IsFusionComputation() &&
218 Match(instr, m::Broadcast(m::ConstantEffectiveScalar()));
219 }
220
221 // Tries to generates a human-readable one-word description of the given
222 // computation.
223 //
224 // Currently we support:
225 //
226 // "return param0 + param1;" --> "add"
227 // "return param0 * param1;" --> "multiply"
228 // "return min(param0, param1);" --> "min"
229 // "return max(param0, param1);" --> "max"
230 // "return param0 <= param1;" --> "less-or-equal"
231 // "return param0 >= param1;" --> "greater-or-equal"
232 // "return param0 > param1;" --> "greater-than"
233 // "return param0 < param1;" --> "less-than"
234 // "return param0 == param1;" --> "equal-to"
235 // "return param0 != param1;" --> "not-equal-to"
236 //
237 // where param0 and param1 are effective scalars. For the ops that are
238 // commutative, we also support them with param0 and param1 swapped.
239 //
240 // This is useful primarily for reduce and map nodes. These take a
241 // subcomputation which is almost always one of the above, and pattern matching
242 // it to a short string lets us tell the user what the subcomputation is without
243 // drawing it as a graph.
MatchTrivialComputation(const HloComputation * computation)244 optional<string> MatchTrivialComputation(const HloComputation* computation) {
245 namespace m = match;
246
247 if (computation->instruction_count() != 3) {
248 return nullopt;
249 }
250 HloInstruction* root = computation->root_instruction();
251 const HloInstruction *param0, *param1;
252 if (!Match(root, m::Op()
253 .WithNumOperands(2)
254 .WithShape(m::Shape().IsEffectiveScalar())
255 .WithBinaryOperandsAnyOrder(
256 m::Parameter(¶m0, 0)
257 .WithShape(m::Shape().IsEffectiveScalar()),
258 m::Parameter(¶m1, 1)
259 .WithShape(m::Shape().IsEffectiveScalar())))) {
260 return nullopt;
261 }
262
263 // If the params are reversed (i.e. operand0 is param1 and operand1 is
264 // param0), check that the operation being performed is commutative.
265 if (root->operand(0) == param1) {
266 CHECK_EQ(root->operand(1), param0);
267 if (root->opcode() == HloOpcode()) {
268 switch (root->comparison_direction()) {
269 case ComparisonDirection::kLe:
270 case ComparisonDirection::kGe:
271 case ComparisonDirection::kGt:
272 case ComparisonDirection::kLt:
273 return nullopt;
274 default:
275 break;
276 }
277 }
278 }
279
280 // If we recognize the root's opcode, we've successfully pattern-matched!
281 switch (root->opcode()) {
282 case HloOpcode::kAdd:
283 return "add";
284 case HloOpcode::kMultiply:
285 return "multiply";
286 case HloOpcode::kMinimum:
287 return "min";
288 case HloOpcode::kMaximum:
289 return "max";
290 case HloOpcode::kCompare: {
291 switch (root->comparison_direction()) {
292 case ComparisonDirection::kLe:
293 return "less-or-equal";
294 case ComparisonDirection::kGe:
295 return "greater-or-equal";
296 case ComparisonDirection::kGt:
297 return "greater-than";
298 case ComparisonDirection::kLt:
299 return "less-than";
300 case ComparisonDirection::kEq:
301 return "equal-to";
302 case ComparisonDirection::kNe:
303 return "not-equal-to";
304 }
305 }
306 default:
307 return nullopt;
308 }
309 }
310
311 // Encapsulates logic for dumping an HLO module to DOT (i.e. graphviz syntax).
312 class HloDotDumper {
313 public:
HloDotDumper(const HloComputation * computation,absl::string_view label,const DebugOptions & debug_options,HloRenderOptions hlo_render_options,const HloExecutionProfile * profile,NodeFilter filter)314 HloDotDumper(const HloComputation* computation, absl::string_view label,
315 const DebugOptions& debug_options,
316 HloRenderOptions hlo_render_options,
317 const HloExecutionProfile* profile, NodeFilter filter)
318 : computation_(computation),
319 label_(label),
320 debug_options_(debug_options),
321 hlo_render_options_(hlo_render_options),
322 profile_(profile),
323 filter_(std::move(filter)) {}
324
325 string Dump();
326
327 private:
328 // Returns the dot graph identifier for the given instruction.
InstructionId(const HloInstruction * instruction)329 string InstructionId(const HloInstruction* instruction) {
330 return StrCat(reinterpret_cast<uint64>(instruction));
331 }
332
333 // Returns the dot graph identifier for the given computation.
SubcomputationId(const HloComputation * computation)334 string SubcomputationId(const HloComputation* computation) {
335 return StrCat("cluster_", reinterpret_cast<uint64>(computation));
336 }
337
338 // Generates graph header/footer. These should be called *after* dumping all
339 // of the instructions and subcomputations for the graph, as they both use
340 // data generated while dumping the graph.
341 string Header();
342 string Footer();
343
344 bool ShouldShowSubcomputation(const HloComputation* subcomp);
345 bool ShouldShowFusionSubcomputation(const HloInstruction* instr);
346
347 // We omit some nodes from the graph, instead drawing them inlined into the
348 // nodes that use them.
349 bool ShouldMergeIntoUsers(const HloInstruction* instr) const;
350
351 string DumpSubcomputation(const HloComputation* subcomp,
352 const HloInstruction* parent_instr);
353 string DumpComputation(const HloComputation* comp);
354 string DumpRootTag();
355 string DumpInstruction(const HloInstruction* instr);
356 ColorScheme GetInstructionColor(const HloInstruction* instr);
357 string GetInstructionNodeShape(const HloInstruction* instr);
358 string GetInstructionNodeLabel(const HloInstruction* instr);
359 string GetInstructionNodeMetadata(const HloInstruction* instr);
360 string GetInstructionNodeBackendConfig(const HloInstruction* instr);
361 string GetInstructionNodeExtraInfo(const HloInstruction* instr);
362 string GetInstructionNodeInlinedOperands(const HloInstruction* instr);
363 void AddInstructionIncomingEdges(const HloInstruction* instr);
364
365 // For most instructions, GetNodeForEdge(instr) returns instr.
366 //
367 // The exception is fusion nodes. For these, we walk up the chain of nested
368 // fusion nodes starting at instr until we reach a node that either (a) isn't
369 // a fusion node, or (b) is a fusion node for which
370 // ShouldShowFusionSubcomputation is false.
371 //
372 // We do this because fusion nodes are expanded inline -- if
373 // ShouldShowFusionSubcomputation is true, the fusion node won't be present in
374 // the graph.
375 //
376 // In general when you want to draw an edge from A to B, you should actually
377 // draw an edge from GetNodeForEdge(A) to GetNodeForEdge(B).
378 const HloInstruction* GetNodeForEdge(const HloInstruction* instr);
379
380 // If instr has just one computation and it's trivial (e.g. "return param0 +
381 // param1"), returns a string you can put into the node's body that names the
382 // subcomputation, e.g. "Subcomputation: <b>add</b>".
383 string GetInstructionTrivialComputationStr(const HloInstruction* instr);
384
385 const HloComputation* computation_; // never null
386 const string label_; // overall name for the graph
387 const DebugOptions& debug_options_;
388 const HloRenderOptions hlo_render_options_;
389 const HloExecutionProfile* profile_; // may be null
390 const NodeFilter filter_;
391
392 // Each HloInstruction dumped gets a monotonically-increasing node ID. This
393 // must start at 1, because that's where graphviz's accounting starts.
394 int64 next_node_id_ = 1;
395 absl::flat_hash_map<const HloInstruction*, int64> node_ids_;
396
397 // The "root" tag doesn't have an associated HloInstruction pointer, so we
398 // need to store it outside the map.
399 int64 root_node_id_;
400
401 // Each (from, to) edge gets a monotonically-increasing ID. This is a
402 // multimap because it's possible for the same edge to appear multiple times
403 // in the graph (e.g. x^2 may be represented as mul(x, x)).
404 int64 next_edge_id_ = 1;
405 std::unordered_multimap<
406 std::pair<const HloInstruction*, const HloInstruction*>, int64,
407 tensorflow::hash<std::pair<const HloInstruction*, const HloInstruction*>>>
408 edge_ids_;
409
410 // Each HloComputation that's emitted gets a monotonically-increasing ID.
411 int64 next_cluster_id_ = 1;
412 absl::flat_hash_map<const HloComputation*, int64> cluster_ids_;
413
414 // Edges to print from Footer(). Edges come at the end because graphviz is
415 // unhappy if an edge from a subcomputation to a node in the outer computation
416 // appears before both the inner computation and the destination node are
417 // defined.
418 std::vector<string> edges_;
419
420 // When coloring by sharding information, we track the sharding string
421 // representation to color association, by round-robin the color schemes.
422 absl::flat_hash_map<HloSharding, ColorScheme, HloSharding::Hasher>
423 sharding_colors_;
424 int64 next_shard_color_ = 0;
425 };
426
Dump()427 string HloDotDumper::Dump() {
428 string body;
429 StrAppend(&body, DumpComputation(computation_));
430 StrAppend(&body, DumpRootTag());
431
432 // By contract, Header() and Footer() have to be called after we've dumped all
433 // our instructions, because they use state generated during that process.
434 string g = Header();
435 StrAppend(&g, body);
436 StrAppend(&g, Footer());
437 return g;
438 }
439
Header()440 string HloDotDumper::Header() {
441 constexpr char fmt[] = R"(digraph G {
442 rankdir = TB;
443 compound = true;
444 label = <<b>%s</b>>;
445 labelloc = t;
446 // Disable the tooltip. Interestingly, "" doesn't work!
447 tooltip = " ";
448 // DOT graphs accept a stylesheet as a URI. So naturally, an inline
449 // stylesheet is a data URI!
450 stylesheet=<
451 data:text/css,
452 @import url(https://fonts.googleapis.com/css?family=Roboto:400,700);
453 svg text {
454 font-family: 'Roboto';
455 font-size: 12px;
456 }
457
458 %s
459 >
460
461 )";
462
463 VLOG(3) << "Generating Header";
464
465 string graph_label =
466 StrCat(label_, "<br/>Computation ", computation_->name());
467 if (computation_->IsFusionComputation()) {
468 StrAppend(&graph_label, " (in fusion instruction ",
469 computation_->FusionInstruction()->name(), ")");
470 }
471 if (profile_ != nullptr) {
472 auto cycles = profile_->total_cycles_executed(*computation_);
473 absl::StrAppendFormat(&graph_label, "<br/>total cycles = %d (%s)", cycles,
474 tensorflow::strings::HumanReadableNum(cycles));
475 }
476
477 // Create CSS rules that say, when you hover over the given node or cluster,
478 // turn the given edge the given color.
479 //
480 // We rely on a few properties of how graphviz generates SVGs:
481 //
482 // - Nodes are named "nodeN", where N corresponds to the 1-based index of
483 // the node in our DOT (i.e. the first node in the DOT is "node1", etc.).
484 // Edges are similarly named "edgeN", and clusters are named "clustN".
485 // - Nodes come before their in- and out-edges in the SVG. We need this
486 // because the "X ~ Y" CSS selector finds a sibling of X that *comes
487 // after X in the DOM* and matches Y.
488 std::vector<string> edge_css_rules;
489 const char* kBlue = "#1976d2";
490 const char* kRed = "#d32f2f";
491 for (const auto& kv : edge_ids_) {
492 const HloInstruction* from_node = kv.first.first;
493 const HloInstruction* to_node = kv.first.second;
494 int64 edge_id = kv.second;
495
496 auto add_hover_css_rule = [&](string elem_type, int64 elem_id,
497 const char* color) {
498 // One could imagine other ways of writing this CSS rule that involve
499 // less duplication, but this way seems to be relatively performant.
500 edge_css_rules.push_back(
501 StrFormat(" #%s%d:hover ~ #edge%d text { fill: %s; }\n"
502 " #%s%d:hover ~ #edge%d path { "
503 "stroke: %s; stroke-width: .2em; }\n"
504 " #%s%d:hover ~ #edge%d polygon { "
505 "fill: %s; stroke: %s; stroke-width: .2em; }\n",
506 elem_type, elem_id, edge_id, color, //
507 elem_type, elem_id, edge_id, color, //
508 elem_type, elem_id, edge_id, color, color));
509 };
510
511 // The "to_node" value may be a NULL, indicating that this points to the
512 // "root" tag rather than a normal node.
513 int64 from_node_id =
514 tensorflow::gtl::FindWithDefault(node_ids_, from_node, -1);
515 if (from_node_id == -1) {
516 LOG(FATAL) << from_node->name() << " was added to edges but not to nodes";
517 }
518 int64 to_node_id =
519 to_node ? tensorflow::gtl::FindWithDefault(node_ids_, to_node, -1)
520 : root_node_id_;
521 if (to_node != nullptr && to_node_id == -1) {
522 LOG(FATAL) << to_node->name() << " was added to edges but not to nodes";
523 }
524
525 add_hover_css_rule("node", from_node_id, kBlue);
526 add_hover_css_rule("node", to_node_id, kRed);
527
528 if (to_node) {
529 VLOG(3) << "Adding css for edge " << edge_id << " from node "
530 << from_node->name() << " to node " << to_node->name();
531 } else {
532 VLOG(3) << "Adding css for edge " << edge_id << " from node "
533 << from_node->name() << " to root tag";
534 }
535
536 // If this edge crosses a fusion cluster boundary, highlight it when the
537 // cluster is hovered over.
538 if (to_node) {
539 if (from_node->IsFused() &&
540 from_node->parent()->root_instruction() == from_node) {
541 int64 cluster_id = cluster_ids_.at(from_node->parent());
542 add_hover_css_rule("clust", cluster_id, kBlue);
543 }
544 if (to_node->IsFused() && to_node->opcode() == HloOpcode::kParameter) {
545 int64 cluster_id = cluster_ids_.at(to_node->parent());
546 add_hover_css_rule("clust", cluster_id, kRed);
547 }
548 }
549 }
550
551 // Browsers require that we URI-encode the contents of our data URI. (It
552 // seems this was a relatively recent change?) In practice, this means that we
553 // need to escape '#'.
554 return StrFormat(
555 fmt, graph_label,
556 absl::StrReplaceAll(StrJoin(edge_css_rules, "\n"), {{"#", "%23"}}));
557 }
558
Footer()559 string HloDotDumper::Footer() { return StrCat(StrJoin(edges_, "\n"), "\n}"); }
560
ShouldShowFusionSubcomputation(const HloInstruction * instr)561 bool HloDotDumper::ShouldShowFusionSubcomputation(const HloInstruction* instr) {
562 CHECK_EQ(instr->opcode(), HloOpcode::kFusion);
563 return ShouldShowSubcomputation(instr->fused_instructions_computation());
564 }
565
ShouldShowSubcomputation(const HloComputation * subcomp)566 bool HloDotDumper::ShouldShowSubcomputation(const HloComputation* subcomp) {
567 if (subcomp->IsFusionComputation()) {
568 const HloInstruction* fusion = subcomp->FusionInstruction();
569 if (!filter_.Show(fusion) || filter_.SomeOrAllOperandsOmitted(fusion) ||
570 !hlo_render_options_.show_fusion_subcomputations) {
571 return false;
572 }
573 }
574
575 // Don't show trivial subcomputations on non-fusion nodes -- these are inlined
576 // into the graph.
577 if (!subcomp->IsFusionComputation() && MatchTrivialComputation(subcomp)) {
578 return false;
579 }
580
581 // Show the subcomputation if we're showing any of its members.
582 return absl::c_any_of(
583 subcomp->instructions(),
584 [&](const HloInstruction* instr) { return filter_.Show(instr); });
585 }
586
DumpSubcomputation(const HloComputation * subcomp,const HloInstruction * parent_instr)587 string HloDotDumper::DumpSubcomputation(const HloComputation* subcomp,
588 const HloInstruction* parent_instr) {
589 VLOG(2) << "Dumping subcomputation " << subcomp->name();
590 // Add an edge from the subcomputation to its parent node. If subcomp
591 // belongs to a fusion node, it's drawn in place of the fusion instruction,
592 // so there's no need to link those.
593 if (parent_instr->opcode() != HloOpcode::kFusion) {
594 const HloInstruction* from = GetNodeForEdge(subcomp->root_instruction());
595 VLOG(2) << "Edge: from " << from->name() << " to " << parent_instr->name()
596 << " as " << next_edge_id_;
597 edge_ids_.insert({{from, parent_instr}, next_edge_id_++});
598 constexpr char edge_fmt[] =
599 R"(%s -> %s [ltail="%s", style="dashed" tooltip="%s -> %s"];)";
600 edges_.push_back(StrFormat(
601 edge_fmt, InstructionId(from), InstructionId(parent_instr),
602 SubcomputationId(subcomp), subcomp->name(), parent_instr->name()));
603 }
604
605 // Have we already dumped this subcomputation? If so, generating the edge
606 // linking it and parent_instr is all we want to do in this function.
607 if (cluster_ids_.find(subcomp) != cluster_ids_.end()) {
608 return "";
609 }
610
611 cluster_ids_[subcomp] = next_cluster_id_++;
612
613 string id = SubcomputationId(subcomp);
614
615 string subcomp_label, style;
616 if (parent_instr->opcode() == HloOpcode::kFusion) {
617 subcomp_label =
618 StrFormat("Fused expression for <b>%s</b><br/>%s",
619 HtmlLikeStringSanitize(parent_instr->name()),
620 HtmlLikeStringSanitize(parent_instr->ToCategory()));
621 string extra_info = GetInstructionNodeExtraInfo(parent_instr);
622 if (!extra_info.empty()) {
623 StrAppend(&subcomp_label, "<br/>", extra_info);
624 }
625 string node_backend_config = GetInstructionNodeBackendConfig(parent_instr);
626 if (!node_backend_config.empty()) {
627 StrAppend(&subcomp_label, "<br/>", node_backend_config);
628 }
629
630 bool highlight = filter_.Highlight(parent_instr);
631 const char* fillcolor;
632 const char* strokecolor;
633 if (debug_options_.xla_hlo_graph_sharding_color() && !highlight) {
634 // Use the sharding color, if the node isn't highlighted.
635 NodeColors node_colors =
636 NodeColorsForScheme(GetInstructionColor(parent_instr));
637 fillcolor = node_colors.fill_color;
638 strokecolor = node_colors.stroke_color;
639 } else {
640 // Subcomputation's fill/stroke color is light/dark red/gray, depending on
641 // whether or not the subcomputation's fusion node is highlighted.
642 fillcolor = highlight ? "#ffcdd2" : "#f5f5f5";
643 strokecolor = highlight ? "#b71c1c" : "#c2c2c2";
644 }
645 style =
646 StrFormat(R"(style="rounded,filled,bold"; fillcolor="%s"; color="%s;")",
647 fillcolor, strokecolor);
648 } else {
649 subcomp_label = StrFormat("Subcomputation for <b>%s</b><br/>%s",
650 HtmlLikeStringSanitize(parent_instr->name()),
651 HtmlLikeStringSanitize(subcomp->name()));
652 style = "style=rounded; color=black;";
653 }
654
655 string comp_body = DumpComputation(subcomp);
656
657 constexpr char computation_fmt[] = R"(subgraph %s {
658 %s
659 label = <%s>;
660 labelloc = t;
661 tooltip = " ";
662 %s
663 } // %s
664
665 )";
666 return StrFormat(computation_fmt, id, style, subcomp_label, comp_body, id);
667 }
668
DumpComputation(const HloComputation * comp)669 string HloDotDumper::DumpComputation(const HloComputation* comp) {
670 string g;
671 for (const auto* instr : comp->instructions()) {
672 if (!filter_.Show(instr)) {
673 continue;
674 }
675
676 // Dump subcomputations within instr.
677 for (const HloComputation* subcomp : instr->called_computations()) {
678 if (ShouldShowSubcomputation(subcomp)) {
679 StrAppend(&g, DumpSubcomputation(subcomp, instr));
680 }
681 }
682
683 StrAppend(&g, DumpInstruction(instr));
684 }
685 return g;
686 }
687
DumpRootTag()688 string HloDotDumper::DumpRootTag() {
689 const HloInstruction* from = GetNodeForEdge(computation_->root_instruction());
690
691 // We didn't display constants or broadcasts of effective scalars within
692 // fusions as separate nodes; so if the root is a constant/broadcast of
693 // scalar, we don't add root tag or edge for it.
694 if (!filter_.Show(from) || from->opcode() == HloOpcode::kConstant ||
695 IsFusedBroadcastOfConstantEffectiveScalar(from)) {
696 return "";
697 }
698
699 auto from_id = InstructionId(from);
700
701 // The ID of the root computation is otherwise unused, so it makes a good ID
702 // to use for the root-tag node. However, the edge_ids_ map requires a
703 // HloInstruction* pointer for the 'to' value, so we use a NULL value there
704 // (rather than a pointer type-cast) to make it obvious if it is erroneously
705 // dereferenced.
706 HloInstruction* to = nullptr;
707 auto to_id = SubcomputationId(computation_);
708
709 string node_body = "ROOT";
710 string node_shape = "circle";
711 ColorScheme color = kBrown;
712
713 VLOG(2) << "Adding root tag as node " << next_node_id_;
714 root_node_id_ = next_node_id_++;
715
716 VLOG(2) << "Adding edge from " << from->name() << " to root tag as "
717 << next_edge_id_;
718 edge_ids_.insert({{from, to}, next_edge_id_++});
719 edges_.push_back(StrFormat(R"(%s -> %s [tooltip=" "];)", from_id, to_id));
720
721 return StrFormat(R"(%s [label=<%s>, shape=%s, tooltip=" ", %s];)"
722 "\n",
723 to_id, node_body, node_shape, NodeColorAttributes(color));
724 }
725
TryGetFusionParameterConstant(const HloInstruction * instr)726 static const HloConstantInstruction* TryGetFusionParameterConstant(
727 const HloInstruction* instr) {
728 if (instr->opcode() != HloOpcode::kParameter || !instr->IsFused()) {
729 return nullptr;
730 }
731 const HloInstruction* fusion = instr->parent()->FusionInstruction();
732 const HloInstruction* operand = fusion->operand(instr->parameter_number());
733 return DynCast<HloConstantInstruction>(operand);
734 }
735
ShouldMergeIntoUsers(const HloInstruction * instr) const736 bool HloDotDumper::ShouldMergeIntoUsers(const HloInstruction* instr) const {
737 // If a node:
738 //
739 // - is a parameter of a fusion node which is bound to a constant,
740 //
741 // or
742 //
743 // - is a tuple-shaped parameter, and
744 // - is not a parameter to a fusion node, and
745 // - has at least kMinUsersToOmit users shown, and
746 // - all of the shown users are get-tuple-elements,
747 //
748 // then we omit it from the graph, merging it with its users.
749 //
750 // This helps us handle the common case where a while loop body has one big
751 // tuple-shaped parameter.
752 if (TryGetFusionParameterConstant(instr) != nullptr) {
753 return true;
754 }
755 const int kMinUsersToOmit = 3;
756 return instr->opcode() == HloOpcode::kParameter && instr->shape().IsTuple() &&
757 !instr->IsFused() &&
758 absl::c_count_if(instr->users(),
759 [&](const HloInstruction* user) {
760 return filter_.Show(user);
761 }) > kMinUsersToOmit &&
762 absl::c_all_of(instr->users(), [&](const HloInstruction* user) {
763 return !filter_.Show(user) ||
764 user->opcode() == HloOpcode::kGetTupleElement;
765 });
766 }
767
DumpInstruction(const HloInstruction * instr)768 string HloDotDumper::DumpInstruction(const HloInstruction* instr) {
769 // We don't display constants or broadcasts of effective scalar constants
770 // within fusions as separate nodes; they're merged into their users.
771 if (instr->opcode() == HloOpcode::kConstant ||
772 IsFusedBroadcastOfConstantEffectiveScalar(instr)) {
773 return "";
774 }
775 // Skip this node if it's merged into its users.
776 if (ShouldMergeIntoUsers(instr)) {
777 return "";
778 }
779 // Omit the fusion node if its subcomputation is drawn, since the
780 // subcomputation will be drawn inline.
781 if (instr->opcode() == HloOpcode::kFusion &&
782 ShouldShowFusionSubcomputation(instr)) {
783 return "";
784 }
785
786 VLOG(2) << "Adding node " << instr->name() << " as " << next_node_id_;
787 node_ids_[instr] = next_node_id_++;
788
789 ColorScheme color = GetInstructionColor(instr);
790 string node_shape = GetInstructionNodeShape(instr);
791 string node_label = GetInstructionNodeLabel(instr);
792 string node_metadata = GetInstructionNodeMetadata(instr);
793 string node_backend_config = GetInstructionNodeBackendConfig(instr);
794 string extra_info = GetInstructionNodeExtraInfo(instr);
795 string inlined_constants = GetInstructionNodeInlinedOperands(instr);
796 string trivial_subcomputation = GetInstructionTrivialComputationStr(instr);
797 AddInstructionIncomingEdges(instr);
798
799 if (!debug_options_.xla_hlo_graph_sharding_color()) {
800 // Override the node's styling if it should be (de-)emphasized.
801 if (filter_.Deemphasized(instr)) {
802 color = kDashedBorder;
803 }
804 if (filter_.Highlight(instr)) {
805 node_shape = "diamond";
806 color = kDarkRed;
807 }
808 }
809 // Build the text that will be displayed inside the node.
810 string node_body = node_label;
811 for (const string& s : {trivial_subcomputation, node_backend_config,
812 extra_info, inlined_constants}) {
813 if (!s.empty()) {
814 StrAppend(&node_body, "<br/>", s);
815 }
816 }
817
818 return StrFormat(R"(%s [label=<%s>, shape=%s, tooltip="%s", %s];)"
819 "\n",
820 InstructionId(instr), node_body, node_shape, node_metadata,
821 NodeColorAttributes(color));
822 }
823
GetInstructionNodeInlinedOperands(const HloInstruction * instr)824 string HloDotDumper::GetInstructionNodeInlinedOperands(
825 const HloInstruction* instr) {
826 // The constant's shape is a parameter because, in the case of a broadcasted
827 // scalar constant, we want to show the broadcasted shape, not the constant's
828 // scalar shape.
829 auto stringify_constant = [](const HloConstantInstruction* constant,
830 const Shape& shape) {
831 // If the shape has a dimension of size zero, print it as e.g.
832 // "{} (f32[42, 0, 10])". The alternative, calling Literal::ToString(),
833 // enumerates all of its empty dimensions (e.g. "{ { {}, {} }, ..."), which
834 // is just noise.
835 if (ShapeUtil::IsZeroElementArray(shape)) {
836 return StrFormat("{} (%s)", ShapeUtil::HumanString(constant->shape()));
837 }
838
839 // Print the literal value of constants with <= K elements. Note that we
840 // use `constant->shape()` rather than `shape`, because if `constant` is a
841 // scalar that's broadcasted into `shape`, we want to print the constant.
842 optional<int64> elem_count;
843 if (shape.IsArray()) {
844 elem_count = ShapeUtil::ElementsIn(constant->shape());
845 }
846 // Allow HloDotDumper to print HloInstruction reconstructed from HloProto
847 // collected from profiling tools. Those constants may not have a valid
848 // literal.
849 if (elem_count.has_value() && *elem_count <= 8 && constant->HasLiteral()) {
850 return StrFormat("%s %s", shape.ToString(),
851 constant->literal().ToStringWithoutShape());
852 }
853
854 // Otherwise, print e.g. "%constant.42 (s32[100])".
855 string constant_name;
856 if (absl::StartsWith(constant->name(), "constant")) {
857 constant_name = constant->name();
858 } else {
859 constant_name = StrCat("constant ", constant->name());
860 }
861 return StrFormat("%s %s", constant_name, ShapeUtil::HumanString(shape));
862 };
863
864 std::vector<string> lines;
865 for (int64 i = 0; i < instr->operand_count(); ++i) {
866 const HloInstruction* operand = instr->operand(i);
867 optional<string> operand_str;
868 if (const auto* constant_operand =
869 DynCast<HloConstantInstruction>(operand)) {
870 operand_str =
871 stringify_constant(constant_operand, constant_operand->shape());
872 } else if (IsFusedBroadcastOfConstantEffectiveScalar(operand)) {
873 operand_str = stringify_constant(
874 Cast<HloConstantInstruction>(operand->operand(0)), operand->shape());
875 } else if (ShouldMergeIntoUsers(operand)) {
876 // Special case: If the operand is a parameter to a fusion node and it
877 // always has a constant value, display it like a regular constant.
878 //
879 // For other parameters, use the parameter number rather than the proper
880 // name, because that's generally how people think of the node.
881 if (operand->opcode() == HloOpcode::kParameter) {
882 if (const HloConstantInstruction* constant =
883 TryGetFusionParameterConstant(operand)) {
884 operand_str = stringify_constant(constant, constant->shape());
885 } else {
886 operand_str = StrFormat("Parameter %d", operand->parameter_number());
887 }
888 } else {
889 operand_str = operand->name();
890 }
891 }
892
893 if (operand_str) {
894 if (instr->operand_count() > 1) {
895 lines.push_back(StrFormat("<b>operand %d</b> = %s", i, *operand_str));
896 } else {
897 lines.push_back(StrFormat("<b>operand</b> = %s", *operand_str));
898 }
899 }
900 }
901 return StrJoin(lines, "<br/>");
902 }
903
GetInstructionColor(const HloInstruction * instr)904 ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) {
905 if (debug_options_.xla_hlo_graph_sharding_color()) {
906 if (!instr->has_sharding()) {
907 return kDashedBorder;
908 }
909 auto it = sharding_colors_.find(instr->sharding());
910 if (it != sharding_colors_.end()) {
911 return it->second;
912 }
913 ColorScheme color = static_cast<ColorScheme>(
914 kBlue + (next_shard_color_++ % (kDashedBorder - kBlue)));
915 sharding_colors_.emplace(instr->sharding(), color);
916 return color;
917 }
918
919 // Choose different weights of orange for small vs large parameters. This
920 // distinction is often important, especially in fusion nodes.
921 auto parameter_color = IsSmall(instr) ? kOrange : kDarkOrange;
922
923 // Special case: If this instruction has a parameter merged into it, paint it
924 // the same color as a parameter. Unless the merged-in parameter is a
925 // parameter to a fusion node that is bound to a constant -- these aren't
926 // "real" parameters from the user's perspective.
927 if (absl::c_any_of(instr->operands(), [&](const HloInstruction* operand) {
928 return operand->opcode() == HloOpcode::kParameter &&
929 ShouldMergeIntoUsers(operand) &&
930 TryGetFusionParameterConstant(operand) == nullptr;
931 })) {
932 return parameter_color;
933 }
934
935 // Pick different colors or shapes for instructions which are particularly
936 // expensive (eg, dot) and those which are unusual in some way or unique
937 // (eg, parameter).
938 switch (instr->opcode()) {
939 case HloOpcode::kAbs:
940 case HloOpcode::kAdd:
941 case HloOpcode::kAnd:
942 case HloOpcode::kAtan2:
943 case HloOpcode::kBitcastConvert:
944 case HloOpcode::kCeil:
945 case HloOpcode::kClamp:
946 case HloOpcode::kClz:
947 case HloOpcode::kCompare:
948 case HloOpcode::kComplex:
949 case HloOpcode::kConvert:
950 case HloOpcode::kCos:
951 case HloOpcode::kDivide:
952 case HloOpcode::kExp:
953 case HloOpcode::kExpm1:
954 case HloOpcode::kFloor:
955 case HloOpcode::kImag:
956 case HloOpcode::kIota:
957 case HloOpcode::kIsFinite:
958 case HloOpcode::kLog:
959 case HloOpcode::kLog1p:
960 case HloOpcode::kMaximum:
961 case HloOpcode::kMinimum:
962 case HloOpcode::kMultiply:
963 case HloOpcode::kNegate:
964 case HloOpcode::kNot:
965 case HloOpcode::kPopulationCount:
966 case HloOpcode::kOr:
967 case HloOpcode::kXor:
968 case HloOpcode::kPower:
969 case HloOpcode::kReal:
970 case HloOpcode::kRemainder:
971 case HloOpcode::kRng:
972 case HloOpcode::kRngGetAndUpdateState:
973 case HloOpcode::kRngBitGenerator:
974 case HloOpcode::kRoundNearestAfz:
975 case HloOpcode::kRsqrt:
976 case HloOpcode::kSelect:
977 case HloOpcode::kShiftLeft:
978 case HloOpcode::kShiftRightArithmetic:
979 case HloOpcode::kShiftRightLogical:
980 case HloOpcode::kLogistic:
981 case HloOpcode::kSign:
982 case HloOpcode::kSin:
983 case HloOpcode::kSlice:
984 case HloOpcode::kSort:
985 case HloOpcode::kSqrt:
986 case HloOpcode::kCbrt:
987 case HloOpcode::kSubtract:
988 case HloOpcode::kTanh:
989 // De-emphasize scalar-shaped elementwise ops -- they're generally
990 // uninteresting.
991 if (ShapeUtil::IsEffectiveScalar(instr->shape())) {
992 return kWhite;
993 }
994 return kYellow;
995 case HloOpcode::kBitcast:
996 case HloOpcode::kGetTupleElement:
997 case HloOpcode::kTrace:
998 case HloOpcode::kAfterAll:
999 case HloOpcode::kAddDependency:
1000 case HloOpcode::kTuple:
1001 return kWhite;
1002 case HloOpcode::kBroadcast:
1003 // De-emphasize nodes which broadcast a scalar within a fusion node --
1004 // these are essentially free.
1005 if (instr->IsFused() &&
1006 ShapeUtil::IsEffectiveScalar(instr->operand(0)->shape())) {
1007 return kWhite;
1008 }
1009 return kGreen;
1010 case HloOpcode::kConcatenate:
1011 case HloOpcode::kDynamicSlice:
1012 case HloOpcode::kGather:
1013 case HloOpcode::kPad:
1014 case HloOpcode::kReshape:
1015 case HloOpcode::kDynamicReshape:
1016 case HloOpcode::kReverse:
1017 case HloOpcode::kTupleSelect:
1018 case HloOpcode::kTranspose:
1019 // De-emphasize scalar-shaped data movement ops and all data movement ops
1020 // inside fusion nodes, both of which are essentially free.
1021 if (ShapeUtil::IsEffectiveScalar(instr->shape()) || instr->IsFused()) {
1022 return kWhite;
1023 }
1024 return kGreen;
1025 case HloOpcode::kDynamicUpdateSlice:
1026 // Unlike the data-movement ops above, dynamic-update-slice is not ~free
1027 // inside of fusion nodes, so we de-emphasize it only if it's
1028 // scalar-shaped.
1029 if (ShapeUtil::IsEffectiveScalar(instr->shape())) {
1030 return kWhite;
1031 }
1032 return kGreen;
1033 case HloOpcode::kScatter:
1034 // Do not de-emphasize Scatter, since it involves significant work.
1035 case HloOpcode::kCopy:
1036 case HloOpcode::kCopyStart:
1037 case HloOpcode::kCopyDone:
1038 // Emphasize copy nodes, which are either physical transposes (and thus
1039 // significant), or copies of read-only buffers (and thus dead weight).
1040 return kGreen;
1041 case HloOpcode::kConvolution:
1042 case HloOpcode::kDot:
1043 case HloOpcode::kFft:
1044 case HloOpcode::kTriangularSolve:
1045 case HloOpcode::kCholesky:
1046 return kDarkBlue;
1047 case HloOpcode::kReducePrecision:
1048 return kRed;
1049 case HloOpcode::kParameter:
1050 return parameter_color;
1051 case HloOpcode::kBatchNormGrad:
1052 case HloOpcode::kBatchNormInference:
1053 case HloOpcode::kBatchNormTraining:
1054 case HloOpcode::kReduce:
1055 case HloOpcode::kReduceWindow:
1056 case HloOpcode::kSelectAndScatter:
1057 return kPurple;
1058 case HloOpcode::kDomain:
1059 case HloOpcode::kFusion:
1060 case HloOpcode::kMap:
1061 case HloOpcode::kGetDimensionSize:
1062 case HloOpcode::kSetDimensionSize:
1063 return kGray;
1064 case HloOpcode::kAllGather:
1065 case HloOpcode::kAllReduce:
1066 case HloOpcode::kAllToAll:
1067 case HloOpcode::kCollectivePermute:
1068 case HloOpcode::kCollectivePermuteStart:
1069 case HloOpcode::kCollectivePermuteDone:
1070 case HloOpcode::kInfeed:
1071 case HloOpcode::kOutfeed:
1072 case HloOpcode::kPartitionId:
1073 case HloOpcode::kRecv:
1074 case HloOpcode::kRecvDone:
1075 case HloOpcode::kSend:
1076 case HloOpcode::kSendDone:
1077 case HloOpcode::kReplicaId:
1078 return kBrown;
1079 case HloOpcode::kCall:
1080 case HloOpcode::kConditional:
1081 case HloOpcode::kCustomCall:
1082 case HloOpcode::kWhile:
1083 return kDarkGreen;
1084 case HloOpcode::kConstant:
1085 LOG(FATAL) << "Constants don't get their own nodes in the graph.";
1086 }
1087 }
1088
GetInstructionNodeShape(const HloInstruction * instr)1089 string HloDotDumper::GetInstructionNodeShape(const HloInstruction* instr) {
1090 // Give while loops a different shape so they're easier to pick out.
1091 switch (instr->opcode()) {
1092 case HloOpcode::kWhile:
1093 return "ellipse";
1094 default:
1095 return "rect";
1096 }
1097 }
1098
GetInstructionNodeLabel(const HloInstruction * instr)1099 string HloDotDumper::GetInstructionNodeLabel(const HloInstruction* instr) {
1100 // If we have a parameter, put the param number in the name.
1101 if (instr->opcode() == HloOpcode::kParameter) {
1102 return StrFormat("<b>Parameter %d</b>", instr->parameter_number());
1103 }
1104
1105 // The HLO instruction name contains usually the opcode, e.g. "%add.42" is
1106 // an add instruction. In this case we render just the name.
1107 if (absl::StartsWith(instr->name(), HloOpcodeString(instr->opcode()))) {
1108 return StrFormat("<b>%s</b>", HtmlLikeStringSanitize(instr->name()));
1109 }
1110 string extended_opcode =
1111 StrCat(HloOpcodeString(instr->opcode()),
1112 instr->opcode() != HloOpcode::kFusion
1113 ? ""
1114 : StrCat(":", xla::ToString(instr->fusion_kind())));
1115 // If the name does not contain the opcode, render both.
1116 return StrFormat("<b>%s</b><br/>%s", HtmlLikeStringSanitize(extended_opcode),
1117 HtmlLikeStringSanitize(instr->name()));
1118 }
1119
GetInstructionNodeMetadata(const HloInstruction * instr)1120 string HloDotDumper::GetInstructionNodeMetadata(const HloInstruction* instr) {
1121 std::vector<string> lines;
1122 if (!instr->metadata().op_name().empty()) {
1123 lines.push_back(HtmlLikeStringSanitize(instr->metadata().op_name()));
1124 }
1125 if (!instr->metadata().op_type().empty()) {
1126 lines.push_back(StrFormat(
1127 "op_type: %s", HtmlLikeStringSanitize(instr->metadata().op_type())));
1128 }
1129 if (!instr->metadata().source_file().empty() &&
1130 instr->metadata().source_line() != 0) {
1131 lines.push_back(StrFormat("source: %s:%d", instr->metadata().source_file(),
1132 instr->metadata().source_line()));
1133 }
1134
1135 return StrJoin(lines, "\n");
1136 }
1137
GetInstructionNodeBackendConfig(const HloInstruction * instr)1138 string HloDotDumper::GetInstructionNodeBackendConfig(
1139 const HloInstruction* instr) {
1140 if (!hlo_render_options_.show_backend_config ||
1141 instr->raw_backend_config_string().empty()) {
1142 return "";
1143 }
1144
1145 return StrCat("backend_config=\"", instr->raw_backend_config_string(), "\"");
1146 }
1147
GetInstructionNodeExtraInfo(const HloInstruction * instr)1148 string HloDotDumper::GetInstructionNodeExtraInfo(const HloInstruction* instr) {
1149 std::vector<string> lines;
1150
1151 // Get the instruction's extra attributes excluding the names of its
1152 // subcomputations, since those are drawn explicitly in the graph.
1153 for (const auto& line : instr->ExtraAttributesToString(
1154 HloPrintOptions().set_print_subcomputation_mode(
1155 HloPrintOptions::PrintSubcomputationMode::kOff))) {
1156 // Some instructions have giant replica group fields, so truncate the
1157 // replica group line length to 128.
1158 constexpr int kMaxReplicaGroupLen = 128;
1159 if (absl::StartsWith(line, "replica_groups=") &&
1160 line.length() > kMaxReplicaGroupLen) {
1161 lines.push_back(HtmlLikeStringSanitize(
1162 StrCat(line.substr(0, kMaxReplicaGroupLen - 3), "...")));
1163 } else {
1164 lines.push_back(HtmlLikeStringSanitize(line));
1165 }
1166 }
1167
1168 // Show the shape and layout of the instruction, unless it's an inlined fusion
1169 // node -- there the shape and layout is present in the output node.
1170 if (instr->opcode() != HloOpcode::kFusion ||
1171 !ShouldShowFusionSubcomputation(instr)) {
1172 // Show layout of instructions with more than one dimension. Don't show
1173 // layout on tuples or tensors with just one dimension (which only have one
1174 // possible layout) to avoid visual noise.
1175 bool shape_is_multidim = false;
1176 ShapeUtil::ForEachSubshape(instr->shape(),
1177 [&](const Shape& s, const ShapeIndex&) {
1178 shape_is_multidim |= s.dimensions_size() > 1;
1179 });
1180 string instr_shape;
1181 if (instr->opcode() != HloOpcode::kTuple && shape_is_multidim) {
1182 instr_shape = ShapeUtil::HumanStringWithLayout(instr->shape());
1183 } else {
1184 instr_shape = ShapeUtil::HumanString(instr->shape());
1185 }
1186
1187 // Some instructions have giant tuples as their shapes, so truncate the
1188 // HLO's shape to kMaxShapeLen characters.
1189 constexpr int kMaxShapeLen = 64;
1190 if (instr_shape.length() > kMaxShapeLen) {
1191 instr_shape = StrCat(
1192 absl::string_view(instr_shape).substr(0, kMaxShapeLen - 3), "...");
1193 }
1194 lines.push_back(HtmlLikeStringSanitize(instr_shape));
1195 }
1196 if (debug_options_.xla_hlo_graph_addresses()) {
1197 lines.push_back(StrFormat("[%p]", instr));
1198 }
1199 if (profile_ != nullptr) {
1200 double hlo_cycles_executed = profile_->GetCyclesTakenBy(*instr);
1201 double total_cycles_executed =
1202 profile_->total_cycles_executed(*instr->parent());
1203 if (hlo_cycles_executed > 0 && total_cycles_executed > 0) {
1204 lines.push_back(
1205 StrFormat("%% of cycles executed=%.2f",
1206 100 * hlo_cycles_executed / total_cycles_executed));
1207 }
1208 }
1209 return StrJoin(lines, "<br/>");
1210 }
1211
AddInstructionIncomingEdges(const HloInstruction * instr)1212 void HloDotDumper::AddInstructionIncomingEdges(const HloInstruction* instr) {
1213 auto add_edge = [&](const HloInstruction* from, const HloInstruction* to,
1214 int64 operand_num, bool control_edge = false) {
1215 from = GetNodeForEdge(from);
1216
1217 if (!filter_.Show(from) || from->opcode() == HloOpcode::kConstant ||
1218 IsFusedBroadcastOfConstantEffectiveScalar(from) ||
1219 ShouldMergeIntoUsers(from)) {
1220 return;
1221 }
1222 VLOG(2) << "Adding edge from " << from->name() << " to " << to->name()
1223 << " as " << next_edge_id_;
1224 edge_ids_.insert({{from, to}, next_edge_id_++});
1225
1226 string edge_label;
1227 if (instr->operand_count() > 1 && !control_edge) {
1228 edge_label =
1229 StrFormat(R"( headlabel="%d", labeldistance=2)", operand_num);
1230 } else if (control_edge) {
1231 edge_label = "style=\"dotted\" color=\"gray\" label=\"ctrl\"";
1232 }
1233
1234 // We print "small" arrays using a hollow arrowhead and "large" arrays using
1235 // a filled arrowhead.
1236 constexpr char kEdgeFmt[] =
1237 R"(%s -> %s [arrowhead=%s tooltip="%s -> %s" %s];)";
1238 edges_.push_back(StrFormat(kEdgeFmt, InstructionId(from), InstructionId(to),
1239 (IsSmall(from) ? "empty" : "normal"),
1240 from->name(), to->name(), edge_label));
1241 };
1242
1243 // Add edges from instr's operands to instr. Parameters within fusion
1244 // expressions are handled specially -- we draw an edge from the corresponding
1245 // operand on the fusion node itself to the parameter.
1246 if (instr->opcode() == HloOpcode::kParameter && instr->IsFused()) {
1247 // Only add the edge if this is not the outermost computation; otherwise it
1248 // will lead from a node we're not drawing.
1249 if (instr->parent() != computation_) {
1250 const HloInstruction* fusion = instr->parent()->FusionInstruction();
1251 add_edge(fusion->operand(instr->parameter_number()), instr,
1252 /*operand_num=*/0);
1253 }
1254 } else {
1255 for (int64 i = 0; i < instr->operand_count(); ++i) {
1256 add_edge(instr->operand(i), instr, i);
1257 }
1258 for (const HloInstruction* pred : instr->control_predecessors()) {
1259 add_edge(pred, instr, /*operand_num=*/0, /*control_edge=*/true);
1260 }
1261 }
1262 }
1263
GetInstructionTrivialComputationStr(const HloInstruction * instr)1264 string HloDotDumper::GetInstructionTrivialComputationStr(
1265 const HloInstruction* instr) {
1266 // called_computations() on a fusion node "inherits" any called computations
1267 // of the fused root, which isn't what we want. Just ignore fusion nodes
1268 // here; they're handled separately.
1269 if (instr->opcode() == HloOpcode::kFusion) {
1270 return "";
1271 }
1272
1273 std::vector<string> lines;
1274 for (int64 i = 0; i < instr->called_computations().size(); ++i) {
1275 optional<string> computation_type =
1276 MatchTrivialComputation(instr->called_computations()[i]);
1277 if (!computation_type) {
1278 continue;
1279 }
1280 if (instr->called_computations().size() == 1) {
1281 lines.push_back(StrFormat("Subcomputation: <b>%s</b>",
1282 HtmlLikeStringSanitize(*computation_type)));
1283 } else {
1284 lines.push_back(StrFormat("Subcomputation %d: <b>%s</b>", i,
1285 HtmlLikeStringSanitize(*computation_type)));
1286 }
1287 }
1288 return StrJoin(lines, "<br/>");
1289 }
1290
GetNodeForEdge(const HloInstruction * instr)1291 const HloInstruction* HloDotDumper::GetNodeForEdge(
1292 const HloInstruction* instr) {
1293 while (instr->opcode() == HloOpcode::kFusion &&
1294 ShouldShowFusionSubcomputation(instr)) {
1295 instr = instr->fused_expression_root();
1296 }
1297 return instr;
1298 }
1299
1300 // Gets a NodeFilter that includes roughly all instructions whose distance from
1301 // root is <= radius.
MakeNodeRadiusAroundFilter(const HloInstruction * root,int64 radius,const absl::flat_hash_set<const HloInstruction * > & boundary)1302 NodeFilter MakeNodeRadiusAroundFilter(
1303 const HloInstruction* root, int64 radius,
1304 const absl::flat_hash_set<const HloInstruction*>& boundary) {
1305 // First, find the neighborhood of nodes with distance from root <= radius.
1306 // These nodes are our initial set of "normal" nodes.
1307 absl::flat_hash_map<const HloInstruction*, NodeFilterResult> nodes;
1308 std::deque<std::pair<const HloInstruction*, /*depth*/ int64>> worklist;
1309 worklist.push_back({root, 0});
1310 while (!worklist.empty()) {
1311 const HloInstruction* instr;
1312 int64 depth;
1313 std::tie(instr, depth) = worklist.front();
1314 worklist.pop_front();
1315
1316 nodes[instr] = kNormalNode;
1317 if (depth == radius) {
1318 continue;
1319 }
1320 if (boundary.contains(instr)) {
1321 continue;
1322 }
1323
1324 // Traverse into instr's operands.
1325 //
1326 // Don't traverse into tuples' operands unless the tuple is the root.
1327 // Usually a tuple is the bottommost node in the graph, and so its operands
1328 // are not interesting to the graph at hand.
1329 if (instr == root || instr->opcode() != HloOpcode::kTuple) {
1330 for (const HloInstruction* operand : instr->operands()) {
1331 if (!nodes.contains(operand)) {
1332 worklist.push_back({operand, depth + 1});
1333 }
1334 }
1335 }
1336
1337 // Traverse into instr's nested computations.
1338 for (const HloComputation* computation : instr->called_computations()) {
1339 worklist.push_back({computation->root_instruction(), depth + 1});
1340 }
1341
1342 // Traverse into instr's users, unless:
1343 //
1344 // - there are a ton of them, in which case they're probably not
1345 // interesting (and anyway, rendering them all would make the graph
1346 // unreadable), or
1347 // - instr is a constant, in which case its users are probably not
1348 // interesting.
1349 if (instr->opcode() == HloOpcode::kConstant) {
1350 continue;
1351 }
1352 constexpr int kMaxUsersToRender = 16;
1353 if (instr->user_count() > kMaxUsersToRender) {
1354 // If we're going to skip this node's users, style it as such.
1355 nodes[instr] = kSomeUsersOmitted;
1356 continue;
1357 }
1358 for (const HloInstruction* user : instr->users()) {
1359 if (!nodes.contains(user)) {
1360 worklist.push_back({user, depth + 1});
1361 }
1362 }
1363 }
1364
1365 auto is_displayed = [&](const HloInstruction* instr) {
1366 // Constants are displayed inline with their users; they're never omitted.
1367 // Nodes in subcomputations are always shown.
1368 return nodes.contains(instr) || instr->opcode() == HloOpcode::kConstant ||
1369 instr->parent() != root->parent();
1370 };
1371
1372 // Make a second pass over 'nodes' to fix up the NodeFilterResults now that we
1373 // know which nodes will be included in the graph.
1374 for (auto& kv : nodes) {
1375 const HloInstruction* instr = kv.first;
1376 NodeFilterResult& filter_result = kv.second;
1377 const auto& operands = instr->operands();
1378
1379 if (absl::c_any_of(operands, is_displayed) &&
1380 !absl::c_all_of(operands, is_displayed)) {
1381 // Mark nodes with some operands omitted appropriately.
1382 filter_result = kSomeOperandsOmitted;
1383 } else if (!operands.empty() && absl::c_none_of(operands, is_displayed)) {
1384 // Mark nodes with *all* operands omitted appropriately.
1385 filter_result = kOmitNodeOperands;
1386 }
1387
1388 // Promote nodes with type kSomeUsersOmitted to kNormalNode if all of their
1389 // users made it into the graph.
1390 if (filter_result == kSomeUsersOmitted &&
1391 absl::c_all_of(instr->users(), is_displayed)) {
1392 filter_result = kNormalNode;
1393 }
1394 }
1395
1396 // Highlight the root node.
1397 nodes[root] = kHighlightNode;
1398
1399 return NodeFilter([=](const HloInstruction* instr) {
1400 auto it = nodes.find(instr);
1401 if (it != nodes.end()) {
1402 return it->second;
1403 }
1404 // Show all nodes in subcomputations.
1405 if (instr->parent() != root->parent()) {
1406 return kNormalNode;
1407 }
1408 return kHideNode;
1409 });
1410 }
1411
1412 // Gets a node filter that includes nodes on all paths from `from` to `to`. If
1413 // the all-paths set contains more than max_nodes elements, includes the nodes
1414 // on the shortest paths and sets hit_limit to true.
MakeNodeFromToFilter(const HloInstruction * from,const HloInstruction * to,int64 max_nodes,bool * hit_limit)1415 NodeFilter MakeNodeFromToFilter(const HloInstruction* from,
1416 const HloInstruction* to, int64 max_nodes,
1417 bool* hit_limit) {
1418 *hit_limit = false;
1419
1420 // Elements in the queue are paths through the graph.
1421 std::deque<std::vector<const HloInstruction*>> queue;
1422 queue.push_front({from});
1423
1424 // Compute the set of nodes we want to show using a slightly-modified
1425 // Djikstra's algorithm. The only real difference is, rather than stopping
1426 // when we find a (shortest) path, we continue until we've found max_nodes
1427 // nodes on some path.
1428 std::unordered_set<const HloInstruction*> visited;
1429 std::unordered_set<const HloInstruction*> to_display = {from, to};
1430 while (!queue.empty() && to_display.size() < max_nodes) {
1431 std::vector<const HloInstruction*> path = std::move(queue.front());
1432 queue.pop_front();
1433 if (!visited.insert(path.back()).second) {
1434 continue;
1435 }
1436
1437 for (const auto* user : path.back()->users()) {
1438 if (user == to) {
1439 auto it = path.begin();
1440 for (; it != path.end() && to_display.size() < max_nodes; ++it) {
1441 to_display.insert(*it);
1442 }
1443 if (it != path.end()) {
1444 *hit_limit = true;
1445 }
1446 } else if (!visited.count(user)) {
1447 auto new_path = path;
1448 new_path.push_back(user);
1449 queue.push_back(std::move(new_path));
1450 }
1451 }
1452 }
1453
1454 return NodeFilter([=](const HloInstruction* instr) {
1455 if (instr == from || instr == to) {
1456 return kHighlightNode;
1457 }
1458 return to_display.count(instr) ? kNormalNode : kHideNode;
1459 });
1460 }
1461
WrapDotInHtml(absl::string_view dot)1462 string WrapDotInHtml(absl::string_view dot) {
1463 static const char html_prefix[] = R"html(
1464 <!DOCTYPE html>
1465 <html>
1466 <head>
1467 <meta charset="utf-8">
1468 <style type="text/css">
1469 body {
1470 height: 100vh;
1471 margin: 0;
1472 }
1473 </style>
1474 </head>
1475 <body>
1476 <!-- Integrity hash is generated by https://www.srihash.org/ -->
1477 <script src="https://cdn.jsdelivr.net/npm/viz.js@2.1.1/viz.js"
1478 integrity="sha384-aD1MJYb0WKIUT+CtwJp5LTuV3U4pLAS6B/nUxL7ECimC2pN9N8vjlMr/yQCAkzxE"
1479 crossorigin="anonymous"></script>
1480 <script src="https://cdn.jsdelivr.net/npm/viz.js@2.1.1/full.render.js"
1481 integrity="sha384-bAixY275aIpCj6Te19y0MILZ4V+VEC8CVFujFEH+Lf7W+4XYYeYLwW5IBI6yQmMT"
1482 crossorigin="anonymous"></script>
1483 <script src="https://cdn.jsdelivr.net/npm/svg-pan-zoom@3.6.0/dist/svg-pan-zoom.min.js"
1484 integrity="sha384-3008WpYB2pOBvE7lwkrKf+qTmbTPGGPYxA9C1YVhvbPukns4ZFj7E98QPLkNW9dS"
1485 crossorigin="anonymous"></script>
1486 <div id="container" style="height:95vh; border:1px solid black; "></div>
1487 <script>
1488 var data = `
1489 )html";
1490
1491 static const char html_suffix[] = R"html(
1492 `;
1493 var cssregex = new RegExp('stylesheet=<([^]*)\n>\n', 'gm');
1494 var results = cssregex.exec(data)
1495 // graphviz has problem dealing with large stylesheets.
1496 // https://github.com/tensorflow/tensorflow/issues/17220#issuecomment-369228492
1497 // In order to avoid the problem, remove the stylesheet from the dot and
1498 // insert it directly info the rendered SVG.
1499 var dot_data = data;
1500 var css_data = ''
1501 if (results !== null) {
1502 css_data = results[1].replace(/\s*data:.*\s*,/,''); // Strip content-type field.
1503 // CSS inside DOT is URL-escaped, so we must unescape it
1504 // before we can insert it into SVG.
1505 css_data = unescape(css_data);
1506 dot_data = data.replace(cssregex, ''); // Remove the stylesheet
1507 }
1508
1509 var render_start = performance.now()
1510 function add_controls(svg) {
1511 var htmlblob = new Blob([document.documentElement.innerHTML],
1512 {type: 'text/html'});
1513 var savehtml = document.createElement('a');
1514 savehtml.setAttribute('href', URL.createObjectURL(htmlblob));
1515 savehtml.setAttribute('download', 'graph.html');
1516 savehtml.innerHTML = " [Save HTML+SVG] ";
1517 document.body.append(savehtml);
1518 var svgblob = new Blob([svg.outerHTML], {type: 'image/svg'});
1519 var savesvg = document.createElement('a');
1520 savesvg.setAttribute('href', URL.createObjectURL(svgblob));
1521 savesvg.setAttribute('download', 'graph.svg');
1522 savesvg.innerHTML = " [Save SVG] ";
1523 document.body.append(savesvg);
1524 var dotblob = new Blob([data], {type: 'text/dot'});
1525 var savedot = document.createElement('a');
1526 savedot.setAttribute('href', URL.createObjectURL(dotblob));
1527 savedot.setAttribute('download', 'graph.dot');
1528 savedot.innerHTML = " [Save DOT] ";
1529 document.body.append(savedot);
1530 // Will get called after embed element was loaded
1531 var panzoom = svgPanZoom(svg, {
1532 zoomEnabled: true,
1533 controlIconsEnabled: true,
1534 });
1535 document.getElementsByTagName("BODY")[0].onresize = function() {
1536 panzoom.resize();
1537 panzoom.fit();
1538 panzoom.center();
1539 };
1540 var render_end = performance.now();
1541 var render_note = document.createElement('div')
1542 render_note.innerHTML = 'Rendering took '
1543 + (render_end - render_start).toFixed(2) + "ms."
1544 document.body.append(render_note);
1545 }
1546 var svg = document.getElementById('graph')
1547 if (svg == null) {
1548 // Need to render SVG first.
1549 var viz = new Viz();
1550 viz.renderSVGElement(dot_data)
1551 .then(function(svg){
1552 var container = document.getElementById('container')
1553 var style = document.createElementNS('http://www.w3.org/2000/svg', 'style');
1554 var node = document.createTextNode(css_data);
1555 style.appendChild(node);
1556 svg.setAttribute('width', '100%');
1557 svg.setAttribute('height', '100%');
1558 svg.setAttribute('id', 'graph');
1559 svg.appendChild(style);
1560 container.appendChild(svg);
1561 add_controls(svg);
1562 })
1563 } else {
1564 // HTML already has rendered SVG embedded, so we just need to add
1565 // controls.
1566 add_controls(svg);
1567 }
1568 </script>
1569 </body>
1570 </html>
1571 )html";
1572
1573 return absl::StrCat(html_prefix, dot, html_suffix);
1574 }
1575
1576 tensorflow::mutex url_renderer_mu(tensorflow::LINKER_INITIALIZED);
1577 std::function<StatusOr<string>(absl::string_view)>* url_renderer
1578 TF_GUARDED_BY(url_renderer_mu) = nullptr;
1579
1580 // Storage for fusion visualization: (module_id, computation_id) -> sequence of
1581 // dot dumps.
1582 tensorflow::mutex fusion_visualizer_state_mu(tensorflow::LINKER_INITIALIZED);
1583 static auto& fusion_visualizer_state TF_GUARDED_BY(fusion_visualizer_state_mu) =
1584 *new absl::flat_hash_map<std::pair<int64, int64>,
1585 std::vector<std::string>>();
1586
1587 // Generates a key to the fusion visualizer state mapping.
1588 std::pair<int, int> FusionVisualizerStateKey(
1589 const HloComputation& computation) {
1590 return std::make_pair(computation.parent()->unique_id(),
1591 computation.unique_id());
1592 }
1593
1594 // Generates a fusion explorer for the given computation using the data in
1595 // fusion_visualizer_state and the URL renderer. Precondition: url_renderer !=
1596 // nullptr.
1597 StatusOr<std::string> WrapFusionExplorer(const HloComputation& computation)
1598 TF_EXCLUSIVE_LOCKS_REQUIRED(url_renderer_mu) {
1599 CHECK(url_renderer != nullptr);
1600 tensorflow::mutex_lock lock(fusion_visualizer_state_mu);
1601 const std::vector<std::string>& dot_graphs =
1602 fusion_visualizer_state[FusionVisualizerStateKey(computation)];
1603 std::vector<std::string> dot_urls;
1604 dot_urls.reserve(dot_graphs.size());
1605 for (const std::string& dot : dot_graphs) {
1606 TF_ASSIGN_OR_RETURN(std::string url, (*url_renderer)(dot));
1607 dot_urls.push_back(url);
1608 }
1609
1610 return absl::StrReplaceAll(
1611 R"(
1612 <!doctype html>
1613 <style>
1614 html, body {height: 100%; text-align: center;}
1615 #display {height: 80%; width: 80%;}
1616 </style>
1617 <title>Fusion Explorer: $TITLE</title>
1618 <iframe id='display' width=80% height=80%></iframe>
1619 <p id='description'></p>
1620 <p>
1621 <a id='prev' href='#'>Prev Step</a>
1622 <a id='next' href='#'>Next Step</a>
1623 </p>
1624 <p>
1625 Use j/k for keyboard navigation.
1626 </p>
1627 <script>
1628 var currId = -1;
1629 var urls = [$URLS];
1630
1631 var setIframe = function() {
1632 document.getElementById('display').src = urls[currId];
1633 };
1634
1635 var update = function(delta) {
1636 currId = (currId + delta + urls.length) % urls.length;
1637 document.getElementById('description').innerHTML = "Frame #"
1638 + (currId + 1) + " / " + urls.length;
1639 setIframe();
1640 };
1641
1642 document.getElementById('prev').onclick = function() {
1643 update(-1);
1644 return false;
1645 };
1646
1647 document.getElementById('next').onclick = function() {
1648 update(1);
1649 return false;
1650 };
1651
1652 window.addEventListener("keydown", function (event) {
1653 if (event.defaultPrevented) {
1654 return;
1655 }
1656 if (event.key == "j") {
1657 update(1);
1658 } else if (event.key == "k") {
1659 update(-1);
1660 } else {
1661 return;
1662 }
1663 event.preventDefault();
1664 }, true);
1665
1666 document.addEventListener("DOMContentLoaded", function() {
1667 update(1);
1668 });
1669
1670 </script>
1671 )",
1672 {{"$URLS", absl::StrJoin(dot_urls, ", ",
1673 [&](std::string* out, const std::string& url) {
1674 absl::StrAppend(out, "\"", url, "\"");
1675 })},
1676 {"$TITLE",
1677 absl::StrCat(computation.parent()->name(), "_", computation.name())}});
1678 }
1679
1680 // Precondition: (url_renderer != nullptr || (format != kUrl
1681 // && format != kFusionVisualization)).
1682 //
1683 // (We specify this as a precondition rather than checking it in here and
1684 // returning an error because we want to fail quickly when there's no URL
1685 // renderer available, and this function runs only after we've done all the work
1686 // of producing dot for the graph.)
WrapDotInFormat(const HloComputation & computation,absl::string_view dot,RenderedGraphFormat format)1687 StatusOr<string> WrapDotInFormat(const HloComputation& computation,
1688 absl::string_view dot,
1689 RenderedGraphFormat format)
1690 TF_EXCLUSIVE_LOCKS_REQUIRED(url_renderer_mu) {
1691 switch (format) {
1692 case RenderedGraphFormat::kUrl:
1693 CHECK(url_renderer != nullptr)
1694 << "Should have checked url_renderer != null before calling.";
1695 return (*url_renderer)(dot);
1696 case RenderedGraphFormat::kHtml:
1697 return WrapDotInHtml(dot);
1698 case RenderedGraphFormat::kDot:
1699 return string(dot);
1700 case RenderedGraphFormat::kFusionVisualization:
1701 return WrapFusionExplorer(computation);
1702 }
1703 }
1704
1705 } // namespace
1706
RegisterGraphToURLRenderer(std::function<StatusOr<string> (absl::string_view)> renderer)1707 void RegisterGraphToURLRenderer(
1708 std::function<StatusOr<string>(absl::string_view)> renderer) {
1709 tensorflow::mutex_lock lock(url_renderer_mu);
1710 if (url_renderer != nullptr) {
1711 LOG(WARNING) << "Multiple calls to RegisterGraphToURLRenderer. Last call "
1712 "wins, but because order of initialization in C++ is "
1713 "nondeterministic, this may not be what you want.";
1714 }
1715 delete url_renderer;
1716 url_renderer = new std::function<StatusOr<string>(absl::string_view)>(
1717 std::move(renderer));
1718 }
1719
RegisterFusionState(const HloComputation & computation,absl::string_view label)1720 Status RegisterFusionState(const HloComputation& computation,
1721 absl::string_view label) {
1722 tensorflow::mutex_lock lock(fusion_visualizer_state_mu);
1723 TF_ASSIGN_OR_RETURN(
1724 string dot_graph,
1725 RenderGraph(computation,
1726 absl::StrCat(computation.parent()->name(), ", ",
1727 computation.name(), ", ", label),
1728 /*debug_options=*/{}, xla::RenderedGraphFormat::kDot,
1729 /*hlo_execution_profile=*/nullptr,
1730 /*hlo_render_options=*/{}));
1731 std::vector<std::string>& fusion_states =
1732 fusion_visualizer_state[FusionVisualizerStateKey(computation)];
1733 if (fusion_states.empty() || fusion_states.back() != dot_graph) {
1734 fusion_states.push_back(dot_graph);
1735 }
1736 return Status::OK();
1737 }
1738
RenderGraph(const HloComputation & computation,absl::string_view label,const DebugOptions & debug_options,RenderedGraphFormat format,const HloExecutionProfile * hlo_execution_profile,HloRenderOptions hlo_render_options)1739 StatusOr<string> RenderGraph(const HloComputation& computation,
1740 absl::string_view label,
1741 const DebugOptions& debug_options,
1742 RenderedGraphFormat format,
1743 const HloExecutionProfile* hlo_execution_profile,
1744 HloRenderOptions hlo_render_options) {
1745 tensorflow::mutex_lock lock(url_renderer_mu);
1746 if (format == RenderedGraphFormat::kUrl && url_renderer == nullptr) {
1747 return Unavailable("Can't render as URL; no URL renderer was registered.");
1748 }
1749
1750 string rendered_dot =
1751 HloDotDumper(&computation, label, debug_options, hlo_render_options,
1752 hlo_execution_profile, NodeFilter())
1753 .Dump();
1754 return WrapDotInFormat(computation, rendered_dot, format);
1755 }
1756
RenderNeighborhoodAround(const HloInstruction & node,int radius,RenderedGraphFormat format,HloRenderOptions hlo_render_options,const absl::flat_hash_set<const HloInstruction * > & boundary)1757 StatusOr<string> RenderNeighborhoodAround(
1758 const HloInstruction& node, int radius, RenderedGraphFormat format,
1759 HloRenderOptions hlo_render_options,
1760 const absl::flat_hash_set<const HloInstruction*>& boundary) {
1761 tensorflow::mutex_lock lock(url_renderer_mu);
1762 if (format == RenderedGraphFormat::kUrl && url_renderer == nullptr) {
1763 return FailedPrecondition(
1764 "Can't render as URL; no URL renderer was registered.");
1765 }
1766
1767 string label =
1768 StrCat("Neighborhood of ", radius, " nodes around ", node.name());
1769 string rendered_dot =
1770 HloDotDumper(node.parent(), label,
1771 node.GetModule()->config().debug_options(),
1772 hlo_render_options, /*profile=*/nullptr,
1773 MakeNodeRadiusAroundFilter(&node, radius, boundary))
1774 .Dump();
1775 return WrapDotInFormat(*node.parent(), rendered_dot, format);
1776 }
1777
RenderAllPathsFromTo(const HloInstruction & from,const HloInstruction & to,int64 max_nodes,RenderedGraphFormat format,HloRenderOptions hlo_render_options)1778 StatusOr<string> RenderAllPathsFromTo(const HloInstruction& from,
1779 const HloInstruction& to, int64 max_nodes,
1780 RenderedGraphFormat format,
1781 HloRenderOptions hlo_render_options) {
1782 tensorflow::mutex_lock lock(url_renderer_mu);
1783 if (format == RenderedGraphFormat::kUrl && url_renderer == nullptr) {
1784 return FailedPrecondition(
1785 "Can't render as URL; no URL renderer was registered.");
1786 }
1787
1788 CHECK_EQ(from.parent(), to.parent()) << "Nodes must be in same computation!";
1789 auto debug_options = from.GetModule()->config().debug_options();
1790
1791 bool hit_limit = false;
1792 NodeFilter filter = MakeNodeFromToFilter(&from, &to, max_nodes, &hit_limit);
1793 string label;
1794 if (!hit_limit) {
1795 label = StrCat("All paths from ", from.name(), " to ", to.name());
1796 } else {
1797 label = StrCat(max_nodes, " nodes on the shortest paths from ", from.name(),
1798 " to ", to.name(),
1799 "<br/><br/>***SHOWING ONLY A SUBSET OF ALL PATHS BETWEEN "
1800 "NODES***<br/><br/>");
1801 }
1802 string rendered_dot =
1803 HloDotDumper(from.parent(), label, debug_options, hlo_render_options,
1804 /*profile=*/nullptr, filter)
1805 .Dump();
1806 return WrapDotInFormat(*from.parent(), rendered_dot, format);
1807 }
1808
1809 } // namespace xla
1810