1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
17
18 #include <unistd.h>
19
20 #include <algorithm>
21 #include <atomic>
22 #include <deque>
23 #include <map>
24 #include <memory>
25 #include <queue>
26 #include <string>
27 #include <tuple>
28 #include <vector>
29
30 #include "absl/container/flat_hash_map.h"
31 #include "absl/strings/match.h"
32 #include "absl/strings/str_cat.h"
33 #include "absl/strings/str_format.h"
34 #include "absl/strings/str_join.h"
35 #include "absl/strings/str_replace.h"
36 #include "absl/types/optional.h"
37 #include "tensorflow/compiler/xla/layout_util.h"
38 #include "tensorflow/compiler/xla/literal.h"
39 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
40 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
41 #include "tensorflow/compiler/xla/service/hlo_module.h"
42 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
43 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
44 #include "tensorflow/compiler/xla/shape_util.h"
45 #include "tensorflow/compiler/xla/types.h"
46 #include "tensorflow/compiler/xla/util.h"
47 #include "tensorflow/compiler/xla/window_util.h"
48 #include "tensorflow/core/lib/core/status.h"
49 #include "tensorflow/core/lib/gtl/map_util.h"
50 #include "tensorflow/core/lib/io/path.h"
51 #include "tensorflow/core/lib/strings/numbers.h"
52 #include "tensorflow/core/platform/env.h"
53 #include "tensorflow/core/platform/mutex.h"
54 #include "tensorflow/core/platform/protobuf.h"
55 #include "tensorflow/core/platform/regexp.h"
56
57 namespace xla {
58 namespace {
59
60 using absl::nullopt;
61 using absl::optional;
62 using absl::StrAppend;
63 using absl::StrCat;
64 using absl::StrFormat;
65 using absl::StrJoin;
66
67 // Used to indicate how we should treat a given HLOInstruction in the graph.
68 // should we treat it like normal, hide it, and so on?
69 enum NodeFilterResult {
70 kNormalNode,
71 kHideNode,
72 // Make the node easy to find in the final graph.
73 kHighlightNode,
74 // "Gray out" the node to indicate that some of its operands have been
75 // omitted.
76 kSomeOperandsOmitted,
77 // Style the node the same as kSomeOperandsOmitted, but also don't connect it
78 // to its operands, even if they're present in the graph.
79 kOmitNodeOperands,
80 // Same style as kSomeOperandsOmitted, but used to indicate that some of the
81 // node's *users* have been omitted.
82 kSomeUsersOmitted,
83 };
84
85 // NodeFilter is essentially a map from HloInstruction*s to NodeFilterResult.
86 // It lets callers tell the graph-drawing routines which nodes they want to be
87 // shown, hidden, or highlighted.
88 class NodeFilter {
89 public:
__anon973ba4530202(const HloInstruction*) 90 NodeFilter() : filter_([](const HloInstruction*) { return kNormalNode; }) {}
91
NodeFilter(std::function<NodeFilterResult (const HloInstruction * instr)> filter)92 explicit NodeFilter(
93 std::function<NodeFilterResult(const HloInstruction* instr)> filter)
94 : filter_(std::move(filter)) {}
95
Show(const HloInstruction * instr) const96 bool Show(const HloInstruction* instr) const {
97 return filter_(instr) != kHideNode;
98 }
Highlight(const HloInstruction * instr) const99 bool Highlight(const HloInstruction* instr) const {
100 return filter_(instr) == kHighlightNode;
101 }
OmitOperands(const HloInstruction * instr) const102 bool OmitOperands(const HloInstruction* instr) const {
103 return filter_(instr) == kOmitNodeOperands;
104 }
SomeOrAllOperandsOmitted(const HloInstruction * instr) const105 bool SomeOrAllOperandsOmitted(const HloInstruction* instr) const {
106 auto result = filter_(instr);
107 return result == kOmitNodeOperands || result == kSomeOperandsOmitted;
108 }
Deemphasized(const HloInstruction * instr) const109 bool Deemphasized(const HloInstruction* instr) const {
110 auto result = filter_(instr);
111 return result == kOmitNodeOperands || result == kSomeOperandsOmitted ||
112 result == kSomeUsersOmitted;
113 }
114
115 private:
116 std::function<NodeFilterResult(const HloInstruction* instr)> filter_;
117 };
118
119 // We arbitrarily set this as the boundary between "large" and "small"
120 // instructions.
IsSmall(const HloInstruction * instr)121 bool IsSmall(const HloInstruction* instr) {
122 if (ShapeUtil::HasPrimitiveType(instr->shape(), OPAQUE_TYPE) ||
123 ShapeUtil::HasPrimitiveType(instr->shape(), TOKEN)) {
124 return true;
125 }
126 return ShapeUtil::ElementsInRecursive(instr->shape()) < 4096;
127 }
128
129 // Node color schemes, used by NodeColorAttributes.
130 enum ColorScheme {
131 kBlue,
132 kBrown,
133 kDarkBlue,
134 kDarkGreen,
135 kDarkOrange,
136 kDarkRed,
137 kGray,
138 kGreen,
139 kOrange,
140 kPurple,
141 kRed,
142 kWhite,
143 kYellow,
144
145 // Causes the node's border to be a dashed line, and its content to be gray
146 // text on a white background, suggesting that this is an "unimportant" node.
147 kDashedBorder,
148 };
149
150 // Graphviz attributes/colors that make up a color scheme.
151 struct NodeColors {
152 const char* style;
153 const char* fill_color;
154 const char* stroke_color;
155 const char* font_color;
156 };
157
NodeColorsForScheme(ColorScheme color)158 NodeColors NodeColorsForScheme(ColorScheme color) {
159 switch (color) {
160 case kBlue:
161 return NodeColors{"filled", "#bbdefb", "#8aacc8", "black"};
162 case kBrown:
163 return NodeColors{"filled", "#bcaaa4", "#8c7b75", "black"};
164 case kDarkBlue:
165 return NodeColors{"filled", "#1565c0", "#003c8f", "white"};
166 case kDarkGreen:
167 return NodeColors{"filled", "#2e7d32", "#005005", "white"};
168 case kDarkOrange:
169 // This is more of a "medium" orange, made to look close to kOrange;
170 // there's probably room for a darker weight if desired.
171 return NodeColors{"filled", "#ffb74d", "#c88719", "black"};
172 case kDarkRed:
173 return NodeColors{"filled", "#b71c1c", "#7f0000", "white"};
174 case kGray:
175 return NodeColors{"filled", "#cfd8dc", "#9ea7aa", "black"};
176 case kGreen:
177 return NodeColors{"filled", "#c8e6c9", "#97b498", "black"};
178 case kOrange:
179 return NodeColors{"filled", "#ffe0b2", "#cbae82", "black"};
180 case kPurple:
181 return NodeColors{"filled", "#e1bee7", "#af8eb5", "black"};
182 case kRed:
183 return NodeColors{"filled", "#ffcdd2", "#cb9ca1", "black"};
184 case kWhite:
185 return NodeColors{"filled", "white", "black", "black"};
186 case kYellow:
187 return NodeColors{"filled", "#fff9c4", "#cbc693", "black"};
188 case kDashedBorder:
189 // "filled,dashed" looks the same as "dashed", since we have a white
190 // background. But we use "filled,dashed" so that when you hover over
191 // any part of the node (not just the text inside the node), our css
192 // :hover rule is triggered.
193 return NodeColors{"filled,dashed", "white", "#757575", "#757575"};
194 }
195 }
196
197 // Given a ColorScheme, returns an attribute string for a node of that color.
198 // Sets the node's style and fill/stroke/text colors.
199 //
200 // Colors are from https://material.io/color.
NodeColorAttributes(ColorScheme color)201 string NodeColorAttributes(ColorScheme color) {
202 NodeColors node_colors = NodeColorsForScheme(color);
203
204 return StrFormat(R"(style="%s", fontcolor="%s", color="%s", fillcolor="%s")",
205 node_colors.style, node_colors.font_color,
206 node_colors.stroke_color, node_colors.fill_color);
207 }
208
209 // Replaces <> with <>, so that this string is safe(er) for use in a
210 // graphviz HTML-like string.
HtmlLikeStringSanitize(absl::string_view s)211 string HtmlLikeStringSanitize(absl::string_view s) {
212 return absl::StrReplaceAll(s, {{"<", "<"}, {">", ">"}});
213 }
214
IsFusedBroadcastOfConstantEffectiveScalar(const HloInstruction * instr)215 bool IsFusedBroadcastOfConstantEffectiveScalar(const HloInstruction* instr) {
216 namespace m = match;
217 return instr->parent()->IsFusionComputation() &&
218 Match(instr, m::Broadcast(m::ConstantEffectiveScalar()));
219 }
220
221 // Tries to generates a human-readable one-word description of the given
222 // computation.
223 //
224 // Currently we support:
225 //
226 // "return param0 + param1;" --> "add"
227 // "return param0 * param1;" --> "multiply"
228 // "return min(param0, param1);" --> "min"
229 // "return max(param0, param1);" --> "max"
230 // "return param0 <= param1;" --> "less-or-equal"
231 // "return param0 >= param1;" --> "greater-or-equal"
232 // "return param0 > param1;" --> "greater-than"
233 // "return param0 < param1;" --> "less-than"
234 // "return param0 == param1;" --> "equal-to"
235 // "return param0 != param1;" --> "not-equal-to"
236 //
237 // where param0 and param1 are effective scalars. For the ops that are
238 // commutative, we also support them with param0 and param1 swapped.
239 //
240 // This is useful primarily for reduce and map nodes. These take a
241 // subcomputation which is almost always one of the above, and pattern matching
242 // it to a short string lets us tell the user what the subcomputation is without
243 // drawing it as a graph.
MatchTrivialComputation(const HloComputation * computation)244 optional<string> MatchTrivialComputation(const HloComputation* computation) {
245 namespace m = match;
246
247 if (computation->instruction_count() != 3) {
248 return nullopt;
249 }
250 HloInstruction* root = computation->root_instruction();
251 const HloInstruction *param0, *param1;
252 if (!Match(root, m::Op()
253 .WithNumOperands(2)
254 .WithShape(m::Shape().IsEffectiveScalar())
255 .WithBinaryOperandsAnyOrder(
256 m::Parameter(¶m0, 0)
257 .WithShape(m::Shape().IsEffectiveScalar()),
258 m::Parameter(¶m1, 1)
259 .WithShape(m::Shape().IsEffectiveScalar())))) {
260 return nullopt;
261 }
262
263 // If the params are reversed (i.e. operand0 is param1 and operand1 is
264 // param0), check that the operation being performed is commutative.
265 if (root->operand(0) == param1) {
266 CHECK_EQ(root->operand(1), param0);
267 if (root->opcode() == HloOpcode()) {
268 switch (root->comparison_direction()) {
269 case ComparisonDirection::kLe:
270 case ComparisonDirection::kGe:
271 case ComparisonDirection::kGt:
272 case ComparisonDirection::kLt:
273 return nullopt;
274 default:
275 break;
276 }
277 }
278 }
279
280 // If we recognize the root's opcode, we've successfully pattern-matched!
281 switch (root->opcode()) {
282 case HloOpcode::kAdd:
283 return "add";
284 case HloOpcode::kMultiply:
285 return "multiply";
286 case HloOpcode::kMinimum:
287 return "min";
288 case HloOpcode::kMaximum:
289 return "max";
290 case HloOpcode::kCompare: {
291 switch (root->comparison_direction()) {
292 case ComparisonDirection::kLe:
293 return "less-or-equal";
294 case ComparisonDirection::kGe:
295 return "greater-or-equal";
296 case ComparisonDirection::kGt:
297 return "greater-than";
298 case ComparisonDirection::kLt:
299 return "less-than";
300 case ComparisonDirection::kEq:
301 return "equal-to";
302 case ComparisonDirection::kNe:
303 return "not-equal-to";
304 }
305 }
306 default:
307 return nullopt;
308 }
309 }
310
311 // Encapsulates logic for dumping an HLO module to DOT (i.e. graphviz syntax).
312 class HloDotDumper {
313 public:
HloDotDumper(const HloComputation * computation,absl::string_view label,const DebugOptions & debug_options,HloRenderOptions hlo_render_options,const HloExecutionProfile * profile,NodeFilter filter)314 HloDotDumper(const HloComputation* computation, absl::string_view label,
315 const DebugOptions& debug_options,
316 HloRenderOptions hlo_render_options,
317 const HloExecutionProfile* profile, NodeFilter filter)
318 : computation_(computation),
319 label_(label),
320 debug_options_(debug_options),
321 hlo_render_options_(hlo_render_options),
322 profile_(profile),
323 filter_(std::move(filter)) {}
324
325 string Dump();
326
327 private:
328 // Returns the dot graph identifier for the given instruction.
InstructionId(const HloInstruction * instruction)329 string InstructionId(const HloInstruction* instruction) {
330 return StrCat(reinterpret_cast<uint64>(instruction));
331 }
332
333 // Returns the dot graph identifier for the given computation.
SubcomputationId(const HloComputation * computation)334 string SubcomputationId(const HloComputation* computation) {
335 return StrCat("cluster_", reinterpret_cast<uint64>(computation));
336 }
337
338 // Generates graph header/footer. These should be called *after* dumping all
339 // of the instructions and subcomputations for the graph, as they both use
340 // data generated while dumping the graph.
341 string Header();
342 string Footer();
343
344 bool ShouldShowSubcomputation(const HloComputation* subcomp);
345 bool ShouldShowFusionSubcomputation(const HloInstruction* instr);
346
347 // We omit some nodes from the graph, instead drawing them inlined into the
348 // nodes that use them.
349 bool ShouldMergeIntoUsers(const HloInstruction* instr) const;
350
351 string DumpSubcomputation(const HloComputation* subcomp,
352 const HloInstruction* parent_instr);
353 string DumpComputation(const HloComputation* comp);
354 string DumpRootTag();
355 string DumpInstruction(const HloInstruction* instr);
356 ColorScheme GetInstructionColor(const HloInstruction* instr);
357 string GetInstructionNodeShape(const HloInstruction* instr);
358 string GetInstructionNodeLabel(const HloInstruction* instr);
359 string GetInstructionNodeMetadata(const HloInstruction* instr);
360 string GetInstructionNodeBackendConfig(const HloInstruction* instr);
361 string GetInstructionNodeExtraInfo(const HloInstruction* instr);
362 string GetInstructionNodeInlinedOperands(const HloInstruction* instr);
363 void AddInstructionIncomingEdges(const HloInstruction* instr);
364
365 // For most instructions, GetNodeForEdge(instr) returns instr.
366 //
367 // The exception is fusion nodes. For these, we walk up the chain of nested
368 // fusion nodes starting at instr until we reach a node that either (a) isn't
369 // a fusion node, or (b) is a fusion node for which
370 // ShouldShowFusionSubcomputation is false.
371 //
372 // We do this because fusion nodes are expanded inline -- if
373 // ShouldShowFusionSubcomputation is true, the fusion node won't be present in
374 // the graph.
375 //
376 // In general when you want to draw an edge from A to B, you should actually
377 // draw an edge from GetNodeForEdge(A) to GetNodeForEdge(B).
378 const HloInstruction* GetNodeForEdge(const HloInstruction* instr);
379
380 // If instr has just one computation and it's trivial (e.g. "return param0 +
381 // param1"), returns a string you can put into the node's body that names the
382 // subcomputation, e.g. "Subcomputation: <b>add</b>".
383 string GetInstructionTrivialComputationStr(const HloInstruction* instr);
384
385 const HloComputation* computation_; // never null
386 const string label_; // overall name for the graph
387 const DebugOptions& debug_options_;
388 const HloRenderOptions hlo_render_options_;
389 const HloExecutionProfile* profile_; // may be null
390 const NodeFilter filter_;
391
392 // Each HloInstruction dumped gets a monotonically-increasing node ID. This
393 // must start at 1, because that's where graphviz's accounting starts.
394 int64 next_node_id_ = 1;
395 absl::flat_hash_map<const HloInstruction*, int64> node_ids_;
396
397 // The "root" tag doesn't have an associated HloInstruction pointer, so we
398 // need to store it outside the map.
399 int64 root_node_id_;
400
401 // Each (from, to) edge gets a monotonically-increasing ID. This is a
402 // multimap because it's possible for the same edge to appear multiple times
403 // in the graph (e.g. x^2 may be represented as mul(x, x)).
404 int64 next_edge_id_ = 1;
405 std::unordered_multimap<
406 std::pair<const HloInstruction*, const HloInstruction*>, int64,
407 tensorflow::hash<std::pair<const HloInstruction*, const HloInstruction*>>>
408 edge_ids_;
409
410 // Each HloComputation that's emitted gets a monotonically-increasing ID.
411 int64 next_cluster_id_ = 1;
412 absl::flat_hash_map<const HloComputation*, int64> cluster_ids_;
413
414 // Edges to print from Footer(). Edges come at the end because graphviz is
415 // unhappy if an edge from a subcomputation to a node in the outer computation
416 // appears before both the inner computation and the destination node are
417 // defined.
418 std::vector<string> edges_;
419
420 // When coloring by sharding information, we track the sharding string
421 // representation to color association, by round-robin the color schemes.
422 absl::flat_hash_map<HloSharding, ColorScheme, HloSharding::Hasher>
423 sharding_colors_;
424 int64 next_shard_color_ = 0;
425 };
426
Dump()427 string HloDotDumper::Dump() {
428 string body;
429 StrAppend(&body, DumpComputation(computation_));
430 StrAppend(&body, DumpRootTag());
431
432 // By contract, Header() and Footer() have to be called after we've dumped all
433 // our instructions, because they use state generated during that process.
434 string g = Header();
435 StrAppend(&g, body);
436 StrAppend(&g, Footer());
437 return g;
438 }
439
Header()440 string HloDotDumper::Header() {
441 constexpr char fmt[] = R"(digraph G {
442 rankdir = TB;
443 compound = true;
444 label = <<b>%s</b>>;
445 labelloc = t;
446 // Disable the tooltip. Interestingly, "" doesn't work!
447 tooltip = " ";
448 // DOT graphs accept a stylesheet as a URI. So naturally, an inline
449 // stylesheet is a data URI!
450 stylesheet=<
451 data:text/css,
452 @import url(https://fonts.googleapis.com/css?family=Roboto:400,700);
453 svg text {
454 font-family: 'Roboto';
455 font-size: 12px;
456 }
457
458 %s
459 >
460
461 )";
462
463 VLOG(3) << "Generating Header";
464
465 string graph_label =
466 StrCat(label_, "<br/>Computation ", computation_->name());
467 if (computation_->IsFusionComputation()) {
468 StrAppend(&graph_label, " (in fusion instruction ",
469 computation_->FusionInstruction()->name(), ")");
470 }
471 if (profile_ != nullptr) {
472 auto cycles = profile_->total_cycles_executed(*computation_);
473 absl::StrAppendFormat(&graph_label, "<br/>total cycles = %d (%s)", cycles,
474 tensorflow::strings::HumanReadableNum(cycles));
475 }
476
477 // Create CSS rules that say, when you hover over the given node or cluster,
478 // turn the given edge the given color.
479 //
480 // We rely on a few properties of how graphviz generates SVGs:
481 //
482 // - Nodes are named "nodeN", where N corresponds to the 1-based index of
483 // the node in our DOT (i.e. the first node in the DOT is "node1", etc.).
484 // Edges are similarly named "edgeN", and clusters are named "clustN".
485 // - Nodes come before their in- and out-edges in the SVG. We need this
486 // because the "X ~ Y" CSS selector finds a sibling of X that *comes
487 // after X in the DOM* and matches Y.
488 std::vector<string> edge_css_rules;
489 const char* kBlue = "#1976d2";
490 const char* kRed = "#d32f2f";
491 for (const auto& kv : edge_ids_) {
492 const HloInstruction* from_node = kv.first.first;
493 const HloInstruction* to_node = kv.first.second;
494 int64_t edge_id = kv.second;
495
496 auto add_hover_css_rule = [&](string elem_type, int64_t elem_id,
497 const char* color) {
498 // One could imagine other ways of writing this CSS rule that involve
499 // less duplication, but this way seems to be relatively performant.
500 edge_css_rules.push_back(
501 StrFormat(" #%s%d:hover ~ #edge%d text { fill: %s; }\n"
502 " #%s%d:hover ~ #edge%d path { "
503 "stroke: %s; stroke-width: .2em; }\n"
504 " #%s%d:hover ~ #edge%d polygon { "
505 "fill: %s; stroke: %s; stroke-width: .2em; }\n",
506 elem_type, elem_id, edge_id, color, //
507 elem_type, elem_id, edge_id, color, //
508 elem_type, elem_id, edge_id, color, color));
509 };
510
511 // The "to_node" value may be a NULL, indicating that this points to the
512 // "root" tag rather than a normal node.
513 int64_t from_node_id =
514 tensorflow::gtl::FindWithDefault(node_ids_, from_node, -1);
515 if (from_node_id == -1) {
516 LOG(FATAL) << from_node->name() << " was added to edges but not to nodes";
517 }
518 int64_t to_node_id =
519 to_node ? tensorflow::gtl::FindWithDefault(node_ids_, to_node, -1)
520 : root_node_id_;
521 if (to_node != nullptr && to_node_id == -1) {
522 LOG(FATAL) << to_node->name() << " was added to edges but not to nodes";
523 }
524
525 add_hover_css_rule("node", from_node_id, kBlue);
526 add_hover_css_rule("node", to_node_id, kRed);
527
528 if (to_node) {
529 VLOG(3) << "Adding css for edge " << edge_id << " from node "
530 << from_node->name() << " to node " << to_node->name();
531 } else {
532 VLOG(3) << "Adding css for edge " << edge_id << " from node "
533 << from_node->name() << " to root tag";
534 }
535
536 // If this edge crosses a fusion cluster boundary, highlight it when the
537 // cluster is hovered over.
538 if (to_node) {
539 if (from_node->IsFused() &&
540 from_node->parent()->root_instruction() == from_node) {
541 int64_t cluster_id = cluster_ids_.at(from_node->parent());
542 add_hover_css_rule("clust", cluster_id, kBlue);
543 }
544 if (to_node->IsFused() && to_node->opcode() == HloOpcode::kParameter) {
545 int64_t cluster_id = cluster_ids_.at(to_node->parent());
546 add_hover_css_rule("clust", cluster_id, kRed);
547 }
548 }
549 }
550
551 // Browsers require that we URI-encode the contents of our data URI. (It
552 // seems this was a relatively recent change?) In practice, this means that we
553 // need to escape '#'.
554 return StrFormat(
555 fmt, graph_label,
556 absl::StrReplaceAll(StrJoin(edge_css_rules, "\n"), {{"#", "%23"}}));
557 }
558
Footer()559 string HloDotDumper::Footer() { return StrCat(StrJoin(edges_, "\n"), "\n}"); }
560
ShouldShowFusionSubcomputation(const HloInstruction * instr)561 bool HloDotDumper::ShouldShowFusionSubcomputation(const HloInstruction* instr) {
562 CHECK_EQ(instr->opcode(), HloOpcode::kFusion);
563 return ShouldShowSubcomputation(instr->fused_instructions_computation());
564 }
565
ShouldShowSubcomputation(const HloComputation * subcomp)566 bool HloDotDumper::ShouldShowSubcomputation(const HloComputation* subcomp) {
567 if (subcomp->IsFusionComputation()) {
568 const HloInstruction* fusion = subcomp->FusionInstruction();
569 if (!filter_.Show(fusion) || filter_.SomeOrAllOperandsOmitted(fusion) ||
570 !hlo_render_options_.show_fusion_subcomputations) {
571 return false;
572 }
573 }
574
575 // Don't show trivial subcomputations on non-fusion nodes -- these are inlined
576 // into the graph.
577 if (!subcomp->IsFusionComputation() && MatchTrivialComputation(subcomp)) {
578 return false;
579 }
580
581 // Show the subcomputation if we're showing any of its members.
582 return absl::c_any_of(
583 subcomp->instructions(),
584 [&](const HloInstruction* instr) { return filter_.Show(instr); });
585 }
586
DumpSubcomputation(const HloComputation * subcomp,const HloInstruction * parent_instr)587 string HloDotDumper::DumpSubcomputation(const HloComputation* subcomp,
588 const HloInstruction* parent_instr) {
589 VLOG(2) << "Dumping subcomputation " << subcomp->name();
590 // Add an edge from the subcomputation to its parent node. If subcomp
591 // belongs to a fusion node, it's drawn in place of the fusion instruction,
592 // so there's no need to link those.
593 if (parent_instr->opcode() != HloOpcode::kFusion) {
594 const HloInstruction* from = GetNodeForEdge(subcomp->root_instruction());
595 VLOG(2) << "Edge: from " << from->name() << " to " << parent_instr->name()
596 << " as " << next_edge_id_;
597 edge_ids_.insert({{from, parent_instr}, next_edge_id_++});
598 constexpr char edge_fmt[] =
599 R"(%s -> %s [ltail="%s", style="dashed" tooltip="%s -> %s"];)";
600 edges_.push_back(StrFormat(
601 edge_fmt, InstructionId(from), InstructionId(parent_instr),
602 SubcomputationId(subcomp), subcomp->name(), parent_instr->name()));
603 }
604
605 // Have we already dumped this subcomputation? If so, generating the edge
606 // linking it and parent_instr is all we want to do in this function.
607 if (cluster_ids_.find(subcomp) != cluster_ids_.end()) {
608 return "";
609 }
610
611 cluster_ids_[subcomp] = next_cluster_id_++;
612
613 string id = SubcomputationId(subcomp);
614
615 string subcomp_label, style;
616 if (parent_instr->opcode() == HloOpcode::kFusion) {
617 subcomp_label =
618 StrFormat("Fused expression for <b>%s</b><br/>%s",
619 HtmlLikeStringSanitize(parent_instr->name()),
620 HtmlLikeStringSanitize(parent_instr->ToCategory()));
621 string extra_info = GetInstructionNodeExtraInfo(parent_instr);
622 if (!extra_info.empty()) {
623 StrAppend(&subcomp_label, "<br/>", extra_info);
624 }
625 string node_backend_config = GetInstructionNodeBackendConfig(parent_instr);
626 if (!node_backend_config.empty()) {
627 StrAppend(&subcomp_label, "<br/>", node_backend_config);
628 }
629
630 bool highlight = filter_.Highlight(parent_instr);
631 const char* fillcolor;
632 const char* strokecolor;
633 if (debug_options_.xla_hlo_graph_sharding_color() && !highlight) {
634 // Use the sharding color, if the node isn't highlighted.
635 NodeColors node_colors =
636 NodeColorsForScheme(GetInstructionColor(parent_instr));
637 fillcolor = node_colors.fill_color;
638 strokecolor = node_colors.stroke_color;
639 } else {
640 // Subcomputation's fill/stroke color is light/dark red/gray, depending on
641 // whether or not the subcomputation's fusion node is highlighted.
642 fillcolor = highlight ? "#ffcdd2" : "#f5f5f5";
643 strokecolor = highlight ? "#b71c1c" : "#c2c2c2";
644 }
645 style =
646 StrFormat(R"(style="rounded,filled,bold"; fillcolor="%s"; color="%s;")",
647 fillcolor, strokecolor);
648 } else {
649 subcomp_label = StrFormat("Subcomputation for <b>%s</b><br/>%s",
650 HtmlLikeStringSanitize(parent_instr->name()),
651 HtmlLikeStringSanitize(subcomp->name()));
652 style = "style=rounded; color=black;";
653 }
654
655 string comp_body = DumpComputation(subcomp);
656
657 constexpr char computation_fmt[] = R"(subgraph %s {
658 %s
659 label = <%s>;
660 labelloc = t;
661 tooltip = " ";
662 %s
663 } // %s
664
665 )";
666 return StrFormat(computation_fmt, id, style, subcomp_label, comp_body, id);
667 }
668
DumpComputation(const HloComputation * comp)669 string HloDotDumper::DumpComputation(const HloComputation* comp) {
670 string g;
671 for (const auto* instr : comp->instructions()) {
672 if (!filter_.Show(instr)) {
673 continue;
674 }
675
676 // Dump subcomputations within instr.
677 for (const HloComputation* subcomp : instr->called_computations()) {
678 if (ShouldShowSubcomputation(subcomp)) {
679 StrAppend(&g, DumpSubcomputation(subcomp, instr));
680 }
681 }
682
683 StrAppend(&g, DumpInstruction(instr));
684 }
685 return g;
686 }
687
DumpRootTag()688 string HloDotDumper::DumpRootTag() {
689 const HloInstruction* from = GetNodeForEdge(computation_->root_instruction());
690
691 // We didn't display constants or broadcasts of effective scalars within
692 // fusions as separate nodes; so if the root is a constant/broadcast of
693 // scalar, we don't add root tag or edge for it.
694 if (!filter_.Show(from) || from->opcode() == HloOpcode::kConstant ||
695 IsFusedBroadcastOfConstantEffectiveScalar(from)) {
696 return "";
697 }
698
699 auto from_id = InstructionId(from);
700
701 // The ID of the root computation is otherwise unused, so it makes a good ID
702 // to use for the root-tag node. However, the edge_ids_ map requires a
703 // HloInstruction* pointer for the 'to' value, so we use a NULL value there
704 // (rather than a pointer type-cast) to make it obvious if it is erroneously
705 // dereferenced.
706 HloInstruction* to = nullptr;
707 auto to_id = SubcomputationId(computation_);
708
709 string node_body = "ROOT";
710 string node_shape = "circle";
711 ColorScheme color = kBrown;
712
713 VLOG(2) << "Adding root tag as node " << next_node_id_;
714 root_node_id_ = next_node_id_++;
715
716 VLOG(2) << "Adding edge from " << from->name() << " to root tag as "
717 << next_edge_id_;
718 edge_ids_.insert({{from, to}, next_edge_id_++});
719 edges_.push_back(StrFormat(R"(%s -> %s [tooltip=" "];)", from_id, to_id));
720
721 return StrFormat(R"(%s [label=<%s>, shape=%s, tooltip=" ", %s];)"
722 "\n",
723 to_id, node_body, node_shape, NodeColorAttributes(color));
724 }
725
TryGetFusionParameterConstant(const HloInstruction * instr)726 static const HloConstantInstruction* TryGetFusionParameterConstant(
727 const HloInstruction* instr) {
728 if (instr->opcode() != HloOpcode::kParameter || !instr->IsFused()) {
729 return nullptr;
730 }
731 const HloInstruction* fusion = instr->parent()->FusionInstruction();
732 const HloInstruction* operand = fusion->operand(instr->parameter_number());
733 return DynCast<HloConstantInstruction>(operand);
734 }
735
ShouldMergeIntoUsers(const HloInstruction * instr) const736 bool HloDotDumper::ShouldMergeIntoUsers(const HloInstruction* instr) const {
737 // If a node:
738 //
739 // - is a parameter of a fusion node which is bound to a constant,
740 //
741 // or
742 //
743 // - is a tuple-shaped parameter, and
744 // - is not a parameter to a fusion node, and
745 // - has at least kMinUsersToOmit users shown, and
746 // - all of the shown users are get-tuple-elements,
747 //
748 // then we omit it from the graph, merging it with its users.
749 //
750 // This helps us handle the common case where a while loop body has one big
751 // tuple-shaped parameter.
752 if (TryGetFusionParameterConstant(instr) != nullptr) {
753 return true;
754 }
755 const int kMinUsersToOmit = 3;
756 return instr->opcode() == HloOpcode::kParameter && instr->shape().IsTuple() &&
757 !instr->IsFused() &&
758 absl::c_count_if(instr->users(),
759 [&](const HloInstruction* user) {
760 return filter_.Show(user);
761 }) > kMinUsersToOmit &&
762 absl::c_all_of(instr->users(), [&](const HloInstruction* user) {
763 return !filter_.Show(user) ||
764 user->opcode() == HloOpcode::kGetTupleElement;
765 });
766 }
767
DumpInstruction(const HloInstruction * instr)768 string HloDotDumper::DumpInstruction(const HloInstruction* instr) {
769 // We don't display constants or broadcasts of effective scalar constants
770 // within fusions as separate nodes; they're merged into their users.
771 if (instr->opcode() == HloOpcode::kConstant ||
772 IsFusedBroadcastOfConstantEffectiveScalar(instr)) {
773 return "";
774 }
775 // Skip this node if it's merged into its users.
776 if (ShouldMergeIntoUsers(instr)) {
777 return "";
778 }
779 // Omit the fusion node if its subcomputation is drawn, since the
780 // subcomputation will be drawn inline.
781 if (instr->opcode() == HloOpcode::kFusion &&
782 ShouldShowFusionSubcomputation(instr)) {
783 return "";
784 }
785
786 VLOG(2) << "Adding node " << instr->name() << " as " << next_node_id_;
787 node_ids_[instr] = next_node_id_++;
788
789 ColorScheme color = GetInstructionColor(instr);
790 string node_shape = GetInstructionNodeShape(instr);
791 string node_label = GetInstructionNodeLabel(instr);
792 string node_metadata = GetInstructionNodeMetadata(instr);
793 string node_backend_config = GetInstructionNodeBackendConfig(instr);
794 string extra_info = GetInstructionNodeExtraInfo(instr);
795 string inlined_constants = GetInstructionNodeInlinedOperands(instr);
796 string trivial_subcomputation = GetInstructionTrivialComputationStr(instr);
797 AddInstructionIncomingEdges(instr);
798
799 if (!debug_options_.xla_hlo_graph_sharding_color()) {
800 // Override the node's styling if it should be (de-)emphasized.
801 if (filter_.Deemphasized(instr)) {
802 color = kDashedBorder;
803 }
804 if (filter_.Highlight(instr)) {
805 node_shape = "diamond";
806 color = kDarkRed;
807 }
808 }
809 // Build the text that will be displayed inside the node.
810 string node_body = node_label;
811 for (const string& s : {trivial_subcomputation, node_backend_config,
812 extra_info, inlined_constants}) {
813 if (!s.empty()) {
814 StrAppend(&node_body, "<br/>", s);
815 }
816 }
817
818 return StrFormat(R"(%s [label=<%s>, shape=%s, tooltip="%s", %s];)"
819 "\n",
820 InstructionId(instr), node_body, node_shape, node_metadata,
821 NodeColorAttributes(color));
822 }
823
GetInstructionNodeInlinedOperands(const HloInstruction * instr)824 string HloDotDumper::GetInstructionNodeInlinedOperands(
825 const HloInstruction* instr) {
826 // The constant's shape is a parameter because, in the case of a broadcasted
827 // scalar constant, we want to show the broadcasted shape, not the constant's
828 // scalar shape.
829 auto stringify_constant = [](const HloConstantInstruction* constant,
830 const Shape& shape) {
831 // If the shape has a dimension of size zero, print it as e.g.
832 // "{} (f32[42, 0, 10])". The alternative, calling Literal::ToString(),
833 // enumerates all of its empty dimensions (e.g. "{ { {}, {} }, ..."), which
834 // is just noise.
835 if (ShapeUtil::IsZeroElementArray(shape)) {
836 return StrFormat("{} (%s)", ShapeUtil::HumanString(constant->shape()));
837 }
838
839 // Print the literal value of constants with <= K elements. Note that we
840 // use `constant->shape()` rather than `shape`, because if `constant` is a
841 // scalar that's broadcasted into `shape`, we want to print the constant.
842 optional<int64> elem_count;
843 if (shape.IsArray()) {
844 elem_count = ShapeUtil::ElementsIn(constant->shape());
845 }
846 // Allow HloDotDumper to print HloInstruction reconstructed from HloProto
847 // collected from profiling tools. Those constants may not have a valid
848 // literal.
849 if (elem_count.has_value() && *elem_count <= 8 && constant->HasLiteral()) {
850 return StrFormat("%s %s", shape.ToString(),
851 constant->literal().ToStringWithoutShape());
852 }
853
854 // Otherwise, print e.g. "%constant.42 (s32[100])".
855 string constant_name;
856 if (absl::StartsWith(constant->name(), "constant")) {
857 constant_name = constant->name();
858 } else {
859 constant_name = StrCat("constant ", constant->name());
860 }
861 return StrFormat("%s %s", constant_name, ShapeUtil::HumanString(shape));
862 };
863
864 std::vector<string> lines;
865 for (int64_t i = 0; i < instr->operand_count(); ++i) {
866 const HloInstruction* operand = instr->operand(i);
867 optional<string> operand_str;
868 if (const auto* constant_operand =
869 DynCast<HloConstantInstruction>(operand)) {
870 operand_str =
871 stringify_constant(constant_operand, constant_operand->shape());
872 } else if (IsFusedBroadcastOfConstantEffectiveScalar(operand)) {
873 operand_str = stringify_constant(
874 Cast<HloConstantInstruction>(operand->operand(0)), operand->shape());
875 } else if (ShouldMergeIntoUsers(operand)) {
876 // Special case: If the operand is a parameter to a fusion node and it
877 // always has a constant value, display it like a regular constant.
878 //
879 // For other parameters, use the parameter number rather than the proper
880 // name, because that's generally how people think of the node.
881 if (operand->opcode() == HloOpcode::kParameter) {
882 if (const HloConstantInstruction* constant =
883 TryGetFusionParameterConstant(operand)) {
884 operand_str = stringify_constant(constant, constant->shape());
885 } else {
886 operand_str = StrFormat("Parameter %d", operand->parameter_number());
887 }
888 } else {
889 operand_str = operand->name();
890 }
891 }
892
893 if (operand_str) {
894 if (instr->operand_count() > 1) {
895 lines.push_back(StrFormat("<b>operand %d</b> = %s", i, *operand_str));
896 } else {
897 lines.push_back(StrFormat("<b>operand</b> = %s", *operand_str));
898 }
899 }
900 }
901 return StrJoin(lines, "<br/>");
902 }
903
GetInstructionColor(const HloInstruction * instr)904 ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) {
905 if (debug_options_.xla_hlo_graph_sharding_color()) {
906 if (!instr->has_sharding()) {
907 return kDashedBorder;
908 }
909 auto it = sharding_colors_.find(instr->sharding());
910 if (it != sharding_colors_.end()) {
911 return it->second;
912 }
913 ColorScheme color = static_cast<ColorScheme>(
914 kBlue + (next_shard_color_++ % (kDashedBorder - kBlue)));
915 sharding_colors_.emplace(instr->sharding(), color);
916 return color;
917 }
918
919 // Choose different weights of orange for small vs large parameters. This
920 // distinction is often important, especially in fusion nodes.
921 auto parameter_color = IsSmall(instr) ? kOrange : kDarkOrange;
922
923 // Special case: If this instruction has a parameter merged into it, paint it
924 // the same color as a parameter. Unless the merged-in parameter is a
925 // parameter to a fusion node that is bound to a constant -- these aren't
926 // "real" parameters from the user's perspective.
927 if (absl::c_any_of(instr->operands(), [&](const HloInstruction* operand) {
928 return operand->opcode() == HloOpcode::kParameter &&
929 ShouldMergeIntoUsers(operand) &&
930 TryGetFusionParameterConstant(operand) == nullptr;
931 })) {
932 return parameter_color;
933 }
934
935 // Pick different colors or shapes for instructions which are particularly
936 // expensive (eg, dot) and those which are unusual in some way or unique
937 // (eg, parameter).
938 switch (instr->opcode()) {
939 case HloOpcode::kAbs:
940 case HloOpcode::kAdd:
941 case HloOpcode::kAnd:
942 case HloOpcode::kAtan2:
943 case HloOpcode::kBitcastConvert:
944 case HloOpcode::kCeil:
945 case HloOpcode::kClamp:
946 case HloOpcode::kClz:
947 case HloOpcode::kCompare:
948 case HloOpcode::kComplex:
949 case HloOpcode::kConvert:
950 case HloOpcode::kCos:
951 case HloOpcode::kDivide:
952 case HloOpcode::kExp:
953 case HloOpcode::kExpm1:
954 case HloOpcode::kFloor:
955 case HloOpcode::kImag:
956 case HloOpcode::kIota:
957 case HloOpcode::kIsFinite:
958 case HloOpcode::kLog:
959 case HloOpcode::kLog1p:
960 case HloOpcode::kMaximum:
961 case HloOpcode::kMinimum:
962 case HloOpcode::kMultiply:
963 case HloOpcode::kNegate:
964 case HloOpcode::kNot:
965 case HloOpcode::kPopulationCount:
966 case HloOpcode::kOr:
967 case HloOpcode::kXor:
968 case HloOpcode::kPower:
969 case HloOpcode::kReal:
970 case HloOpcode::kRemainder:
971 case HloOpcode::kRng:
972 case HloOpcode::kRngGetAndUpdateState:
973 case HloOpcode::kRngBitGenerator:
974 case HloOpcode::kRoundNearestAfz:
975 case HloOpcode::kRsqrt:
976 case HloOpcode::kSelect:
977 case HloOpcode::kShiftLeft:
978 case HloOpcode::kShiftRightArithmetic:
979 case HloOpcode::kShiftRightLogical:
980 case HloOpcode::kLogistic:
981 case HloOpcode::kSign:
982 case HloOpcode::kSin:
983 case HloOpcode::kSlice:
984 case HloOpcode::kSort:
985 case HloOpcode::kSqrt:
986 case HloOpcode::kCbrt:
987 case HloOpcode::kSubtract:
988 case HloOpcode::kTanh:
989 // De-emphasize scalar-shaped elementwise ops -- they're generally
990 // uninteresting.
991 if (ShapeUtil::IsEffectiveScalar(instr->shape())) {
992 return kWhite;
993 }
994 return kYellow;
995 case HloOpcode::kBitcast:
996 case HloOpcode::kGetTupleElement:
997 case HloOpcode::kTrace:
998 case HloOpcode::kAfterAll:
999 case HloOpcode::kAddDependency:
1000 case HloOpcode::kTuple:
1001 return kWhite;
1002 case HloOpcode::kBroadcast:
1003 // De-emphasize nodes which broadcast a scalar within a fusion node --
1004 // these are essentially free.
1005 if (instr->IsFused() &&
1006 ShapeUtil::IsEffectiveScalar(instr->operand(0)->shape())) {
1007 return kWhite;
1008 }
1009 return kGreen;
1010 case HloOpcode::kConcatenate:
1011 case HloOpcode::kDynamicSlice:
1012 case HloOpcode::kGather:
1013 case HloOpcode::kPad:
1014 case HloOpcode::kReshape:
1015 case HloOpcode::kDynamicReshape:
1016 case HloOpcode::kReverse:
1017 case HloOpcode::kTupleSelect:
1018 case HloOpcode::kTranspose:
1019 // De-emphasize scalar-shaped data movement ops and all data movement ops
1020 // inside fusion nodes, both of which are essentially free.
1021 if (ShapeUtil::IsEffectiveScalar(instr->shape()) || instr->IsFused()) {
1022 return kWhite;
1023 }
1024 return kGreen;
1025 case HloOpcode::kDynamicUpdateSlice:
1026 // Unlike the data-movement ops above, dynamic-update-slice is not ~free
1027 // inside of fusion nodes, so we de-emphasize it only if it's
1028 // scalar-shaped.
1029 if (ShapeUtil::IsEffectiveScalar(instr->shape())) {
1030 return kWhite;
1031 }
1032 return kGreen;
1033 case HloOpcode::kScatter:
1034 // Do not de-emphasize Scatter, since it involves significant work.
1035 case HloOpcode::kCopy:
1036 case HloOpcode::kCopyStart:
1037 case HloOpcode::kCopyDone:
1038 // Emphasize copy nodes, which are either physical transposes (and thus
1039 // significant), or copies of read-only buffers (and thus dead weight).
1040 return kGreen;
1041 case HloOpcode::kConvolution:
1042 case HloOpcode::kDot:
1043 case HloOpcode::kFft:
1044 case HloOpcode::kTriangularSolve:
1045 case HloOpcode::kCholesky:
1046 return kDarkBlue;
1047 case HloOpcode::kReducePrecision:
1048 return kRed;
1049 case HloOpcode::kParameter:
1050 return parameter_color;
1051 case HloOpcode::kBatchNormGrad:
1052 case HloOpcode::kBatchNormInference:
1053 case HloOpcode::kBatchNormTraining:
1054 case HloOpcode::kReduce:
1055 case HloOpcode::kReduceWindow:
1056 case HloOpcode::kSelectAndScatter:
1057 return kPurple;
1058 case HloOpcode::kDomain:
1059 case HloOpcode::kFusion:
1060 case HloOpcode::kMap:
1061 case HloOpcode::kGetDimensionSize:
1062 case HloOpcode::kSetDimensionSize:
1063 return kGray;
1064 case HloOpcode::kAllGather:
1065 case HloOpcode::kAllGatherStart:
1066 case HloOpcode::kAllGatherDone:
1067 case HloOpcode::kAllReduce:
1068 case HloOpcode::kReduceScatter:
1069 case HloOpcode::kAllReduceStart:
1070 case HloOpcode::kAllReduceDone:
1071 case HloOpcode::kAllToAll:
1072 case HloOpcode::kCollectivePermute:
1073 case HloOpcode::kCollectivePermuteStart:
1074 case HloOpcode::kCollectivePermuteDone:
1075 case HloOpcode::kInfeed:
1076 case HloOpcode::kOutfeed:
1077 case HloOpcode::kPartitionId:
1078 case HloOpcode::kRecv:
1079 case HloOpcode::kRecvDone:
1080 case HloOpcode::kSend:
1081 case HloOpcode::kSendDone:
1082 case HloOpcode::kReplicaId:
1083 return kBrown;
1084 case HloOpcode::kCall:
1085 case HloOpcode::kConditional:
1086 case HloOpcode::kCustomCall:
1087 case HloOpcode::kWhile:
1088 return kDarkGreen;
1089 case HloOpcode::kConstant:
1090 LOG(FATAL) << "Constants don't get their own nodes in the graph.";
1091 }
1092 }
1093
GetInstructionNodeShape(const HloInstruction * instr)1094 string HloDotDumper::GetInstructionNodeShape(const HloInstruction* instr) {
1095 // Give while loops a different shape so they're easier to pick out.
1096 switch (instr->opcode()) {
1097 case HloOpcode::kWhile:
1098 return "ellipse";
1099 default:
1100 return "rect";
1101 }
1102 }
1103
GetInstructionNodeLabel(const HloInstruction * instr)1104 string HloDotDumper::GetInstructionNodeLabel(const HloInstruction* instr) {
1105 // If we have a parameter, put the param number in the name.
1106 if (instr->opcode() == HloOpcode::kParameter) {
1107 return StrFormat("<b>Parameter %d</b>", instr->parameter_number());
1108 }
1109
1110 // The HLO instruction name contains usually the opcode, e.g. "%add.42" is
1111 // an add instruction. In this case we render just the name.
1112 if (absl::StartsWith(instr->name(), HloOpcodeString(instr->opcode()))) {
1113 return StrFormat("<b>%s</b>", HtmlLikeStringSanitize(instr->name()));
1114 }
1115 string extended_opcode =
1116 StrCat(HloOpcodeString(instr->opcode()),
1117 instr->opcode() != HloOpcode::kFusion
1118 ? ""
1119 : StrCat(":", xla::ToString(instr->fusion_kind())));
1120 // If the name does not contain the opcode, render both.
1121 return StrFormat("<b>%s</b><br/>%s", HtmlLikeStringSanitize(extended_opcode),
1122 HtmlLikeStringSanitize(instr->name()));
1123 }
1124
GetInstructionNodeMetadata(const HloInstruction * instr)1125 string HloDotDumper::GetInstructionNodeMetadata(const HloInstruction* instr) {
1126 std::vector<string> lines;
1127 if (!instr->metadata().op_name().empty()) {
1128 lines.push_back(HtmlLikeStringSanitize(instr->metadata().op_name()));
1129 }
1130 if (!instr->metadata().op_type().empty()) {
1131 lines.push_back(StrFormat(
1132 "op_type: %s", HtmlLikeStringSanitize(instr->metadata().op_type())));
1133 }
1134 if (!instr->metadata().source_file().empty() &&
1135 instr->metadata().source_line() != 0) {
1136 lines.push_back(StrFormat("source: %s:%d", instr->metadata().source_file(),
1137 instr->metadata().source_line()));
1138 }
1139
1140 return StrJoin(lines, "\n");
1141 }
1142
GetInstructionNodeBackendConfig(const HloInstruction * instr)1143 string HloDotDumper::GetInstructionNodeBackendConfig(
1144 const HloInstruction* instr) {
1145 if (!hlo_render_options_.show_backend_config ||
1146 instr->raw_backend_config_string().empty()) {
1147 return "";
1148 }
1149
1150 return StrCat("backend_config=\"", instr->raw_backend_config_string(), "\"");
1151 }
1152
GetInstructionNodeExtraInfo(const HloInstruction * instr)1153 string HloDotDumper::GetInstructionNodeExtraInfo(const HloInstruction* instr) {
1154 std::vector<string> lines;
1155
1156 // Get the instruction's extra attributes excluding the names of its
1157 // subcomputations, since those are drawn explicitly in the graph.
1158 for (const auto& line : instr->ExtraAttributesToString(
1159 HloPrintOptions().set_print_subcomputation_mode(
1160 HloPrintOptions::PrintSubcomputationMode::kOff))) {
1161 // Some instructions have giant replica group fields, so truncate the
1162 // replica group line length to 128.
1163 constexpr int kMaxReplicaGroupLen = 128;
1164 if (absl::StartsWith(line, "replica_groups=") &&
1165 line.length() > kMaxReplicaGroupLen) {
1166 lines.push_back(HtmlLikeStringSanitize(
1167 StrCat(line.substr(0, kMaxReplicaGroupLen - 3), "...")));
1168 } else {
1169 lines.push_back(HtmlLikeStringSanitize(line));
1170 }
1171 }
1172
1173 // Show the shape and layout of the instruction, unless it's an inlined fusion
1174 // node -- there the shape and layout is present in the output node.
1175 if (instr->opcode() != HloOpcode::kFusion ||
1176 !ShouldShowFusionSubcomputation(instr)) {
1177 // Show layout of instructions with more than one dimension. Don't show
1178 // layout on tuples or tensors with just one dimension (which only have one
1179 // possible layout) to avoid visual noise.
1180 bool shape_is_multidim = false;
1181 ShapeUtil::ForEachSubshape(instr->shape(),
1182 [&](const Shape& s, const ShapeIndex&) {
1183 shape_is_multidim |= s.dimensions_size() > 1;
1184 });
1185 string instr_shape;
1186 if (instr->opcode() != HloOpcode::kTuple && shape_is_multidim) {
1187 instr_shape = ShapeUtil::HumanStringWithLayout(instr->shape());
1188 } else {
1189 instr_shape = ShapeUtil::HumanString(instr->shape());
1190 }
1191
1192 // Some instructions have giant tuples as their shapes, so truncate the
1193 // HLO's shape to kMaxShapeLen characters.
1194 constexpr int kMaxShapeLen = 64;
1195 if (instr_shape.length() > kMaxShapeLen) {
1196 instr_shape = StrCat(
1197 absl::string_view(instr_shape).substr(0, kMaxShapeLen - 3), "...");
1198 }
1199 lines.push_back(HtmlLikeStringSanitize(instr_shape));
1200 }
1201 if (debug_options_.xla_hlo_graph_addresses()) {
1202 lines.push_back(StrFormat("[%p]", instr));
1203 }
1204 if (profile_ != nullptr) {
1205 double hlo_cycles_executed = profile_->GetCyclesTakenBy(*instr);
1206 double total_cycles_executed =
1207 profile_->total_cycles_executed(*instr->parent());
1208 if (hlo_cycles_executed > 0 && total_cycles_executed > 0) {
1209 lines.push_back(
1210 StrFormat("%% of cycles executed=%.2f",
1211 100 * hlo_cycles_executed / total_cycles_executed));
1212 }
1213 }
1214 return StrJoin(lines, "<br/>");
1215 }
1216
AddInstructionIncomingEdges(const HloInstruction * instr)1217 void HloDotDumper::AddInstructionIncomingEdges(const HloInstruction* instr) {
1218 auto add_edge = [&](const HloInstruction* from, const HloInstruction* to,
1219 int64_t operand_num, bool control_edge = false) {
1220 from = GetNodeForEdge(from);
1221
1222 if (!filter_.Show(from) || from->opcode() == HloOpcode::kConstant ||
1223 IsFusedBroadcastOfConstantEffectiveScalar(from) ||
1224 ShouldMergeIntoUsers(from)) {
1225 return;
1226 }
1227 VLOG(2) << "Adding edge from " << from->name() << " to " << to->name()
1228 << " as " << next_edge_id_;
1229 edge_ids_.insert({{from, to}, next_edge_id_++});
1230
1231 string edge_label;
1232 if (instr->operand_count() > 1 && !control_edge) {
1233 edge_label =
1234 StrFormat(R"( headlabel="%d", labeldistance=2)", operand_num);
1235 } else if (control_edge) {
1236 edge_label = "style=\"dotted\" color=\"gray\" label=\"ctrl\"";
1237 }
1238
1239 // We print "small" arrays using a hollow arrowhead and "large" arrays using
1240 // a filled arrowhead.
1241 constexpr char kEdgeFmt[] =
1242 R"(%s -> %s [arrowhead=%s tooltip="%s -> %s" %s];)";
1243 edges_.push_back(StrFormat(kEdgeFmt, InstructionId(from), InstructionId(to),
1244 (IsSmall(from) ? "empty" : "normal"),
1245 from->name(), to->name(), edge_label));
1246 };
1247
1248 // Add edges from instr's operands to instr. Parameters within fusion
1249 // expressions are handled specially -- we draw an edge from the corresponding
1250 // operand on the fusion node itself to the parameter.
1251 if (instr->opcode() == HloOpcode::kParameter && instr->IsFused()) {
1252 // Only add the edge if this is not the outermost computation; otherwise it
1253 // will lead from a node we're not drawing.
1254 if (instr->parent() != computation_) {
1255 const HloInstruction* fusion = instr->parent()->FusionInstruction();
1256 add_edge(fusion->operand(instr->parameter_number()), instr,
1257 /*operand_num=*/0);
1258 }
1259 } else {
1260 for (int64_t i = 0; i < instr->operand_count(); ++i) {
1261 add_edge(instr->operand(i), instr, i);
1262 }
1263 for (const HloInstruction* pred : instr->control_predecessors()) {
1264 add_edge(pred, instr, /*operand_num=*/0, /*control_edge=*/true);
1265 }
1266 }
1267 }
1268
GetInstructionTrivialComputationStr(const HloInstruction * instr)1269 string HloDotDumper::GetInstructionTrivialComputationStr(
1270 const HloInstruction* instr) {
1271 // called_computations() on a fusion node "inherits" any called computations
1272 // of the fused root, which isn't what we want. Just ignore fusion nodes
1273 // here; they're handled separately.
1274 if (instr->opcode() == HloOpcode::kFusion) {
1275 return "";
1276 }
1277
1278 std::vector<string> lines;
1279 for (int64_t i = 0; i < instr->called_computations().size(); ++i) {
1280 optional<string> computation_type =
1281 MatchTrivialComputation(instr->called_computations()[i]);
1282 if (!computation_type) {
1283 continue;
1284 }
1285 if (instr->called_computations().size() == 1) {
1286 lines.push_back(StrFormat("Subcomputation: <b>%s</b>",
1287 HtmlLikeStringSanitize(*computation_type)));
1288 } else {
1289 lines.push_back(StrFormat("Subcomputation %d: <b>%s</b>", i,
1290 HtmlLikeStringSanitize(*computation_type)));
1291 }
1292 }
1293 return StrJoin(lines, "<br/>");
1294 }
1295
GetNodeForEdge(const HloInstruction * instr)1296 const HloInstruction* HloDotDumper::GetNodeForEdge(
1297 const HloInstruction* instr) {
1298 while (instr->opcode() == HloOpcode::kFusion &&
1299 ShouldShowFusionSubcomputation(instr)) {
1300 instr = instr->fused_expression_root();
1301 }
1302 return instr;
1303 }
1304
1305 // Gets a NodeFilter that includes roughly all instructions whose distance from
1306 // root is <= radius.
MakeNodeRadiusAroundFilter(const HloInstruction * root,int64_t radius,const absl::flat_hash_set<const HloInstruction * > & boundary)1307 NodeFilter MakeNodeRadiusAroundFilter(
1308 const HloInstruction* root, int64_t radius,
1309 const absl::flat_hash_set<const HloInstruction*>& boundary) {
1310 // First, find the neighborhood of nodes with distance from root <= radius.
1311 // These nodes are our initial set of "normal" nodes.
1312 absl::flat_hash_map<const HloInstruction*, NodeFilterResult> nodes;
1313 std::deque<std::pair<const HloInstruction*, /*depth*/ int64>> worklist;
1314 worklist.push_back({root, 0});
1315 while (!worklist.empty()) {
1316 const HloInstruction* instr;
1317 int64_t depth;
1318 std::tie(instr, depth) = worklist.front();
1319 worklist.pop_front();
1320
1321 nodes[instr] = kNormalNode;
1322 if (depth == radius) {
1323 continue;
1324 }
1325 if (boundary.contains(instr)) {
1326 continue;
1327 }
1328
1329 // Traverse into instr's operands.
1330 //
1331 // Don't traverse into tuples' operands unless the tuple is the root.
1332 // Usually a tuple is the bottommost node in the graph, and so its operands
1333 // are not interesting to the graph at hand.
1334 if (instr == root || instr->opcode() != HloOpcode::kTuple) {
1335 for (const HloInstruction* operand : instr->operands()) {
1336 if (!nodes.contains(operand)) {
1337 worklist.push_back({operand, depth + 1});
1338 }
1339 }
1340 }
1341
1342 // Traverse into instr's nested computations.
1343 for (const HloComputation* computation : instr->called_computations()) {
1344 worklist.push_back({computation->root_instruction(), depth + 1});
1345 }
1346
1347 // Traverse into instr's users, unless:
1348 //
1349 // - there are a ton of them, in which case they're probably not
1350 // interesting (and anyway, rendering them all would make the graph
1351 // unreadable), or
1352 // - instr is a constant, in which case its users are probably not
1353 // interesting.
1354 if (instr->opcode() == HloOpcode::kConstant) {
1355 continue;
1356 }
1357 constexpr int kMaxUsersToRender = 16;
1358 if (instr->user_count() > kMaxUsersToRender) {
1359 // If we're going to skip this node's users, style it as such.
1360 nodes[instr] = kSomeUsersOmitted;
1361 continue;
1362 }
1363 for (const HloInstruction* user : instr->users()) {
1364 if (!nodes.contains(user)) {
1365 worklist.push_back({user, depth + 1});
1366 }
1367 }
1368 }
1369
1370 auto is_displayed = [&](const HloInstruction* instr) {
1371 // Constants are displayed inline with their users; they're never omitted.
1372 // Nodes in subcomputations are always shown.
1373 return nodes.contains(instr) || instr->opcode() == HloOpcode::kConstant ||
1374 instr->parent() != root->parent();
1375 };
1376
1377 // Make a second pass over 'nodes' to fix up the NodeFilterResults now that we
1378 // know which nodes will be included in the graph.
1379 for (auto& kv : nodes) {
1380 const HloInstruction* instr = kv.first;
1381 NodeFilterResult& filter_result = kv.second;
1382 const auto& operands = instr->operands();
1383
1384 if (absl::c_any_of(operands, is_displayed) &&
1385 !absl::c_all_of(operands, is_displayed)) {
1386 // Mark nodes with some operands omitted appropriately.
1387 filter_result = kSomeOperandsOmitted;
1388 } else if (!operands.empty() && absl::c_none_of(operands, is_displayed)) {
1389 // Mark nodes with *all* operands omitted appropriately.
1390 filter_result = kOmitNodeOperands;
1391 }
1392
1393 // Promote nodes with type kSomeUsersOmitted to kNormalNode if all of their
1394 // users made it into the graph.
1395 if (filter_result == kSomeUsersOmitted &&
1396 absl::c_all_of(instr->users(), is_displayed)) {
1397 filter_result = kNormalNode;
1398 }
1399 }
1400
1401 // Highlight the root node.
1402 nodes[root] = kHighlightNode;
1403
1404 return NodeFilter([=](const HloInstruction* instr) {
1405 auto it = nodes.find(instr);
1406 if (it != nodes.end()) {
1407 return it->second;
1408 }
1409 // Show all nodes in subcomputations.
1410 if (instr->parent() != root->parent()) {
1411 return kNormalNode;
1412 }
1413 return kHideNode;
1414 });
1415 }
1416
1417 // Gets a node filter that includes nodes on all paths from `from` to `to`. If
1418 // the all-paths set contains more than max_nodes elements, includes the nodes
1419 // on the shortest paths and sets hit_limit to true.
MakeNodeFromToFilter(const HloInstruction * from,const HloInstruction * to,int64_t max_nodes,bool * hit_limit)1420 NodeFilter MakeNodeFromToFilter(const HloInstruction* from,
1421 const HloInstruction* to, int64_t max_nodes,
1422 bool* hit_limit) {
1423 *hit_limit = false;
1424
1425 // Elements in the queue are paths through the graph.
1426 std::deque<std::vector<const HloInstruction*>> queue;
1427 queue.push_front({from});
1428
1429 // Compute the set of nodes we want to show using a slightly-modified
1430 // Djikstra's algorithm. The only real difference is, rather than stopping
1431 // when we find a (shortest) path, we continue until we've found max_nodes
1432 // nodes on some path.
1433 std::unordered_set<const HloInstruction*> visited;
1434 std::unordered_set<const HloInstruction*> to_display = {from, to};
1435 while (!queue.empty() && to_display.size() < max_nodes) {
1436 std::vector<const HloInstruction*> path = std::move(queue.front());
1437 queue.pop_front();
1438 if (!visited.insert(path.back()).second) {
1439 continue;
1440 }
1441
1442 for (const auto* user : path.back()->users()) {
1443 if (user == to) {
1444 auto it = path.begin();
1445 for (; it != path.end() && to_display.size() < max_nodes; ++it) {
1446 to_display.insert(*it);
1447 }
1448 if (it != path.end()) {
1449 *hit_limit = true;
1450 }
1451 } else if (!visited.count(user)) {
1452 auto new_path = path;
1453 new_path.push_back(user);
1454 queue.push_back(std::move(new_path));
1455 }
1456 }
1457 }
1458
1459 return NodeFilter([=](const HloInstruction* instr) {
1460 if (instr == from || instr == to) {
1461 return kHighlightNode;
1462 }
1463 return to_display.count(instr) ? kNormalNode : kHideNode;
1464 });
1465 }
1466
WrapDotInHtml(absl::string_view dot)1467 string WrapDotInHtml(absl::string_view dot) {
1468 static const char html_prefix[] = R"html(
1469 <!DOCTYPE html>
1470 <html>
1471 <head>
1472 <meta charset="utf-8">
1473 <style type="text/css">
1474 body {
1475 height: 100vh;
1476 margin: 0;
1477 }
1478 </style>
1479 </head>
1480 <body>
1481 <!-- Integrity hash is generated by https://www.srihash.org/ -->
1482 <script src="https://cdn.jsdelivr.net/npm/viz.js@2.1.1/viz.js"
1483 integrity="sha384-aD1MJYb0WKIUT+CtwJp5LTuV3U4pLAS6B/nUxL7ECimC2pN9N8vjlMr/yQCAkzxE"
1484 crossorigin="anonymous"></script>
1485 <script src="https://cdn.jsdelivr.net/npm/viz.js@2.1.1/full.render.js"
1486 integrity="sha384-bAixY275aIpCj6Te19y0MILZ4V+VEC8CVFujFEH+Lf7W+4XYYeYLwW5IBI6yQmMT"
1487 crossorigin="anonymous"></script>
1488 <script src="https://cdn.jsdelivr.net/npm/svg-pan-zoom@3.6.0/dist/svg-pan-zoom.min.js"
1489 integrity="sha384-3008WpYB2pOBvE7lwkrKf+qTmbTPGGPYxA9C1YVhvbPukns4ZFj7E98QPLkNW9dS"
1490 crossorigin="anonymous"></script>
1491 <div id="container" style="height:95vh; border:1px solid black; "></div>
1492 <script>
1493 var data = `
1494 )html";
1495
1496 static const char html_suffix[] = R"html(
1497 `;
1498 var cssregex = new RegExp('stylesheet=<([^]*)\n>\n', 'gm');
1499 var results = cssregex.exec(data)
1500 // graphviz has problem dealing with large stylesheets.
1501 // https://github.com/tensorflow/tensorflow/issues/17220#issuecomment-369228492
1502 // In order to avoid the problem, remove the stylesheet from the dot and
1503 // insert it directly info the rendered SVG.
1504 var dot_data = data;
1505 var css_data = ''
1506 if (results !== null) {
1507 css_data = results[1].replace(/\s*data:.*\s*,/,''); // Strip content-type field.
1508 // CSS inside DOT is URL-escaped, so we must unescape it
1509 // before we can insert it into SVG.
1510 css_data = unescape(css_data);
1511 dot_data = data.replace(cssregex, ''); // Remove the stylesheet
1512 }
1513
1514 var render_start = performance.now()
1515 function add_controls(svg) {
1516 var htmlblob = new Blob([document.documentElement.innerHTML],
1517 {type: 'text/html'});
1518 var savehtml = document.createElement('a');
1519 savehtml.setAttribute('href', URL.createObjectURL(htmlblob));
1520 savehtml.setAttribute('download', 'graph.html');
1521 savehtml.innerHTML = " [Save HTML+SVG] ";
1522 document.body.append(savehtml);
1523 var svgblob = new Blob([svg.outerHTML], {type: 'image/svg'});
1524 var savesvg = document.createElement('a');
1525 savesvg.setAttribute('href', URL.createObjectURL(svgblob));
1526 savesvg.setAttribute('download', 'graph.svg');
1527 savesvg.innerHTML = " [Save SVG] ";
1528 document.body.append(savesvg);
1529 var dotblob = new Blob([data], {type: 'text/dot'});
1530 var savedot = document.createElement('a');
1531 savedot.setAttribute('href', URL.createObjectURL(dotblob));
1532 savedot.setAttribute('download', 'graph.dot');
1533 savedot.innerHTML = " [Save DOT] ";
1534 document.body.append(savedot);
1535 // Will get called after embed element was loaded
1536 var panzoom = svgPanZoom(svg, {
1537 zoomEnabled: true,
1538 controlIconsEnabled: true,
1539 });
1540 document.getElementsByTagName("BODY")[0].onresize = function() {
1541 panzoom.resize();
1542 panzoom.fit();
1543 panzoom.center();
1544 };
1545 var render_end = performance.now();
1546 var render_note = document.createElement('div')
1547 render_note.innerHTML = 'Rendering took '
1548 + (render_end - render_start).toFixed(2) + "ms."
1549 document.body.append(render_note);
1550 }
1551 var svg = document.getElementById('graph')
1552 if (svg == null) {
1553 // Need to render SVG first.
1554 var viz = new Viz();
1555 viz.renderSVGElement(dot_data)
1556 .then(function(svg){
1557 var container = document.getElementById('container')
1558 var style = document.createElementNS('http://www.w3.org/2000/svg', 'style');
1559 var node = document.createTextNode(css_data);
1560 style.appendChild(node);
1561 svg.setAttribute('width', '100%');
1562 svg.setAttribute('height', '100%');
1563 svg.setAttribute('id', 'graph');
1564 svg.appendChild(style);
1565 container.appendChild(svg);
1566 add_controls(svg);
1567 })
1568 } else {
1569 // HTML already has rendered SVG embedded, so we just need to add
1570 // controls.
1571 add_controls(svg);
1572 }
1573 </script>
1574 </body>
1575 </html>
1576 )html";
1577
1578 return absl::StrCat(html_prefix, dot, html_suffix);
1579 }
1580
1581 tensorflow::mutex url_renderer_mu(tensorflow::LINKER_INITIALIZED);
1582 std::function<StatusOr<string>(absl::string_view)>* url_renderer
1583 TF_GUARDED_BY(url_renderer_mu) = nullptr;
1584
1585 // Storage for fusion visualization: (module_id, computation_id) -> sequence of
1586 // dot dumps.
1587 tensorflow::mutex fusion_visualizer_state_mu(tensorflow::LINKER_INITIALIZED);
1588 static auto& fusion_visualizer_state TF_GUARDED_BY(fusion_visualizer_state_mu) =
1589 *new absl::flat_hash_map<std::pair<int64, int64>,
1590 std::vector<std::string>>();
1591
1592 // Generates a key to the fusion visualizer state mapping.
1593 std::pair<int, int> FusionVisualizerStateKey(
1594 const HloComputation& computation) {
1595 return std::make_pair(computation.parent()->unique_id(),
1596 computation.unique_id());
1597 }
1598
1599 // Generates a fusion explorer for the given computation using the data in
1600 // fusion_visualizer_state and the URL renderer. Precondition: url_renderer !=
1601 // nullptr.
1602 StatusOr<std::string> WrapFusionExplorer(const HloComputation& computation)
1603 TF_EXCLUSIVE_LOCKS_REQUIRED(url_renderer_mu) {
1604 CHECK(url_renderer != nullptr);
1605 tensorflow::mutex_lock lock(fusion_visualizer_state_mu);
1606 const std::vector<std::string>& dot_graphs =
1607 fusion_visualizer_state[FusionVisualizerStateKey(computation)];
1608 std::vector<std::string> dot_urls;
1609 dot_urls.reserve(dot_graphs.size());
1610 for (const std::string& dot : dot_graphs) {
1611 TF_ASSIGN_OR_RETURN(std::string url, (*url_renderer)(dot));
1612 dot_urls.push_back(url);
1613 }
1614
1615 return absl::StrReplaceAll(
1616 R"(
1617 <!doctype html>
1618 <style>
1619 html, body {height: 100%; text-align: center;}
1620 #display {height: 80%; width: 80%;}
1621 </style>
1622 <title>Fusion Explorer: $TITLE</title>
1623 <iframe id='display' width=80% height=80%></iframe>
1624 <p id='description'></p>
1625 <p>
1626 <a id='prev' href='#'>Prev Step</a>
1627 <a id='next' href='#'>Next Step</a>
1628 </p>
1629 <p>
1630 Use j/k for keyboard navigation.
1631 </p>
1632 <script>
1633 var currId = -1;
1634 var urls = [$URLS];
1635
1636 var setIframe = function() {
1637 document.getElementById('display').src = urls[currId];
1638 };
1639
1640 var update = function(delta) {
1641 currId = (currId + delta + urls.length) % urls.length;
1642 document.getElementById('description').innerHTML = "Frame #"
1643 + (currId + 1) + " / " + urls.length;
1644 setIframe();
1645 };
1646
1647 document.getElementById('prev').onclick = function() {
1648 update(-1);
1649 return false;
1650 };
1651
1652 document.getElementById('next').onclick = function() {
1653 update(1);
1654 return false;
1655 };
1656
1657 window.addEventListener("keydown", function (event) {
1658 if (event.defaultPrevented) {
1659 return;
1660 }
1661 if (event.key == "j") {
1662 update(1);
1663 } else if (event.key == "k") {
1664 update(-1);
1665 } else {
1666 return;
1667 }
1668 event.preventDefault();
1669 }, true);
1670
1671 document.addEventListener("DOMContentLoaded", function() {
1672 update(1);
1673 });
1674
1675 </script>
1676 )",
1677 {{"$URLS", absl::StrJoin(dot_urls, ", ",
1678 [&](std::string* out, const std::string& url) {
1679 absl::StrAppend(out, "\"", url, "\"");
1680 })},
1681 {"$TITLE",
1682 absl::StrCat(computation.parent()->name(), "_", computation.name())}});
1683 }
1684
1685 // Precondition: (url_renderer != nullptr || (format != kUrl
1686 // && format != kFusionVisualization)).
1687 //
1688 // (We specify this as a precondition rather than checking it in here and
1689 // returning an error because we want to fail quickly when there's no URL
1690 // renderer available, and this function runs only after we've done all the work
1691 // of producing dot for the graph.)
WrapDotInFormat(const HloComputation & computation,absl::string_view dot,RenderedGraphFormat format)1692 StatusOr<string> WrapDotInFormat(const HloComputation& computation,
1693 absl::string_view dot,
1694 RenderedGraphFormat format)
1695 TF_EXCLUSIVE_LOCKS_REQUIRED(url_renderer_mu) {
1696 switch (format) {
1697 case RenderedGraphFormat::kUrl:
1698 CHECK(url_renderer != nullptr)
1699 << "Should have checked url_renderer != null before calling.";
1700 return (*url_renderer)(dot);
1701 case RenderedGraphFormat::kHtml:
1702 return WrapDotInHtml(dot);
1703 case RenderedGraphFormat::kDot:
1704 return string(dot);
1705 case RenderedGraphFormat::kFusionVisualization:
1706 return WrapFusionExplorer(computation);
1707 }
1708 }
1709
1710 } // namespace
1711
RegisterGraphToURLRenderer(std::function<StatusOr<string> (absl::string_view)> renderer)1712 void RegisterGraphToURLRenderer(
1713 std::function<StatusOr<string>(absl::string_view)> renderer) {
1714 tensorflow::mutex_lock lock(url_renderer_mu);
1715 if (url_renderer != nullptr) {
1716 LOG(WARNING) << "Multiple calls to RegisterGraphToURLRenderer. Last call "
1717 "wins, but because order of initialization in C++ is "
1718 "nondeterministic, this may not be what you want.";
1719 }
1720 delete url_renderer;
1721 url_renderer = new std::function<StatusOr<string>(absl::string_view)>(
1722 std::move(renderer));
1723 }
1724
RegisterFusionState(const HloComputation & computation,absl::string_view label)1725 Status RegisterFusionState(const HloComputation& computation,
1726 absl::string_view label) {
1727 tensorflow::mutex_lock lock(fusion_visualizer_state_mu);
1728 TF_ASSIGN_OR_RETURN(
1729 string dot_graph,
1730 RenderGraph(computation,
1731 absl::StrCat(computation.parent()->name(), ", ",
1732 computation.name(), ", ", label),
1733 /*debug_options=*/{}, xla::RenderedGraphFormat::kDot,
1734 /*hlo_execution_profile=*/nullptr,
1735 /*hlo_render_options=*/{}));
1736 std::vector<std::string>& fusion_states =
1737 fusion_visualizer_state[FusionVisualizerStateKey(computation)];
1738 if (fusion_states.empty() || fusion_states.back() != dot_graph) {
1739 fusion_states.push_back(dot_graph);
1740 }
1741 return Status::OK();
1742 }
1743
RenderGraph(const HloComputation & computation,absl::string_view label,const DebugOptions & debug_options,RenderedGraphFormat format,const HloExecutionProfile * hlo_execution_profile,HloRenderOptions hlo_render_options)1744 StatusOr<string> RenderGraph(const HloComputation& computation,
1745 absl::string_view label,
1746 const DebugOptions& debug_options,
1747 RenderedGraphFormat format,
1748 const HloExecutionProfile* hlo_execution_profile,
1749 HloRenderOptions hlo_render_options) {
1750 tensorflow::mutex_lock lock(url_renderer_mu);
1751 if (format == RenderedGraphFormat::kUrl && url_renderer == nullptr) {
1752 return Unavailable("Can't render as URL; no URL renderer was registered.");
1753 }
1754
1755 string rendered_dot =
1756 HloDotDumper(&computation, label, debug_options, hlo_render_options,
1757 hlo_execution_profile, NodeFilter())
1758 .Dump();
1759 return WrapDotInFormat(computation, rendered_dot, format);
1760 }
1761
RenderNeighborhoodAround(const HloInstruction & node,int radius,RenderedGraphFormat format,HloRenderOptions hlo_render_options,const absl::flat_hash_set<const HloInstruction * > & boundary)1762 StatusOr<string> RenderNeighborhoodAround(
1763 const HloInstruction& node, int radius, RenderedGraphFormat format,
1764 HloRenderOptions hlo_render_options,
1765 const absl::flat_hash_set<const HloInstruction*>& boundary) {
1766 tensorflow::mutex_lock lock(url_renderer_mu);
1767 if (format == RenderedGraphFormat::kUrl && url_renderer == nullptr) {
1768 return FailedPrecondition(
1769 "Can't render as URL; no URL renderer was registered.");
1770 }
1771
1772 string label =
1773 StrCat("Neighborhood of ", radius, " nodes around ", node.name());
1774 string rendered_dot =
1775 HloDotDumper(node.parent(), label,
1776 node.GetModule()->config().debug_options(),
1777 hlo_render_options, /*profile=*/nullptr,
1778 MakeNodeRadiusAroundFilter(&node, radius, boundary))
1779 .Dump();
1780 return WrapDotInFormat(*node.parent(), rendered_dot, format);
1781 }
1782
RenderAllPathsFromTo(const HloInstruction & from,const HloInstruction & to,int64_t max_nodes,RenderedGraphFormat format,HloRenderOptions hlo_render_options)1783 StatusOr<string> RenderAllPathsFromTo(const HloInstruction& from,
1784 const HloInstruction& to,
1785 int64_t max_nodes,
1786 RenderedGraphFormat format,
1787 HloRenderOptions hlo_render_options) {
1788 tensorflow::mutex_lock lock(url_renderer_mu);
1789 if (format == RenderedGraphFormat::kUrl && url_renderer == nullptr) {
1790 return FailedPrecondition(
1791 "Can't render as URL; no URL renderer was registered.");
1792 }
1793
1794 CHECK_EQ(from.parent(), to.parent()) << "Nodes must be in same computation!";
1795 auto debug_options = from.GetModule()->config().debug_options();
1796
1797 bool hit_limit = false;
1798 NodeFilter filter = MakeNodeFromToFilter(&from, &to, max_nodes, &hit_limit);
1799 string label;
1800 if (!hit_limit) {
1801 label = StrCat("All paths from ", from.name(), " to ", to.name());
1802 } else {
1803 label = StrCat(max_nodes, " nodes on the shortest paths from ", from.name(),
1804 " to ", to.name(),
1805 "<br/><br/>***SHOWING ONLY A SUBSET OF ALL PATHS BETWEEN "
1806 "NODES***<br/><br/>");
1807 }
1808 string rendered_dot =
1809 HloDotDumper(from.parent(), label, debug_options, hlo_render_options,
1810 /*profile=*/nullptr, filter)
1811 .Dump();
1812 return WrapDotInFormat(*from.parent(), rendered_dot, format);
1813 }
1814
1815 } // namespace xla
1816