1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
17
18 #include <unistd.h>
19
20 #include <algorithm>
21 #include <atomic>
22 #include <deque>
23 #include <map>
24 #include <memory>
25 #include <optional>
26 #include <queue>
27 #include <string>
28 #include <tuple>
29 #include <vector>
30
31 #include "absl/container/flat_hash_map.h"
32 #include "absl/container/flat_hash_set.h"
33 #include "absl/strings/match.h"
34 #include "absl/strings/str_cat.h"
35 #include "absl/strings/str_format.h"
36 #include "absl/strings/str_join.h"
37 #include "absl/strings/str_replace.h"
38 #include "tensorflow/compiler/xla/layout_util.h"
39 #include "tensorflow/compiler/xla/literal.h"
40 #include "tensorflow/compiler/xla/primitive_util.h"
41 #include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h"
42 #include "tensorflow/compiler/xla/service/gpu/cublas_cudnn.h"
43 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
44 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
45 #include "tensorflow/compiler/xla/service/hlo_module.h"
46 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
47 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
48 #include "tensorflow/compiler/xla/shape_util.h"
49 #include "tensorflow/compiler/xla/types.h"
50 #include "tensorflow/compiler/xla/util.h"
51 #include "tensorflow/compiler/xla/window_util.h"
52 #include "tensorflow/core/lib/core/status.h"
53 #include "tensorflow/core/lib/gtl/map_util.h"
54 #include "tensorflow/core/lib/io/zlib_compression_options.h"
55 #include "tensorflow/core/lib/io/zlib_outputbuffer.h"
56 #include "tensorflow/core/lib/strings/numbers.h"
57 #include "tensorflow/core/platform/base64.h"
58 #include "tensorflow/core/platform/env.h"
59 #include "tensorflow/core/platform/protobuf.h"
60 #include "tensorflow/core/platform/regexp.h"
61 #include "tensorflow/stream_executor/dnn.h"
62
63 namespace xla {
64 namespace {
65
66 using absl::StrAppend;
67 using absl::StrCat;
68 using absl::StrFormat;
69 using absl::StrJoin;
70 using std::nullopt;
71 using std::optional;
72
73 // Used to indicate how we should treat a given HLOInstruction in the graph.
74 // should we treat it like normal, hide it, and so on?
75 enum NodeFilterResult {
76 kNormalNode,
77 kHideNode,
78 // Make the node easy to find in the final graph.
79 kHighlightNode,
80 // "Gray out" the node to indicate that some of its operands have been
81 // omitted.
82 kSomeOperandsOmitted,
83 // Style the node the same as kSomeOperandsOmitted, but also don't connect it
84 // to its operands, even if they're present in the graph.
85 kOmitNodeOperands,
86 // Same style as kSomeOperandsOmitted, but used to indicate that some of the
87 // node's *users* have been omitted.
88 kSomeUsersOmitted,
89 };
90
91 // NodeFilter is essentially a map from HloInstruction*s to NodeFilterResult.
92 // It lets callers tell the graph-drawing routines which nodes they want to be
93 // shown, hidden, or highlighted.
94 class NodeFilter {
95 public:
__anon758a858c0202(const HloInstruction*) 96 NodeFilter() : filter_([](const HloInstruction*) { return kNormalNode; }) {}
97
NodeFilter(std::function<NodeFilterResult (const HloInstruction * instr)> filter)98 explicit NodeFilter(
99 std::function<NodeFilterResult(const HloInstruction* instr)> filter)
100 : filter_(std::move(filter)) {}
101
Show(const HloInstruction * instr) const102 bool Show(const HloInstruction* instr) const {
103 return filter_(instr) != kHideNode;
104 }
Highlight(const HloInstruction * instr) const105 bool Highlight(const HloInstruction* instr) const {
106 return filter_(instr) == kHighlightNode;
107 }
OmitOperands(const HloInstruction * instr) const108 bool OmitOperands(const HloInstruction* instr) const {
109 return filter_(instr) == kOmitNodeOperands;
110 }
SomeOrAllOperandsOmitted(const HloInstruction * instr) const111 bool SomeOrAllOperandsOmitted(const HloInstruction* instr) const {
112 auto result = filter_(instr);
113 return result == kOmitNodeOperands || result == kSomeOperandsOmitted;
114 }
Deemphasized(const HloInstruction * instr) const115 bool Deemphasized(const HloInstruction* instr) const {
116 auto result = filter_(instr);
117 return result == kOmitNodeOperands || result == kSomeOperandsOmitted ||
118 result == kSomeUsersOmitted;
119 }
120
121 private:
122 std::function<NodeFilterResult(const HloInstruction* instr)> filter_;
123 };
124
125 // We arbitrarily set this as the boundary between "large" and "small"
126 // instructions.
IsSmall(const HloInstruction * instr)127 bool IsSmall(const HloInstruction* instr) {
128 if (ShapeUtil::HasPrimitiveType(instr->shape(), OPAQUE_TYPE) ||
129 ShapeUtil::HasPrimitiveType(instr->shape(), TOKEN)) {
130 return true;
131 }
132 return ShapeUtil::ElementsInRecursive(instr->shape()) < 4096;
133 }
134
135 // Node color schemes, used by NodeColorAttributes.
136 enum ColorScheme {
137 kBlue,
138 kBrown,
139 kDarkBlue,
140 kDarkGreen,
141 kDarkOrange,
142 kDarkRed,
143 kGray,
144 kGreen,
145 kOrange,
146 kPurple,
147 kRed,
148 kWhite,
149 kYellow,
150
151 // Causes the node's border to be a dashed line, and its content to be gray
152 // text on a white background, suggesting that this is an "unimportant" node.
153 kDashedBorder,
154 };
155
156 // Graphviz attributes/colors that make up a color scheme.
157 struct NodeColors {
158 const char* style;
159 const char* fill_color;
160 const char* stroke_color;
161 const char* font_color;
162 };
163
NodeColorsForScheme(ColorScheme color)164 NodeColors NodeColorsForScheme(ColorScheme color) {
165 switch (color) {
166 case kBlue:
167 return NodeColors{"filled", "#bbdefb", "#8aacc8", "black"};
168 case kBrown:
169 return NodeColors{"filled", "#bcaaa4", "#8c7b75", "black"};
170 case kDarkBlue:
171 return NodeColors{"filled", "#1565c0", "#003c8f", "white"};
172 case kDarkGreen:
173 return NodeColors{"filled", "#2e7d32", "#005005", "white"};
174 case kDarkOrange:
175 // This is more of a "medium" orange, made to look close to kOrange;
176 // there's probably room for a darker weight if desired.
177 return NodeColors{"filled", "#ffb74d", "#c88719", "black"};
178 case kDarkRed:
179 return NodeColors{"filled", "#b71c1c", "#7f0000", "white"};
180 case kGray:
181 return NodeColors{"filled", "#cfd8dc", "#9ea7aa", "black"};
182 case kGreen:
183 return NodeColors{"filled", "#c8e6c9", "#97b498", "black"};
184 case kOrange:
185 return NodeColors{"filled", "#ffe0b2", "#cbae82", "black"};
186 case kPurple:
187 return NodeColors{"filled", "#e1bee7", "#af8eb5", "black"};
188 case kRed:
189 return NodeColors{"filled", "#ffcdd2", "#cb9ca1", "black"};
190 case kWhite:
191 return NodeColors{"filled", "white", "black", "black"};
192 case kYellow:
193 return NodeColors{"filled", "#fff9c4", "#cbc693", "black"};
194 case kDashedBorder:
195 // "filled,dashed" looks the same as "dashed", since we have a white
196 // background. But we use "filled,dashed" so that when you hover over
197 // any part of the node (not just the text inside the node), our css
198 // :hover rule is triggered.
199 return NodeColors{"filled,dashed", "white", "#757575", "#757575"};
200 }
201 }
202
203 // Given a ColorScheme, returns an attribute string for a node of that color.
204 // Sets the node's style and fill/stroke/text colors.
205 //
206 // Colors are from https://material.io/color.
NodeColorAttributes(ColorScheme color)207 std::string NodeColorAttributes(ColorScheme color) {
208 NodeColors node_colors = NodeColorsForScheme(color);
209
210 return StrFormat(R"(style="%s", fontcolor="%s", color="%s", fillcolor="%s")",
211 node_colors.style, node_colors.font_color,
212 node_colors.stroke_color, node_colors.fill_color);
213 }
214
215 // Replaces <> with <>, so that this string is safe(er) for use in a
216 // graphviz HTML-like string.
HtmlLikeStringSanitize(absl::string_view s)217 std::string HtmlLikeStringSanitize(absl::string_view s) {
218 return absl::StrReplaceAll(s, {{"<", "<"}, {">", ">"}});
219 }
220
IsFusedBroadcastOfConstantEffectiveScalar(const HloInstruction * instr)221 bool IsFusedBroadcastOfConstantEffectiveScalar(const HloInstruction* instr) {
222 namespace m = match;
223 return instr->parent()->IsFusionComputation() &&
224 Match(instr, m::Broadcast(m::ConstantEffectiveScalar()));
225 }
226
227 // Tries to generates a human-readable one-word description of the given
228 // computation.
229 //
230 // Currently we support:
231 //
232 // "return param0 + param1;" --> "add"
233 // "return param0 * param1;" --> "multiply"
234 // "return min(param0, param1);" --> "min"
235 // "return max(param0, param1);" --> "max"
236 // "return xor(param0, param1);" --> "xor"
237 // "return and(param0, param1);" --> "and"
238 // "return or(param0, param1);" --> "or"
239 // "return param0 <= param1;" --> "less-or-equal"
240 // "return param0 >= param1;" --> "greater-or-equal"
241 // "return param0 > param1;" --> "greater-than"
242 // "return param0 < param1;" --> "less-than"
243 // "return param0 == param1;" --> "equal-to"
244 // "return param0 != param1;" --> "not-equal-to"
245 //
246 // where param0 and param1 are effective scalars. For the ops that are
247 // commutative, we also support them with param0 and param1 swapped.
248 //
249 // This is useful primarily for reduce and map nodes. These take a
250 // subcomputation which is almost always one of the above, and pattern matching
251 // it to a short string lets us tell the user what the subcomputation is without
252 // drawing it as a graph.
MatchTrivialComputation(const HloComputation * computation)253 optional<std::string> MatchTrivialComputation(
254 const HloComputation* computation) {
255 namespace m = match;
256
257 if (computation->instruction_count() != 3) {
258 return nullopt;
259 }
260 HloInstruction* root = computation->root_instruction();
261 const HloInstruction *param0, *param1;
262 if (!Match(root, m::Op()
263 .WithNumOperands(2)
264 .WithShape(m::Shape().IsEffectiveScalar())
265 .WithBinaryOperandsAnyOrder(
266 m::Parameter(¶m0, 0)
267 .WithShape(m::Shape().IsEffectiveScalar()),
268 m::Parameter(¶m1, 1)
269 .WithShape(m::Shape().IsEffectiveScalar())))) {
270 return nullopt;
271 }
272
273 // If the params are reversed (i.e. operand0 is param1 and operand1 is
274 // param0), check that the operation being performed is commutative.
275 if (root->operand(0) == param1) {
276 CHECK_EQ(root->operand(1), param0);
277 if (root->opcode() == HloOpcode()) {
278 switch (root->comparison_direction()) {
279 case ComparisonDirection::kLe:
280 case ComparisonDirection::kGe:
281 case ComparisonDirection::kGt:
282 case ComparisonDirection::kLt:
283 return nullopt;
284 default:
285 break;
286 }
287 }
288 }
289
290 // If we recognize the root's opcode, we've successfully pattern-matched!
291 switch (root->opcode()) {
292 case HloOpcode::kAdd:
293 return "add";
294 case HloOpcode::kMultiply:
295 return "multiply";
296 case HloOpcode::kMinimum:
297 return "min";
298 case HloOpcode::kMaximum:
299 return "max";
300 case HloOpcode::kXor:
301 return "xor";
302 case HloOpcode::kAnd:
303 return "and";
304 case HloOpcode::kOr:
305 return "or";
306 case HloOpcode::kCompare: {
307 switch (root->comparison_direction()) {
308 case ComparisonDirection::kLe:
309 return "less-or-equal";
310 case ComparisonDirection::kGe:
311 return "greater-or-equal";
312 case ComparisonDirection::kGt:
313 return "greater-than";
314 case ComparisonDirection::kLt:
315 return "less-than";
316 case ComparisonDirection::kEq:
317 return "equal-to";
318 case ComparisonDirection::kNe:
319 return "not-equal-to";
320 }
321 }
322 default:
323 return nullopt;
324 }
325 }
326
327 // Encapsulates logic for dumping an HLO module to DOT (i.e. graphviz syntax).
328 class HloDotDumper {
329 public:
HloDotDumper(const HloComputation * computation,absl::string_view label,const DebugOptions & debug_options,HloRenderOptions hlo_render_options,const HloExecutionProfile * profile,NodeFilter filter)330 HloDotDumper(const HloComputation* computation, absl::string_view label,
331 const DebugOptions& debug_options,
332 HloRenderOptions hlo_render_options,
333 const HloExecutionProfile* profile, NodeFilter filter)
334 : computation_(computation),
335 label_(label),
336 debug_options_(debug_options),
337 hlo_render_options_(hlo_render_options),
338 profile_(profile),
339 filter_(std::move(filter)) {}
340
341 std::string Dump();
342
343 // Returns a CSS id assigned to the instruction, if that exists.
CssIdForInstruction(const HloInstruction & instr)344 std::optional<std::string> CssIdForInstruction(const HloInstruction& instr) {
345 if (instr.opcode() == HloOpcode::kFusion) {
346 // For fusion we render it as a subcomputation.
347 auto it = cluster_ids_.find(instr.called_computations()[0]);
348 if (it == cluster_ids_.end()) {
349 return std::nullopt;
350 }
351 return StrCat("#a_clust", it->second, " path");
352 }
353 auto it = node_ids_.find(&instr);
354 if (it == node_ids_.end()) {
355 return std::nullopt;
356 }
357 return StrCat("#node", it->second, " polygon");
358 }
359
360 private:
361 // Returns the dot graph identifier for the given instruction.
InstructionId(const HloInstruction * instruction)362 std::string InstructionId(const HloInstruction* instruction) {
363 return StrCat(reinterpret_cast<uint64_t>(instruction));
364 }
365
366 // Returns the dot graph identifier for the given computation.
SubcomputationId(const HloComputation * computation)367 std::string SubcomputationId(const HloComputation* computation) {
368 return StrCat("cluster_", reinterpret_cast<uint64_t>(computation));
369 }
370
371 // Generates graph header/footer. These should be called *after* dumping all
372 // of the instructions and subcomputations for the graph, as they both use
373 // data generated while dumping the graph.
374 std::string Header();
375 std::string Footer();
376
377 bool ShouldShowSubcomputation(const HloComputation* subcomp);
378 bool ShouldShowFusionSubcomputation(const HloInstruction* instr);
379
380 // We omit some nodes from the graph, instead drawing them inlined into the
381 // nodes that use them.
382 bool ShouldMergeIntoUsers(const HloInstruction* instr) const;
383
384 std::string DumpSubcomputation(const HloComputation* subcomp,
385 const HloInstruction* parent_instr);
386 std::string DumpComputation(const HloComputation* comp);
387 std::string DumpRootTag();
388 std::string DumpInstruction(const HloInstruction* instr);
389 ColorScheme GetInstructionColor(const HloInstruction* instr);
390 std::string GetInstructionNodeShape(const HloInstruction* instr);
391 std::string GetInstructionNodeLabel(const HloInstruction* instr);
392 std::string GetInstructionNodeMetadata(const HloInstruction* instr);
393 std::string GetInstructionNodeBackendConfig(const HloInstruction* instr);
394 std::string GetInstructionNodeExtraInfo(const HloInstruction* instr);
395 std::string GetInstructionNodeInlinedOperands(const HloInstruction* instr);
396 void AddInstructionIncomingEdges(const HloInstruction* instr);
397
398 // For most instructions, GetNodeForEdge(instr) returns instr.
399 //
400 // The exception is fusion nodes. For these, we walk up the chain of nested
401 // fusion nodes starting at instr until we reach a node that either (a) isn't
402 // a fusion node, or (b) is a fusion node for which
403 // ShouldShowFusionSubcomputation is false.
404 //
405 // We do this because fusion nodes are expanded inline -- if
406 // ShouldShowFusionSubcomputation is true, the fusion node won't be present in
407 // the graph.
408 //
409 // In general when you want to draw an edge from A to B, you should actually
410 // draw an edge from GetNodeForEdge(A).
411 const HloInstruction* GetNodeForEdge(const HloInstruction* instr);
412
413 // If instr has just one computation and it's trivial (e.g. "return param0 +
414 // param1"), returns a string you can put into the node's body that names the
415 // subcomputation, e.g. "Subcomputation: <b>add</b>".
416 std::string GetInstructionTrivialComputationStr(const HloInstruction* instr);
417
418 const HloComputation* computation_; // never null
419 const std::string label_; // overall name for the graph
420 const DebugOptions& debug_options_;
421 const HloRenderOptions hlo_render_options_;
422 const HloExecutionProfile* profile_; // may be null
423 const NodeFilter filter_;
424
425 // Each HloInstruction dumped gets a monotonically-increasing node ID. This
426 // must start at 1, because that's where graphviz's accounting starts.
427 int64_t next_node_id_ = 1;
428 absl::flat_hash_map<const HloInstruction*, int64_t> node_ids_;
429
430 // The "root" tag doesn't have an associated HloInstruction pointer, so we
431 // need to store it outside the map.
432 int64_t root_node_id_;
433
434 // Each (from, to) edge gets a monotonically-increasing ID. This is a
435 // multimap because it's possible for the same edge to appear multiple times
436 // in the graph (e.g. x^2 may be represented as mul(x, x)).
437 int64_t next_edge_id_ = 1;
438 std::unordered_multimap<
439 std::pair<const HloInstruction*, const HloInstruction*>, int64_t,
440 absl::Hash<std::pair<const HloInstruction*, const HloInstruction*>>>
441 edge_ids_;
442
443 // Each HloComputation that's emitted gets a monotonically-increasing ID.
444 int64_t next_cluster_id_ = 1;
445 absl::flat_hash_map<const HloComputation*, int64_t> cluster_ids_;
446
447 // Edges to print from Footer(). Edges come at the end because graphviz is
448 // unhappy if an edge from a subcomputation to a node in the outer computation
449 // appears before both the inner computation and the destination node are
450 // defined.
451 std::vector<std::string> edges_;
452
453 // When coloring by sharding information, we track the sharding string
454 // representation to color association, by round-robin the color schemes.
455 absl::flat_hash_map<HloSharding, ColorScheme> sharding_colors_;
456 int64_t next_shard_color_ = 0;
457 };
458
Dump()459 std::string HloDotDumper::Dump() {
460 std::string body;
461 StrAppend(&body, DumpComputation(computation_));
462 StrAppend(&body, DumpRootTag());
463
464 // By contract, Header() and Footer() have to be called after we've dumped all
465 // our instructions, because they use state generated during that process.
466 std::string g = Header();
467 StrAppend(&g, body);
468 StrAppend(&g, Footer());
469 return g;
470 }
471
Header()472 std::string HloDotDumper::Header() {
473 constexpr char fmt[] = R"(digraph G {
474 rankdir = TB;
475 compound = true;
476 label = <<b>%s</b>>;
477 labelloc = t;
478 // Disable the tooltip. Interestingly, "" doesn't work!
479 tooltip = " ";
480 // DOT graphs accept a stylesheet as a URI. So naturally, an inline
481 // stylesheet is a data URI!
482 stylesheet=<
483 data:text/css,
484 @import url(https://fonts.googleapis.com/css?family=Roboto:400,700);
485 svg text {
486 font-family: 'Roboto';
487 font-size: 12px;
488 }
489
490 %s
491 >
492
493 )";
494
495 VLOG(3) << "Generating Header";
496
497 std::string graph_label =
498 StrCat(label_, "<br/>Computation ", computation_->name());
499 if (computation_->IsFusionComputation()) {
500 StrAppend(&graph_label, " (in fusion instruction ",
501 computation_->FusionInstruction()->name(), ")");
502 }
503 if (profile_ != nullptr) {
504 auto cycles = profile_->total_cycles_executed(*computation_);
505 absl::StrAppendFormat(&graph_label, "<br/>total cycles = %d (%s)", cycles,
506 tensorflow::strings::HumanReadableNum(cycles));
507 }
508
509 // Create CSS rules that say, when you hover over the given node or cluster,
510 // turn the given edge the given color.
511 //
512 // We rely on a few properties of how graphviz generates SVGs:
513 //
514 // - Nodes are named "nodeN", where N corresponds to the 1-based index of
515 // the node in our DOT (i.e. the first node in the DOT is "node1", etc.).
516 // Edges are similarly named "edgeN", and clusters are named "clustN".
517 // - Nodes come before their in- and out-edges in the SVG. We need this
518 // because the "X ~ Y" CSS selector finds a sibling of X that *comes
519 // after X in the DOM* and matches Y.
520 std::vector<std::string> edge_css_rules;
521 const char* kBlue = "#1976d2";
522 const char* kRed = "#d32f2f";
523 for (const auto& kv : edge_ids_) {
524 const HloInstruction* from_node = kv.first.first;
525 const HloInstruction* to_node = kv.first.second;
526 int64_t edge_id = kv.second;
527
528 auto add_hover_css_rule = [&](std::string elem_type, int64_t elem_id,
529 const char* color) {
530 // One could imagine other ways of writing this CSS rule that involve
531 // less duplication, but this way seems to be relatively performant.
532 edge_css_rules.push_back(
533 StrFormat(" #%s%d:hover ~ #edge%d text { fill: %s; }\n"
534 " #%s%d:hover ~ #edge%d path { "
535 "stroke: %s; stroke-width: .2em; }\n"
536 " #%s%d:hover ~ #edge%d polygon { "
537 "fill: %s; stroke: %s; stroke-width: .2em; }\n",
538 elem_type, elem_id, edge_id, color, //
539 elem_type, elem_id, edge_id, color, //
540 elem_type, elem_id, edge_id, color, color));
541 };
542
543 // The "to_node" value may be a NULL, indicating that this points to the
544 // "root" tag rather than a normal node.
545 int64_t from_node_id =
546 tensorflow::gtl::FindWithDefault(node_ids_, from_node, -1);
547 if (from_node_id == -1) {
548 LOG(FATAL) << from_node->name() << " was added to edges but not to nodes";
549 }
550 int64_t to_node_id =
551 to_node ? tensorflow::gtl::FindWithDefault(node_ids_, to_node, -1)
552 : root_node_id_;
553 if (to_node != nullptr && to_node_id == -1) {
554 LOG(FATAL) << to_node->name() << " was added to edges but not to nodes";
555 }
556
557 add_hover_css_rule("node", from_node_id, kBlue);
558 add_hover_css_rule("node", to_node_id, kRed);
559
560 if (to_node) {
561 VLOG(3) << "Adding css for edge " << edge_id << " from node "
562 << from_node->name() << " to node " << to_node->name();
563 } else {
564 VLOG(3) << "Adding css for edge " << edge_id << " from node "
565 << from_node->name() << " to root tag";
566 }
567
568 // If this edge crosses a fusion cluster boundary, highlight it when the
569 // cluster is hovered over.
570 if (to_node) {
571 if (from_node->IsFused() &&
572 from_node->parent()->root_instruction() == from_node) {
573 int64_t cluster_id = cluster_ids_.at(from_node->parent());
574 add_hover_css_rule("clust", cluster_id, kBlue);
575 }
576 if (to_node->IsFused() && to_node->opcode() == HloOpcode::kParameter) {
577 int64_t cluster_id = cluster_ids_.at(to_node->parent());
578 add_hover_css_rule("clust", cluster_id, kRed);
579 }
580 }
581 }
582
583 // Browsers require that we URI-encode the contents of our data URI. (It
584 // seems this was a relatively recent change?) In practice, this means that we
585 // need to escape '#'.
586 return StrFormat(
587 fmt, graph_label,
588 absl::StrReplaceAll(StrJoin(edge_css_rules, "\n"), {{"#", "%23"}}));
589 }
590
Footer()591 std::string HloDotDumper::Footer() {
592 return StrCat(StrJoin(edges_, "\n"), "\n}");
593 }
594
ShouldShowFusionSubcomputation(const HloInstruction * instr)595 bool HloDotDumper::ShouldShowFusionSubcomputation(const HloInstruction* instr) {
596 CHECK_EQ(instr->opcode(), HloOpcode::kFusion);
597 return ShouldShowSubcomputation(instr->fused_instructions_computation());
598 }
599
ShouldShowSubcomputation(const HloComputation * subcomp)600 bool HloDotDumper::ShouldShowSubcomputation(const HloComputation* subcomp) {
601 if (subcomp->IsFusionComputation()) {
602 const HloInstruction* fusion = subcomp->FusionInstruction();
603 if (!filter_.Show(fusion) || filter_.SomeOrAllOperandsOmitted(fusion) ||
604 !hlo_render_options_.show_fusion_subcomputations) {
605 return false;
606 }
607 }
608
609 // Don't show trivial subcomputations on non-fusion nodes -- these are inlined
610 // into the graph.
611 if (!subcomp->IsFusionComputation() && MatchTrivialComputation(subcomp)) {
612 return false;
613 }
614
615 // Show the subcomputation if we're showing any of its members.
616 return absl::c_any_of(
617 subcomp->instructions(),
618 [&](const HloInstruction* instr) { return filter_.Show(instr); });
619 }
620
DumpSubcomputation(const HloComputation * subcomp,const HloInstruction * parent_instr)621 std::string HloDotDumper::DumpSubcomputation(
622 const HloComputation* subcomp, const HloInstruction* parent_instr) {
623 VLOG(2) << "Dumping subcomputation " << subcomp->name();
624 // Add an edge from the subcomputation to its parent node. If subcomp
625 // belongs to a fusion node, it's drawn in place of the fusion instruction,
626 // so there's no need to link those.
627 if (parent_instr->opcode() != HloOpcode::kFusion) {
628 const HloInstruction* from = GetNodeForEdge(subcomp->root_instruction());
629 VLOG(2) << "Edge: from " << from->name() << " to " << parent_instr->name()
630 << " as " << next_edge_id_;
631 edge_ids_.insert({{from, parent_instr}, next_edge_id_++});
632 constexpr char edge_fmt[] =
633 R"(%s -> %s [ltail="%s", style="dashed" tooltip="%s -> %s"];)";
634 edges_.push_back(StrFormat(
635 edge_fmt, InstructionId(from), InstructionId(parent_instr),
636 SubcomputationId(subcomp), subcomp->name(), parent_instr->name()));
637 }
638
639 // Have we already dumped this subcomputation? If so, generating the edge
640 // linking it and parent_instr is all we want to do in this function.
641 if (cluster_ids_.find(subcomp) != cluster_ids_.end()) {
642 return "";
643 }
644
645 cluster_ids_[subcomp] = next_cluster_id_++;
646
647 std::string id = SubcomputationId(subcomp);
648
649 std::string subcomp_label, style;
650 if (parent_instr->opcode() == HloOpcode::kFusion) {
651 subcomp_label =
652 StrFormat("Fused expression for <b>%s</b><br/>%s",
653 HtmlLikeStringSanitize(parent_instr->name()),
654 HtmlLikeStringSanitize(parent_instr->ToCategory()));
655 std::string extra_info = GetInstructionNodeExtraInfo(parent_instr);
656 if (!extra_info.empty()) {
657 StrAppend(&subcomp_label, "<br/>", extra_info);
658 }
659 std::string node_backend_config =
660 GetInstructionNodeBackendConfig(parent_instr);
661 if (!node_backend_config.empty()) {
662 StrAppend(&subcomp_label, "<br/>", node_backend_config);
663 }
664
665 bool highlight = filter_.Highlight(parent_instr);
666 const char* fillcolor;
667 const char* strokecolor;
668 if (debug_options_.xla_hlo_graph_sharding_color() && !highlight) {
669 // Use the sharding color, if the node isn't highlighted.
670 NodeColors node_colors =
671 NodeColorsForScheme(GetInstructionColor(parent_instr));
672 fillcolor = node_colors.fill_color;
673 strokecolor = node_colors.stroke_color;
674 } else {
675 // Subcomputation's fill/stroke color is light/dark red/gray, depending on
676 // whether or not the subcomputation's fusion node is highlighted.
677 fillcolor = highlight ? "#ffcdd2" : "#f5f5f5";
678 strokecolor = highlight ? "#b71c1c" : "#c2c2c2";
679 }
680 style =
681 StrFormat(R"(style="rounded,filled,bold"; fillcolor="%s"; color="%s;")",
682 fillcolor, strokecolor);
683 } else {
684 subcomp_label = StrFormat("Subcomputation for <b>%s</b><br/>%s",
685 HtmlLikeStringSanitize(parent_instr->name()),
686 HtmlLikeStringSanitize(subcomp->name()));
687 style = "style=rounded; color=black;";
688 }
689
690 std::string comp_body = DumpComputation(subcomp);
691
692 constexpr char computation_fmt[] = R"(subgraph %s {
693 %s
694 label = <%s>;
695 labelloc = t;
696 tooltip = " ";
697 %s
698 } // %s
699
700 )";
701 return StrFormat(computation_fmt, id, style, subcomp_label, comp_body, id);
702 }
703
DumpComputation(const HloComputation * comp)704 std::string HloDotDumper::DumpComputation(const HloComputation* comp) {
705 std::string g;
706 for (const auto* instr : comp->instructions()) {
707 if (!filter_.Show(instr)) {
708 continue;
709 }
710
711 // Dump subcomputations within instr.
712 for (const HloComputation* subcomp : instr->called_computations()) {
713 if (ShouldShowSubcomputation(subcomp)) {
714 StrAppend(&g, DumpSubcomputation(subcomp, instr));
715 }
716 }
717
718 StrAppend(&g, DumpInstruction(instr));
719 }
720 return g;
721 }
722
DumpRootTag()723 std::string HloDotDumper::DumpRootTag() {
724 const HloInstruction* from = GetNodeForEdge(computation_->root_instruction());
725
726 // We didn't display constants or broadcasts of effective scalars within
727 // fusions as separate nodes; so if the root is a constant/broadcast of
728 // scalar, we don't add root tag or edge for it.
729 if (!filter_.Show(from) || from->opcode() == HloOpcode::kConstant ||
730 IsFusedBroadcastOfConstantEffectiveScalar(from)) {
731 return "";
732 }
733
734 auto from_id = InstructionId(from);
735
736 // The ID of the root computation is otherwise unused, so it makes a good ID
737 // to use for the root-tag node. However, the edge_ids_ map requires a
738 // HloInstruction* pointer for the 'to' value, so we use a NULL value there
739 // (rather than a pointer type-cast) to make it obvious if it is erroneously
740 // dereferenced.
741 HloInstruction* to = nullptr;
742 auto to_id = SubcomputationId(computation_);
743
744 std::string node_body = "ROOT";
745 std::string node_shape = "circle";
746 ColorScheme color = kBrown;
747
748 VLOG(2) << "Adding root tag as node " << next_node_id_;
749 root_node_id_ = next_node_id_++;
750
751 VLOG(2) << "Adding edge from " << from->name() << " to root tag as "
752 << next_edge_id_;
753 edge_ids_.insert({{from, to}, next_edge_id_++});
754 edges_.push_back(StrFormat(R"(%s -> %s [tooltip=" "];)", from_id, to_id));
755
756 return StrFormat(R"(%s [label=<%s>, shape=%s, tooltip=" ", %s];)"
757 "\n",
758 to_id, node_body, node_shape, NodeColorAttributes(color));
759 }
760
TryGetFusionParameterConstant(const HloInstruction * instr)761 static const HloConstantInstruction* TryGetFusionParameterConstant(
762 const HloInstruction* instr) {
763 if (instr->opcode() != HloOpcode::kParameter || !instr->IsFused()) {
764 return nullptr;
765 }
766 const HloInstruction* fusion = instr->parent()->FusionInstruction();
767 const HloInstruction* operand = fusion->operand(instr->parameter_number());
768 return DynCast<HloConstantInstruction>(operand);
769 }
770
ShouldMergeIntoUsers(const HloInstruction * instr) const771 bool HloDotDumper::ShouldMergeIntoUsers(const HloInstruction* instr) const {
772 // If a node:
773 //
774 // - is a get-tuple-element that isn't the root of the computation, or
775 // - is a parameter of a fusion node which is bound to a constant, or
776 // - all of:
777 // - is a tuple-shaped parameter, and
778 // - is not a parameter to a fusion node, and
779 // - has at least kMinUsersToOmit users shown, and
780 // - all of the shown users are get-tuple-elements,
781 //
782 // then we omit it from the graph, merging it with its users.
783 //
784 // This helps us handle the common case where a while loop body has one big
785 // tuple-shaped parameter.
786 if ((instr->opcode() == HloOpcode::kGetTupleElement &&
787 instr != instr->parent()->root_instruction()) ||
788 TryGetFusionParameterConstant(instr) != nullptr) {
789 return true;
790 }
791 const int kMinUsersToOmit = 3;
792 return instr->opcode() == HloOpcode::kParameter && instr->shape().IsTuple() &&
793 !instr->IsFused() &&
794 absl::c_count_if(instr->users(),
795 [&](const HloInstruction* user) {
796 return filter_.Show(user);
797 }) > kMinUsersToOmit &&
798 absl::c_all_of(instr->users(), [&](const HloInstruction* user) {
799 return !filter_.Show(user) ||
800 user->opcode() == HloOpcode::kGetTupleElement;
801 });
802 }
803
DumpInstruction(const HloInstruction * instr)804 std::string HloDotDumper::DumpInstruction(const HloInstruction* instr) {
805 // We don't display constants or broadcasts of effective scalar constants
806 // within fusions as separate nodes; they're merged into their users.
807 if ((instr->opcode() == HloOpcode::kConstant ||
808 IsFusedBroadcastOfConstantEffectiveScalar(instr)) &&
809 instr != instr->parent()->root_instruction()) {
810 return "";
811 }
812 // Skip this node if it's merged into its users.
813 if (ShouldMergeIntoUsers(instr)) {
814 return "";
815 }
816 // Omit the fusion node if its subcomputation is drawn, since the
817 // subcomputation will be drawn inline.
818 if (instr->opcode() == HloOpcode::kFusion &&
819 ShouldShowFusionSubcomputation(instr)) {
820 return "";
821 }
822
823 VLOG(2) << "Adding node " << instr->name() << " as " << next_node_id_;
824 node_ids_[instr] = next_node_id_++;
825
826 ColorScheme color = GetInstructionColor(instr);
827 std::string node_shape = GetInstructionNodeShape(instr);
828 std::string node_label = GetInstructionNodeLabel(instr);
829 std::string node_metadata = GetInstructionNodeMetadata(instr);
830 std::string node_backend_config = GetInstructionNodeBackendConfig(instr);
831 std::string extra_info = GetInstructionNodeExtraInfo(instr);
832 std::string inlined_constants = GetInstructionNodeInlinedOperands(instr);
833 std::string trivial_subcomputation =
834 GetInstructionTrivialComputationStr(instr);
835 AddInstructionIncomingEdges(instr);
836
837 if (!debug_options_.xla_hlo_graph_sharding_color()) {
838 // Override the node's styling if it should be (de-)emphasized.
839 if (filter_.Deemphasized(instr)) {
840 color = kDashedBorder;
841 }
842 if (filter_.Highlight(instr)) {
843 node_shape = "diamond";
844 color = kDarkRed;
845 }
846 }
847 // Build the text that will be displayed inside the node.
848 std::string node_body = node_label;
849 for (const std::string& s : {trivial_subcomputation, extra_info,
850 inlined_constants, node_backend_config}) {
851 if (!s.empty()) {
852 StrAppend(&node_body, "<br/>", s);
853 }
854 }
855
856 return StrFormat(R"(%s [label=<%s>, shape=%s, tooltip="%s", %s];)"
857 "\n",
858 InstructionId(instr), node_body, node_shape, node_metadata,
859 NodeColorAttributes(color));
860 }
861
GetInstructionNodeInlinedOperands(const HloInstruction * instr)862 std::string HloDotDumper::GetInstructionNodeInlinedOperands(
863 const HloInstruction* instr) {
864 // The constant's shape is a parameter because, in the case of a broadcasted
865 // scalar constant, we want to show the broadcasted shape, not the constant's
866 // scalar shape.
867 auto stringify_constant = [](const HloConstantInstruction* constant,
868 const Shape& shape) {
869 // If the shape has a dimension of size zero, print it as e.g.
870 // "{} (f32[42, 0, 10])". The alternative, calling Literal::ToString(),
871 // enumerates all of its empty dimensions (e.g. "{ { {}, {} }, ..."), which
872 // is just noise.
873 if (ShapeUtil::IsZeroElementArray(shape)) {
874 return StrFormat("{} (%s)", ShapeUtil::HumanString(constant->shape()));
875 }
876
877 // Print the literal value of constants with <= K elements. Note that we
878 // use `constant->shape()` rather than `shape`, because if `constant` is a
879 // scalar that's broadcasted into `shape`, we want to print the constant.
880 optional<int64_t> elem_count;
881 if (shape.IsArray()) {
882 elem_count = ShapeUtil::ElementsIn(constant->shape());
883 }
884 // Allow HloDotDumper to print HloInstruction reconstructed from HloProto
885 // collected from profiling tools. Those constants may not have a valid
886 // literal.
887 if (elem_count.has_value() && *elem_count <= 8 && constant->HasLiteral()) {
888 // In addition to our check that the constant doesn't have too many
889 // elements, also check that the stringified constant isn't too long. For
890 // example, 8 small ints is okay, but 8 long floats takes up a lot of
891 // horizontal space and probably isn't interesting.
892 std::string literal_str = constant->literal().ToStringWithoutShape();
893 if (literal_str.size() <= 64) {
894 return StrFormat("%s %s", shape.ToString(), literal_str);
895 }
896 }
897
898 // Otherwise, print e.g. "%constant.42 (s32[100])".
899 std::string constant_name;
900 if (absl::StartsWith(constant->name(), "constant")) {
901 constant_name = constant->name();
902 } else {
903 constant_name = StrCat("constant ", constant->name());
904 }
905 return StrFormat("%s %s", constant_name, ShapeUtil::HumanString(shape));
906 };
907
908 std::vector<std::string> lines;
909 for (int64_t i = 0; i < instr->operand_count(); ++i) {
910 const HloInstruction* operand = instr->operand(i);
911 optional<std::string> operand_str;
912 if (const auto* constant_operand =
913 DynCast<HloConstantInstruction>(operand)) {
914 operand_str =
915 stringify_constant(constant_operand, constant_operand->shape());
916 } else if (IsFusedBroadcastOfConstantEffectiveScalar(operand)) {
917 operand_str = stringify_constant(
918 Cast<HloConstantInstruction>(operand->operand(0)), operand->shape());
919 } else if (ShouldMergeIntoUsers(operand)) {
920 // Special case: If the operand is a parameter to a fusion node and it
921 // always has a constant value, display it like a regular constant.
922 //
923 // For other parameters, use the parameter number rather than the proper
924 // name, because that's generally how people think of the node.
925 if (operand->opcode() == HloOpcode::kParameter) {
926 if (const HloConstantInstruction* constant =
927 TryGetFusionParameterConstant(operand)) {
928 operand_str = stringify_constant(constant, constant->shape());
929 } else {
930 operand_str = StrFormat("Parameter %d", operand->parameter_number());
931 }
932 } else if (operand->opcode() == HloOpcode::kGetTupleElement) {
933 operand_str =
934 StrFormat("tuple-element %d of %s %s", operand->tuple_index(),
935 operand->operand(0)->name(),
936 ShapeUtil::HumanStringWithLayout(operand->shape()));
937 } else {
938 operand_str = operand->name();
939 }
940 }
941
942 if (operand_str) {
943 if (instr->operand_count() > 1) {
944 lines.push_back(StrFormat("<b>operand %d</b> = %s", i, *operand_str));
945 } else {
946 lines.push_back(StrFormat("<b>operand</b> = %s", *operand_str));
947 }
948 }
949 }
950
951 // Special case: fused parameter is fed from a get-tuple-element. If
952 // so, name the tuple index.
953 if (instr->opcode() == HloOpcode::kParameter && instr->IsFused()) {
954 const HloInstruction* param_input =
955 instr->parent()->FusionInstruction()->operand(
956 instr->parameter_number());
957 if (param_input->opcode() == HloOpcode::kGetTupleElement) {
958 lines.push_back(
959 StrFormat("tuple-element %d of %s %s", param_input->tuple_index(),
960 param_input->operand(0)->name(),
961 ShapeUtil::HumanStringWithLayout(param_input->shape())));
962 }
963 }
964
965 return StrJoin(lines, "<br/>");
966 }
967
GetInstructionColor(const HloInstruction * instr)968 ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) {
969 if (debug_options_.xla_hlo_graph_sharding_color()) {
970 if (!instr->has_sharding()) {
971 return kDashedBorder;
972 }
973 auto it = sharding_colors_.find(instr->sharding());
974 if (it != sharding_colors_.end()) {
975 return it->second;
976 }
977 ColorScheme color = static_cast<ColorScheme>(
978 kBlue + (next_shard_color_++ % (kDashedBorder - kBlue)));
979 sharding_colors_.emplace(instr->sharding(), color);
980 return color;
981 }
982
983 // Choose different weights of orange for small vs large parameters. This
984 // distinction is often important, especially in fusion nodes.
985 auto parameter_color = IsSmall(instr) ? kOrange : kDarkOrange;
986
987 // Special case: If this instruction has a parameter merged into it, paint it
988 // the same color as a parameter. Unless the merged-in parameter is a
989 // parameter to a fusion node that is bound to a constant -- these aren't
990 // "real" parameters from the user's perspective.
991 if (absl::c_any_of(instr->operands(), [&](const HloInstruction* operand) {
992 return operand->opcode() == HloOpcode::kParameter &&
993 ShouldMergeIntoUsers(operand) &&
994 TryGetFusionParameterConstant(operand) == nullptr;
995 })) {
996 return parameter_color;
997 }
998
999 // Pick different colors or shapes for instructions which are particularly
1000 // expensive (eg, dot) and those which are unusual in some way or unique
1001 // (eg, parameter).
1002 switch (instr->opcode()) {
1003 case HloOpcode::kAbs:
1004 case HloOpcode::kAdd:
1005 case HloOpcode::kAnd:
1006 case HloOpcode::kAtan2:
1007 case HloOpcode::kBitcastConvert:
1008 case HloOpcode::kCeil:
1009 case HloOpcode::kClamp:
1010 case HloOpcode::kClz:
1011 case HloOpcode::kCompare:
1012 case HloOpcode::kComplex:
1013 case HloOpcode::kConvert:
1014 case HloOpcode::kCos:
1015 case HloOpcode::kDivide:
1016 case HloOpcode::kExp:
1017 case HloOpcode::kExpm1:
1018 case HloOpcode::kFloor:
1019 case HloOpcode::kImag:
1020 case HloOpcode::kIota:
1021 case HloOpcode::kIsFinite:
1022 case HloOpcode::kLog:
1023 case HloOpcode::kLog1p:
1024 case HloOpcode::kMaximum:
1025 case HloOpcode::kMinimum:
1026 case HloOpcode::kMultiply:
1027 case HloOpcode::kNegate:
1028 case HloOpcode::kNot:
1029 case HloOpcode::kPopulationCount:
1030 case HloOpcode::kOr:
1031 case HloOpcode::kXor:
1032 case HloOpcode::kPower:
1033 case HloOpcode::kReal:
1034 case HloOpcode::kRemainder:
1035 case HloOpcode::kRng:
1036 case HloOpcode::kRngGetAndUpdateState:
1037 case HloOpcode::kRngBitGenerator:
1038 case HloOpcode::kRoundNearestAfz:
1039 case HloOpcode::kRoundNearestEven:
1040 case HloOpcode::kRsqrt:
1041 case HloOpcode::kSelect:
1042 case HloOpcode::kShiftLeft:
1043 case HloOpcode::kShiftRightArithmetic:
1044 case HloOpcode::kShiftRightLogical:
1045 case HloOpcode::kLogistic:
1046 case HloOpcode::kSign:
1047 case HloOpcode::kSin:
1048 case HloOpcode::kSlice:
1049 case HloOpcode::kSort:
1050 case HloOpcode::kSqrt:
1051 case HloOpcode::kCbrt:
1052 case HloOpcode::kSubtract:
1053 case HloOpcode::kTanh:
1054 // De-emphasize scalar-shaped elementwise ops -- they're generally
1055 // uninteresting.
1056 if (ShapeUtil::IsEffectiveScalar(instr->shape())) {
1057 return kWhite;
1058 }
1059 return kYellow;
1060 case HloOpcode::kBitcast:
1061 case HloOpcode::kGetTupleElement:
1062 case HloOpcode::kAfterAll:
1063 case HloOpcode::kAddDependency:
1064 case HloOpcode::kTuple:
1065 case HloOpcode::kOptimizationBarrier:
1066 return kWhite;
1067 case HloOpcode::kConstant:
1068 // Constants aren't usually shown as their own nodes, but they'll be
1069 // present if e.g. they're the root of a computation.
1070 return kWhite;
1071 case HloOpcode::kBroadcast:
1072 // De-emphasize nodes which broadcast a scalar within a fusion node --
1073 // these are essentially free.
1074 if (instr->IsFused() &&
1075 ShapeUtil::IsEffectiveScalar(instr->operand(0)->shape())) {
1076 return kWhite;
1077 }
1078 return kGreen;
1079 case HloOpcode::kConcatenate:
1080 case HloOpcode::kDynamicSlice:
1081 case HloOpcode::kGather:
1082 case HloOpcode::kPad:
1083 case HloOpcode::kReshape:
1084 case HloOpcode::kDynamicReshape:
1085 case HloOpcode::kReverse:
1086 case HloOpcode::kTranspose:
1087 // De-emphasize scalar-shaped data movement ops and all data movement ops
1088 // inside fusion nodes, both of which are essentially free.
1089 if (ShapeUtil::IsEffectiveScalar(instr->shape()) || instr->IsFused()) {
1090 return kWhite;
1091 }
1092 return kGreen;
1093 case HloOpcode::kDynamicUpdateSlice:
1094 // Unlike the data-movement ops above, dynamic-update-slice is not ~free
1095 // inside of fusion nodes, so we de-emphasize it only if it's
1096 // scalar-shaped.
1097 if (ShapeUtil::IsEffectiveScalar(instr->shape())) {
1098 return kWhite;
1099 }
1100 return kGreen;
1101 case HloOpcode::kCopy:
1102 case HloOpcode::kCopyStart:
1103 case HloOpcode::kCopyDone:
1104 // Emphasize copy nodes, which are either physical transposes (and thus
1105 // significant), or copies of read-only buffers (and thus dead weight).
1106 return kGreen;
1107 case HloOpcode::kAsyncStart:
1108 case HloOpcode::kAsyncUpdate:
1109 case HloOpcode::kAsyncDone:
1110 return GetInstructionColor(instr->async_wrapped_instruction());
1111 case HloOpcode::kConvolution:
1112 case HloOpcode::kDot:
1113 case HloOpcode::kFft:
1114 case HloOpcode::kTriangularSolve:
1115 case HloOpcode::kCholesky:
1116 return kDarkBlue;
1117 case HloOpcode::kReducePrecision:
1118 return kRed;
1119 case HloOpcode::kParameter:
1120 return parameter_color;
1121 case HloOpcode::kBatchNormGrad:
1122 case HloOpcode::kBatchNormInference:
1123 case HloOpcode::kBatchNormTraining:
1124 case HloOpcode::kReduce:
1125 case HloOpcode::kReduceWindow:
1126 case HloOpcode::kScatter: // scatter is a kind of reduction
1127 case HloOpcode::kSelectAndScatter:
1128 return kPurple;
1129 case HloOpcode::kDomain:
1130 case HloOpcode::kFusion:
1131 case HloOpcode::kMap:
1132 case HloOpcode::kGetDimensionSize:
1133 case HloOpcode::kSetDimensionSize:
1134 return kGray;
1135 case HloOpcode::kAllGather:
1136 case HloOpcode::kAllGatherStart:
1137 case HloOpcode::kAllGatherDone:
1138 case HloOpcode::kAllReduce:
1139 case HloOpcode::kReduceScatter:
1140 case HloOpcode::kAllReduceStart:
1141 case HloOpcode::kAllReduceDone:
1142 case HloOpcode::kAllToAll:
1143 case HloOpcode::kCollectivePermute:
1144 case HloOpcode::kCollectivePermuteStart:
1145 case HloOpcode::kCollectivePermuteDone:
1146 case HloOpcode::kInfeed:
1147 case HloOpcode::kOutfeed:
1148 case HloOpcode::kPartitionId:
1149 case HloOpcode::kRecv:
1150 case HloOpcode::kRecvDone:
1151 case HloOpcode::kSend:
1152 case HloOpcode::kSendDone:
1153 case HloOpcode::kReplicaId:
1154 return kBrown;
1155 case HloOpcode::kCall:
1156 case HloOpcode::kConditional:
1157 case HloOpcode::kCustomCall:
1158 case HloOpcode::kWhile:
1159 return kDarkGreen;
1160 }
1161 }
1162
GetInstructionNodeShape(const HloInstruction * instr)1163 std::string HloDotDumper::GetInstructionNodeShape(const HloInstruction* instr) {
1164 // Give while loops a different shape so they're easier to pick out.
1165 switch (instr->opcode()) {
1166 case HloOpcode::kWhile:
1167 return "ellipse";
1168 default:
1169 return "rect";
1170 }
1171 }
1172
GetInstructionNodeLabel(const HloInstruction * instr)1173 std::string HloDotDumper::GetInstructionNodeLabel(const HloInstruction* instr) {
1174 // If we have a parameter, put the param number in the name.
1175 if (instr->opcode() == HloOpcode::kParameter) {
1176 return StrFormat("<b>Parameter %d</b>", instr->parameter_number());
1177 }
1178
1179 // The HLO instruction name contains usually the opcode, e.g. "%add.42" is
1180 // an add instruction. In this case we render just the name.
1181 if (absl::StartsWith(instr->name(), HloOpcodeString(instr->opcode()))) {
1182 return StrFormat("<b>%s</b>", HtmlLikeStringSanitize(instr->name()));
1183 }
1184 std::string extended_opcode =
1185 StrCat(HloOpcodeString(instr->opcode()),
1186 instr->opcode() != HloOpcode::kFusion
1187 ? ""
1188 : StrCat(":", xla::ToString(instr->fusion_kind())));
1189 // If the name does not contain the opcode, render both.
1190 return StrFormat("<b>%s</b><br/>%s", HtmlLikeStringSanitize(extended_opcode),
1191 HtmlLikeStringSanitize(instr->name()));
1192 }
1193
GetInstructionNodeMetadata(const HloInstruction * instr)1194 std::string HloDotDumper::GetInstructionNodeMetadata(
1195 const HloInstruction* instr) {
1196 std::vector<std::string> lines;
1197 if (!instr->metadata().op_name().empty()) {
1198 lines.push_back(HtmlLikeStringSanitize(instr->metadata().op_name()));
1199 }
1200 if (!instr->metadata().op_type().empty()) {
1201 lines.push_back(StrFormat(
1202 "op_type: %s", HtmlLikeStringSanitize(instr->metadata().op_type())));
1203 }
1204 if (!instr->metadata().source_file().empty() &&
1205 instr->metadata().source_line() != 0) {
1206 lines.push_back(StrFormat("source: %s:%d", instr->metadata().source_file(),
1207 instr->metadata().source_line()));
1208 }
1209
1210 return StrJoin(lines, "\n");
1211 }
1212
1213 static std::vector<std::pair<std::string, std::string>>
ExtractCudnnConvBackendConfigProps(const gpu::CudnnConvBackendConfig & config)1214 ExtractCudnnConvBackendConfigProps(const gpu::CudnnConvBackendConfig& config) {
1215 std::vector<std::pair<std::string, std::string>> props;
1216 if (config.conv_result_scale() != 1) {
1217 props.emplace_back("conv_result_scale", StrCat(config.conv_result_scale()));
1218 }
1219 if (config.side_input_scale() != 0 && config.side_input_scale() != 1) {
1220 props.emplace_back("side_input_scale", StrCat(config.side_input_scale()));
1221 }
1222 props.emplace_back(
1223 "activation_mode",
1224 se::dnn::ActivationModeString(
1225 static_cast<se::dnn::ActivationMode>(config.activation_mode())));
1226
1227 props.emplace_back("algo",
1228 se::dnn::AlgorithmDesc(config.algorithm()).ToString());
1229
1230 // Skip workspace size; it's already explicit in the graph in the output shape
1231 // of the conv.
1232 return props;
1233 }
1234
1235 static std::vector<std::pair<std::string, std::string>>
ExtractGemmBackendConfigProps(const gpu::GemmBackendConfig & config,const HloInstruction * instr)1236 ExtractGemmBackendConfigProps(const gpu::GemmBackendConfig& config,
1237 const HloInstruction* instr) {
1238 std::vector<std::pair<std::string, std::string>> props;
1239 if (primitive_util::IsComplexType(instr->shape().element_type())) {
1240 if (config.alpha_real() != 1 || config.alpha_imag() != 1) {
1241 props.emplace_back("alpha_real", StrCat(config.alpha_real()));
1242 props.emplace_back("alpha_imag", StrCat(config.alpha_real()));
1243 }
1244 } else {
1245 if (config.alpha_real() != 1) {
1246 props.emplace_back("alpha", StrCat(config.alpha_real()));
1247 }
1248 }
1249 if (config.beta() != 0 && config.beta() != 1) {
1250 props.emplace_back("beta", StrCat(config.beta()));
1251 }
1252 props.emplace_back(
1253 "", absl::StrReplaceAll(
1254 DotDimensionNumbersToString(config.dot_dimension_numbers()),
1255 {{", ", "<br/>"}}));
1256 if (config.algorithm_case() == gpu::GemmBackendConfig::kSelectedAlgorithm) {
1257 props.emplace_back("algorithm", StrCat(config.selected_algorithm()));
1258 }
1259 return props;
1260 }
1261
GetInstructionNodeBackendConfig(const HloInstruction * instr)1262 std::string HloDotDumper::GetInstructionNodeBackendConfig(
1263 const HloInstruction* instr) {
1264 // custom-calls for convs and gemms have backend-configs with fields that are
1265 // semantically significant. Print these configs unconditionally.
1266 //
1267 // (We could elide the semantically-insignificant fields when
1268 // !show_backend_config, but this is simpler, and it's not too noisy.)
1269 std::vector<std::pair<std::string, std::string>> props;
1270 if (gpu::IsCustomCallToDnnConvolution(*instr)) {
1271 StatusOr<gpu::CudnnConvBackendConfig> config =
1272 instr->backend_config<gpu::CudnnConvBackendConfig>();
1273 if (config.ok()) {
1274 props = ExtractCudnnConvBackendConfigProps(*config);
1275 }
1276 } else if (instr->IsCustomCall(gpu::kGemmCallTarget)) {
1277 StatusOr<gpu::GemmBackendConfig> config =
1278 instr->backend_config<gpu::GemmBackendConfig>();
1279 if (config.ok()) {
1280 // gemm strides are generally uninteresting (derived from the instruction
1281 // shape), so we hide them by default.
1282 props = ExtractGemmBackendConfigProps(*config, instr);
1283 }
1284 }
1285
1286 if (!props.empty()) {
1287 // Put a linebreak before the backend-config properties if there's more than
1288 // one. Makes it easier to see.
1289 return StrCat((props.size() > 1 ? "<br/>" : ""),
1290 StrJoin(props, "<br/>",
1291 [](std::string* out,
1292 const std::pair<std::string, std::string>& kv) {
1293 if (!kv.first.empty()) {
1294 return StrAppend(out, kv.first, "=", kv.second);
1295 }
1296 StrAppend(out, kv.second);
1297 }));
1298 }
1299
1300 if (!hlo_render_options_.show_backend_config ||
1301 instr->raw_backend_config_string().empty()) {
1302 return "";
1303 }
1304
1305 return StrCat("backend_config=\"", instr->raw_backend_config_string(), "\"");
1306 }
1307
GetInstructionNodeExtraInfo(const HloInstruction * instr)1308 std::string HloDotDumper::GetInstructionNodeExtraInfo(
1309 const HloInstruction* instr) {
1310 std::vector<std::string> lines;
1311
1312 // Get the instruction's extra attributes excluding the names of its
1313 // subcomputations, since those are drawn explicitly in the graph.
1314 for (const auto& line : instr->ExtraAttributesToString(
1315 HloPrintOptions().set_print_subcomputation_mode(
1316 HloPrintOptions::PrintSubcomputationMode::kOff))) {
1317 // Some instructions have giant device identifier fields, so truncate their
1318 // length to 128.
1319 constexpr int kMaxDeviceIdFieldLen = 128;
1320 if ((absl::StartsWith(line, "replica_groups=") ||
1321 absl::StartsWith(line, "source_target_pairs=")) &&
1322 line.length() > kMaxDeviceIdFieldLen) {
1323 lines.push_back(HtmlLikeStringSanitize(
1324 StrCat(line.substr(0, kMaxDeviceIdFieldLen - 3), "...")));
1325 } else {
1326 lines.push_back(HtmlLikeStringSanitize(line));
1327 }
1328 }
1329
1330 // Show the shape and layout of the instruction, unless it's an inlined fusion
1331 // node -- there the shape and layout is present in the output node.
1332 if (instr->opcode() != HloOpcode::kFusion ||
1333 !ShouldShowFusionSubcomputation(instr)) {
1334 // Show layout of instructions with more than one dimension. Don't show
1335 // layout on tuples or tensors with just one dimension (which only have one
1336 // possible layout) to avoid visual noise.
1337 bool shape_is_multidim = false;
1338 ShapeUtil::ForEachSubshape(instr->shape(),
1339 [&](const Shape& s, const ShapeIndex&) {
1340 shape_is_multidim |= s.dimensions_size() > 1;
1341 });
1342 std::string instr_shape;
1343 if (instr->opcode() != HloOpcode::kTuple && shape_is_multidim) {
1344 instr_shape = ShapeUtil::HumanStringWithLayout(instr->shape());
1345 } else {
1346 instr_shape = ShapeUtil::HumanString(instr->shape());
1347 }
1348
1349 // Some instructions have giant tuples as their shapes, so truncate the
1350 // HLO's shape to kMaxShapeLen characters.
1351 constexpr int kMaxShapeLen = 64;
1352 if (instr_shape.length() > kMaxShapeLen) {
1353 instr_shape = StrCat(
1354 absl::string_view(instr_shape).substr(0, kMaxShapeLen - 3), "...");
1355 }
1356 lines.push_back(HtmlLikeStringSanitize(instr_shape));
1357 }
1358 if (debug_options_.xla_hlo_graph_addresses()) {
1359 lines.push_back(StrFormat("[%p]", instr));
1360 }
1361 if (profile_ != nullptr) {
1362 double hlo_cycles_executed = profile_->GetCyclesTakenBy(*instr);
1363 double total_cycles_executed =
1364 profile_->total_cycles_executed(*instr->parent());
1365 if (hlo_cycles_executed > 0 && total_cycles_executed > 0) {
1366 lines.push_back(
1367 StrFormat("%% of cycles executed=%.2f",
1368 100 * hlo_cycles_executed / total_cycles_executed));
1369 }
1370 }
1371 return StrJoin(lines, "<br/>");
1372 }
1373
AddInstructionIncomingEdges(const HloInstruction * instr)1374 void HloDotDumper::AddInstructionIncomingEdges(const HloInstruction* instr) {
1375 auto add_edge = [&](const HloInstruction* from, const HloInstruction* to,
1376 int64_t operand_num, bool control_edge = false) {
1377 from = GetNodeForEdge(from);
1378
1379 if (!filter_.Show(from) || from->opcode() == HloOpcode::kConstant ||
1380 IsFusedBroadcastOfConstantEffectiveScalar(from) ||
1381 ShouldMergeIntoUsers(from)) {
1382 return;
1383 }
1384 VLOG(2) << "Adding edge from " << from->name() << " to " << to->name()
1385 << " as " << next_edge_id_;
1386 edge_ids_.insert({{from, to}, next_edge_id_++});
1387
1388 std::string edge_label;
1389 if (control_edge) {
1390 edge_label = "style=\"dotted\" color=\"gray\" label=\"ctrl\"";
1391 } else if (instr->operand_count() > 1) {
1392 edge_label =
1393 StrFormat(R"( headlabel="%d", labeldistance=2)", operand_num);
1394 }
1395
1396 // We print "small" arrays using a hollow arrowhead and "large" arrays using
1397 // a filled arrowhead.
1398 constexpr char kEdgeFmt[] =
1399 R"(%s -> %s [arrowhead=%s tooltip="%s -> %s" %s];)";
1400 edges_.push_back(StrFormat(kEdgeFmt, InstructionId(from), InstructionId(to),
1401 (IsSmall(from) ? "empty" : "normal"),
1402 from->name(), to->name(), edge_label));
1403 };
1404
1405 // Add edges from instr's operands to instr. Parameters within fusion
1406 // expressions are handled specially -- we draw an edge from the corresponding
1407 // operand on the fusion node itself to the parameter.
1408 if (instr->opcode() == HloOpcode::kParameter && instr->IsFused()) {
1409 // Only add the edge if this is not the outermost computation; otherwise it
1410 // will lead from a node we're not drawing.
1411 if (instr->parent() != computation_) {
1412 const HloInstruction* fusion = instr->parent()->FusionInstruction();
1413 add_edge(fusion->operand(instr->parameter_number()), instr,
1414 /*operand_num=*/0);
1415 }
1416 } else {
1417 for (int64_t i = 0; i < instr->operand_count(); ++i) {
1418 add_edge(instr->operand(i), instr, i);
1419 }
1420 for (const HloInstruction* pred : instr->control_predecessors()) {
1421 add_edge(pred, instr, /*operand_num=*/0, /*control_edge=*/true);
1422 }
1423 }
1424 }
1425
GetInstructionTrivialComputationStr(const HloInstruction * instr)1426 std::string HloDotDumper::GetInstructionTrivialComputationStr(
1427 const HloInstruction* instr) {
1428 // called_computations() on a fusion node "inherits" any called computations
1429 // of the fused root, which isn't what we want. Just ignore fusion nodes
1430 // here; they're handled separately.
1431 if (instr->opcode() == HloOpcode::kFusion) {
1432 return "";
1433 }
1434
1435 std::vector<std::string> lines;
1436 for (int64_t i = 0; i < instr->called_computations().size(); ++i) {
1437 optional<std::string> computation_type =
1438 MatchTrivialComputation(instr->called_computations()[i]);
1439 if (!computation_type) {
1440 continue;
1441 }
1442 if (instr->called_computations().size() == 1) {
1443 lines.push_back(StrFormat("Subcomputation: <b>%s</b>",
1444 HtmlLikeStringSanitize(*computation_type)));
1445 } else {
1446 lines.push_back(StrFormat("Subcomputation %d: <b>%s</b>", i,
1447 HtmlLikeStringSanitize(*computation_type)));
1448 }
1449 }
1450 return StrJoin(lines, "<br/>");
1451 }
1452
GetNodeForEdge(const HloInstruction * instr)1453 const HloInstruction* HloDotDumper::GetNodeForEdge(
1454 const HloInstruction* instr) {
1455 // Skip over get-tuple-element nodes.
1456 if (instr->opcode() == HloOpcode::kGetTupleElement) {
1457 instr = instr->operand(0);
1458 }
1459 while (instr->opcode() == HloOpcode::kFusion &&
1460 ShouldShowFusionSubcomputation(instr)) {
1461 instr = instr->fused_expression_root();
1462 }
1463 return instr;
1464 }
1465
1466 // Gets a NodeFilter that includes roughly all instructions whose distance from
1467 // root is <= radius.
MakeNodeRadiusAroundFilter(const HloInstruction * root,int64_t radius,const absl::flat_hash_set<const HloInstruction * > & boundary)1468 NodeFilter MakeNodeRadiusAroundFilter(
1469 const HloInstruction* root, int64_t radius,
1470 const absl::flat_hash_set<const HloInstruction*>& boundary) {
1471 // First, find the neighborhood of nodes with distance from root <= radius.
1472 // These nodes are our initial set of "normal" nodes.
1473 absl::flat_hash_map<const HloInstruction*, NodeFilterResult> nodes;
1474 std::deque<std::pair<const HloInstruction*, /*depth*/ int64_t>> worklist;
1475 worklist.push_back({root, 0});
1476 while (!worklist.empty()) {
1477 const HloInstruction* instr;
1478 int64_t depth;
1479 std::tie(instr, depth) = worklist.front();
1480 worklist.pop_front();
1481
1482 nodes[instr] = kNormalNode;
1483 if (depth == radius) {
1484 continue;
1485 }
1486 if (boundary.contains(instr)) {
1487 continue;
1488 }
1489
1490 // Traverse into instr's operands.
1491 //
1492 // Don't traverse into tuples' operands unless the tuple is the root.
1493 // Usually a tuple is the bottommost node in the graph, and so its operands
1494 // are not interesting to the graph at hand.
1495 if (instr == root || instr->opcode() != HloOpcode::kTuple) {
1496 for (const HloInstruction* operand : instr->operands()) {
1497 if (!nodes.contains(operand)) {
1498 worklist.push_back({operand, depth + 1});
1499 }
1500 }
1501 }
1502
1503 // Traverse into instr's nested computations.
1504 for (const HloComputation* computation : instr->called_computations()) {
1505 worklist.push_back({computation->root_instruction(), depth + 1});
1506 }
1507
1508 // Traverse into instr's users, unless:
1509 //
1510 // - there are a ton of them, in which case they're probably not
1511 // interesting (and anyway, rendering them all would make the graph
1512 // unreadable), or
1513 // - instr is a constant, in which case its users are probably not
1514 // interesting.
1515 if (instr->opcode() == HloOpcode::kConstant) {
1516 continue;
1517 }
1518 constexpr int kMaxUsersToRender = 16;
1519 if (instr->user_count() > kMaxUsersToRender) {
1520 // If we're going to skip this node's users, style it as such.
1521 nodes[instr] = kSomeUsersOmitted;
1522 continue;
1523 }
1524 for (const HloInstruction* user : instr->users()) {
1525 if (!nodes.contains(user)) {
1526 worklist.push_back({user, depth + 1});
1527 }
1528 }
1529 }
1530
1531 auto is_displayed = [&](const HloInstruction* instr) {
1532 // Constants are displayed inline with their users; they're never omitted.
1533 // Nodes in subcomputations are always shown.
1534 return nodes.contains(instr) || instr->opcode() == HloOpcode::kConstant ||
1535 instr->parent() != root->parent();
1536 };
1537
1538 // Make a second pass over 'nodes' to fix up the NodeFilterResults now that we
1539 // know which nodes will be included in the graph.
1540 for (auto& kv : nodes) {
1541 const HloInstruction* instr = kv.first;
1542 NodeFilterResult& filter_result = kv.second;
1543 const auto& operands = instr->operands();
1544
1545 if (absl::c_any_of(operands, is_displayed) &&
1546 !absl::c_all_of(operands, is_displayed)) {
1547 // Mark nodes with some operands omitted appropriately.
1548 filter_result = kSomeOperandsOmitted;
1549 } else if (!operands.empty() && absl::c_none_of(operands, is_displayed)) {
1550 // Mark nodes with *all* operands omitted appropriately.
1551 filter_result = kOmitNodeOperands;
1552 }
1553
1554 // Promote nodes with type kSomeUsersOmitted to kNormalNode if all of their
1555 // users made it into the graph.
1556 if (filter_result == kSomeUsersOmitted &&
1557 absl::c_all_of(instr->users(), is_displayed)) {
1558 filter_result = kNormalNode;
1559 }
1560 }
1561
1562 // Highlight the root node.
1563 nodes[root] = kHighlightNode;
1564
1565 return NodeFilter([=](const HloInstruction* instr) {
1566 auto it = nodes.find(instr);
1567 if (it != nodes.end()) {
1568 return it->second;
1569 }
1570 // Show all nodes in subcomputations.
1571 if (instr->parent() != root->parent()) {
1572 return kNormalNode;
1573 }
1574 return kHideNode;
1575 });
1576 }
1577
1578 // Gets a node filter that includes nodes on all paths from `from` to `to`. If
1579 // the all-paths set contains more than max_nodes elements, includes the nodes
1580 // on the shortest paths and sets hit_limit to true.
MakeNodeFromToFilter(const HloInstruction * from,const HloInstruction * to,int64_t max_nodes,bool * hit_limit)1581 NodeFilter MakeNodeFromToFilter(const HloInstruction* from,
1582 const HloInstruction* to, int64_t max_nodes,
1583 bool* hit_limit) {
1584 *hit_limit = false;
1585
1586 // Elements in the queue are paths through the graph.
1587 std::deque<std::vector<const HloInstruction*>> queue;
1588 queue.push_front({from});
1589
1590 // Compute the set of nodes we want to show using a slightly-modified
1591 // Djikstra's algorithm. The only real difference is, rather than stopping
1592 // when we find a (shortest) path, we continue until we've found max_nodes
1593 // nodes on some path.
1594 absl::flat_hash_set<const HloInstruction*> visited;
1595 absl::flat_hash_set<const HloInstruction*> to_display = {from, to};
1596 while (!queue.empty() && to_display.size() < max_nodes) {
1597 std::vector<const HloInstruction*> path = std::move(queue.front());
1598 queue.pop_front();
1599 if (!visited.insert(path.back()).second) {
1600 continue;
1601 }
1602
1603 for (const auto* user : path.back()->users()) {
1604 if (user == to) {
1605 auto it = path.begin();
1606 for (; it != path.end() && to_display.size() < max_nodes; ++it) {
1607 to_display.insert(*it);
1608 }
1609 if (it != path.end()) {
1610 *hit_limit = true;
1611 }
1612 } else if (!visited.count(user)) {
1613 auto new_path = path;
1614 new_path.push_back(user);
1615 queue.push_back(std::move(new_path));
1616 }
1617 }
1618 }
1619
1620 return NodeFilter([=](const HloInstruction* instr) {
1621 if (instr == from || instr == to) {
1622 return kHighlightNode;
1623 }
1624 return to_display.count(instr) ? kNormalNode : kHideNode;
1625 });
1626 }
1627
1628 static const char* kRenderDotJS = R"(
1629 <!-- Integrity hash is generated by https://www.srihash.org/ -->
1630 <script src="https://cdn.jsdelivr.net/npm/viz.js@2.1.1/viz.js"
1631 integrity="sha384-aD1MJYb0WKIUT+CtwJp5LTuV3U4pLAS6B/nUxL7ECimC2pN9N8vjlMr/yQCAkzxE"
1632 crossorigin="anonymous"></script>
1633 <script src="https://cdn.jsdelivr.net/npm/viz.js@2.1.1/full.render.js"
1634 integrity="sha384-bAixY275aIpCj6Te19y0MILZ4V+VEC8CVFujFEH+Lf7W+4XYYeYLwW5IBI6yQmMT"
1635 crossorigin="anonymous"></script>
1636 <script src="https://cdn.jsdelivr.net/npm/svg-pan-zoom@3.6.0/dist/svg-pan-zoom.min.js"
1637 integrity="sha384-3008WpYB2pOBvE7lwkrKf+qTmbTPGGPYxA9C1YVhvbPukns4ZFj7E98QPLkNW9dS"
1638 crossorigin="anonymous"></script>
1639 <script src="https://cdn.jsdelivr.net/npm/@hpcc-js/wasm/dist/index.min.js"
1640 integrity="sha384-X+8WXyWZ+W2gUHiSSj0aePAkE77Fl6eZ+QIByw+Ii8LzWEJ/W8bI8M4RkneDAJ4D"
1641 crossorigin="anonymous"></script>
1642 )";
1643
WrapDotInHtml(absl::string_view dot)1644 std::string WrapDotInHtml(absl::string_view dot) {
1645 std::string html_prefix =
1646 absl::StrReplaceAll(R"html(
1647 <!DOCTYPE html>
1648 <html>
1649 <head>
1650 <meta charset="utf-8">
1651 <style type="text/css">
1652 body {
1653 height: 100vh;
1654 margin: 0;
1655 }
1656 </style>
1657 </head>
1658 <body>
1659 $JS_INCLUDE
1660 <div id="container" style="height:95vh; border:1px solid black; "></div>
1661 <script>
1662 var data = `
1663 )html",
1664 {{"$JS_INCLUDE", kRenderDotJS}});
1665
1666 static const char html_suffix[] = R"html(
1667 `;
1668 var cssregex = new RegExp('stylesheet=<([^]*)\n>\n', 'gm');
1669 var results = cssregex.exec(data)
1670 // graphviz has problem dealing with large stylesheets.
1671 // https://github.com/tensorflow/tensorflow/issues/17220#issuecomment-369228492
1672 // In order to avoid the problem, remove the stylesheet from the dot and
1673 // insert it directly info the rendered SVG.
1674 var dot_data = data;
1675 var css_data = ''
1676 if (results !== null) {
1677 css_data = results[1].replace(/\s*data:.*\s*,/,''); // Strip content-type field.
1678 // CSS inside DOT is URL-escaped, so we must unescape it
1679 // before we can insert it into SVG.
1680 css_data = unescape(css_data);
1681 dot_data = data.replace(cssregex, ''); // Remove the stylesheet
1682 }
1683
1684 var render_start = performance.now()
1685 function add_controls(svg) {
1686 var htmlblob = new Blob([document.documentElement.innerHTML],
1687 {type: 'text/html'});
1688 var savehtml = document.createElement('a');
1689 savehtml.setAttribute('href', URL.createObjectURL(htmlblob));
1690 savehtml.setAttribute('download', 'graph.html');
1691 savehtml.innerHTML = " [Save HTML+SVG] ";
1692 document.body.append(savehtml);
1693 var svgblob = new Blob([svg.outerHTML], {type: 'image/svg'});
1694 var savesvg = document.createElement('a');
1695 savesvg.setAttribute('href', URL.createObjectURL(svgblob));
1696 savesvg.setAttribute('download', 'graph.svg');
1697 savesvg.innerHTML = " [Save SVG] ";
1698 document.body.append(savesvg);
1699 var dotblob = new Blob([data], {type: 'text/dot'});
1700 var savedot = document.createElement('a');
1701 savedot.setAttribute('href', URL.createObjectURL(dotblob));
1702 savedot.setAttribute('download', 'graph.dot');
1703 savedot.innerHTML = " [Save DOT] ";
1704 document.body.append(savedot);
1705 // Will get called after embed element was loaded
1706 var panzoom = svgPanZoom(svg, {
1707 zoomEnabled: true,
1708 controlIconsEnabled: true,
1709 });
1710 document.getElementsByTagName("BODY")[0].onresize = function() {
1711 panzoom.resize();
1712 panzoom.fit();
1713 panzoom.center();
1714 };
1715 var render_end = performance.now();
1716 var render_note = document.createElement('div')
1717 render_note.innerHTML = 'Rendering took '
1718 + (render_end - render_start).toFixed(2) + "ms."
1719 document.body.append(render_note);
1720 }
1721 var svg = document.getElementById('graph')
1722 if (svg == null) {
1723 // Need to render SVG first.
1724 var viz = new Viz();
1725 viz.renderSVGElement(dot_data)
1726 .then(function(svg){
1727 var container = document.getElementById('container')
1728 var style = document.createElementNS('http://www.w3.org/2000/svg', 'style');
1729 var node = document.createTextNode(css_data);
1730 style.appendChild(node);
1731 svg.setAttribute('width', '100%');
1732 svg.setAttribute('height', '100%');
1733 svg.setAttribute('id', 'graph');
1734 svg.appendChild(style);
1735 container.appendChild(svg);
1736 add_controls(svg);
1737 })
1738 } else {
1739 // HTML already has rendered SVG embedded, so we just need to add
1740 // controls.
1741 add_controls(svg);
1742 }
1743 </script>
1744 </body>
1745 </html>
1746 )html";
1747
1748 return absl::StrCat(html_prefix, dot, html_suffix);
1749 }
1750
1751 absl::Mutex url_renderer_mu(absl::kConstInit);
1752 std::function<StatusOr<std::string>(absl::string_view)>* url_renderer
1753 ABSL_GUARDED_BY(url_renderer_mu) = nullptr;
1754
1755 // Storage for fusion visualization: (module_id, computation_id) -> sequence of
1756 // fusion states.
1757 absl::Mutex fusion_visualizer_state_mu(absl::kConstInit);
1758 namespace {
1759
1760 // Fusion state: a sequence of rendered graphs in DOT formats with explanations.
1761 // Rendered graphs can be shared across frames, hence the storage indirection.
1762 struct FusionVisualizerProgress {
1763 // Creates a frame with a new rendered graph.
AddStatexla::__anon758a858c0111::__anon758a858c0f11::FusionVisualizerProgress1764 void AddState(absl::string_view dot, absl::string_view explanation,
1765 std::optional<std::string> to_highlight) {
1766 if (dot_graphs.empty() || dot_graphs.back() != dot) {
1767 dot_graphs.push_back(std::string(dot));
1768 }
1769 frames.push_back({static_cast<int>(dot_graphs.size() - 1),
1770 std::string(explanation), to_highlight.value_or("")});
1771 }
1772
1773 std::vector<std::string> dot_graphs;
1774
1775 struct FusionFrame {
1776 int dot_graph;
1777 std::string label;
1778 std::string to_highlight;
1779 };
1780
1781 std::vector<FusionFrame> frames;
1782 };
1783
1784 } // namespace
1785
1786 static auto& fusion_visualizer_states
1787 TF_GUARDED_BY(fusion_visualizer_state_mu) = *new absl::flat_hash_map<
1788 std::pair<int64_t, int64_t>, FusionVisualizerProgress>();
1789
1790 // Generates a key to the fusion visualizer state mapping.
1791 static std::pair<int, int> FusionVisualizerStateKey(
1792 const HloComputation& computation) {
1793 return std::make_pair(computation.parent()->unique_id(),
1794 computation.unique_id());
1795 }
1796
1797 // Precondition: (url_renderer != nullptr || format != kUrl).
1798 //
1799 // (We specify this as a precondition rather than checking it in here and
1800 // returning an error because we want to fail quickly when there's no URL
1801 // renderer available, and this function runs only after we've done all the work
1802 // of producing dot for the graph.)
1803 StatusOr<std::string> WrapDotInFormat(const HloComputation& computation,
1804 absl::string_view dot,
1805 RenderedGraphFormat format)
1806 ABSL_EXCLUSIVE_LOCKS_REQUIRED(url_renderer_mu) {
1807 switch (format) {
1808 case RenderedGraphFormat::kUrl:
1809 CHECK(url_renderer != nullptr)
1810 << "Should have checked url_renderer != null before calling.";
1811 return (*url_renderer)(dot);
1812 case RenderedGraphFormat::kHtml:
1813 return WrapDotInHtml(dot);
1814 case RenderedGraphFormat::kDot:
1815 return std::string(dot);
1816 }
1817 }
1818
1819 } // namespace
1820
1821 // Compress with zlib + b64 encode.
1822 static StatusOr<std::string> CompressAndEncode(absl::string_view input) {
1823 class WritableStringFile : public tensorflow::WritableFile {
1824 public:
1825 explicit WritableStringFile(std::string* data) : data_(data){};
1826 ~WritableStringFile() override = default;
1827
1828 Status Append(absl::string_view data) override {
1829 absl::StrAppend(data_, data);
1830 return OkStatus();
1831 }
1832
1833 Status Close() override { return OkStatus(); }
1834 Status Flush() override { return OkStatus(); }
1835 Status Sync() override { return OkStatus(); }
1836
1837 private:
1838 std::string* data_;
1839 };
1840
1841 std::string compressed;
1842 WritableStringFile f(&compressed);
1843
1844 auto gz_opts = tensorflow::io::ZlibCompressionOptions::GZIP();
1845 tensorflow::io::ZlibOutputBuffer gz_file(&f, gz_opts.input_buffer_size,
1846 gz_opts.output_buffer_size, gz_opts);
1847 TF_RETURN_IF_ERROR(gz_file.Init());
1848 TF_RETURN_IF_ERROR(gz_file.Append(input));
1849 TF_RETURN_IF_ERROR(gz_file.Close());
1850
1851 std::string encoded;
1852 TF_RETURN_IF_ERROR(tensorflow::Base64Encode(compressed, &encoded));
1853 return absl::StrReplaceAll(encoded, {{"_", "/"}, {"-", "+"}});
1854 }
1855
1856 static std::string EscapeJSONString(absl::string_view raw) {
1857 return absl::StrCat(
1858 "\"",
1859 absl::StrReplaceAll(raw, {{"\n", "\\n"}, {"\"", "\\\""}, {"\\", "\\\\"}}),
1860 "\"");
1861 }
1862
WrapFusionExplorer(const HloComputation & computation)1863 StatusOr<std::string> WrapFusionExplorer(const HloComputation& computation) {
1864 absl::MutexLock lock(&fusion_visualizer_state_mu);
1865 using absl::StrAppend;
1866 using absl::StrFormat;
1867 using absl::StrJoin;
1868 const FusionVisualizerProgress& visualizer_progress =
1869 fusion_visualizer_states[FusionVisualizerStateKey(computation)];
1870 if (visualizer_progress.frames.empty()) {
1871 return InternalError("Empty");
1872 }
1873
1874 std::string dot_graphs =
1875 StrFormat("[%s]", StrJoin(visualizer_progress.dot_graphs, ", ",
1876 [&](std::string* out, const std::string& dot) {
1877 StrAppend(out, EscapeJSONString(dot));
1878 }));
1879
1880 std::string frames = StrJoin(
1881 visualizer_progress.frames, ", ", [&](std::string* out, const auto& p) {
1882 StrAppend(out, StrFormat("[%d, %s, %s]", p.dot_graph,
1883 EscapeJSONString(p.label),
1884 EscapeJSONString(p.to_highlight)));
1885 });
1886
1887 TF_ASSIGN_OR_RETURN(std::string dot_graphs_compressed,
1888 CompressAndEncode(dot_graphs));
1889
1890 return absl::StrReplaceAll(
1891 R"(
1892 <!DOCTYPE html>
1893 <html>
1894 <head>
1895 <meta charset="utf-8">
1896 <style>
1897 html, body {height: 100%; text-align: center;}
1898 #rendered {height: 70%; width: 80%; border:1px solid black; margin: auto; }
1899 #label {width: 80%; margin: auto;}
1900 #performance_note { font-size: small; color: gray; }
1901 #frames_list {
1902 list-style: none; text-align: left; height: 20%; overflow: scroll;
1903 }
1904 #frames_list li { padding: 0.2em; margin: 0.2em; }
1905 .selected { background-color: #e0e0e0; }
1906 .selected a { color: black; text-decoration: none; }
1907 #rendered svg { height: 100% !important; width: 100% !important; }
1908 </style>
1909 </head>
1910 <body>
1911 <script src="https://www.gstatic.com/external_hosted/hpcc_js_wasm/index.min.js"
1912 integrity="sha384-LigJPbR3TOfU/Xbb+PjiN1dGJYPweLk7kiGnaMgmxnUmKWaCFKbb5tH6iLlyVhPZ"
1913 crossorigin="anonymous"></script>
1914 <script src="https://www.gstatic.com/external_hosted/svg_pan_zoom/svg-pan-zoom.js">
1915 </script>
1916
1917 <title>Fusion Explorer: $TITLE</title>
1918 <div id='rendered'><center>Loading...</center></div>
1919 <ul id='frames_list'></ul>
1920 <p>Use j/k for keyboard navigation.</p>
1921 <p id='performance_note'>Loading data...</p>
1922 <script>
1923 <!--
1924 const renderCache = {};
1925
1926 const cssregex = new RegExp('stylesheet=<([^]*)\n>\n', 'gm');
1927 const hpccWasm = window["@hpcc-js/wasm"];
1928
1929 const getIdFromHash = () => {
1930 let hash = window.location.hash;
1931 if (hash.indexOf('frame') == -1) {
1932 return 0;
1933 }
1934 return parseInt(window.location.hash.substring('#frame'.length, window.location.hash.length));
1935 }
1936
1937 const renderCurrentFrame = () => {
1938 if (!window.loaded) { return; }
1939 const frames_list = document.getElementById('frames_list');
1940 const currId = getIdFromHash();
1941
1942 for (let selected of frames_list.getElementsByClassName('selected')) {
1943 selected.classList.remove('selected');
1944 }
1945
1946 const selected = frames_list.children[currId];
1947 selected.classList.add('selected');
1948 selected.scrollIntoView();
1949
1950 const frame = frames[currId];
1951 const dot_ptr = frame[0];
1952 let dot_txt = window.dots[dot_ptr];
1953 const label = frame[1];
1954 document.getElementById('performance_note').innerText = "Rendering...";
1955 const results = cssregex.exec(dot_txt)
1956 let css_data = ''
1957 if (results !== null) {
1958 css_data = results[1].replace(/\s*data:.*\s*,/,''); // Strip content-type field.
1959 // CSS inside DOT is URL-escaped, so we must unescape it
1960 // before we can insert it into SVG.
1961 css_data = unescape(css_data);
1962 dot_txt = dot_txt.replace(cssregex, ''); // Remove the stylesheet
1963 }
1964
1965 let render_start = performance.now();
1966 const render_callback = svg => {
1967 renderCache[dot_ptr] = svg;
1968 var area = document.getElementById('rendered');
1969 area.innerHTML = `${svg}<style>${css_data}</style>`;
1970 var panzoom = svgPanZoom(area.children[0], {
1971 zoomEnabled: true, controlIconsEnabled: true, });
1972 var to_highlight = frame[2].length ?
1973 document.querySelector(`${frame[2]}`) : null;
1974 if (to_highlight) {
1975 to_highlight.style.setProperty('fill', 'red');
1976 }
1977 document.getElementById('performance_note').innerText =
1978 `Rendering took ${(performance.now() - render_start).toFixed(2)}ms`;
1979 };
1980 if (renderCache[dot_ptr]) {
1981 render_callback(renderCache[dot_ptr]);
1982 } else {
1983 hpccWasm.graphviz.layout(dot_txt, "svg", "dot").then(render_callback);
1984 }
1985 };
1986
1987 const update = (delta) => {
1988 let currId = getIdFromHash();
1989 currId = (currId + delta + frames.length) % frames.length;
1990 window.location.hash = `#frame${currId}`
1991 };
1992
1993 const renderFrameList = () => {
1994 const currId = getIdFromHash();
1995 const frames_list = document.getElementById('frames_list');
1996 for (let i=0; i<frames.length; i++) {
1997 const f = frames[i];
1998 let frame_descr = f[1];
1999 const rendered = document.createElement("li");
2000 if (frame_descr == "") {
2001 frame_descr = "Unnamed state";
2002 }
2003 rendered.innerHTML = `<a href="#frame${i}">${frame_descr}</a>`;
2004 if (i == currId) {
2005 rendered.classList.add('selected');
2006 }
2007 frames_list.appendChild(rendered);
2008 }
2009 };
2010
2011 const decompress = async function(compressed) {
2012 const ds = new DecompressionStream('gzip');
2013 const in_fetch = await fetch(`data:application/octet-stream;base64,${compressed}`);
2014 const in_blob = await in_fetch.blob();
2015 const out_stream = in_blob.stream().pipeThrough(ds);
2016 const out_blob = await new Response(out_stream).blob();
2017 return await out_blob.text();
2018 }
2019
2020 const dots_compressed = "$DOTS";
2021 const frames = [$FRAMES];
2022 let loaded = false;
2023
2024 window.addEventListener('hashchange', () => {
2025 renderCurrentFrame();
2026 });
2027
2028 window.addEventListener("keydown", (event) => {
2029 if (event.defaultPrevented) {
2030 return;
2031 }
2032 if (event.key == "j") {
2033 update(1);
2034 } else if (event.key == "k") {
2035 update(-1);
2036 } else {
2037 return;
2038 }
2039 event.preventDefault();
2040 }, true);
2041
2042 document.addEventListener("DOMContentLoaded", () => {
2043 decompress(dots_compressed).then(text => {
2044 window.dots = JSON.parse(text);
2045 window.loaded = true;
2046 renderFrameList();
2047 renderCurrentFrame();
2048 });
2049 });
2050
2051 //-->
2052 </script>
2053 </body>
2054 </html>
2055 )",
2056 {{"$DOTS", dot_graphs_compressed},
2057 {"$FRAMES", frames},
2058 {"$TITLE",
2059 absl::StrCat(computation.parent()->name(), "_", computation.name())}});
2060 }
2061
2062 void RegisterGraphToURLRenderer(
2063 std::function<StatusOr<std::string>(absl::string_view)> renderer) {
2064 absl::MutexLock lock(&url_renderer_mu);
2065 if (url_renderer != nullptr) {
2066 LOG(WARNING) << "Multiple calls to RegisterGraphToURLRenderer. Last call "
2067 "wins, but because order of initialization in C++ is "
2068 "nondeterministic, this may not be what you want.";
2069 }
2070 delete url_renderer;
2071 url_renderer = new std::function<StatusOr<std::string>(absl::string_view)>(
2072 std::move(renderer));
2073 }
2074
2075 void RegisterFusionState(const HloComputation& computation,
2076 absl::string_view label,
2077 const HloInstruction& consumer,
2078 const HloInstruction* producer) {
2079 absl::MutexLock lock(&fusion_visualizer_state_mu);
2080 FusionVisualizerProgress& fusion_progress =
2081 fusion_visualizer_states[FusionVisualizerStateKey(computation)];
2082
2083 // Radius size in which to render.
2084 static constexpr int kRenderRadius = 4;
2085
2086 absl::flat_hash_set<const HloInstruction*> render_boundary;
2087 for (const HloInstruction* user : consumer.users()) {
2088 render_boundary.insert(user);
2089 }
2090
2091 HloDotDumper dumper(
2092 consumer.parent(),
2093 StrCat("Rendering of ", kRenderRadius, " nodes around fusion consumer"),
2094 consumer.GetModule()->config().debug_options(), {}, /*profile=*/nullptr,
2095 MakeNodeRadiusAroundFilter(&consumer, kRenderRadius, render_boundary));
2096 std::string dot_txt = dumper.Dump();
2097
2098 std::optional<std::string> producer_to_highlight;
2099 if (producer) {
2100 producer_to_highlight = dumper.CssIdForInstruction(*producer);
2101 }
2102
2103 fusion_progress.AddState(dot_txt, label, producer_to_highlight);
2104 }
2105
2106 StatusOr<std::string> RenderGraph(
2107 const HloComputation& computation, absl::string_view label,
2108 const DebugOptions& debug_options, RenderedGraphFormat format,
2109 const HloExecutionProfile* hlo_execution_profile,
2110 HloRenderOptions hlo_render_options) {
2111 absl::MutexLock lock(&url_renderer_mu);
2112 if (format == RenderedGraphFormat::kUrl && url_renderer == nullptr) {
2113 return Unavailable("Can't render as URL; no URL renderer was registered.");
2114 }
2115
2116 std::string rendered_dot =
2117 HloDotDumper(&computation, label, debug_options, hlo_render_options,
2118 hlo_execution_profile, NodeFilter())
2119 .Dump();
2120 return WrapDotInFormat(computation, rendered_dot, format);
2121 }
2122
2123 StatusOr<std::string> RenderNeighborhoodAround(
2124 const HloInstruction& node, int radius, RenderedGraphFormat format,
2125 HloRenderOptions hlo_render_options,
2126 const absl::flat_hash_set<const HloInstruction*>& boundary) {
2127 absl::MutexLock lock(&url_renderer_mu);
2128 if (format == RenderedGraphFormat::kUrl && url_renderer == nullptr) {
2129 return FailedPrecondition(
2130 "Can't render as URL; no URL renderer was registered.");
2131 }
2132
2133 std::string label =
2134 StrCat("Neighborhood of ", radius, " nodes around ", node.name());
2135 std::string rendered_dot =
2136 HloDotDumper(node.parent(), label,
2137 node.GetModule()->config().debug_options(),
2138 hlo_render_options, /*profile=*/nullptr,
2139 MakeNodeRadiusAroundFilter(&node, radius, boundary))
2140 .Dump();
2141 return WrapDotInFormat(*node.parent(), rendered_dot, format);
2142 }
2143
2144 StatusOr<std::string> RenderAllPathsFromTo(
2145 const HloInstruction& from, const HloInstruction& to, int64_t max_nodes,
2146 RenderedGraphFormat format, HloRenderOptions hlo_render_options) {
2147 absl::MutexLock lock(&url_renderer_mu);
2148 if (format == RenderedGraphFormat::kUrl && url_renderer == nullptr) {
2149 return FailedPrecondition(
2150 "Can't render as URL; no URL renderer was registered.");
2151 }
2152
2153 CHECK_EQ(from.parent(), to.parent()) << "Nodes must be in same computation!";
2154 auto debug_options = from.GetModule()->config().debug_options();
2155
2156 bool hit_limit = false;
2157 NodeFilter filter = MakeNodeFromToFilter(&from, &to, max_nodes, &hit_limit);
2158 std::string label;
2159 if (!hit_limit) {
2160 label = StrCat("All paths from ", from.name(), " to ", to.name());
2161 } else {
2162 label = StrCat(max_nodes, " nodes on the shortest paths from ", from.name(),
2163 " to ", to.name(),
2164 "<br/><br/>***SHOWING ONLY A SUBSET OF ALL PATHS BETWEEN "
2165 "NODES***<br/><br/>");
2166 }
2167 std::string rendered_dot =
2168 HloDotDumper(from.parent(), label, debug_options, hlo_render_options,
2169 /*profile=*/nullptr, filter)
2170 .Dump();
2171 return WrapDotInFormat(*from.parent(), rendered_dot, format);
2172 }
2173
2174 } // namespace xla
2175