• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
17 
18 #include <unistd.h>
19 
20 #include <algorithm>
21 #include <atomic>
22 #include <deque>
23 #include <map>
24 #include <memory>
25 #include <optional>
26 #include <queue>
27 #include <string>
28 #include <tuple>
29 #include <vector>
30 
31 #include "absl/container/flat_hash_map.h"
32 #include "absl/container/flat_hash_set.h"
33 #include "absl/strings/match.h"
34 #include "absl/strings/str_cat.h"
35 #include "absl/strings/str_format.h"
36 #include "absl/strings/str_join.h"
37 #include "absl/strings/str_replace.h"
38 #include "tensorflow/compiler/xla/layout_util.h"
39 #include "tensorflow/compiler/xla/literal.h"
40 #include "tensorflow/compiler/xla/primitive_util.h"
41 #include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h"
42 #include "tensorflow/compiler/xla/service/gpu/cublas_cudnn.h"
43 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
44 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
45 #include "tensorflow/compiler/xla/service/hlo_module.h"
46 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
47 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
48 #include "tensorflow/compiler/xla/shape_util.h"
49 #include "tensorflow/compiler/xla/types.h"
50 #include "tensorflow/compiler/xla/util.h"
51 #include "tensorflow/compiler/xla/window_util.h"
52 #include "tensorflow/core/lib/core/status.h"
53 #include "tensorflow/core/lib/gtl/map_util.h"
54 #include "tensorflow/core/lib/io/zlib_compression_options.h"
55 #include "tensorflow/core/lib/io/zlib_outputbuffer.h"
56 #include "tensorflow/core/lib/strings/numbers.h"
57 #include "tensorflow/core/platform/base64.h"
58 #include "tensorflow/core/platform/env.h"
59 #include "tensorflow/core/platform/protobuf.h"
60 #include "tensorflow/core/platform/regexp.h"
61 #include "tensorflow/stream_executor/dnn.h"
62 
63 namespace xla {
64 namespace {
65 
66 using absl::StrAppend;
67 using absl::StrCat;
68 using absl::StrFormat;
69 using absl::StrJoin;
70 using std::nullopt;
71 using std::optional;
72 
73 // Used to indicate how we should treat a given HLOInstruction in the graph.
74 // should we treat it like normal, hide it, and so on?
75 enum NodeFilterResult {
76   kNormalNode,
77   kHideNode,
78   // Make the node easy to find in the final graph.
79   kHighlightNode,
80   // "Gray out" the node to indicate that some of its operands have been
81   // omitted.
82   kSomeOperandsOmitted,
83   // Style the node the same as kSomeOperandsOmitted, but also don't connect it
84   // to its operands, even if they're present in the graph.
85   kOmitNodeOperands,
86   // Same style as kSomeOperandsOmitted, but used to indicate that some of the
87   // node's *users* have been omitted.
88   kSomeUsersOmitted,
89 };
90 
91 // NodeFilter is essentially a map from HloInstruction*s to NodeFilterResult.
92 // It lets callers tell the graph-drawing routines which nodes they want to be
93 // shown, hidden, or highlighted.
94 class NodeFilter {
95  public:
__anon758a858c0202(const HloInstruction*) 96   NodeFilter() : filter_([](const HloInstruction*) { return kNormalNode; }) {}
97 
NodeFilter(std::function<NodeFilterResult (const HloInstruction * instr)> filter)98   explicit NodeFilter(
99       std::function<NodeFilterResult(const HloInstruction* instr)> filter)
100       : filter_(std::move(filter)) {}
101 
Show(const HloInstruction * instr) const102   bool Show(const HloInstruction* instr) const {
103     return filter_(instr) != kHideNode;
104   }
Highlight(const HloInstruction * instr) const105   bool Highlight(const HloInstruction* instr) const {
106     return filter_(instr) == kHighlightNode;
107   }
OmitOperands(const HloInstruction * instr) const108   bool OmitOperands(const HloInstruction* instr) const {
109     return filter_(instr) == kOmitNodeOperands;
110   }
SomeOrAllOperandsOmitted(const HloInstruction * instr) const111   bool SomeOrAllOperandsOmitted(const HloInstruction* instr) const {
112     auto result = filter_(instr);
113     return result == kOmitNodeOperands || result == kSomeOperandsOmitted;
114   }
Deemphasized(const HloInstruction * instr) const115   bool Deemphasized(const HloInstruction* instr) const {
116     auto result = filter_(instr);
117     return result == kOmitNodeOperands || result == kSomeOperandsOmitted ||
118            result == kSomeUsersOmitted;
119   }
120 
121  private:
122   std::function<NodeFilterResult(const HloInstruction* instr)> filter_;
123 };
124 
125 // We arbitrarily set this as the boundary between "large" and "small"
126 // instructions.
IsSmall(const HloInstruction * instr)127 bool IsSmall(const HloInstruction* instr) {
128   if (ShapeUtil::HasPrimitiveType(instr->shape(), OPAQUE_TYPE) ||
129       ShapeUtil::HasPrimitiveType(instr->shape(), TOKEN)) {
130     return true;
131   }
132   return ShapeUtil::ElementsInRecursive(instr->shape()) < 4096;
133 }
134 
135 // Node color schemes, used by NodeColorAttributes.
136 enum ColorScheme {
137   kBlue,
138   kBrown,
139   kDarkBlue,
140   kDarkGreen,
141   kDarkOrange,
142   kDarkRed,
143   kGray,
144   kGreen,
145   kOrange,
146   kPurple,
147   kRed,
148   kWhite,
149   kYellow,
150 
151   // Causes the node's border to be a dashed line, and its content to be gray
152   // text on a white background, suggesting that this is an "unimportant" node.
153   kDashedBorder,
154 };
155 
156 // Graphviz attributes/colors that make up a color scheme.
157 struct NodeColors {
158   const char* style;
159   const char* fill_color;
160   const char* stroke_color;
161   const char* font_color;
162 };
163 
NodeColorsForScheme(ColorScheme color)164 NodeColors NodeColorsForScheme(ColorScheme color) {
165   switch (color) {
166     case kBlue:
167       return NodeColors{"filled", "#bbdefb", "#8aacc8", "black"};
168     case kBrown:
169       return NodeColors{"filled", "#bcaaa4", "#8c7b75", "black"};
170     case kDarkBlue:
171       return NodeColors{"filled", "#1565c0", "#003c8f", "white"};
172     case kDarkGreen:
173       return NodeColors{"filled", "#2e7d32", "#005005", "white"};
174     case kDarkOrange:
175       // This is more of a "medium" orange, made to look close to kOrange;
176       // there's probably room for a darker weight if desired.
177       return NodeColors{"filled", "#ffb74d", "#c88719", "black"};
178     case kDarkRed:
179       return NodeColors{"filled", "#b71c1c", "#7f0000", "white"};
180     case kGray:
181       return NodeColors{"filled", "#cfd8dc", "#9ea7aa", "black"};
182     case kGreen:
183       return NodeColors{"filled", "#c8e6c9", "#97b498", "black"};
184     case kOrange:
185       return NodeColors{"filled", "#ffe0b2", "#cbae82", "black"};
186     case kPurple:
187       return NodeColors{"filled", "#e1bee7", "#af8eb5", "black"};
188     case kRed:
189       return NodeColors{"filled", "#ffcdd2", "#cb9ca1", "black"};
190     case kWhite:
191       return NodeColors{"filled", "white", "black", "black"};
192     case kYellow:
193       return NodeColors{"filled", "#fff9c4", "#cbc693", "black"};
194     case kDashedBorder:
195       // "filled,dashed" looks the same as "dashed", since we have a white
196       // background.  But we use "filled,dashed" so that when you hover over
197       // any part of the node (not just the text inside the node), our css
198       // :hover rule is triggered.
199       return NodeColors{"filled,dashed", "white", "#757575", "#757575"};
200   }
201 }
202 
203 // Given a ColorScheme, returns an attribute string for a node of that color.
204 // Sets the node's style and fill/stroke/text colors.
205 //
206 // Colors are from https://material.io/color.
NodeColorAttributes(ColorScheme color)207 std::string NodeColorAttributes(ColorScheme color) {
208   NodeColors node_colors = NodeColorsForScheme(color);
209 
210   return StrFormat(R"(style="%s", fontcolor="%s", color="%s", fillcolor="%s")",
211                    node_colors.style, node_colors.font_color,
212                    node_colors.stroke_color, node_colors.fill_color);
213 }
214 
215 // Replaces <> with &lt;&gt;, so that this string is safe(er) for use in a
216 // graphviz HTML-like string.
HtmlLikeStringSanitize(absl::string_view s)217 std::string HtmlLikeStringSanitize(absl::string_view s) {
218   return absl::StrReplaceAll(s, {{"<", "&lt;"}, {">", "&gt;"}});
219 }
220 
IsFusedBroadcastOfConstantEffectiveScalar(const HloInstruction * instr)221 bool IsFusedBroadcastOfConstantEffectiveScalar(const HloInstruction* instr) {
222   namespace m = match;
223   return instr->parent()->IsFusionComputation() &&
224          Match(instr, m::Broadcast(m::ConstantEffectiveScalar()));
225 }
226 
227 // Tries to generates a human-readable one-word description of the given
228 // computation.
229 //
230 // Currently we support:
231 //
232 //   "return param0 + param1;"      --> "add"
233 //   "return param0 * param1;"      --> "multiply"
234 //   "return min(param0, param1);"  --> "min"
235 //   "return max(param0, param1);"  --> "max"
236 //   "return xor(param0, param1);"  --> "xor"
237 //   "return and(param0, param1);"  --> "and"
238 //   "return or(param0, param1);"   --> "or"
239 //   "return param0 <= param1;"     --> "less-or-equal"
240 //   "return param0 >= param1;"     --> "greater-or-equal"
241 //   "return param0 >  param1;"     --> "greater-than"
242 //   "return param0 <  param1;"     --> "less-than"
243 //   "return param0 == param1;"     --> "equal-to"
244 //   "return param0 != param1;"     --> "not-equal-to"
245 //
246 // where param0 and param1 are effective scalars.  For the ops that are
247 // commutative, we also support them with param0 and param1 swapped.
248 //
249 // This is useful primarily for reduce and map nodes.  These take a
250 // subcomputation which is almost always one of the above, and pattern matching
251 // it to a short string lets us tell the user what the subcomputation is without
252 // drawing it as a graph.
MatchTrivialComputation(const HloComputation * computation)253 optional<std::string> MatchTrivialComputation(
254     const HloComputation* computation) {
255   namespace m = match;
256 
257   if (computation->instruction_count() != 3) {
258     return nullopt;
259   }
260   HloInstruction* root = computation->root_instruction();
261   const HloInstruction *param0, *param1;
262   if (!Match(root, m::Op()
263                        .WithNumOperands(2)
264                        .WithShape(m::Shape().IsEffectiveScalar())
265                        .WithBinaryOperandsAnyOrder(
266                            m::Parameter(&param0, 0)
267                                .WithShape(m::Shape().IsEffectiveScalar()),
268                            m::Parameter(&param1, 1)
269                                .WithShape(m::Shape().IsEffectiveScalar())))) {
270     return nullopt;
271   }
272 
273   // If the params are reversed (i.e. operand0 is param1 and operand1 is
274   // param0), check that the operation being performed is commutative.
275   if (root->operand(0) == param1) {
276     CHECK_EQ(root->operand(1), param0);
277     if (root->opcode() == HloOpcode()) {
278       switch (root->comparison_direction()) {
279         case ComparisonDirection::kLe:
280         case ComparisonDirection::kGe:
281         case ComparisonDirection::kGt:
282         case ComparisonDirection::kLt:
283           return nullopt;
284         default:
285           break;
286       }
287     }
288   }
289 
290   // If we recognize the root's opcode, we've successfully pattern-matched!
291   switch (root->opcode()) {
292     case HloOpcode::kAdd:
293       return "add";
294     case HloOpcode::kMultiply:
295       return "multiply";
296     case HloOpcode::kMinimum:
297       return "min";
298     case HloOpcode::kMaximum:
299       return "max";
300     case HloOpcode::kXor:
301       return "xor";
302     case HloOpcode::kAnd:
303       return "and";
304     case HloOpcode::kOr:
305       return "or";
306     case HloOpcode::kCompare: {
307       switch (root->comparison_direction()) {
308         case ComparisonDirection::kLe:
309           return "less-or-equal";
310         case ComparisonDirection::kGe:
311           return "greater-or-equal";
312         case ComparisonDirection::kGt:
313           return "greater-than";
314         case ComparisonDirection::kLt:
315           return "less-than";
316         case ComparisonDirection::kEq:
317           return "equal-to";
318         case ComparisonDirection::kNe:
319           return "not-equal-to";
320       }
321     }
322     default:
323       return nullopt;
324   }
325 }
326 
327 // Encapsulates logic for dumping an HLO module to DOT (i.e. graphviz syntax).
328 class HloDotDumper {
329  public:
HloDotDumper(const HloComputation * computation,absl::string_view label,const DebugOptions & debug_options,HloRenderOptions hlo_render_options,const HloExecutionProfile * profile,NodeFilter filter)330   HloDotDumper(const HloComputation* computation, absl::string_view label,
331                const DebugOptions& debug_options,
332                HloRenderOptions hlo_render_options,
333                const HloExecutionProfile* profile, NodeFilter filter)
334       : computation_(computation),
335         label_(label),
336         debug_options_(debug_options),
337         hlo_render_options_(hlo_render_options),
338         profile_(profile),
339         filter_(std::move(filter)) {}
340 
341   std::string Dump();
342 
343   // Returns a CSS id assigned to the instruction, if that exists.
CssIdForInstruction(const HloInstruction & instr)344   std::optional<std::string> CssIdForInstruction(const HloInstruction& instr) {
345     if (instr.opcode() == HloOpcode::kFusion) {
346       // For fusion we render it as a subcomputation.
347       auto it = cluster_ids_.find(instr.called_computations()[0]);
348       if (it == cluster_ids_.end()) {
349         return std::nullopt;
350       }
351       return StrCat("#a_clust", it->second, " path");
352     }
353     auto it = node_ids_.find(&instr);
354     if (it == node_ids_.end()) {
355       return std::nullopt;
356     }
357     return StrCat("#node", it->second, " polygon");
358   }
359 
360  private:
361   // Returns the dot graph identifier for the given instruction.
InstructionId(const HloInstruction * instruction)362   std::string InstructionId(const HloInstruction* instruction) {
363     return StrCat(reinterpret_cast<uint64_t>(instruction));
364   }
365 
366   // Returns the dot graph identifier for the given computation.
SubcomputationId(const HloComputation * computation)367   std::string SubcomputationId(const HloComputation* computation) {
368     return StrCat("cluster_", reinterpret_cast<uint64_t>(computation));
369   }
370 
371   // Generates graph header/footer.  These should be called *after* dumping all
372   // of the instructions and subcomputations for the graph, as they both use
373   // data generated while dumping the graph.
374   std::string Header();
375   std::string Footer();
376 
377   bool ShouldShowSubcomputation(const HloComputation* subcomp);
378   bool ShouldShowFusionSubcomputation(const HloInstruction* instr);
379 
380   // We omit some nodes from the graph, instead drawing them inlined into the
381   // nodes that use them.
382   bool ShouldMergeIntoUsers(const HloInstruction* instr) const;
383 
384   std::string DumpSubcomputation(const HloComputation* subcomp,
385                                  const HloInstruction* parent_instr);
386   std::string DumpComputation(const HloComputation* comp);
387   std::string DumpRootTag();
388   std::string DumpInstruction(const HloInstruction* instr);
389   ColorScheme GetInstructionColor(const HloInstruction* instr);
390   std::string GetInstructionNodeShape(const HloInstruction* instr);
391   std::string GetInstructionNodeLabel(const HloInstruction* instr);
392   std::string GetInstructionNodeMetadata(const HloInstruction* instr);
393   std::string GetInstructionNodeBackendConfig(const HloInstruction* instr);
394   std::string GetInstructionNodeExtraInfo(const HloInstruction* instr);
395   std::string GetInstructionNodeInlinedOperands(const HloInstruction* instr);
396   void AddInstructionIncomingEdges(const HloInstruction* instr);
397 
398   // For most instructions, GetNodeForEdge(instr) returns instr.
399   //
400   // The exception is fusion nodes.  For these, we walk up the chain of nested
401   // fusion nodes starting at instr until we reach a node that either (a) isn't
402   // a fusion node, or (b) is a fusion node for which
403   // ShouldShowFusionSubcomputation is false.
404   //
405   // We do this because fusion nodes are expanded inline -- if
406   // ShouldShowFusionSubcomputation is true, the fusion node won't be present in
407   // the graph.
408   //
409   // In general when you want to draw an edge from A to B, you should actually
410   // draw an edge from GetNodeForEdge(A).
411   const HloInstruction* GetNodeForEdge(const HloInstruction* instr);
412 
413   // If instr has just one computation and it's trivial (e.g. "return param0 +
414   // param1"), returns a string you can put into the node's body that names the
415   // subcomputation, e.g. "Subcomputation: <b>add</b>".
416   std::string GetInstructionTrivialComputationStr(const HloInstruction* instr);
417 
418   const HloComputation* computation_;  // never null
419   const std::string label_;            // overall name for the graph
420   const DebugOptions& debug_options_;
421   const HloRenderOptions hlo_render_options_;
422   const HloExecutionProfile* profile_;  // may be null
423   const NodeFilter filter_;
424 
425   // Each HloInstruction dumped gets a monotonically-increasing node ID.  This
426   // must start at 1, because that's where graphviz's accounting starts.
427   int64_t next_node_id_ = 1;
428   absl::flat_hash_map<const HloInstruction*, int64_t> node_ids_;
429 
430   // The "root" tag doesn't have an associated HloInstruction pointer, so we
431   // need to store it outside the map.
432   int64_t root_node_id_;
433 
434   // Each (from, to) edge gets a monotonically-increasing ID.  This is a
435   // multimap because it's possible for the same edge to appear multiple times
436   // in the graph (e.g. x^2 may be represented as mul(x, x)).
437   int64_t next_edge_id_ = 1;
438   std::unordered_multimap<
439       std::pair<const HloInstruction*, const HloInstruction*>, int64_t,
440       absl::Hash<std::pair<const HloInstruction*, const HloInstruction*>>>
441       edge_ids_;
442 
443   // Each HloComputation that's emitted gets a monotonically-increasing ID.
444   int64_t next_cluster_id_ = 1;
445   absl::flat_hash_map<const HloComputation*, int64_t> cluster_ids_;
446 
447   // Edges to print from Footer().  Edges come at the end because graphviz is
448   // unhappy if an edge from a subcomputation to a node in the outer computation
449   // appears before both the inner computation and the destination node are
450   // defined.
451   std::vector<std::string> edges_;
452 
453   // When coloring by sharding information, we track the sharding string
454   // representation to color association, by round-robin the color schemes.
455   absl::flat_hash_map<HloSharding, ColorScheme> sharding_colors_;
456   int64_t next_shard_color_ = 0;
457 };
458 
Dump()459 std::string HloDotDumper::Dump() {
460   std::string body;
461   StrAppend(&body, DumpComputation(computation_));
462   StrAppend(&body, DumpRootTag());
463 
464   // By contract, Header() and Footer() have to be called after we've dumped all
465   // our instructions, because they use state generated during that process.
466   std::string g = Header();
467   StrAppend(&g, body);
468   StrAppend(&g, Footer());
469   return g;
470 }
471 
Header()472 std::string HloDotDumper::Header() {
473   constexpr char fmt[] = R"(digraph G {
474 rankdir = TB;
475 compound = true;
476 label = <<b>%s</b>>;
477 labelloc = t;
478 // Disable the tooltip.  Interestingly, "" doesn't work!
479 tooltip = " ";
480 // DOT graphs accept a stylesheet as a URI.  So naturally, an inline
481 // stylesheet is a data URI!
482 stylesheet=<
483   data:text/css,
484   @import url(https://fonts.googleapis.com/css?family=Roboto:400,700);
485   svg text {
486     font-family: 'Roboto';
487     font-size: 12px;
488   }
489 
490   %s
491 >
492 
493 )";
494 
495   VLOG(3) << "Generating Header";
496 
497   std::string graph_label =
498       StrCat(label_, "<br/>Computation ", computation_->name());
499   if (computation_->IsFusionComputation()) {
500     StrAppend(&graph_label, " (in fusion instruction ",
501               computation_->FusionInstruction()->name(), ")");
502   }
503   if (profile_ != nullptr) {
504     auto cycles = profile_->total_cycles_executed(*computation_);
505     absl::StrAppendFormat(&graph_label, "<br/>total cycles = %d (%s)", cycles,
506                           tensorflow::strings::HumanReadableNum(cycles));
507   }
508 
509   // Create CSS rules that say, when you hover over the given node or cluster,
510   // turn the given edge the given color.
511   //
512   // We rely on a few properties of how graphviz generates SVGs:
513   //
514   //  - Nodes are named "nodeN", where N corresponds to the 1-based index of
515   //    the node in our DOT (i.e. the first node in the DOT is "node1", etc.).
516   //    Edges are similarly named "edgeN", and clusters are named "clustN".
517   //  - Nodes come before their in- and out-edges in the SVG.  We need this
518   //    because the "X ~ Y" CSS selector finds a sibling of X that *comes
519   //    after X in the DOM* and matches Y.
520   std::vector<std::string> edge_css_rules;
521   const char* kBlue = "#1976d2";
522   const char* kRed = "#d32f2f";
523   for (const auto& kv : edge_ids_) {
524     const HloInstruction* from_node = kv.first.first;
525     const HloInstruction* to_node = kv.first.second;
526     int64_t edge_id = kv.second;
527 
528     auto add_hover_css_rule = [&](std::string elem_type, int64_t elem_id,
529                                   const char* color) {
530       // One could imagine other ways of writing this CSS rule that involve
531       // less duplication, but this way seems to be relatively performant.
532       edge_css_rules.push_back(
533           StrFormat("  #%s%d:hover ~ #edge%d text { fill: %s; }\n"
534                     "  #%s%d:hover ~ #edge%d path { "
535                     "stroke: %s; stroke-width: .2em; }\n"
536                     "  #%s%d:hover ~ #edge%d polygon { "
537                     "fill: %s; stroke: %s; stroke-width: .2em; }\n",
538                     elem_type, elem_id, edge_id, color,  //
539                     elem_type, elem_id, edge_id, color,  //
540                     elem_type, elem_id, edge_id, color, color));
541     };
542 
543     // The "to_node" value may be a NULL, indicating that this points to the
544     // "root" tag rather than a normal node.
545     int64_t from_node_id =
546         tensorflow::gtl::FindWithDefault(node_ids_, from_node, -1);
547     if (from_node_id == -1) {
548       LOG(FATAL) << from_node->name() << " was added to edges but not to nodes";
549     }
550     int64_t to_node_id =
551         to_node ? tensorflow::gtl::FindWithDefault(node_ids_, to_node, -1)
552                 : root_node_id_;
553     if (to_node != nullptr && to_node_id == -1) {
554       LOG(FATAL) << to_node->name() << " was added to edges but not to nodes";
555     }
556 
557     add_hover_css_rule("node", from_node_id, kBlue);
558     add_hover_css_rule("node", to_node_id, kRed);
559 
560     if (to_node) {
561       VLOG(3) << "Adding css for edge " << edge_id << " from node "
562               << from_node->name() << " to node " << to_node->name();
563     } else {
564       VLOG(3) << "Adding css for edge " << edge_id << " from node "
565               << from_node->name() << " to root tag";
566     }
567 
568     // If this edge crosses a fusion cluster boundary, highlight it when the
569     // cluster is hovered over.
570     if (to_node) {
571       if (from_node->IsFused() &&
572           from_node->parent()->root_instruction() == from_node) {
573         int64_t cluster_id = cluster_ids_.at(from_node->parent());
574         add_hover_css_rule("clust", cluster_id, kBlue);
575       }
576       if (to_node->IsFused() && to_node->opcode() == HloOpcode::kParameter) {
577         int64_t cluster_id = cluster_ids_.at(to_node->parent());
578         add_hover_css_rule("clust", cluster_id, kRed);
579       }
580     }
581   }
582 
583   // Browsers require that we URI-encode the contents of our data URI.  (It
584   // seems this was a relatively recent change?) In practice, this means that we
585   // need to escape '#'.
586   return StrFormat(
587       fmt, graph_label,
588       absl::StrReplaceAll(StrJoin(edge_css_rules, "\n"), {{"#", "%23"}}));
589 }
590 
Footer()591 std::string HloDotDumper::Footer() {
592   return StrCat(StrJoin(edges_, "\n"), "\n}");
593 }
594 
ShouldShowFusionSubcomputation(const HloInstruction * instr)595 bool HloDotDumper::ShouldShowFusionSubcomputation(const HloInstruction* instr) {
596   CHECK_EQ(instr->opcode(), HloOpcode::kFusion);
597   return ShouldShowSubcomputation(instr->fused_instructions_computation());
598 }
599 
ShouldShowSubcomputation(const HloComputation * subcomp)600 bool HloDotDumper::ShouldShowSubcomputation(const HloComputation* subcomp) {
601   if (subcomp->IsFusionComputation()) {
602     const HloInstruction* fusion = subcomp->FusionInstruction();
603     if (!filter_.Show(fusion) || filter_.SomeOrAllOperandsOmitted(fusion) ||
604         !hlo_render_options_.show_fusion_subcomputations) {
605       return false;
606     }
607   }
608 
609   // Don't show trivial subcomputations on non-fusion nodes -- these are inlined
610   // into the graph.
611   if (!subcomp->IsFusionComputation() && MatchTrivialComputation(subcomp)) {
612     return false;
613   }
614 
615   // Show the subcomputation if we're showing any of its members.
616   return absl::c_any_of(
617       subcomp->instructions(),
618       [&](const HloInstruction* instr) { return filter_.Show(instr); });
619 }
620 
DumpSubcomputation(const HloComputation * subcomp,const HloInstruction * parent_instr)621 std::string HloDotDumper::DumpSubcomputation(
622     const HloComputation* subcomp, const HloInstruction* parent_instr) {
623   VLOG(2) << "Dumping subcomputation " << subcomp->name();
624   // Add an edge from the subcomputation to its parent node.  If subcomp
625   // belongs to a fusion node, it's drawn in place of the fusion instruction,
626   // so there's no need to link those.
627   if (parent_instr->opcode() != HloOpcode::kFusion) {
628     const HloInstruction* from = GetNodeForEdge(subcomp->root_instruction());
629     VLOG(2) << "Edge: from " << from->name() << " to " << parent_instr->name()
630             << " as " << next_edge_id_;
631     edge_ids_.insert({{from, parent_instr}, next_edge_id_++});
632     constexpr char edge_fmt[] =
633         R"(%s -> %s [ltail="%s", style="dashed" tooltip="%s -> %s"];)";
634     edges_.push_back(StrFormat(
635         edge_fmt, InstructionId(from), InstructionId(parent_instr),
636         SubcomputationId(subcomp), subcomp->name(), parent_instr->name()));
637   }
638 
639   // Have we already dumped this subcomputation?  If so, generating the edge
640   // linking it and parent_instr is all we want to do in this function.
641   if (cluster_ids_.find(subcomp) != cluster_ids_.end()) {
642     return "";
643   }
644 
645   cluster_ids_[subcomp] = next_cluster_id_++;
646 
647   std::string id = SubcomputationId(subcomp);
648 
649   std::string subcomp_label, style;
650   if (parent_instr->opcode() == HloOpcode::kFusion) {
651     subcomp_label =
652         StrFormat("Fused expression for <b>%s</b><br/>%s",
653                   HtmlLikeStringSanitize(parent_instr->name()),
654                   HtmlLikeStringSanitize(parent_instr->ToCategory()));
655     std::string extra_info = GetInstructionNodeExtraInfo(parent_instr);
656     if (!extra_info.empty()) {
657       StrAppend(&subcomp_label, "<br/>", extra_info);
658     }
659     std::string node_backend_config =
660         GetInstructionNodeBackendConfig(parent_instr);
661     if (!node_backend_config.empty()) {
662       StrAppend(&subcomp_label, "<br/>", node_backend_config);
663     }
664 
665     bool highlight = filter_.Highlight(parent_instr);
666     const char* fillcolor;
667     const char* strokecolor;
668     if (debug_options_.xla_hlo_graph_sharding_color() && !highlight) {
669       // Use the sharding color, if the node isn't highlighted.
670       NodeColors node_colors =
671           NodeColorsForScheme(GetInstructionColor(parent_instr));
672       fillcolor = node_colors.fill_color;
673       strokecolor = node_colors.stroke_color;
674     } else {
675       // Subcomputation's fill/stroke color is light/dark red/gray, depending on
676       // whether or not the subcomputation's fusion node is highlighted.
677       fillcolor = highlight ? "#ffcdd2" : "#f5f5f5";
678       strokecolor = highlight ? "#b71c1c" : "#c2c2c2";
679     }
680     style =
681         StrFormat(R"(style="rounded,filled,bold"; fillcolor="%s"; color="%s;")",
682                   fillcolor, strokecolor);
683   } else {
684     subcomp_label = StrFormat("Subcomputation for <b>%s</b><br/>%s",
685                               HtmlLikeStringSanitize(parent_instr->name()),
686                               HtmlLikeStringSanitize(subcomp->name()));
687     style = "style=rounded; color=black;";
688   }
689 
690   std::string comp_body = DumpComputation(subcomp);
691 
692   constexpr char computation_fmt[] = R"(subgraph %s {
693 %s
694 label = <%s>;
695 labelloc = t;
696 tooltip = " ";
697 %s
698 }  // %s
699 
700 )";
701   return StrFormat(computation_fmt, id, style, subcomp_label, comp_body, id);
702 }
703 
DumpComputation(const HloComputation * comp)704 std::string HloDotDumper::DumpComputation(const HloComputation* comp) {
705   std::string g;
706   for (const auto* instr : comp->instructions()) {
707     if (!filter_.Show(instr)) {
708       continue;
709     }
710 
711     // Dump subcomputations within instr.
712     for (const HloComputation* subcomp : instr->called_computations()) {
713       if (ShouldShowSubcomputation(subcomp)) {
714         StrAppend(&g, DumpSubcomputation(subcomp, instr));
715       }
716     }
717 
718     StrAppend(&g, DumpInstruction(instr));
719   }
720   return g;
721 }
722 
DumpRootTag()723 std::string HloDotDumper::DumpRootTag() {
724   const HloInstruction* from = GetNodeForEdge(computation_->root_instruction());
725 
726   // We didn't display constants or broadcasts of effective scalars within
727   // fusions as separate nodes; so if the root is a constant/broadcast of
728   // scalar, we don't add root tag or edge for it.
729   if (!filter_.Show(from) || from->opcode() == HloOpcode::kConstant ||
730       IsFusedBroadcastOfConstantEffectiveScalar(from)) {
731     return "";
732   }
733 
734   auto from_id = InstructionId(from);
735 
736   // The ID of the root computation is otherwise unused, so it makes a good ID
737   // to use for the root-tag node.  However, the edge_ids_ map requires a
738   // HloInstruction* pointer for the 'to' value, so we use a NULL value there
739   // (rather than a pointer type-cast) to make it obvious if it is erroneously
740   // dereferenced.
741   HloInstruction* to = nullptr;
742   auto to_id = SubcomputationId(computation_);
743 
744   std::string node_body = "ROOT";
745   std::string node_shape = "circle";
746   ColorScheme color = kBrown;
747 
748   VLOG(2) << "Adding root tag as node " << next_node_id_;
749   root_node_id_ = next_node_id_++;
750 
751   VLOG(2) << "Adding edge from " << from->name() << " to root tag as "
752           << next_edge_id_;
753   edge_ids_.insert({{from, to}, next_edge_id_++});
754   edges_.push_back(StrFormat(R"(%s -> %s [tooltip=" "];)", from_id, to_id));
755 
756   return StrFormat(R"(%s [label=<%s>, shape=%s, tooltip=" ", %s];)"
757                    "\n",
758                    to_id, node_body, node_shape, NodeColorAttributes(color));
759 }
760 
TryGetFusionParameterConstant(const HloInstruction * instr)761 static const HloConstantInstruction* TryGetFusionParameterConstant(
762     const HloInstruction* instr) {
763   if (instr->opcode() != HloOpcode::kParameter || !instr->IsFused()) {
764     return nullptr;
765   }
766   const HloInstruction* fusion = instr->parent()->FusionInstruction();
767   const HloInstruction* operand = fusion->operand(instr->parameter_number());
768   return DynCast<HloConstantInstruction>(operand);
769 }
770 
ShouldMergeIntoUsers(const HloInstruction * instr) const771 bool HloDotDumper::ShouldMergeIntoUsers(const HloInstruction* instr) const {
772   // If a node:
773   //
774   //  - is a get-tuple-element that isn't the root of the computation, or
775   //  - is a parameter of a fusion node which is bound to a constant, or
776   //  - all of:
777   //    - is a tuple-shaped parameter, and
778   //    - is not a parameter to a fusion node, and
779   //    - has at least kMinUsersToOmit users shown, and
780   //    - all of the shown users are get-tuple-elements,
781   //
782   // then we omit it from the graph, merging it with its users.
783   //
784   // This helps us handle the common case where a while loop body has one big
785   // tuple-shaped parameter.
786   if ((instr->opcode() == HloOpcode::kGetTupleElement &&
787        instr != instr->parent()->root_instruction()) ||
788       TryGetFusionParameterConstant(instr) != nullptr) {
789     return true;
790   }
791   const int kMinUsersToOmit = 3;
792   return instr->opcode() == HloOpcode::kParameter && instr->shape().IsTuple() &&
793          !instr->IsFused() &&
794          absl::c_count_if(instr->users(),
795                           [&](const HloInstruction* user) {
796                             return filter_.Show(user);
797                           }) > kMinUsersToOmit &&
798          absl::c_all_of(instr->users(), [&](const HloInstruction* user) {
799            return !filter_.Show(user) ||
800                   user->opcode() == HloOpcode::kGetTupleElement;
801          });
802 }
803 
DumpInstruction(const HloInstruction * instr)804 std::string HloDotDumper::DumpInstruction(const HloInstruction* instr) {
805   // We don't display constants or broadcasts of effective scalar constants
806   // within fusions as separate nodes; they're merged into their users.
807   if ((instr->opcode() == HloOpcode::kConstant ||
808        IsFusedBroadcastOfConstantEffectiveScalar(instr)) &&
809       instr != instr->parent()->root_instruction()) {
810     return "";
811   }
812   // Skip this node if it's merged into its users.
813   if (ShouldMergeIntoUsers(instr)) {
814     return "";
815   }
816   // Omit the fusion node if its subcomputation is drawn, since the
817   // subcomputation will be drawn inline.
818   if (instr->opcode() == HloOpcode::kFusion &&
819       ShouldShowFusionSubcomputation(instr)) {
820     return "";
821   }
822 
823   VLOG(2) << "Adding node " << instr->name() << " as " << next_node_id_;
824   node_ids_[instr] = next_node_id_++;
825 
826   ColorScheme color = GetInstructionColor(instr);
827   std::string node_shape = GetInstructionNodeShape(instr);
828   std::string node_label = GetInstructionNodeLabel(instr);
829   std::string node_metadata = GetInstructionNodeMetadata(instr);
830   std::string node_backend_config = GetInstructionNodeBackendConfig(instr);
831   std::string extra_info = GetInstructionNodeExtraInfo(instr);
832   std::string inlined_constants = GetInstructionNodeInlinedOperands(instr);
833   std::string trivial_subcomputation =
834       GetInstructionTrivialComputationStr(instr);
835   AddInstructionIncomingEdges(instr);
836 
837   if (!debug_options_.xla_hlo_graph_sharding_color()) {
838     // Override the node's styling if it should be (de-)emphasized.
839     if (filter_.Deemphasized(instr)) {
840       color = kDashedBorder;
841     }
842     if (filter_.Highlight(instr)) {
843       node_shape = "diamond";
844       color = kDarkRed;
845     }
846   }
847   // Build the text that will be displayed inside the node.
848   std::string node_body = node_label;
849   for (const std::string& s : {trivial_subcomputation, extra_info,
850                                inlined_constants, node_backend_config}) {
851     if (!s.empty()) {
852       StrAppend(&node_body, "<br/>", s);
853     }
854   }
855 
856   return StrFormat(R"(%s [label=<%s>, shape=%s, tooltip="%s", %s];)"
857                    "\n",
858                    InstructionId(instr), node_body, node_shape, node_metadata,
859                    NodeColorAttributes(color));
860 }
861 
GetInstructionNodeInlinedOperands(const HloInstruction * instr)862 std::string HloDotDumper::GetInstructionNodeInlinedOperands(
863     const HloInstruction* instr) {
864   // The constant's shape is a parameter because, in the case of a broadcasted
865   // scalar constant, we want to show the broadcasted shape, not the constant's
866   // scalar shape.
867   auto stringify_constant = [](const HloConstantInstruction* constant,
868                                const Shape& shape) {
869     // If the shape has a dimension of size zero, print it as e.g.
870     // "{} (f32[42, 0, 10])".  The alternative, calling Literal::ToString(),
871     // enumerates all of its empty dimensions (e.g.  "{ { {}, {} }, ..."), which
872     // is just noise.
873     if (ShapeUtil::IsZeroElementArray(shape)) {
874       return StrFormat("{} (%s)", ShapeUtil::HumanString(constant->shape()));
875     }
876 
877     // Print the literal value of constants with <= K elements.  Note that we
878     // use `constant->shape()` rather than `shape`, because if `constant` is a
879     // scalar that's broadcasted into `shape`, we want to print the constant.
880     optional<int64_t> elem_count;
881     if (shape.IsArray()) {
882       elem_count = ShapeUtil::ElementsIn(constant->shape());
883     }
884     // Allow HloDotDumper to print HloInstruction reconstructed from HloProto
885     // collected from profiling tools. Those constants may not have a valid
886     // literal.
887     if (elem_count.has_value() && *elem_count <= 8 && constant->HasLiteral()) {
888       // In addition to our check that the constant doesn't have too many
889       // elements, also check that the stringified constant isn't too long.  For
890       // example, 8 small ints is okay, but 8 long floats takes up a lot of
891       // horizontal space and probably isn't interesting.
892       std::string literal_str = constant->literal().ToStringWithoutShape();
893       if (literal_str.size() <= 64) {
894         return StrFormat("%s %s", shape.ToString(), literal_str);
895       }
896     }
897 
898     // Otherwise, print e.g. "%constant.42 (s32[100])".
899     std::string constant_name;
900     if (absl::StartsWith(constant->name(), "constant")) {
901       constant_name = constant->name();
902     } else {
903       constant_name = StrCat("constant ", constant->name());
904     }
905     return StrFormat("%s %s", constant_name, ShapeUtil::HumanString(shape));
906   };
907 
908   std::vector<std::string> lines;
909   for (int64_t i = 0; i < instr->operand_count(); ++i) {
910     const HloInstruction* operand = instr->operand(i);
911     optional<std::string> operand_str;
912     if (const auto* constant_operand =
913             DynCast<HloConstantInstruction>(operand)) {
914       operand_str =
915           stringify_constant(constant_operand, constant_operand->shape());
916     } else if (IsFusedBroadcastOfConstantEffectiveScalar(operand)) {
917       operand_str = stringify_constant(
918           Cast<HloConstantInstruction>(operand->operand(0)), operand->shape());
919     } else if (ShouldMergeIntoUsers(operand)) {
920       // Special case: If the operand is a parameter to a fusion node and it
921       // always has a constant value, display it like a regular constant.
922       //
923       // For other parameters, use the parameter number rather than the proper
924       // name, because that's generally how people think of the node.
925       if (operand->opcode() == HloOpcode::kParameter) {
926         if (const HloConstantInstruction* constant =
927                 TryGetFusionParameterConstant(operand)) {
928           operand_str = stringify_constant(constant, constant->shape());
929         } else {
930           operand_str = StrFormat("Parameter %d", operand->parameter_number());
931         }
932       } else if (operand->opcode() == HloOpcode::kGetTupleElement) {
933         operand_str =
934             StrFormat("tuple-element %d of %s %s", operand->tuple_index(),
935                       operand->operand(0)->name(),
936                       ShapeUtil::HumanStringWithLayout(operand->shape()));
937       } else {
938         operand_str = operand->name();
939       }
940     }
941 
942     if (operand_str) {
943       if (instr->operand_count() > 1) {
944         lines.push_back(StrFormat("<b>operand %d</b> = %s", i, *operand_str));
945       } else {
946         lines.push_back(StrFormat("<b>operand</b> = %s", *operand_str));
947       }
948     }
949   }
950 
951   // Special case: fused parameter is fed from a get-tuple-element.  If
952   // so, name the tuple index.
953   if (instr->opcode() == HloOpcode::kParameter && instr->IsFused()) {
954     const HloInstruction* param_input =
955         instr->parent()->FusionInstruction()->operand(
956             instr->parameter_number());
957     if (param_input->opcode() == HloOpcode::kGetTupleElement) {
958       lines.push_back(
959           StrFormat("tuple-element %d of %s %s", param_input->tuple_index(),
960                     param_input->operand(0)->name(),
961                     ShapeUtil::HumanStringWithLayout(param_input->shape())));
962     }
963   }
964 
965   return StrJoin(lines, "<br/>");
966 }
967 
GetInstructionColor(const HloInstruction * instr)968 ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) {
969   if (debug_options_.xla_hlo_graph_sharding_color()) {
970     if (!instr->has_sharding()) {
971       return kDashedBorder;
972     }
973     auto it = sharding_colors_.find(instr->sharding());
974     if (it != sharding_colors_.end()) {
975       return it->second;
976     }
977     ColorScheme color = static_cast<ColorScheme>(
978         kBlue + (next_shard_color_++ % (kDashedBorder - kBlue)));
979     sharding_colors_.emplace(instr->sharding(), color);
980     return color;
981   }
982 
983   // Choose different weights of orange for small vs large parameters.  This
984   // distinction is often important, especially in fusion nodes.
985   auto parameter_color = IsSmall(instr) ? kOrange : kDarkOrange;
986 
987   // Special case: If this instruction has a parameter merged into it, paint it
988   // the same color as a parameter.  Unless the merged-in parameter is a
989   // parameter to a fusion node that is bound to a constant -- these aren't
990   // "real" parameters from the user's perspective.
991   if (absl::c_any_of(instr->operands(), [&](const HloInstruction* operand) {
992         return operand->opcode() == HloOpcode::kParameter &&
993                ShouldMergeIntoUsers(operand) &&
994                TryGetFusionParameterConstant(operand) == nullptr;
995       })) {
996     return parameter_color;
997   }
998 
999   // Pick different colors or shapes for instructions which are particularly
1000   // expensive (eg, dot) and those which are unusual in some way or unique
1001   // (eg, parameter).
1002   switch (instr->opcode()) {
1003     case HloOpcode::kAbs:
1004     case HloOpcode::kAdd:
1005     case HloOpcode::kAnd:
1006     case HloOpcode::kAtan2:
1007     case HloOpcode::kBitcastConvert:
1008     case HloOpcode::kCeil:
1009     case HloOpcode::kClamp:
1010     case HloOpcode::kClz:
1011     case HloOpcode::kCompare:
1012     case HloOpcode::kComplex:
1013     case HloOpcode::kConvert:
1014     case HloOpcode::kCos:
1015     case HloOpcode::kDivide:
1016     case HloOpcode::kExp:
1017     case HloOpcode::kExpm1:
1018     case HloOpcode::kFloor:
1019     case HloOpcode::kImag:
1020     case HloOpcode::kIota:
1021     case HloOpcode::kIsFinite:
1022     case HloOpcode::kLog:
1023     case HloOpcode::kLog1p:
1024     case HloOpcode::kMaximum:
1025     case HloOpcode::kMinimum:
1026     case HloOpcode::kMultiply:
1027     case HloOpcode::kNegate:
1028     case HloOpcode::kNot:
1029     case HloOpcode::kPopulationCount:
1030     case HloOpcode::kOr:
1031     case HloOpcode::kXor:
1032     case HloOpcode::kPower:
1033     case HloOpcode::kReal:
1034     case HloOpcode::kRemainder:
1035     case HloOpcode::kRng:
1036     case HloOpcode::kRngGetAndUpdateState:
1037     case HloOpcode::kRngBitGenerator:
1038     case HloOpcode::kRoundNearestAfz:
1039     case HloOpcode::kRoundNearestEven:
1040     case HloOpcode::kRsqrt:
1041     case HloOpcode::kSelect:
1042     case HloOpcode::kShiftLeft:
1043     case HloOpcode::kShiftRightArithmetic:
1044     case HloOpcode::kShiftRightLogical:
1045     case HloOpcode::kLogistic:
1046     case HloOpcode::kSign:
1047     case HloOpcode::kSin:
1048     case HloOpcode::kSlice:
1049     case HloOpcode::kSort:
1050     case HloOpcode::kSqrt:
1051     case HloOpcode::kCbrt:
1052     case HloOpcode::kSubtract:
1053     case HloOpcode::kTanh:
1054       // De-emphasize scalar-shaped elementwise ops -- they're generally
1055       // uninteresting.
1056       if (ShapeUtil::IsEffectiveScalar(instr->shape())) {
1057         return kWhite;
1058       }
1059       return kYellow;
1060     case HloOpcode::kBitcast:
1061     case HloOpcode::kGetTupleElement:
1062     case HloOpcode::kAfterAll:
1063     case HloOpcode::kAddDependency:
1064     case HloOpcode::kTuple:
1065     case HloOpcode::kOptimizationBarrier:
1066       return kWhite;
1067     case HloOpcode::kConstant:
1068       // Constants aren't usually shown as their own nodes, but they'll be
1069       // present if e.g. they're the root of a computation.
1070       return kWhite;
1071     case HloOpcode::kBroadcast:
1072       // De-emphasize nodes which broadcast a scalar within a fusion node --
1073       // these are essentially free.
1074       if (instr->IsFused() &&
1075           ShapeUtil::IsEffectiveScalar(instr->operand(0)->shape())) {
1076         return kWhite;
1077       }
1078       return kGreen;
1079     case HloOpcode::kConcatenate:
1080     case HloOpcode::kDynamicSlice:
1081     case HloOpcode::kGather:
1082     case HloOpcode::kPad:
1083     case HloOpcode::kReshape:
1084     case HloOpcode::kDynamicReshape:
1085     case HloOpcode::kReverse:
1086     case HloOpcode::kTranspose:
1087       // De-emphasize scalar-shaped data movement ops and all data movement ops
1088       // inside fusion nodes, both of which are essentially free.
1089       if (ShapeUtil::IsEffectiveScalar(instr->shape()) || instr->IsFused()) {
1090         return kWhite;
1091       }
1092       return kGreen;
1093     case HloOpcode::kDynamicUpdateSlice:
1094       // Unlike the data-movement ops above, dynamic-update-slice is not ~free
1095       // inside of fusion nodes, so we de-emphasize it only if it's
1096       // scalar-shaped.
1097       if (ShapeUtil::IsEffectiveScalar(instr->shape())) {
1098         return kWhite;
1099       }
1100       return kGreen;
1101     case HloOpcode::kCopy:
1102     case HloOpcode::kCopyStart:
1103     case HloOpcode::kCopyDone:
1104       // Emphasize copy nodes, which are either physical transposes (and thus
1105       // significant), or copies of read-only buffers (and thus dead weight).
1106       return kGreen;
1107     case HloOpcode::kAsyncStart:
1108     case HloOpcode::kAsyncUpdate:
1109     case HloOpcode::kAsyncDone:
1110       return GetInstructionColor(instr->async_wrapped_instruction());
1111     case HloOpcode::kConvolution:
1112     case HloOpcode::kDot:
1113     case HloOpcode::kFft:
1114     case HloOpcode::kTriangularSolve:
1115     case HloOpcode::kCholesky:
1116       return kDarkBlue;
1117     case HloOpcode::kReducePrecision:
1118       return kRed;
1119     case HloOpcode::kParameter:
1120       return parameter_color;
1121     case HloOpcode::kBatchNormGrad:
1122     case HloOpcode::kBatchNormInference:
1123     case HloOpcode::kBatchNormTraining:
1124     case HloOpcode::kReduce:
1125     case HloOpcode::kReduceWindow:
1126     case HloOpcode::kScatter:  // scatter is a kind of reduction
1127     case HloOpcode::kSelectAndScatter:
1128       return kPurple;
1129     case HloOpcode::kDomain:
1130     case HloOpcode::kFusion:
1131     case HloOpcode::kMap:
1132     case HloOpcode::kGetDimensionSize:
1133     case HloOpcode::kSetDimensionSize:
1134       return kGray;
1135     case HloOpcode::kAllGather:
1136     case HloOpcode::kAllGatherStart:
1137     case HloOpcode::kAllGatherDone:
1138     case HloOpcode::kAllReduce:
1139     case HloOpcode::kReduceScatter:
1140     case HloOpcode::kAllReduceStart:
1141     case HloOpcode::kAllReduceDone:
1142     case HloOpcode::kAllToAll:
1143     case HloOpcode::kCollectivePermute:
1144     case HloOpcode::kCollectivePermuteStart:
1145     case HloOpcode::kCollectivePermuteDone:
1146     case HloOpcode::kInfeed:
1147     case HloOpcode::kOutfeed:
1148     case HloOpcode::kPartitionId:
1149     case HloOpcode::kRecv:
1150     case HloOpcode::kRecvDone:
1151     case HloOpcode::kSend:
1152     case HloOpcode::kSendDone:
1153     case HloOpcode::kReplicaId:
1154       return kBrown;
1155     case HloOpcode::kCall:
1156     case HloOpcode::kConditional:
1157     case HloOpcode::kCustomCall:
1158     case HloOpcode::kWhile:
1159       return kDarkGreen;
1160   }
1161 }
1162 
GetInstructionNodeShape(const HloInstruction * instr)1163 std::string HloDotDumper::GetInstructionNodeShape(const HloInstruction* instr) {
1164   // Give while loops a different shape so they're easier to pick out.
1165   switch (instr->opcode()) {
1166     case HloOpcode::kWhile:
1167       return "ellipse";
1168     default:
1169       return "rect";
1170   }
1171 }
1172 
GetInstructionNodeLabel(const HloInstruction * instr)1173 std::string HloDotDumper::GetInstructionNodeLabel(const HloInstruction* instr) {
1174   // If we have a parameter, put the param number in the name.
1175   if (instr->opcode() == HloOpcode::kParameter) {
1176     return StrFormat("<b>Parameter %d</b>", instr->parameter_number());
1177   }
1178 
1179   // The HLO instruction name contains usually the opcode, e.g. "%add.42" is
1180   // an add instruction.  In this case we render just the name.
1181   if (absl::StartsWith(instr->name(), HloOpcodeString(instr->opcode()))) {
1182     return StrFormat("<b>%s</b>", HtmlLikeStringSanitize(instr->name()));
1183   }
1184   std::string extended_opcode =
1185       StrCat(HloOpcodeString(instr->opcode()),
1186              instr->opcode() != HloOpcode::kFusion
1187                  ? ""
1188                  : StrCat(":", xla::ToString(instr->fusion_kind())));
1189   // If the name does not contain the opcode, render both.
1190   return StrFormat("<b>%s</b><br/>%s", HtmlLikeStringSanitize(extended_opcode),
1191                    HtmlLikeStringSanitize(instr->name()));
1192 }
1193 
GetInstructionNodeMetadata(const HloInstruction * instr)1194 std::string HloDotDumper::GetInstructionNodeMetadata(
1195     const HloInstruction* instr) {
1196   std::vector<std::string> lines;
1197   if (!instr->metadata().op_name().empty()) {
1198     lines.push_back(HtmlLikeStringSanitize(instr->metadata().op_name()));
1199   }
1200   if (!instr->metadata().op_type().empty()) {
1201     lines.push_back(StrFormat(
1202         "op_type: %s", HtmlLikeStringSanitize(instr->metadata().op_type())));
1203   }
1204   if (!instr->metadata().source_file().empty() &&
1205       instr->metadata().source_line() != 0) {
1206     lines.push_back(StrFormat("source: %s:%d", instr->metadata().source_file(),
1207                               instr->metadata().source_line()));
1208   }
1209 
1210   return StrJoin(lines, "\n");
1211 }
1212 
1213 static std::vector<std::pair<std::string, std::string>>
ExtractCudnnConvBackendConfigProps(const gpu::CudnnConvBackendConfig & config)1214 ExtractCudnnConvBackendConfigProps(const gpu::CudnnConvBackendConfig& config) {
1215   std::vector<std::pair<std::string, std::string>> props;
1216   if (config.conv_result_scale() != 1) {
1217     props.emplace_back("conv_result_scale", StrCat(config.conv_result_scale()));
1218   }
1219   if (config.side_input_scale() != 0 && config.side_input_scale() != 1) {
1220     props.emplace_back("side_input_scale", StrCat(config.side_input_scale()));
1221   }
1222   props.emplace_back(
1223       "activation_mode",
1224       se::dnn::ActivationModeString(
1225           static_cast<se::dnn::ActivationMode>(config.activation_mode())));
1226 
1227   props.emplace_back("algo",
1228                      se::dnn::AlgorithmDesc(config.algorithm()).ToString());
1229 
1230   // Skip workspace size; it's already explicit in the graph in the output shape
1231   // of the conv.
1232   return props;
1233 }
1234 
1235 static std::vector<std::pair<std::string, std::string>>
ExtractGemmBackendConfigProps(const gpu::GemmBackendConfig & config,const HloInstruction * instr)1236 ExtractGemmBackendConfigProps(const gpu::GemmBackendConfig& config,
1237                               const HloInstruction* instr) {
1238   std::vector<std::pair<std::string, std::string>> props;
1239   if (primitive_util::IsComplexType(instr->shape().element_type())) {
1240     if (config.alpha_real() != 1 || config.alpha_imag() != 1) {
1241       props.emplace_back("alpha_real", StrCat(config.alpha_real()));
1242       props.emplace_back("alpha_imag", StrCat(config.alpha_real()));
1243     }
1244   } else {
1245     if (config.alpha_real() != 1) {
1246       props.emplace_back("alpha", StrCat(config.alpha_real()));
1247     }
1248   }
1249   if (config.beta() != 0 && config.beta() != 1) {
1250     props.emplace_back("beta", StrCat(config.beta()));
1251   }
1252   props.emplace_back(
1253       "", absl::StrReplaceAll(
1254               DotDimensionNumbersToString(config.dot_dimension_numbers()),
1255               {{", ", "<br/>"}}));
1256   if (config.algorithm_case() == gpu::GemmBackendConfig::kSelectedAlgorithm) {
1257     props.emplace_back("algorithm", StrCat(config.selected_algorithm()));
1258   }
1259   return props;
1260 }
1261 
GetInstructionNodeBackendConfig(const HloInstruction * instr)1262 std::string HloDotDumper::GetInstructionNodeBackendConfig(
1263     const HloInstruction* instr) {
1264   // custom-calls for convs and gemms have backend-configs with fields that are
1265   // semantically significant.  Print these configs unconditionally.
1266   //
1267   // (We could elide the semantically-insignificant fields when
1268   // !show_backend_config, but this is simpler, and it's not too noisy.)
1269   std::vector<std::pair<std::string, std::string>> props;
1270   if (gpu::IsCustomCallToDnnConvolution(*instr)) {
1271     StatusOr<gpu::CudnnConvBackendConfig> config =
1272         instr->backend_config<gpu::CudnnConvBackendConfig>();
1273     if (config.ok()) {
1274       props = ExtractCudnnConvBackendConfigProps(*config);
1275     }
1276   } else if (instr->IsCustomCall(gpu::kGemmCallTarget)) {
1277     StatusOr<gpu::GemmBackendConfig> config =
1278         instr->backend_config<gpu::GemmBackendConfig>();
1279     if (config.ok()) {
1280       // gemm strides are generally uninteresting (derived from the instruction
1281       // shape), so we hide them by default.
1282       props = ExtractGemmBackendConfigProps(*config, instr);
1283     }
1284   }
1285 
1286   if (!props.empty()) {
1287     // Put a linebreak before the backend-config properties if there's more than
1288     // one.  Makes it easier to see.
1289     return StrCat((props.size() > 1 ? "<br/>" : ""),
1290                   StrJoin(props, "<br/>",
1291                           [](std::string* out,
1292                              const std::pair<std::string, std::string>& kv) {
1293                             if (!kv.first.empty()) {
1294                               return StrAppend(out, kv.first, "=", kv.second);
1295                             }
1296                             StrAppend(out, kv.second);
1297                           }));
1298   }
1299 
1300   if (!hlo_render_options_.show_backend_config ||
1301       instr->raw_backend_config_string().empty()) {
1302     return "";
1303   }
1304 
1305   return StrCat("backend_config=\"", instr->raw_backend_config_string(), "\"");
1306 }
1307 
GetInstructionNodeExtraInfo(const HloInstruction * instr)1308 std::string HloDotDumper::GetInstructionNodeExtraInfo(
1309     const HloInstruction* instr) {
1310   std::vector<std::string> lines;
1311 
1312   // Get the instruction's extra attributes excluding the names of its
1313   // subcomputations, since those are drawn explicitly in the graph.
1314   for (const auto& line : instr->ExtraAttributesToString(
1315            HloPrintOptions().set_print_subcomputation_mode(
1316                HloPrintOptions::PrintSubcomputationMode::kOff))) {
1317     // Some instructions have giant device identifier fields, so truncate their
1318     // length to 128.
1319     constexpr int kMaxDeviceIdFieldLen = 128;
1320     if ((absl::StartsWith(line, "replica_groups=") ||
1321          absl::StartsWith(line, "source_target_pairs=")) &&
1322         line.length() > kMaxDeviceIdFieldLen) {
1323       lines.push_back(HtmlLikeStringSanitize(
1324           StrCat(line.substr(0, kMaxDeviceIdFieldLen - 3), "...")));
1325     } else {
1326       lines.push_back(HtmlLikeStringSanitize(line));
1327     }
1328   }
1329 
1330   // Show the shape and layout of the instruction, unless it's an inlined fusion
1331   // node -- there the shape and layout is present in the output node.
1332   if (instr->opcode() != HloOpcode::kFusion ||
1333       !ShouldShowFusionSubcomputation(instr)) {
1334     // Show layout of instructions with more than one dimension.  Don't show
1335     // layout on tuples or tensors with just one dimension (which only have one
1336     // possible layout) to avoid visual noise.
1337     bool shape_is_multidim = false;
1338     ShapeUtil::ForEachSubshape(instr->shape(),
1339                                [&](const Shape& s, const ShapeIndex&) {
1340                                  shape_is_multidim |= s.dimensions_size() > 1;
1341                                });
1342     std::string instr_shape;
1343     if (instr->opcode() != HloOpcode::kTuple && shape_is_multidim) {
1344       instr_shape = ShapeUtil::HumanStringWithLayout(instr->shape());
1345     } else {
1346       instr_shape = ShapeUtil::HumanString(instr->shape());
1347     }
1348 
1349     // Some instructions have giant tuples as their shapes, so truncate the
1350     // HLO's shape to kMaxShapeLen characters.
1351     constexpr int kMaxShapeLen = 64;
1352     if (instr_shape.length() > kMaxShapeLen) {
1353       instr_shape = StrCat(
1354           absl::string_view(instr_shape).substr(0, kMaxShapeLen - 3), "...");
1355     }
1356     lines.push_back(HtmlLikeStringSanitize(instr_shape));
1357   }
1358   if (debug_options_.xla_hlo_graph_addresses()) {
1359     lines.push_back(StrFormat("[%p]", instr));
1360   }
1361   if (profile_ != nullptr) {
1362     double hlo_cycles_executed = profile_->GetCyclesTakenBy(*instr);
1363     double total_cycles_executed =
1364         profile_->total_cycles_executed(*instr->parent());
1365     if (hlo_cycles_executed > 0 && total_cycles_executed > 0) {
1366       lines.push_back(
1367           StrFormat("%% of cycles executed=%.2f",
1368                     100 * hlo_cycles_executed / total_cycles_executed));
1369     }
1370   }
1371   return StrJoin(lines, "<br/>");
1372 }
1373 
AddInstructionIncomingEdges(const HloInstruction * instr)1374 void HloDotDumper::AddInstructionIncomingEdges(const HloInstruction* instr) {
1375   auto add_edge = [&](const HloInstruction* from, const HloInstruction* to,
1376                       int64_t operand_num, bool control_edge = false) {
1377     from = GetNodeForEdge(from);
1378 
1379     if (!filter_.Show(from) || from->opcode() == HloOpcode::kConstant ||
1380         IsFusedBroadcastOfConstantEffectiveScalar(from) ||
1381         ShouldMergeIntoUsers(from)) {
1382       return;
1383     }
1384     VLOG(2) << "Adding edge from " << from->name() << " to " << to->name()
1385             << " as " << next_edge_id_;
1386     edge_ids_.insert({{from, to}, next_edge_id_++});
1387 
1388     std::string edge_label;
1389     if (control_edge) {
1390       edge_label = "style=\"dotted\" color=\"gray\" label=\"ctrl\"";
1391     } else if (instr->operand_count() > 1) {
1392       edge_label =
1393           StrFormat(R"( headlabel="%d", labeldistance=2)", operand_num);
1394     }
1395 
1396     // We print "small" arrays using a hollow arrowhead and "large" arrays using
1397     // a filled arrowhead.
1398     constexpr char kEdgeFmt[] =
1399         R"(%s -> %s [arrowhead=%s tooltip="%s -> %s" %s];)";
1400     edges_.push_back(StrFormat(kEdgeFmt, InstructionId(from), InstructionId(to),
1401                                (IsSmall(from) ? "empty" : "normal"),
1402                                from->name(), to->name(), edge_label));
1403   };
1404 
1405   // Add edges from instr's operands to instr.  Parameters within fusion
1406   // expressions are handled specially -- we draw an edge from the corresponding
1407   // operand on the fusion node itself to the parameter.
1408   if (instr->opcode() == HloOpcode::kParameter && instr->IsFused()) {
1409     // Only add the edge if this is not the outermost computation; otherwise it
1410     // will lead from a node we're not drawing.
1411     if (instr->parent() != computation_) {
1412       const HloInstruction* fusion = instr->parent()->FusionInstruction();
1413       add_edge(fusion->operand(instr->parameter_number()), instr,
1414                /*operand_num=*/0);
1415     }
1416   } else {
1417     for (int64_t i = 0; i < instr->operand_count(); ++i) {
1418       add_edge(instr->operand(i), instr, i);
1419     }
1420     for (const HloInstruction* pred : instr->control_predecessors()) {
1421       add_edge(pred, instr, /*operand_num=*/0, /*control_edge=*/true);
1422     }
1423   }
1424 }
1425 
GetInstructionTrivialComputationStr(const HloInstruction * instr)1426 std::string HloDotDumper::GetInstructionTrivialComputationStr(
1427     const HloInstruction* instr) {
1428   // called_computations() on a fusion node "inherits" any called computations
1429   // of the fused root, which isn't what we want.  Just ignore fusion nodes
1430   // here; they're handled separately.
1431   if (instr->opcode() == HloOpcode::kFusion) {
1432     return "";
1433   }
1434 
1435   std::vector<std::string> lines;
1436   for (int64_t i = 0; i < instr->called_computations().size(); ++i) {
1437     optional<std::string> computation_type =
1438         MatchTrivialComputation(instr->called_computations()[i]);
1439     if (!computation_type) {
1440       continue;
1441     }
1442     if (instr->called_computations().size() == 1) {
1443       lines.push_back(StrFormat("Subcomputation: <b>%s</b>",
1444                                 HtmlLikeStringSanitize(*computation_type)));
1445     } else {
1446       lines.push_back(StrFormat("Subcomputation %d: <b>%s</b>", i,
1447                                 HtmlLikeStringSanitize(*computation_type)));
1448     }
1449   }
1450   return StrJoin(lines, "<br/>");
1451 }
1452 
GetNodeForEdge(const HloInstruction * instr)1453 const HloInstruction* HloDotDumper::GetNodeForEdge(
1454     const HloInstruction* instr) {
1455   // Skip over get-tuple-element nodes.
1456   if (instr->opcode() == HloOpcode::kGetTupleElement) {
1457     instr = instr->operand(0);
1458   }
1459   while (instr->opcode() == HloOpcode::kFusion &&
1460          ShouldShowFusionSubcomputation(instr)) {
1461     instr = instr->fused_expression_root();
1462   }
1463   return instr;
1464 }
1465 
1466 // Gets a NodeFilter that includes roughly all instructions whose distance from
1467 // root is <= radius.
MakeNodeRadiusAroundFilter(const HloInstruction * root,int64_t radius,const absl::flat_hash_set<const HloInstruction * > & boundary)1468 NodeFilter MakeNodeRadiusAroundFilter(
1469     const HloInstruction* root, int64_t radius,
1470     const absl::flat_hash_set<const HloInstruction*>& boundary) {
1471   // First, find the neighborhood of nodes with distance from root <= radius.
1472   // These nodes are our initial set of "normal" nodes.
1473   absl::flat_hash_map<const HloInstruction*, NodeFilterResult> nodes;
1474   std::deque<std::pair<const HloInstruction*, /*depth*/ int64_t>> worklist;
1475   worklist.push_back({root, 0});
1476   while (!worklist.empty()) {
1477     const HloInstruction* instr;
1478     int64_t depth;
1479     std::tie(instr, depth) = worklist.front();
1480     worklist.pop_front();
1481 
1482     nodes[instr] = kNormalNode;
1483     if (depth == radius) {
1484       continue;
1485     }
1486     if (boundary.contains(instr)) {
1487       continue;
1488     }
1489 
1490     // Traverse into instr's operands.
1491     //
1492     // Don't traverse into tuples' operands unless the tuple is the root.
1493     // Usually a tuple is the bottommost node in the graph, and so its operands
1494     // are not interesting to the graph at hand.
1495     if (instr == root || instr->opcode() != HloOpcode::kTuple) {
1496       for (const HloInstruction* operand : instr->operands()) {
1497         if (!nodes.contains(operand)) {
1498           worklist.push_back({operand, depth + 1});
1499         }
1500       }
1501     }
1502 
1503     // Traverse into instr's nested computations.
1504     for (const HloComputation* computation : instr->called_computations()) {
1505       worklist.push_back({computation->root_instruction(), depth + 1});
1506     }
1507 
1508     // Traverse into instr's users, unless:
1509     //
1510     //  - there are a ton of them, in which case they're probably not
1511     //    interesting (and anyway, rendering them all would make the graph
1512     //    unreadable), or
1513     //  - instr is a constant, in which case its users are probably not
1514     //    interesting.
1515     if (instr->opcode() == HloOpcode::kConstant) {
1516       continue;
1517     }
1518     constexpr int kMaxUsersToRender = 16;
1519     if (instr->user_count() > kMaxUsersToRender) {
1520       // If we're going to skip this node's users, style it as such.
1521       nodes[instr] = kSomeUsersOmitted;
1522       continue;
1523     }
1524     for (const HloInstruction* user : instr->users()) {
1525       if (!nodes.contains(user)) {
1526         worklist.push_back({user, depth + 1});
1527       }
1528     }
1529   }
1530 
1531   auto is_displayed = [&](const HloInstruction* instr) {
1532     // Constants are displayed inline with their users; they're never omitted.
1533     // Nodes in subcomputations are always shown.
1534     return nodes.contains(instr) || instr->opcode() == HloOpcode::kConstant ||
1535            instr->parent() != root->parent();
1536   };
1537 
1538   // Make a second pass over 'nodes' to fix up the NodeFilterResults now that we
1539   // know which nodes will be included in the graph.
1540   for (auto& kv : nodes) {
1541     const HloInstruction* instr = kv.first;
1542     NodeFilterResult& filter_result = kv.second;
1543     const auto& operands = instr->operands();
1544 
1545     if (absl::c_any_of(operands, is_displayed) &&
1546         !absl::c_all_of(operands, is_displayed)) {
1547       // Mark nodes with some operands omitted appropriately.
1548       filter_result = kSomeOperandsOmitted;
1549     } else if (!operands.empty() && absl::c_none_of(operands, is_displayed)) {
1550       // Mark nodes with *all* operands omitted appropriately.
1551       filter_result = kOmitNodeOperands;
1552     }
1553 
1554     // Promote nodes with type kSomeUsersOmitted to kNormalNode if all of their
1555     // users made it into the graph.
1556     if (filter_result == kSomeUsersOmitted &&
1557         absl::c_all_of(instr->users(), is_displayed)) {
1558       filter_result = kNormalNode;
1559     }
1560   }
1561 
1562   // Highlight the root node.
1563   nodes[root] = kHighlightNode;
1564 
1565   return NodeFilter([=](const HloInstruction* instr) {
1566     auto it = nodes.find(instr);
1567     if (it != nodes.end()) {
1568       return it->second;
1569     }
1570     // Show all nodes in subcomputations.
1571     if (instr->parent() != root->parent()) {
1572       return kNormalNode;
1573     }
1574     return kHideNode;
1575   });
1576 }
1577 
1578 // Gets a node filter that includes nodes on all paths from `from` to `to`.  If
1579 // the all-paths set contains more than max_nodes elements, includes the nodes
1580 // on the shortest paths and sets hit_limit to true.
MakeNodeFromToFilter(const HloInstruction * from,const HloInstruction * to,int64_t max_nodes,bool * hit_limit)1581 NodeFilter MakeNodeFromToFilter(const HloInstruction* from,
1582                                 const HloInstruction* to, int64_t max_nodes,
1583                                 bool* hit_limit) {
1584   *hit_limit = false;
1585 
1586   // Elements in the queue are paths through the graph.
1587   std::deque<std::vector<const HloInstruction*>> queue;
1588   queue.push_front({from});
1589 
1590   // Compute the set of nodes we want to show using a slightly-modified
1591   // Djikstra's algorithm.  The only real difference is, rather than stopping
1592   // when we find a (shortest) path, we continue until we've found max_nodes
1593   // nodes on some path.
1594   absl::flat_hash_set<const HloInstruction*> visited;
1595   absl::flat_hash_set<const HloInstruction*> to_display = {from, to};
1596   while (!queue.empty() && to_display.size() < max_nodes) {
1597     std::vector<const HloInstruction*> path = std::move(queue.front());
1598     queue.pop_front();
1599     if (!visited.insert(path.back()).second) {
1600       continue;
1601     }
1602 
1603     for (const auto* user : path.back()->users()) {
1604       if (user == to) {
1605         auto it = path.begin();
1606         for (; it != path.end() && to_display.size() < max_nodes; ++it) {
1607           to_display.insert(*it);
1608         }
1609         if (it != path.end()) {
1610           *hit_limit = true;
1611         }
1612       } else if (!visited.count(user)) {
1613         auto new_path = path;
1614         new_path.push_back(user);
1615         queue.push_back(std::move(new_path));
1616       }
1617     }
1618   }
1619 
1620   return NodeFilter([=](const HloInstruction* instr) {
1621     if (instr == from || instr == to) {
1622       return kHighlightNode;
1623     }
1624     return to_display.count(instr) ? kNormalNode : kHideNode;
1625   });
1626 }
1627 
1628 static const char* kRenderDotJS = R"(
1629   <!-- Integrity hash is generated by https://www.srihash.org/ -->
1630   <script src="https://cdn.jsdelivr.net/npm/viz.js@2.1.1/viz.js"
1631      integrity="sha384-aD1MJYb0WKIUT+CtwJp5LTuV3U4pLAS6B/nUxL7ECimC2pN9N8vjlMr/yQCAkzxE"
1632      crossorigin="anonymous"></script>
1633   <script src="https://cdn.jsdelivr.net/npm/viz.js@2.1.1/full.render.js"
1634      integrity="sha384-bAixY275aIpCj6Te19y0MILZ4V+VEC8CVFujFEH+Lf7W+4XYYeYLwW5IBI6yQmMT"
1635      crossorigin="anonymous"></script>
1636   <script src="https://cdn.jsdelivr.net/npm/svg-pan-zoom@3.6.0/dist/svg-pan-zoom.min.js"
1637      integrity="sha384-3008WpYB2pOBvE7lwkrKf+qTmbTPGGPYxA9C1YVhvbPukns4ZFj7E98QPLkNW9dS"
1638      crossorigin="anonymous"></script>
1639   <script src="https://cdn.jsdelivr.net/npm/@hpcc-js/wasm/dist/index.min.js"
1640      integrity="sha384-X+8WXyWZ+W2gUHiSSj0aePAkE77Fl6eZ+QIByw+Ii8LzWEJ/W8bI8M4RkneDAJ4D"
1641      crossorigin="anonymous"></script>
1642 )";
1643 
WrapDotInHtml(absl::string_view dot)1644 std::string WrapDotInHtml(absl::string_view dot) {
1645   std::string html_prefix =
1646       absl::StrReplaceAll(R"html(
1647 <!DOCTYPE html>
1648 <html>
1649 <head>
1650   <meta charset="utf-8">
1651   <style type="text/css">
1652     body {
1653       height: 100vh;
1654       margin: 0;
1655     }
1656   </style>
1657 </head>
1658 <body>
1659   $JS_INCLUDE
1660   <div id="container" style="height:95vh; border:1px solid black; "></div>
1661   <script>
1662     var data = `
1663 )html",
1664                           {{"$JS_INCLUDE", kRenderDotJS}});
1665 
1666   static const char html_suffix[] = R"html(
1667 `;
1668     var cssregex = new RegExp('stylesheet=<([^]*)\n>\n', 'gm');
1669     var results = cssregex.exec(data)
1670     // graphviz has problem dealing with large stylesheets.
1671     // https://github.com/tensorflow/tensorflow/issues/17220#issuecomment-369228492
1672     // In order to avoid the problem, remove the stylesheet from the dot and
1673     // insert it directly info the rendered SVG.
1674     var dot_data = data;
1675     var css_data = ''
1676     if (results !== null) {
1677         css_data = results[1].replace(/\s*data:.*\s*,/,''); // Strip content-type field.
1678         // CSS inside DOT is URL-escaped, so we must unescape it
1679         // before we can insert it into SVG.
1680         css_data = unescape(css_data);
1681         dot_data = data.replace(cssregex, ''); // Remove the stylesheet
1682     }
1683 
1684     var render_start = performance.now()
1685     function add_controls(svg) {
1686         var htmlblob = new Blob([document.documentElement.innerHTML],
1687                                 {type: 'text/html'});
1688         var savehtml = document.createElement('a');
1689         savehtml.setAttribute('href', URL.createObjectURL(htmlblob));
1690         savehtml.setAttribute('download', 'graph.html');
1691         savehtml.innerHTML = " [Save HTML+SVG] ";
1692         document.body.append(savehtml);
1693         var svgblob = new Blob([svg.outerHTML], {type: 'image/svg'});
1694         var savesvg = document.createElement('a');
1695         savesvg.setAttribute('href', URL.createObjectURL(svgblob));
1696         savesvg.setAttribute('download', 'graph.svg');
1697         savesvg.innerHTML = " [Save SVG] ";
1698         document.body.append(savesvg);
1699         var dotblob =  new Blob([data], {type: 'text/dot'});
1700         var savedot = document.createElement('a');
1701         savedot.setAttribute('href', URL.createObjectURL(dotblob));
1702         savedot.setAttribute('download', 'graph.dot');
1703         savedot.innerHTML = " [Save DOT] ";
1704         document.body.append(savedot);
1705         // Will get called after embed element was loaded
1706         var panzoom = svgPanZoom(svg, {
1707             zoomEnabled: true,
1708             controlIconsEnabled: true,
1709         });
1710         document.getElementsByTagName("BODY")[0].onresize = function() {
1711             panzoom.resize();
1712             panzoom.fit();
1713             panzoom.center();
1714         };
1715         var render_end = performance.now();
1716         var render_note = document.createElement('div')
1717         render_note.innerHTML = 'Rendering took '
1718                                 + (render_end - render_start).toFixed(2) + "ms."
1719         document.body.append(render_note);
1720     }
1721     var svg = document.getElementById('graph')
1722     if (svg == null) {
1723         // Need to render SVG first.
1724         var viz = new Viz();
1725         viz.renderSVGElement(dot_data)
1726             .then(function(svg){
1727                 var container = document.getElementById('container')
1728                 var style = document.createElementNS('http://www.w3.org/2000/svg', 'style');
1729                 var node = document.createTextNode(css_data);
1730                 style.appendChild(node);
1731                 svg.setAttribute('width', '100%');
1732                 svg.setAttribute('height', '100%');
1733                 svg.setAttribute('id', 'graph');
1734                 svg.appendChild(style);
1735                 container.appendChild(svg);
1736                 add_controls(svg);
1737             })
1738     } else {
1739         // HTML already has rendered SVG embedded, so we just need to add
1740         // controls.
1741         add_controls(svg);
1742     }
1743   </script>
1744 </body>
1745 </html>
1746 )html";
1747 
1748   return absl::StrCat(html_prefix, dot, html_suffix);
1749 }
1750 
1751 absl::Mutex url_renderer_mu(absl::kConstInit);
1752 std::function<StatusOr<std::string>(absl::string_view)>* url_renderer
1753     ABSL_GUARDED_BY(url_renderer_mu) = nullptr;
1754 
1755 // Storage for fusion visualization: (module_id, computation_id) -> sequence of
1756 // fusion states.
1757 absl::Mutex fusion_visualizer_state_mu(absl::kConstInit);
1758 namespace {
1759 
1760 // Fusion state: a sequence of rendered graphs in DOT formats with explanations.
1761 // Rendered graphs can be shared across frames, hence the storage indirection.
1762 struct FusionVisualizerProgress {
1763   // Creates a frame with a new rendered graph.
AddStatexla::__anon758a858c0111::__anon758a858c0f11::FusionVisualizerProgress1764   void AddState(absl::string_view dot, absl::string_view explanation,
1765                 std::optional<std::string> to_highlight) {
1766     if (dot_graphs.empty() || dot_graphs.back() != dot) {
1767       dot_graphs.push_back(std::string(dot));
1768     }
1769     frames.push_back({static_cast<int>(dot_graphs.size() - 1),
1770                       std::string(explanation), to_highlight.value_or("")});
1771   }
1772 
1773   std::vector<std::string> dot_graphs;
1774 
1775   struct FusionFrame {
1776     int dot_graph;
1777     std::string label;
1778     std::string to_highlight;
1779   };
1780 
1781   std::vector<FusionFrame> frames;
1782 };
1783 
1784 }  // namespace
1785 
1786 static auto& fusion_visualizer_states
1787     TF_GUARDED_BY(fusion_visualizer_state_mu) = *new absl::flat_hash_map<
1788         std::pair<int64_t, int64_t>, FusionVisualizerProgress>();
1789 
1790 // Generates a key to the fusion visualizer state mapping.
1791 static std::pair<int, int> FusionVisualizerStateKey(
1792     const HloComputation& computation) {
1793   return std::make_pair(computation.parent()->unique_id(),
1794                         computation.unique_id());
1795 }
1796 
1797 // Precondition: (url_renderer != nullptr || format != kUrl).
1798 //
1799 // (We specify this as a precondition rather than checking it in here and
1800 // returning an error because we want to fail quickly when there's no URL
1801 // renderer available, and this function runs only after we've done all the work
1802 // of producing dot for the graph.)
1803 StatusOr<std::string> WrapDotInFormat(const HloComputation& computation,
1804                                       absl::string_view dot,
1805                                       RenderedGraphFormat format)
1806     ABSL_EXCLUSIVE_LOCKS_REQUIRED(url_renderer_mu) {
1807   switch (format) {
1808     case RenderedGraphFormat::kUrl:
1809       CHECK(url_renderer != nullptr)
1810           << "Should have checked url_renderer != null before calling.";
1811       return (*url_renderer)(dot);
1812     case RenderedGraphFormat::kHtml:
1813       return WrapDotInHtml(dot);
1814     case RenderedGraphFormat::kDot:
1815       return std::string(dot);
1816   }
1817 }
1818 
1819 }  // namespace
1820 
1821 // Compress with zlib + b64 encode.
1822 static StatusOr<std::string> CompressAndEncode(absl::string_view input) {
1823   class WritableStringFile : public tensorflow::WritableFile {
1824    public:
1825     explicit WritableStringFile(std::string* data) : data_(data){};
1826     ~WritableStringFile() override = default;
1827 
1828     Status Append(absl::string_view data) override {
1829       absl::StrAppend(data_, data);
1830       return OkStatus();
1831     }
1832 
1833     Status Close() override { return OkStatus(); }
1834     Status Flush() override { return OkStatus(); }
1835     Status Sync() override { return OkStatus(); }
1836 
1837    private:
1838     std::string* data_;
1839   };
1840 
1841   std::string compressed;
1842   WritableStringFile f(&compressed);
1843 
1844   auto gz_opts = tensorflow::io::ZlibCompressionOptions::GZIP();
1845   tensorflow::io::ZlibOutputBuffer gz_file(&f, gz_opts.input_buffer_size,
1846                                            gz_opts.output_buffer_size, gz_opts);
1847   TF_RETURN_IF_ERROR(gz_file.Init());
1848   TF_RETURN_IF_ERROR(gz_file.Append(input));
1849   TF_RETURN_IF_ERROR(gz_file.Close());
1850 
1851   std::string encoded;
1852   TF_RETURN_IF_ERROR(tensorflow::Base64Encode(compressed, &encoded));
1853   return absl::StrReplaceAll(encoded, {{"_", "/"}, {"-", "+"}});
1854 }
1855 
1856 static std::string EscapeJSONString(absl::string_view raw) {
1857   return absl::StrCat(
1858       "\"",
1859       absl::StrReplaceAll(raw, {{"\n", "\\n"}, {"\"", "\\\""}, {"\\", "\\\\"}}),
1860       "\"");
1861 }
1862 
WrapFusionExplorer(const HloComputation & computation)1863 StatusOr<std::string> WrapFusionExplorer(const HloComputation& computation) {
1864   absl::MutexLock lock(&fusion_visualizer_state_mu);
1865   using absl::StrAppend;
1866   using absl::StrFormat;
1867   using absl::StrJoin;
1868   const FusionVisualizerProgress& visualizer_progress =
1869       fusion_visualizer_states[FusionVisualizerStateKey(computation)];
1870   if (visualizer_progress.frames.empty()) {
1871     return InternalError("Empty");
1872   }
1873 
1874   std::string dot_graphs =
1875       StrFormat("[%s]", StrJoin(visualizer_progress.dot_graphs, ", ",
1876                                 [&](std::string* out, const std::string& dot) {
1877                                   StrAppend(out, EscapeJSONString(dot));
1878                                 }));
1879 
1880   std::string frames = StrJoin(
1881       visualizer_progress.frames, ", ", [&](std::string* out, const auto& p) {
1882         StrAppend(out, StrFormat("[%d, %s, %s]", p.dot_graph,
1883                                  EscapeJSONString(p.label),
1884                                  EscapeJSONString(p.to_highlight)));
1885       });
1886 
1887   TF_ASSIGN_OR_RETURN(std::string dot_graphs_compressed,
1888                       CompressAndEncode(dot_graphs));
1889 
1890   return absl::StrReplaceAll(
1891       R"(
1892 <!DOCTYPE html>
1893 <html>
1894 <head>
1895   <meta charset="utf-8">
1896   <style>
1897     html, body {height: 100%; text-align: center;}
1898     #rendered {height: 70%; width: 80%; border:1px solid black; margin: auto; }
1899     #label {width: 80%; margin: auto;}
1900     #performance_note { font-size: small; color: gray; }
1901     #frames_list {
1902       list-style: none; text-align: left; height: 20%; overflow: scroll;
1903     }
1904     #frames_list   li { padding: 0.2em; margin: 0.2em; }
1905     .selected { background-color: #e0e0e0; }
1906     .selected a { color: black; text-decoration: none; }
1907     #rendered svg { height: 100% !important; width: 100% !important; }
1908   </style>
1909 </head>
1910 <body>
1911   <script src="https://www.gstatic.com/external_hosted/hpcc_js_wasm/index.min.js"
1912       integrity="sha384-LigJPbR3TOfU/Xbb+PjiN1dGJYPweLk7kiGnaMgmxnUmKWaCFKbb5tH6iLlyVhPZ"
1913       crossorigin="anonymous"></script>
1914   <script src="https://www.gstatic.com/external_hosted/svg_pan_zoom/svg-pan-zoom.js">
1915   </script>
1916 
1917   <title>Fusion Explorer: $TITLE</title>
1918   <div id='rendered'><center>Loading...</center></div>
1919   <ul id='frames_list'></ul>
1920   <p>Use j/k for keyboard navigation.</p>
1921   <p id='performance_note'>Loading data...</p>
1922   <script>
1923   <!--
1924   const renderCache = {};
1925 
1926   const cssregex = new RegExp('stylesheet=<([^]*)\n>\n', 'gm');
1927   const hpccWasm = window["@hpcc-js/wasm"];
1928 
1929   const getIdFromHash = () => {
1930     let hash = window.location.hash;
1931     if (hash.indexOf('frame') == -1) {
1932       return 0;
1933     }
1934     return parseInt(window.location.hash.substring('#frame'.length, window.location.hash.length));
1935   }
1936 
1937   const renderCurrentFrame = () => {
1938     if (!window.loaded) { return; }
1939     const frames_list = document.getElementById('frames_list');
1940     const currId = getIdFromHash();
1941 
1942     for (let selected of frames_list.getElementsByClassName('selected')) {
1943         selected.classList.remove('selected');
1944     }
1945 
1946     const selected = frames_list.children[currId];
1947     selected.classList.add('selected');
1948     selected.scrollIntoView();
1949 
1950     const frame = frames[currId];
1951     const dot_ptr = frame[0];
1952     let dot_txt = window.dots[dot_ptr];
1953     const label = frame[1];
1954     document.getElementById('performance_note').innerText = "Rendering...";
1955     const results = cssregex.exec(dot_txt)
1956     let css_data = ''
1957     if (results !== null) {
1958         css_data = results[1].replace(/\s*data:.*\s*,/,''); // Strip content-type field.
1959         // CSS inside DOT is URL-escaped, so we must unescape it
1960         // before we can insert it into SVG.
1961         css_data = unescape(css_data);
1962         dot_txt = dot_txt.replace(cssregex, ''); // Remove the stylesheet
1963     }
1964 
1965     let render_start = performance.now();
1966     const render_callback = svg => {
1967       renderCache[dot_ptr] = svg;
1968       var area = document.getElementById('rendered');
1969       area.innerHTML = `${svg}<style>${css_data}</style>`;
1970       var panzoom = svgPanZoom(area.children[0], {
1971           zoomEnabled: true, controlIconsEnabled: true, });
1972       var to_highlight = frame[2].length ?
1973         document.querySelector(`${frame[2]}`) : null;
1974       if (to_highlight) {
1975         to_highlight.style.setProperty('fill', 'red');
1976       }
1977       document.getElementById('performance_note').innerText =
1978         `Rendering took ${(performance.now() - render_start).toFixed(2)}ms`;
1979     };
1980     if (renderCache[dot_ptr]) {
1981       render_callback(renderCache[dot_ptr]);
1982     } else {
1983       hpccWasm.graphviz.layout(dot_txt, "svg", "dot").then(render_callback);
1984     }
1985   };
1986 
1987   const update = (delta) => {
1988     let currId = getIdFromHash();
1989     currId = (currId + delta + frames.length) % frames.length;
1990     window.location.hash = `#frame${currId}`
1991   };
1992 
1993   const renderFrameList = () => {
1994     const currId = getIdFromHash();
1995     const frames_list = document.getElementById('frames_list');
1996     for (let i=0; i<frames.length; i++) {
1997       const f = frames[i];
1998       let frame_descr = f[1];
1999       const rendered = document.createElement("li");
2000       if (frame_descr == "") {
2001         frame_descr = "Unnamed state";
2002       }
2003       rendered.innerHTML = `<a href="#frame${i}">${frame_descr}</a>`;
2004       if (i == currId) {
2005         rendered.classList.add('selected');
2006       }
2007       frames_list.appendChild(rendered);
2008     }
2009   };
2010 
2011   const decompress = async function(compressed) {
2012     const ds = new DecompressionStream('gzip');
2013     const in_fetch = await fetch(`data:application/octet-stream;base64,${compressed}`);
2014     const in_blob = await in_fetch.blob();
2015     const out_stream = in_blob.stream().pipeThrough(ds);
2016     const out_blob = await new Response(out_stream).blob();
2017     return await out_blob.text();
2018   }
2019 
2020   const dots_compressed = "$DOTS";
2021   const frames = [$FRAMES];
2022   let loaded = false;
2023 
2024   window.addEventListener('hashchange', () => {
2025     renderCurrentFrame();
2026   });
2027 
2028   window.addEventListener("keydown", (event) => {
2029     if (event.defaultPrevented) {
2030       return;
2031     }
2032     if (event.key == "j") {
2033       update(1);
2034     } else if (event.key == "k") {
2035       update(-1);
2036     } else {
2037       return;
2038     }
2039     event.preventDefault();
2040   }, true);
2041 
2042   document.addEventListener("DOMContentLoaded", () => {
2043     decompress(dots_compressed).then(text => {
2044       window.dots = JSON.parse(text);
2045       window.loaded = true;
2046       renderFrameList();
2047       renderCurrentFrame();
2048     });
2049   });
2050 
2051   //-->
2052   </script>
2053   </body>
2054 </html>
2055   )",
2056       {{"$DOTS", dot_graphs_compressed},
2057        {"$FRAMES", frames},
2058        {"$TITLE",
2059         absl::StrCat(computation.parent()->name(), "_", computation.name())}});
2060 }
2061 
2062 void RegisterGraphToURLRenderer(
2063     std::function<StatusOr<std::string>(absl::string_view)> renderer) {
2064   absl::MutexLock lock(&url_renderer_mu);
2065   if (url_renderer != nullptr) {
2066     LOG(WARNING) << "Multiple calls to RegisterGraphToURLRenderer.  Last call "
2067                     "wins, but because order of initialization in C++ is "
2068                     "nondeterministic, this may not be what you want.";
2069   }
2070   delete url_renderer;
2071   url_renderer = new std::function<StatusOr<std::string>(absl::string_view)>(
2072       std::move(renderer));
2073 }
2074 
2075 void RegisterFusionState(const HloComputation& computation,
2076                          absl::string_view label,
2077                          const HloInstruction& consumer,
2078                          const HloInstruction* producer) {
2079   absl::MutexLock lock(&fusion_visualizer_state_mu);
2080   FusionVisualizerProgress& fusion_progress =
2081       fusion_visualizer_states[FusionVisualizerStateKey(computation)];
2082 
2083   // Radius size in which to render.
2084   static constexpr int kRenderRadius = 4;
2085 
2086   absl::flat_hash_set<const HloInstruction*> render_boundary;
2087   for (const HloInstruction* user : consumer.users()) {
2088     render_boundary.insert(user);
2089   }
2090 
2091   HloDotDumper dumper(
2092       consumer.parent(),
2093       StrCat("Rendering of ", kRenderRadius, " nodes around fusion consumer"),
2094       consumer.GetModule()->config().debug_options(), {}, /*profile=*/nullptr,
2095       MakeNodeRadiusAroundFilter(&consumer, kRenderRadius, render_boundary));
2096   std::string dot_txt = dumper.Dump();
2097 
2098   std::optional<std::string> producer_to_highlight;
2099   if (producer) {
2100     producer_to_highlight = dumper.CssIdForInstruction(*producer);
2101   }
2102 
2103   fusion_progress.AddState(dot_txt, label, producer_to_highlight);
2104 }
2105 
2106 StatusOr<std::string> RenderGraph(
2107     const HloComputation& computation, absl::string_view label,
2108     const DebugOptions& debug_options, RenderedGraphFormat format,
2109     const HloExecutionProfile* hlo_execution_profile,
2110     HloRenderOptions hlo_render_options) {
2111   absl::MutexLock lock(&url_renderer_mu);
2112   if (format == RenderedGraphFormat::kUrl && url_renderer == nullptr) {
2113     return Unavailable("Can't render as URL; no URL renderer was registered.");
2114   }
2115 
2116   std::string rendered_dot =
2117       HloDotDumper(&computation, label, debug_options, hlo_render_options,
2118                    hlo_execution_profile, NodeFilter())
2119           .Dump();
2120   return WrapDotInFormat(computation, rendered_dot, format);
2121 }
2122 
2123 StatusOr<std::string> RenderNeighborhoodAround(
2124     const HloInstruction& node, int radius, RenderedGraphFormat format,
2125     HloRenderOptions hlo_render_options,
2126     const absl::flat_hash_set<const HloInstruction*>& boundary) {
2127   absl::MutexLock lock(&url_renderer_mu);
2128   if (format == RenderedGraphFormat::kUrl && url_renderer == nullptr) {
2129     return FailedPrecondition(
2130         "Can't render as URL; no URL renderer was registered.");
2131   }
2132 
2133   std::string label =
2134       StrCat("Neighborhood of ", radius, " nodes around ", node.name());
2135   std::string rendered_dot =
2136       HloDotDumper(node.parent(), label,
2137                    node.GetModule()->config().debug_options(),
2138                    hlo_render_options, /*profile=*/nullptr,
2139                    MakeNodeRadiusAroundFilter(&node, radius, boundary))
2140           .Dump();
2141   return WrapDotInFormat(*node.parent(), rendered_dot, format);
2142 }
2143 
2144 StatusOr<std::string> RenderAllPathsFromTo(
2145     const HloInstruction& from, const HloInstruction& to, int64_t max_nodes,
2146     RenderedGraphFormat format, HloRenderOptions hlo_render_options) {
2147   absl::MutexLock lock(&url_renderer_mu);
2148   if (format == RenderedGraphFormat::kUrl && url_renderer == nullptr) {
2149     return FailedPrecondition(
2150         "Can't render as URL; no URL renderer was registered.");
2151   }
2152 
2153   CHECK_EQ(from.parent(), to.parent()) << "Nodes must be in same computation!";
2154   auto debug_options = from.GetModule()->config().debug_options();
2155 
2156   bool hit_limit = false;
2157   NodeFilter filter = MakeNodeFromToFilter(&from, &to, max_nodes, &hit_limit);
2158   std::string label;
2159   if (!hit_limit) {
2160     label = StrCat("All paths from ", from.name(), " to ", to.name());
2161   } else {
2162     label = StrCat(max_nodes, " nodes on the shortest paths from ", from.name(),
2163                    " to ", to.name(),
2164                    "<br/><br/>***SHOWING ONLY A SUBSET OF ALL PATHS BETWEEN "
2165                    "NODES***<br/><br/>");
2166   }
2167   std::string rendered_dot =
2168       HloDotDumper(from.parent(), label, debug_options, hlo_render_options,
2169                    /*profile=*/nullptr, filter)
2170           .Dump();
2171   return WrapDotInFormat(*from.parent(), rendered_dot, format);
2172 }
2173 
2174 }  // namespace xla
2175