• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
17 
18 #include <unistd.h>
19 
20 #include <algorithm>
21 #include <atomic>
22 #include <deque>
23 #include <map>
24 #include <memory>
25 #include <queue>
26 #include <string>
27 #include <tuple>
28 #include <vector>
29 
30 #include "absl/container/flat_hash_map.h"
31 #include "absl/strings/match.h"
32 #include "absl/strings/str_cat.h"
33 #include "absl/strings/str_format.h"
34 #include "absl/strings/str_join.h"
35 #include "absl/strings/str_replace.h"
36 #include "absl/types/optional.h"
37 #include "tensorflow/compiler/xla/layout_util.h"
38 #include "tensorflow/compiler/xla/literal.h"
39 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
40 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
41 #include "tensorflow/compiler/xla/service/hlo_module.h"
42 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
43 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
44 #include "tensorflow/compiler/xla/shape_util.h"
45 #include "tensorflow/compiler/xla/types.h"
46 #include "tensorflow/compiler/xla/util.h"
47 #include "tensorflow/compiler/xla/window_util.h"
48 #include "tensorflow/core/lib/core/status.h"
49 #include "tensorflow/core/lib/gtl/map_util.h"
50 #include "tensorflow/core/lib/io/path.h"
51 #include "tensorflow/core/lib/strings/numbers.h"
52 #include "tensorflow/core/platform/env.h"
53 #include "tensorflow/core/platform/mutex.h"
54 #include "tensorflow/core/platform/protobuf.h"
55 #include "tensorflow/core/platform/regexp.h"
56 
57 namespace xla {
58 namespace {
59 
60 using absl::nullopt;
61 using absl::optional;
62 using absl::StrAppend;
63 using absl::StrCat;
64 using absl::StrFormat;
65 using absl::StrJoin;
66 
67 // Used to indicate how we should treat a given HLOInstruction in the graph.
68 // should we treat it like normal, hide it, and so on?
69 enum NodeFilterResult {
70   kNormalNode,
71   kHideNode,
72   // Make the node easy to find in the final graph.
73   kHighlightNode,
74   // "Gray out" the node to indicate that some of its operands have been
75   // omitted.
76   kSomeOperandsOmitted,
77   // Style the node the same as kSomeOperandsOmitted, but also don't connect it
78   // to its operands, even if they're present in the graph.
79   kOmitNodeOperands,
80   // Same style as kSomeOperandsOmitted, but used to indicate that some of the
81   // node's *users* have been omitted.
82   kSomeUsersOmitted,
83 };
84 
85 // NodeFilter is essentially a map from HloInstruction*s to NodeFilterResult.
86 // It lets callers tell the graph-drawing routines which nodes they want to be
87 // shown, hidden, or highlighted.
88 class NodeFilter {
89  public:
__anon973ba4530202(const HloInstruction*) 90   NodeFilter() : filter_([](const HloInstruction*) { return kNormalNode; }) {}
91 
NodeFilter(std::function<NodeFilterResult (const HloInstruction * instr)> filter)92   explicit NodeFilter(
93       std::function<NodeFilterResult(const HloInstruction* instr)> filter)
94       : filter_(std::move(filter)) {}
95 
Show(const HloInstruction * instr) const96   bool Show(const HloInstruction* instr) const {
97     return filter_(instr) != kHideNode;
98   }
Highlight(const HloInstruction * instr) const99   bool Highlight(const HloInstruction* instr) const {
100     return filter_(instr) == kHighlightNode;
101   }
OmitOperands(const HloInstruction * instr) const102   bool OmitOperands(const HloInstruction* instr) const {
103     return filter_(instr) == kOmitNodeOperands;
104   }
SomeOrAllOperandsOmitted(const HloInstruction * instr) const105   bool SomeOrAllOperandsOmitted(const HloInstruction* instr) const {
106     auto result = filter_(instr);
107     return result == kOmitNodeOperands || result == kSomeOperandsOmitted;
108   }
Deemphasized(const HloInstruction * instr) const109   bool Deemphasized(const HloInstruction* instr) const {
110     auto result = filter_(instr);
111     return result == kOmitNodeOperands || result == kSomeOperandsOmitted ||
112            result == kSomeUsersOmitted;
113   }
114 
115  private:
116   std::function<NodeFilterResult(const HloInstruction* instr)> filter_;
117 };
118 
119 // We arbitrarily set this as the boundary between "large" and "small"
120 // instructions.
IsSmall(const HloInstruction * instr)121 bool IsSmall(const HloInstruction* instr) {
122   if (ShapeUtil::HasPrimitiveType(instr->shape(), OPAQUE_TYPE) ||
123       ShapeUtil::HasPrimitiveType(instr->shape(), TOKEN)) {
124     return true;
125   }
126   return ShapeUtil::ElementsInRecursive(instr->shape()) < 4096;
127 }
128 
129 // Node color schemes, used by NodeColorAttributes.
130 enum ColorScheme {
131   kBlue,
132   kBrown,
133   kDarkBlue,
134   kDarkGreen,
135   kDarkOrange,
136   kDarkRed,
137   kGray,
138   kGreen,
139   kOrange,
140   kPurple,
141   kRed,
142   kWhite,
143   kYellow,
144 
145   // Causes the node's border to be a dashed line, and its content to be gray
146   // text on a white background, suggesting that this is an "unimportant" node.
147   kDashedBorder,
148 };
149 
150 // Graphviz attributes/colors that make up a color scheme.
151 struct NodeColors {
152   const char* style;
153   const char* fill_color;
154   const char* stroke_color;
155   const char* font_color;
156 };
157 
NodeColorsForScheme(ColorScheme color)158 NodeColors NodeColorsForScheme(ColorScheme color) {
159   switch (color) {
160     case kBlue:
161       return NodeColors{"filled", "#bbdefb", "#8aacc8", "black"};
162     case kBrown:
163       return NodeColors{"filled", "#bcaaa4", "#8c7b75", "black"};
164     case kDarkBlue:
165       return NodeColors{"filled", "#1565c0", "#003c8f", "white"};
166     case kDarkGreen:
167       return NodeColors{"filled", "#2e7d32", "#005005", "white"};
168     case kDarkOrange:
169       // This is more of a "medium" orange, made to look close to kOrange;
170       // there's probably room for a darker weight if desired.
171       return NodeColors{"filled", "#ffb74d", "#c88719", "black"};
172     case kDarkRed:
173       return NodeColors{"filled", "#b71c1c", "#7f0000", "white"};
174     case kGray:
175       return NodeColors{"filled", "#cfd8dc", "#9ea7aa", "black"};
176     case kGreen:
177       return NodeColors{"filled", "#c8e6c9", "#97b498", "black"};
178     case kOrange:
179       return NodeColors{"filled", "#ffe0b2", "#cbae82", "black"};
180     case kPurple:
181       return NodeColors{"filled", "#e1bee7", "#af8eb5", "black"};
182     case kRed:
183       return NodeColors{"filled", "#ffcdd2", "#cb9ca1", "black"};
184     case kWhite:
185       return NodeColors{"filled", "white", "black", "black"};
186     case kYellow:
187       return NodeColors{"filled", "#fff9c4", "#cbc693", "black"};
188     case kDashedBorder:
189       // "filled,dashed" looks the same as "dashed", since we have a white
190       // background.  But we use "filled,dashed" so that when you hover over
191       // any part of the node (not just the text inside the node), our css
192       // :hover rule is triggered.
193       return NodeColors{"filled,dashed", "white", "#757575", "#757575"};
194   }
195 }
196 
197 // Given a ColorScheme, returns an attribute string for a node of that color.
198 // Sets the node's style and fill/stroke/text colors.
199 //
200 // Colors are from https://material.io/color.
NodeColorAttributes(ColorScheme color)201 string NodeColorAttributes(ColorScheme color) {
202   NodeColors node_colors = NodeColorsForScheme(color);
203 
204   return StrFormat(R"(style="%s", fontcolor="%s", color="%s", fillcolor="%s")",
205                    node_colors.style, node_colors.font_color,
206                    node_colors.stroke_color, node_colors.fill_color);
207 }
208 
209 // Replaces <> with &lt;&gt;, so that this string is safe(er) for use in a
210 // graphviz HTML-like string.
HtmlLikeStringSanitize(absl::string_view s)211 string HtmlLikeStringSanitize(absl::string_view s) {
212   return absl::StrReplaceAll(s, {{"<", "&lt;"}, {">", "&gt;"}});
213 }
214 
IsFusedBroadcastOfConstantEffectiveScalar(const HloInstruction * instr)215 bool IsFusedBroadcastOfConstantEffectiveScalar(const HloInstruction* instr) {
216   namespace m = match;
217   return instr->parent()->IsFusionComputation() &&
218          Match(instr, m::Broadcast(m::ConstantEffectiveScalar()));
219 }
220 
221 // Tries to generates a human-readable one-word description of the given
222 // computation.
223 //
224 // Currently we support:
225 //
226 //   "return param0 + param1;"      --> "add"
227 //   "return param0 * param1;"      --> "multiply"
228 //   "return min(param0, param1);"  --> "min"
229 //   "return max(param0, param1);"  --> "max"
230 //   "return param0 <= param1;"     --> "less-or-equal"
231 //   "return param0 >= param1;"     --> "greater-or-equal"
232 //   "return param0 >  param1;"     --> "greater-than"
233 //   "return param0 <  param1;"     --> "less-than"
234 //   "return param0 == param1;"     --> "equal-to"
235 //   "return param0 != param1;"     --> "not-equal-to"
236 //
237 // where param0 and param1 are effective scalars.  For the ops that are
238 // commutative, we also support them with param0 and param1 swapped.
239 //
240 // This is useful primarily for reduce and map nodes.  These take a
241 // subcomputation which is almost always one of the above, and pattern matching
242 // it to a short string lets us tell the user what the subcomputation is without
243 // drawing it as a graph.
MatchTrivialComputation(const HloComputation * computation)244 optional<string> MatchTrivialComputation(const HloComputation* computation) {
245   namespace m = match;
246 
247   if (computation->instruction_count() != 3) {
248     return nullopt;
249   }
250   HloInstruction* root = computation->root_instruction();
251   const HloInstruction *param0, *param1;
252   if (!Match(root, m::Op()
253                        .WithNumOperands(2)
254                        .WithShape(m::Shape().IsEffectiveScalar())
255                        .WithBinaryOperandsAnyOrder(
256                            m::Parameter(&param0, 0)
257                                .WithShape(m::Shape().IsEffectiveScalar()),
258                            m::Parameter(&param1, 1)
259                                .WithShape(m::Shape().IsEffectiveScalar())))) {
260     return nullopt;
261   }
262 
263   // If the params are reversed (i.e. operand0 is param1 and operand1 is
264   // param0), check that the operation being performed is commutative.
265   if (root->operand(0) == param1) {
266     CHECK_EQ(root->operand(1), param0);
267     if (root->opcode() == HloOpcode()) {
268       switch (root->comparison_direction()) {
269         case ComparisonDirection::kLe:
270         case ComparisonDirection::kGe:
271         case ComparisonDirection::kGt:
272         case ComparisonDirection::kLt:
273           return nullopt;
274         default:
275           break;
276       }
277     }
278   }
279 
280   // If we recognize the root's opcode, we've successfully pattern-matched!
281   switch (root->opcode()) {
282     case HloOpcode::kAdd:
283       return "add";
284     case HloOpcode::kMultiply:
285       return "multiply";
286     case HloOpcode::kMinimum:
287       return "min";
288     case HloOpcode::kMaximum:
289       return "max";
290     case HloOpcode::kCompare: {
291       switch (root->comparison_direction()) {
292         case ComparisonDirection::kLe:
293           return "less-or-equal";
294         case ComparisonDirection::kGe:
295           return "greater-or-equal";
296         case ComparisonDirection::kGt:
297           return "greater-than";
298         case ComparisonDirection::kLt:
299           return "less-than";
300         case ComparisonDirection::kEq:
301           return "equal-to";
302         case ComparisonDirection::kNe:
303           return "not-equal-to";
304       }
305     }
306     default:
307       return nullopt;
308   }
309 }
310 
311 // Encapsulates logic for dumping an HLO module to DOT (i.e. graphviz syntax).
312 class HloDotDumper {
313  public:
HloDotDumper(const HloComputation * computation,absl::string_view label,const DebugOptions & debug_options,HloRenderOptions hlo_render_options,const HloExecutionProfile * profile,NodeFilter filter)314   HloDotDumper(const HloComputation* computation, absl::string_view label,
315                const DebugOptions& debug_options,
316                HloRenderOptions hlo_render_options,
317                const HloExecutionProfile* profile, NodeFilter filter)
318       : computation_(computation),
319         label_(label),
320         debug_options_(debug_options),
321         hlo_render_options_(hlo_render_options),
322         profile_(profile),
323         filter_(std::move(filter)) {}
324 
325   string Dump();
326 
327  private:
328   // Returns the dot graph identifier for the given instruction.
InstructionId(const HloInstruction * instruction)329   string InstructionId(const HloInstruction* instruction) {
330     return StrCat(reinterpret_cast<uint64>(instruction));
331   }
332 
333   // Returns the dot graph identifier for the given computation.
SubcomputationId(const HloComputation * computation)334   string SubcomputationId(const HloComputation* computation) {
335     return StrCat("cluster_", reinterpret_cast<uint64>(computation));
336   }
337 
338   // Generates graph header/footer.  These should be called *after* dumping all
339   // of the instructions and subcomputations for the graph, as they both use
340   // data generated while dumping the graph.
341   string Header();
342   string Footer();
343 
344   bool ShouldShowSubcomputation(const HloComputation* subcomp);
345   bool ShouldShowFusionSubcomputation(const HloInstruction* instr);
346 
347   // We omit some nodes from the graph, instead drawing them inlined into the
348   // nodes that use them.
349   bool ShouldMergeIntoUsers(const HloInstruction* instr) const;
350 
351   string DumpSubcomputation(const HloComputation* subcomp,
352                             const HloInstruction* parent_instr);
353   string DumpComputation(const HloComputation* comp);
354   string DumpRootTag();
355   string DumpInstruction(const HloInstruction* instr);
356   ColorScheme GetInstructionColor(const HloInstruction* instr);
357   string GetInstructionNodeShape(const HloInstruction* instr);
358   string GetInstructionNodeLabel(const HloInstruction* instr);
359   string GetInstructionNodeMetadata(const HloInstruction* instr);
360   string GetInstructionNodeBackendConfig(const HloInstruction* instr);
361   string GetInstructionNodeExtraInfo(const HloInstruction* instr);
362   string GetInstructionNodeInlinedOperands(const HloInstruction* instr);
363   void AddInstructionIncomingEdges(const HloInstruction* instr);
364 
365   // For most instructions, GetNodeForEdge(instr) returns instr.
366   //
367   // The exception is fusion nodes.  For these, we walk up the chain of nested
368   // fusion nodes starting at instr until we reach a node that either (a) isn't
369   // a fusion node, or (b) is a fusion node for which
370   // ShouldShowFusionSubcomputation is false.
371   //
372   // We do this because fusion nodes are expanded inline -- if
373   // ShouldShowFusionSubcomputation is true, the fusion node won't be present in
374   // the graph.
375   //
376   // In general when you want to draw an edge from A to B, you should actually
377   // draw an edge from GetNodeForEdge(A) to GetNodeForEdge(B).
378   const HloInstruction* GetNodeForEdge(const HloInstruction* instr);
379 
380   // If instr has just one computation and it's trivial (e.g. "return param0 +
381   // param1"), returns a string you can put into the node's body that names the
382   // subcomputation, e.g. "Subcomputation: <b>add</b>".
383   string GetInstructionTrivialComputationStr(const HloInstruction* instr);
384 
385   const HloComputation* computation_;  // never null
386   const string label_;                 // overall name for the graph
387   const DebugOptions& debug_options_;
388   const HloRenderOptions hlo_render_options_;
389   const HloExecutionProfile* profile_;  // may be null
390   const NodeFilter filter_;
391 
392   // Each HloInstruction dumped gets a monotonically-increasing node ID.  This
393   // must start at 1, because that's where graphviz's accounting starts.
394   int64 next_node_id_ = 1;
395   absl::flat_hash_map<const HloInstruction*, int64> node_ids_;
396 
397   // The "root" tag doesn't have an associated HloInstruction pointer, so we
398   // need to store it outside the map.
399   int64 root_node_id_;
400 
401   // Each (from, to) edge gets a monotonically-increasing ID.  This is a
402   // multimap because it's possible for the same edge to appear multiple times
403   // in the graph (e.g. x^2 may be represented as mul(x, x)).
404   int64 next_edge_id_ = 1;
405   std::unordered_multimap<
406       std::pair<const HloInstruction*, const HloInstruction*>, int64,
407       tensorflow::hash<std::pair<const HloInstruction*, const HloInstruction*>>>
408       edge_ids_;
409 
410   // Each HloComputation that's emitted gets a monotonically-increasing ID.
411   int64 next_cluster_id_ = 1;
412   absl::flat_hash_map<const HloComputation*, int64> cluster_ids_;
413 
414   // Edges to print from Footer().  Edges come at the end because graphviz is
415   // unhappy if an edge from a subcomputation to a node in the outer computation
416   // appears before both the inner computation and the destination node are
417   // defined.
418   std::vector<string> edges_;
419 
420   // When coloring by sharding information, we track the sharding string
421   // representation to color association, by round-robin the color schemes.
422   absl::flat_hash_map<HloSharding, ColorScheme, HloSharding::Hasher>
423       sharding_colors_;
424   int64 next_shard_color_ = 0;
425 };
426 
Dump()427 string HloDotDumper::Dump() {
428   string body;
429   StrAppend(&body, DumpComputation(computation_));
430   StrAppend(&body, DumpRootTag());
431 
432   // By contract, Header() and Footer() have to be called after we've dumped all
433   // our instructions, because they use state generated during that process.
434   string g = Header();
435   StrAppend(&g, body);
436   StrAppend(&g, Footer());
437   return g;
438 }
439 
Header()440 string HloDotDumper::Header() {
441   constexpr char fmt[] = R"(digraph G {
442 rankdir = TB;
443 compound = true;
444 label = <<b>%s</b>>;
445 labelloc = t;
446 // Disable the tooltip.  Interestingly, "" doesn't work!
447 tooltip = " ";
448 // DOT graphs accept a stylesheet as a URI.  So naturally, an inline
449 // stylesheet is a data URI!
450 stylesheet=<
451   data:text/css,
452   @import url(https://fonts.googleapis.com/css?family=Roboto:400,700);
453   svg text {
454     font-family: 'Roboto';
455     font-size: 12px;
456   }
457 
458   %s
459 >
460 
461 )";
462 
463   VLOG(3) << "Generating Header";
464 
465   string graph_label =
466       StrCat(label_, "<br/>Computation ", computation_->name());
467   if (computation_->IsFusionComputation()) {
468     StrAppend(&graph_label, " (in fusion instruction ",
469               computation_->FusionInstruction()->name(), ")");
470   }
471   if (profile_ != nullptr) {
472     auto cycles = profile_->total_cycles_executed(*computation_);
473     absl::StrAppendFormat(&graph_label, "<br/>total cycles = %d (%s)", cycles,
474                           tensorflow::strings::HumanReadableNum(cycles));
475   }
476 
477   // Create CSS rules that say, when you hover over the given node or cluster,
478   // turn the given edge the given color.
479   //
480   // We rely on a few properties of how graphviz generates SVGs:
481   //
482   //  - Nodes are named "nodeN", where N corresponds to the 1-based index of
483   //    the node in our DOT (i.e. the first node in the DOT is "node1", etc.).
484   //    Edges are similarly named "edgeN", and clusters are named "clustN".
485   //  - Nodes come before their in- and out-edges in the SVG.  We need this
486   //    because the "X ~ Y" CSS selector finds a sibling of X that *comes
487   //    after X in the DOM* and matches Y.
488   std::vector<string> edge_css_rules;
489   const char* kBlue = "#1976d2";
490   const char* kRed = "#d32f2f";
491   for (const auto& kv : edge_ids_) {
492     const HloInstruction* from_node = kv.first.first;
493     const HloInstruction* to_node = kv.first.second;
494     int64_t edge_id = kv.second;
495 
496     auto add_hover_css_rule = [&](string elem_type, int64_t elem_id,
497                                   const char* color) {
498       // One could imagine other ways of writing this CSS rule that involve
499       // less duplication, but this way seems to be relatively performant.
500       edge_css_rules.push_back(
501           StrFormat("  #%s%d:hover ~ #edge%d text { fill: %s; }\n"
502                     "  #%s%d:hover ~ #edge%d path { "
503                     "stroke: %s; stroke-width: .2em; }\n"
504                     "  #%s%d:hover ~ #edge%d polygon { "
505                     "fill: %s; stroke: %s; stroke-width: .2em; }\n",
506                     elem_type, elem_id, edge_id, color,  //
507                     elem_type, elem_id, edge_id, color,  //
508                     elem_type, elem_id, edge_id, color, color));
509     };
510 
511     // The "to_node" value may be a NULL, indicating that this points to the
512     // "root" tag rather than a normal node.
513     int64_t from_node_id =
514         tensorflow::gtl::FindWithDefault(node_ids_, from_node, -1);
515     if (from_node_id == -1) {
516       LOG(FATAL) << from_node->name() << " was added to edges but not to nodes";
517     }
518     int64_t to_node_id =
519         to_node ? tensorflow::gtl::FindWithDefault(node_ids_, to_node, -1)
520                 : root_node_id_;
521     if (to_node != nullptr && to_node_id == -1) {
522       LOG(FATAL) << to_node->name() << " was added to edges but not to nodes";
523     }
524 
525     add_hover_css_rule("node", from_node_id, kBlue);
526     add_hover_css_rule("node", to_node_id, kRed);
527 
528     if (to_node) {
529       VLOG(3) << "Adding css for edge " << edge_id << " from node "
530               << from_node->name() << " to node " << to_node->name();
531     } else {
532       VLOG(3) << "Adding css for edge " << edge_id << " from node "
533               << from_node->name() << " to root tag";
534     }
535 
536     // If this edge crosses a fusion cluster boundary, highlight it when the
537     // cluster is hovered over.
538     if (to_node) {
539       if (from_node->IsFused() &&
540           from_node->parent()->root_instruction() == from_node) {
541         int64_t cluster_id = cluster_ids_.at(from_node->parent());
542         add_hover_css_rule("clust", cluster_id, kBlue);
543       }
544       if (to_node->IsFused() && to_node->opcode() == HloOpcode::kParameter) {
545         int64_t cluster_id = cluster_ids_.at(to_node->parent());
546         add_hover_css_rule("clust", cluster_id, kRed);
547       }
548     }
549   }
550 
551   // Browsers require that we URI-encode the contents of our data URI.  (It
552   // seems this was a relatively recent change?) In practice, this means that we
553   // need to escape '#'.
554   return StrFormat(
555       fmt, graph_label,
556       absl::StrReplaceAll(StrJoin(edge_css_rules, "\n"), {{"#", "%23"}}));
557 }
558 
Footer()559 string HloDotDumper::Footer() { return StrCat(StrJoin(edges_, "\n"), "\n}"); }
560 
ShouldShowFusionSubcomputation(const HloInstruction * instr)561 bool HloDotDumper::ShouldShowFusionSubcomputation(const HloInstruction* instr) {
562   CHECK_EQ(instr->opcode(), HloOpcode::kFusion);
563   return ShouldShowSubcomputation(instr->fused_instructions_computation());
564 }
565 
ShouldShowSubcomputation(const HloComputation * subcomp)566 bool HloDotDumper::ShouldShowSubcomputation(const HloComputation* subcomp) {
567   if (subcomp->IsFusionComputation()) {
568     const HloInstruction* fusion = subcomp->FusionInstruction();
569     if (!filter_.Show(fusion) || filter_.SomeOrAllOperandsOmitted(fusion) ||
570         !hlo_render_options_.show_fusion_subcomputations) {
571       return false;
572     }
573   }
574 
575   // Don't show trivial subcomputations on non-fusion nodes -- these are inlined
576   // into the graph.
577   if (!subcomp->IsFusionComputation() && MatchTrivialComputation(subcomp)) {
578     return false;
579   }
580 
581   // Show the subcomputation if we're showing any of its members.
582   return absl::c_any_of(
583       subcomp->instructions(),
584       [&](const HloInstruction* instr) { return filter_.Show(instr); });
585 }
586 
DumpSubcomputation(const HloComputation * subcomp,const HloInstruction * parent_instr)587 string HloDotDumper::DumpSubcomputation(const HloComputation* subcomp,
588                                         const HloInstruction* parent_instr) {
589   VLOG(2) << "Dumping subcomputation " << subcomp->name();
590   // Add an edge from the subcomputation to its parent node.  If subcomp
591   // belongs to a fusion node, it's drawn in place of the fusion instruction,
592   // so there's no need to link those.
593   if (parent_instr->opcode() != HloOpcode::kFusion) {
594     const HloInstruction* from = GetNodeForEdge(subcomp->root_instruction());
595     VLOG(2) << "Edge: from " << from->name() << " to " << parent_instr->name()
596             << " as " << next_edge_id_;
597     edge_ids_.insert({{from, parent_instr}, next_edge_id_++});
598     constexpr char edge_fmt[] =
599         R"(%s -> %s [ltail="%s", style="dashed" tooltip="%s -> %s"];)";
600     edges_.push_back(StrFormat(
601         edge_fmt, InstructionId(from), InstructionId(parent_instr),
602         SubcomputationId(subcomp), subcomp->name(), parent_instr->name()));
603   }
604 
605   // Have we already dumped this subcomputation?  If so, generating the edge
606   // linking it and parent_instr is all we want to do in this function.
607   if (cluster_ids_.find(subcomp) != cluster_ids_.end()) {
608     return "";
609   }
610 
611   cluster_ids_[subcomp] = next_cluster_id_++;
612 
613   string id = SubcomputationId(subcomp);
614 
615   string subcomp_label, style;
616   if (parent_instr->opcode() == HloOpcode::kFusion) {
617     subcomp_label =
618         StrFormat("Fused expression for <b>%s</b><br/>%s",
619                   HtmlLikeStringSanitize(parent_instr->name()),
620                   HtmlLikeStringSanitize(parent_instr->ToCategory()));
621     string extra_info = GetInstructionNodeExtraInfo(parent_instr);
622     if (!extra_info.empty()) {
623       StrAppend(&subcomp_label, "<br/>", extra_info);
624     }
625     string node_backend_config = GetInstructionNodeBackendConfig(parent_instr);
626     if (!node_backend_config.empty()) {
627       StrAppend(&subcomp_label, "<br/>", node_backend_config);
628     }
629 
630     bool highlight = filter_.Highlight(parent_instr);
631     const char* fillcolor;
632     const char* strokecolor;
633     if (debug_options_.xla_hlo_graph_sharding_color() && !highlight) {
634       // Use the sharding color, if the node isn't highlighted.
635       NodeColors node_colors =
636           NodeColorsForScheme(GetInstructionColor(parent_instr));
637       fillcolor = node_colors.fill_color;
638       strokecolor = node_colors.stroke_color;
639     } else {
640       // Subcomputation's fill/stroke color is light/dark red/gray, depending on
641       // whether or not the subcomputation's fusion node is highlighted.
642       fillcolor = highlight ? "#ffcdd2" : "#f5f5f5";
643       strokecolor = highlight ? "#b71c1c" : "#c2c2c2";
644     }
645     style =
646         StrFormat(R"(style="rounded,filled,bold"; fillcolor="%s"; color="%s;")",
647                   fillcolor, strokecolor);
648   } else {
649     subcomp_label = StrFormat("Subcomputation for <b>%s</b><br/>%s",
650                               HtmlLikeStringSanitize(parent_instr->name()),
651                               HtmlLikeStringSanitize(subcomp->name()));
652     style = "style=rounded; color=black;";
653   }
654 
655   string comp_body = DumpComputation(subcomp);
656 
657   constexpr char computation_fmt[] = R"(subgraph %s {
658 %s
659 label = <%s>;
660 labelloc = t;
661 tooltip = " ";
662 %s
663 }  // %s
664 
665 )";
666   return StrFormat(computation_fmt, id, style, subcomp_label, comp_body, id);
667 }
668 
DumpComputation(const HloComputation * comp)669 string HloDotDumper::DumpComputation(const HloComputation* comp) {
670   string g;
671   for (const auto* instr : comp->instructions()) {
672     if (!filter_.Show(instr)) {
673       continue;
674     }
675 
676     // Dump subcomputations within instr.
677     for (const HloComputation* subcomp : instr->called_computations()) {
678       if (ShouldShowSubcomputation(subcomp)) {
679         StrAppend(&g, DumpSubcomputation(subcomp, instr));
680       }
681     }
682 
683     StrAppend(&g, DumpInstruction(instr));
684   }
685   return g;
686 }
687 
DumpRootTag()688 string HloDotDumper::DumpRootTag() {
689   const HloInstruction* from = GetNodeForEdge(computation_->root_instruction());
690 
691   // We didn't display constants or broadcasts of effective scalars within
692   // fusions as separate nodes; so if the root is a constant/broadcast of
693   // scalar, we don't add root tag or edge for it.
694   if (!filter_.Show(from) || from->opcode() == HloOpcode::kConstant ||
695       IsFusedBroadcastOfConstantEffectiveScalar(from)) {
696     return "";
697   }
698 
699   auto from_id = InstructionId(from);
700 
701   // The ID of the root computation is otherwise unused, so it makes a good ID
702   // to use for the root-tag node.  However, the edge_ids_ map requires a
703   // HloInstruction* pointer for the 'to' value, so we use a NULL value there
704   // (rather than a pointer type-cast) to make it obvious if it is erroneously
705   // dereferenced.
706   HloInstruction* to = nullptr;
707   auto to_id = SubcomputationId(computation_);
708 
709   string node_body = "ROOT";
710   string node_shape = "circle";
711   ColorScheme color = kBrown;
712 
713   VLOG(2) << "Adding root tag as node " << next_node_id_;
714   root_node_id_ = next_node_id_++;
715 
716   VLOG(2) << "Adding edge from " << from->name() << " to root tag as "
717           << next_edge_id_;
718   edge_ids_.insert({{from, to}, next_edge_id_++});
719   edges_.push_back(StrFormat(R"(%s -> %s [tooltip=" "];)", from_id, to_id));
720 
721   return StrFormat(R"(%s [label=<%s>, shape=%s, tooltip=" ", %s];)"
722                    "\n",
723                    to_id, node_body, node_shape, NodeColorAttributes(color));
724 }
725 
TryGetFusionParameterConstant(const HloInstruction * instr)726 static const HloConstantInstruction* TryGetFusionParameterConstant(
727     const HloInstruction* instr) {
728   if (instr->opcode() != HloOpcode::kParameter || !instr->IsFused()) {
729     return nullptr;
730   }
731   const HloInstruction* fusion = instr->parent()->FusionInstruction();
732   const HloInstruction* operand = fusion->operand(instr->parameter_number());
733   return DynCast<HloConstantInstruction>(operand);
734 }
735 
ShouldMergeIntoUsers(const HloInstruction * instr) const736 bool HloDotDumper::ShouldMergeIntoUsers(const HloInstruction* instr) const {
737   // If a node:
738   //
739   //  - is a parameter of a fusion node which is bound to a constant,
740   //
741   // or
742   //
743   //  - is a tuple-shaped parameter, and
744   //  - is not a parameter to a fusion node, and
745   //  - has at least kMinUsersToOmit users shown, and
746   //  - all of the shown users are get-tuple-elements,
747   //
748   // then we omit it from the graph, merging it with its users.
749   //
750   // This helps us handle the common case where a while loop body has one big
751   // tuple-shaped parameter.
752   if (TryGetFusionParameterConstant(instr) != nullptr) {
753     return true;
754   }
755   const int kMinUsersToOmit = 3;
756   return instr->opcode() == HloOpcode::kParameter && instr->shape().IsTuple() &&
757          !instr->IsFused() &&
758          absl::c_count_if(instr->users(),
759                           [&](const HloInstruction* user) {
760                             return filter_.Show(user);
761                           }) > kMinUsersToOmit &&
762          absl::c_all_of(instr->users(), [&](const HloInstruction* user) {
763            return !filter_.Show(user) ||
764                   user->opcode() == HloOpcode::kGetTupleElement;
765          });
766 }
767 
DumpInstruction(const HloInstruction * instr)768 string HloDotDumper::DumpInstruction(const HloInstruction* instr) {
769   // We don't display constants or broadcasts of effective scalar constants
770   // within fusions as separate nodes; they're merged into their users.
771   if (instr->opcode() == HloOpcode::kConstant ||
772       IsFusedBroadcastOfConstantEffectiveScalar(instr)) {
773     return "";
774   }
775   // Skip this node if it's merged into its users.
776   if (ShouldMergeIntoUsers(instr)) {
777     return "";
778   }
779   // Omit the fusion node if its subcomputation is drawn, since the
780   // subcomputation will be drawn inline.
781   if (instr->opcode() == HloOpcode::kFusion &&
782       ShouldShowFusionSubcomputation(instr)) {
783     return "";
784   }
785 
786   VLOG(2) << "Adding node " << instr->name() << " as " << next_node_id_;
787   node_ids_[instr] = next_node_id_++;
788 
789   ColorScheme color = GetInstructionColor(instr);
790   string node_shape = GetInstructionNodeShape(instr);
791   string node_label = GetInstructionNodeLabel(instr);
792   string node_metadata = GetInstructionNodeMetadata(instr);
793   string node_backend_config = GetInstructionNodeBackendConfig(instr);
794   string extra_info = GetInstructionNodeExtraInfo(instr);
795   string inlined_constants = GetInstructionNodeInlinedOperands(instr);
796   string trivial_subcomputation = GetInstructionTrivialComputationStr(instr);
797   AddInstructionIncomingEdges(instr);
798 
799   if (!debug_options_.xla_hlo_graph_sharding_color()) {
800     // Override the node's styling if it should be (de-)emphasized.
801     if (filter_.Deemphasized(instr)) {
802       color = kDashedBorder;
803     }
804     if (filter_.Highlight(instr)) {
805       node_shape = "diamond";
806       color = kDarkRed;
807     }
808   }
809   // Build the text that will be displayed inside the node.
810   string node_body = node_label;
811   for (const string& s : {trivial_subcomputation, node_backend_config,
812                           extra_info, inlined_constants}) {
813     if (!s.empty()) {
814       StrAppend(&node_body, "<br/>", s);
815     }
816   }
817 
818   return StrFormat(R"(%s [label=<%s>, shape=%s, tooltip="%s", %s];)"
819                    "\n",
820                    InstructionId(instr), node_body, node_shape, node_metadata,
821                    NodeColorAttributes(color));
822 }
823 
GetInstructionNodeInlinedOperands(const HloInstruction * instr)824 string HloDotDumper::GetInstructionNodeInlinedOperands(
825     const HloInstruction* instr) {
826   // The constant's shape is a parameter because, in the case of a broadcasted
827   // scalar constant, we want to show the broadcasted shape, not the constant's
828   // scalar shape.
829   auto stringify_constant = [](const HloConstantInstruction* constant,
830                                const Shape& shape) {
831     // If the shape has a dimension of size zero, print it as e.g.
832     // "{} (f32[42, 0, 10])".  The alternative, calling Literal::ToString(),
833     // enumerates all of its empty dimensions (e.g.  "{ { {}, {} }, ..."), which
834     // is just noise.
835     if (ShapeUtil::IsZeroElementArray(shape)) {
836       return StrFormat("{} (%s)", ShapeUtil::HumanString(constant->shape()));
837     }
838 
839     // Print the literal value of constants with <= K elements.  Note that we
840     // use `constant->shape()` rather than `shape`, because if `constant` is a
841     // scalar that's broadcasted into `shape`, we want to print the constant.
842     optional<int64> elem_count;
843     if (shape.IsArray()) {
844       elem_count = ShapeUtil::ElementsIn(constant->shape());
845     }
846     // Allow HloDotDumper to print HloInstruction reconstructed from HloProto
847     // collected from profiling tools. Those constants may not have a valid
848     // literal.
849     if (elem_count.has_value() && *elem_count <= 8 && constant->HasLiteral()) {
850       return StrFormat("%s %s", shape.ToString(),
851                        constant->literal().ToStringWithoutShape());
852     }
853 
854     // Otherwise, print e.g. "%constant.42 (s32[100])".
855     string constant_name;
856     if (absl::StartsWith(constant->name(), "constant")) {
857       constant_name = constant->name();
858     } else {
859       constant_name = StrCat("constant ", constant->name());
860     }
861     return StrFormat("%s %s", constant_name, ShapeUtil::HumanString(shape));
862   };
863 
864   std::vector<string> lines;
865   for (int64_t i = 0; i < instr->operand_count(); ++i) {
866     const HloInstruction* operand = instr->operand(i);
867     optional<string> operand_str;
868     if (const auto* constant_operand =
869             DynCast<HloConstantInstruction>(operand)) {
870       operand_str =
871           stringify_constant(constant_operand, constant_operand->shape());
872     } else if (IsFusedBroadcastOfConstantEffectiveScalar(operand)) {
873       operand_str = stringify_constant(
874           Cast<HloConstantInstruction>(operand->operand(0)), operand->shape());
875     } else if (ShouldMergeIntoUsers(operand)) {
876       // Special case: If the operand is a parameter to a fusion node and it
877       // always has a constant value, display it like a regular constant.
878       //
879       // For other parameters, use the parameter number rather than the proper
880       // name, because that's generally how people think of the node.
881       if (operand->opcode() == HloOpcode::kParameter) {
882         if (const HloConstantInstruction* constant =
883                 TryGetFusionParameterConstant(operand)) {
884           operand_str = stringify_constant(constant, constant->shape());
885         } else {
886           operand_str = StrFormat("Parameter %d", operand->parameter_number());
887         }
888       } else {
889         operand_str = operand->name();
890       }
891     }
892 
893     if (operand_str) {
894       if (instr->operand_count() > 1) {
895         lines.push_back(StrFormat("<b>operand %d</b> = %s", i, *operand_str));
896       } else {
897         lines.push_back(StrFormat("<b>operand</b> = %s", *operand_str));
898       }
899     }
900   }
901   return StrJoin(lines, "<br/>");
902 }
903 
GetInstructionColor(const HloInstruction * instr)904 ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) {
905   if (debug_options_.xla_hlo_graph_sharding_color()) {
906     if (!instr->has_sharding()) {
907       return kDashedBorder;
908     }
909     auto it = sharding_colors_.find(instr->sharding());
910     if (it != sharding_colors_.end()) {
911       return it->second;
912     }
913     ColorScheme color = static_cast<ColorScheme>(
914         kBlue + (next_shard_color_++ % (kDashedBorder - kBlue)));
915     sharding_colors_.emplace(instr->sharding(), color);
916     return color;
917   }
918 
919   // Choose different weights of orange for small vs large parameters.  This
920   // distinction is often important, especially in fusion nodes.
921   auto parameter_color = IsSmall(instr) ? kOrange : kDarkOrange;
922 
923   // Special case: If this instruction has a parameter merged into it, paint it
924   // the same color as a parameter.  Unless the merged-in parameter is a
925   // parameter to a fusion node that is bound to a constant -- these aren't
926   // "real" parameters from the user's perspective.
927   if (absl::c_any_of(instr->operands(), [&](const HloInstruction* operand) {
928         return operand->opcode() == HloOpcode::kParameter &&
929                ShouldMergeIntoUsers(operand) &&
930                TryGetFusionParameterConstant(operand) == nullptr;
931       })) {
932     return parameter_color;
933   }
934 
935   // Pick different colors or shapes for instructions which are particularly
936   // expensive (eg, dot) and those which are unusual in some way or unique
937   // (eg, parameter).
938   switch (instr->opcode()) {
939     case HloOpcode::kAbs:
940     case HloOpcode::kAdd:
941     case HloOpcode::kAnd:
942     case HloOpcode::kAtan2:
943     case HloOpcode::kBitcastConvert:
944     case HloOpcode::kCeil:
945     case HloOpcode::kClamp:
946     case HloOpcode::kClz:
947     case HloOpcode::kCompare:
948     case HloOpcode::kComplex:
949     case HloOpcode::kConvert:
950     case HloOpcode::kCos:
951     case HloOpcode::kDivide:
952     case HloOpcode::kExp:
953     case HloOpcode::kExpm1:
954     case HloOpcode::kFloor:
955     case HloOpcode::kImag:
956     case HloOpcode::kIota:
957     case HloOpcode::kIsFinite:
958     case HloOpcode::kLog:
959     case HloOpcode::kLog1p:
960     case HloOpcode::kMaximum:
961     case HloOpcode::kMinimum:
962     case HloOpcode::kMultiply:
963     case HloOpcode::kNegate:
964     case HloOpcode::kNot:
965     case HloOpcode::kPopulationCount:
966     case HloOpcode::kOr:
967     case HloOpcode::kXor:
968     case HloOpcode::kPower:
969     case HloOpcode::kReal:
970     case HloOpcode::kRemainder:
971     case HloOpcode::kRng:
972     case HloOpcode::kRngGetAndUpdateState:
973     case HloOpcode::kRngBitGenerator:
974     case HloOpcode::kRoundNearestAfz:
975     case HloOpcode::kRsqrt:
976     case HloOpcode::kSelect:
977     case HloOpcode::kShiftLeft:
978     case HloOpcode::kShiftRightArithmetic:
979     case HloOpcode::kShiftRightLogical:
980     case HloOpcode::kLogistic:
981     case HloOpcode::kSign:
982     case HloOpcode::kSin:
983     case HloOpcode::kSlice:
984     case HloOpcode::kSort:
985     case HloOpcode::kSqrt:
986     case HloOpcode::kCbrt:
987     case HloOpcode::kSubtract:
988     case HloOpcode::kTanh:
989       // De-emphasize scalar-shaped elementwise ops -- they're generally
990       // uninteresting.
991       if (ShapeUtil::IsEffectiveScalar(instr->shape())) {
992         return kWhite;
993       }
994       return kYellow;
995     case HloOpcode::kBitcast:
996     case HloOpcode::kGetTupleElement:
997     case HloOpcode::kTrace:
998     case HloOpcode::kAfterAll:
999     case HloOpcode::kAddDependency:
1000     case HloOpcode::kTuple:
1001       return kWhite;
1002     case HloOpcode::kBroadcast:
1003       // De-emphasize nodes which broadcast a scalar within a fusion node --
1004       // these are essentially free.
1005       if (instr->IsFused() &&
1006           ShapeUtil::IsEffectiveScalar(instr->operand(0)->shape())) {
1007         return kWhite;
1008       }
1009       return kGreen;
1010     case HloOpcode::kConcatenate:
1011     case HloOpcode::kDynamicSlice:
1012     case HloOpcode::kGather:
1013     case HloOpcode::kPad:
1014     case HloOpcode::kReshape:
1015     case HloOpcode::kDynamicReshape:
1016     case HloOpcode::kReverse:
1017     case HloOpcode::kTupleSelect:
1018     case HloOpcode::kTranspose:
1019       // De-emphasize scalar-shaped data movement ops and all data movement ops
1020       // inside fusion nodes, both of which are essentially free.
1021       if (ShapeUtil::IsEffectiveScalar(instr->shape()) || instr->IsFused()) {
1022         return kWhite;
1023       }
1024       return kGreen;
1025     case HloOpcode::kDynamicUpdateSlice:
1026       // Unlike the data-movement ops above, dynamic-update-slice is not ~free
1027       // inside of fusion nodes, so we de-emphasize it only if it's
1028       // scalar-shaped.
1029       if (ShapeUtil::IsEffectiveScalar(instr->shape())) {
1030         return kWhite;
1031       }
1032       return kGreen;
1033     case HloOpcode::kScatter:
1034       // Do not de-emphasize Scatter, since it involves significant work.
1035     case HloOpcode::kCopy:
1036     case HloOpcode::kCopyStart:
1037     case HloOpcode::kCopyDone:
1038       // Emphasize copy nodes, which are either physical transposes (and thus
1039       // significant), or copies of read-only buffers (and thus dead weight).
1040       return kGreen;
1041     case HloOpcode::kConvolution:
1042     case HloOpcode::kDot:
1043     case HloOpcode::kFft:
1044     case HloOpcode::kTriangularSolve:
1045     case HloOpcode::kCholesky:
1046       return kDarkBlue;
1047     case HloOpcode::kReducePrecision:
1048       return kRed;
1049     case HloOpcode::kParameter:
1050       return parameter_color;
1051     case HloOpcode::kBatchNormGrad:
1052     case HloOpcode::kBatchNormInference:
1053     case HloOpcode::kBatchNormTraining:
1054     case HloOpcode::kReduce:
1055     case HloOpcode::kReduceWindow:
1056     case HloOpcode::kSelectAndScatter:
1057       return kPurple;
1058     case HloOpcode::kDomain:
1059     case HloOpcode::kFusion:
1060     case HloOpcode::kMap:
1061     case HloOpcode::kGetDimensionSize:
1062     case HloOpcode::kSetDimensionSize:
1063       return kGray;
1064     case HloOpcode::kAllGather:
1065     case HloOpcode::kAllGatherStart:
1066     case HloOpcode::kAllGatherDone:
1067     case HloOpcode::kAllReduce:
1068     case HloOpcode::kReduceScatter:
1069     case HloOpcode::kAllReduceStart:
1070     case HloOpcode::kAllReduceDone:
1071     case HloOpcode::kAllToAll:
1072     case HloOpcode::kCollectivePermute:
1073     case HloOpcode::kCollectivePermuteStart:
1074     case HloOpcode::kCollectivePermuteDone:
1075     case HloOpcode::kInfeed:
1076     case HloOpcode::kOutfeed:
1077     case HloOpcode::kPartitionId:
1078     case HloOpcode::kRecv:
1079     case HloOpcode::kRecvDone:
1080     case HloOpcode::kSend:
1081     case HloOpcode::kSendDone:
1082     case HloOpcode::kReplicaId:
1083       return kBrown;
1084     case HloOpcode::kCall:
1085     case HloOpcode::kConditional:
1086     case HloOpcode::kCustomCall:
1087     case HloOpcode::kWhile:
1088       return kDarkGreen;
1089     case HloOpcode::kConstant:
1090       LOG(FATAL) << "Constants don't get their own nodes in the graph.";
1091   }
1092 }
1093 
GetInstructionNodeShape(const HloInstruction * instr)1094 string HloDotDumper::GetInstructionNodeShape(const HloInstruction* instr) {
1095   // Give while loops a different shape so they're easier to pick out.
1096   switch (instr->opcode()) {
1097     case HloOpcode::kWhile:
1098       return "ellipse";
1099     default:
1100       return "rect";
1101   }
1102 }
1103 
GetInstructionNodeLabel(const HloInstruction * instr)1104 string HloDotDumper::GetInstructionNodeLabel(const HloInstruction* instr) {
1105   // If we have a parameter, put the param number in the name.
1106   if (instr->opcode() == HloOpcode::kParameter) {
1107     return StrFormat("<b>Parameter %d</b>", instr->parameter_number());
1108   }
1109 
1110   // The HLO instruction name contains usually the opcode, e.g. "%add.42" is
1111   // an add instruction.  In this case we render just the name.
1112   if (absl::StartsWith(instr->name(), HloOpcodeString(instr->opcode()))) {
1113     return StrFormat("<b>%s</b>", HtmlLikeStringSanitize(instr->name()));
1114   }
1115   string extended_opcode =
1116       StrCat(HloOpcodeString(instr->opcode()),
1117              instr->opcode() != HloOpcode::kFusion
1118                  ? ""
1119                  : StrCat(":", xla::ToString(instr->fusion_kind())));
1120   // If the name does not contain the opcode, render both.
1121   return StrFormat("<b>%s</b><br/>%s", HtmlLikeStringSanitize(extended_opcode),
1122                    HtmlLikeStringSanitize(instr->name()));
1123 }
1124 
GetInstructionNodeMetadata(const HloInstruction * instr)1125 string HloDotDumper::GetInstructionNodeMetadata(const HloInstruction* instr) {
1126   std::vector<string> lines;
1127   if (!instr->metadata().op_name().empty()) {
1128     lines.push_back(HtmlLikeStringSanitize(instr->metadata().op_name()));
1129   }
1130   if (!instr->metadata().op_type().empty()) {
1131     lines.push_back(StrFormat(
1132         "op_type: %s", HtmlLikeStringSanitize(instr->metadata().op_type())));
1133   }
1134   if (!instr->metadata().source_file().empty() &&
1135       instr->metadata().source_line() != 0) {
1136     lines.push_back(StrFormat("source: %s:%d", instr->metadata().source_file(),
1137                               instr->metadata().source_line()));
1138   }
1139 
1140   return StrJoin(lines, "\n");
1141 }
1142 
GetInstructionNodeBackendConfig(const HloInstruction * instr)1143 string HloDotDumper::GetInstructionNodeBackendConfig(
1144     const HloInstruction* instr) {
1145   if (!hlo_render_options_.show_backend_config ||
1146       instr->raw_backend_config_string().empty()) {
1147     return "";
1148   }
1149 
1150   return StrCat("backend_config=\"", instr->raw_backend_config_string(), "\"");
1151 }
1152 
GetInstructionNodeExtraInfo(const HloInstruction * instr)1153 string HloDotDumper::GetInstructionNodeExtraInfo(const HloInstruction* instr) {
1154   std::vector<string> lines;
1155 
1156   // Get the instruction's extra attributes excluding the names of its
1157   // subcomputations, since those are drawn explicitly in the graph.
1158   for (const auto& line : instr->ExtraAttributesToString(
1159            HloPrintOptions().set_print_subcomputation_mode(
1160                HloPrintOptions::PrintSubcomputationMode::kOff))) {
1161     // Some instructions have giant replica group fields, so truncate the
1162     // replica group line length to 128.
1163     constexpr int kMaxReplicaGroupLen = 128;
1164     if (absl::StartsWith(line, "replica_groups=") &&
1165         line.length() > kMaxReplicaGroupLen) {
1166       lines.push_back(HtmlLikeStringSanitize(
1167           StrCat(line.substr(0, kMaxReplicaGroupLen - 3), "...")));
1168     } else {
1169       lines.push_back(HtmlLikeStringSanitize(line));
1170     }
1171   }
1172 
1173   // Show the shape and layout of the instruction, unless it's an inlined fusion
1174   // node -- there the shape and layout is present in the output node.
1175   if (instr->opcode() != HloOpcode::kFusion ||
1176       !ShouldShowFusionSubcomputation(instr)) {
1177     // Show layout of instructions with more than one dimension.  Don't show
1178     // layout on tuples or tensors with just one dimension (which only have one
1179     // possible layout) to avoid visual noise.
1180     bool shape_is_multidim = false;
1181     ShapeUtil::ForEachSubshape(instr->shape(),
1182                                [&](const Shape& s, const ShapeIndex&) {
1183                                  shape_is_multidim |= s.dimensions_size() > 1;
1184                                });
1185     string instr_shape;
1186     if (instr->opcode() != HloOpcode::kTuple && shape_is_multidim) {
1187       instr_shape = ShapeUtil::HumanStringWithLayout(instr->shape());
1188     } else {
1189       instr_shape = ShapeUtil::HumanString(instr->shape());
1190     }
1191 
1192     // Some instructions have giant tuples as their shapes, so truncate the
1193     // HLO's shape to kMaxShapeLen characters.
1194     constexpr int kMaxShapeLen = 64;
1195     if (instr_shape.length() > kMaxShapeLen) {
1196       instr_shape = StrCat(
1197           absl::string_view(instr_shape).substr(0, kMaxShapeLen - 3), "...");
1198     }
1199     lines.push_back(HtmlLikeStringSanitize(instr_shape));
1200   }
1201   if (debug_options_.xla_hlo_graph_addresses()) {
1202     lines.push_back(StrFormat("[%p]", instr));
1203   }
1204   if (profile_ != nullptr) {
1205     double hlo_cycles_executed = profile_->GetCyclesTakenBy(*instr);
1206     double total_cycles_executed =
1207         profile_->total_cycles_executed(*instr->parent());
1208     if (hlo_cycles_executed > 0 && total_cycles_executed > 0) {
1209       lines.push_back(
1210           StrFormat("%% of cycles executed=%.2f",
1211                     100 * hlo_cycles_executed / total_cycles_executed));
1212     }
1213   }
1214   return StrJoin(lines, "<br/>");
1215 }
1216 
AddInstructionIncomingEdges(const HloInstruction * instr)1217 void HloDotDumper::AddInstructionIncomingEdges(const HloInstruction* instr) {
1218   auto add_edge = [&](const HloInstruction* from, const HloInstruction* to,
1219                       int64_t operand_num, bool control_edge = false) {
1220     from = GetNodeForEdge(from);
1221 
1222     if (!filter_.Show(from) || from->opcode() == HloOpcode::kConstant ||
1223         IsFusedBroadcastOfConstantEffectiveScalar(from) ||
1224         ShouldMergeIntoUsers(from)) {
1225       return;
1226     }
1227     VLOG(2) << "Adding edge from " << from->name() << " to " << to->name()
1228             << " as " << next_edge_id_;
1229     edge_ids_.insert({{from, to}, next_edge_id_++});
1230 
1231     string edge_label;
1232     if (instr->operand_count() > 1 && !control_edge) {
1233       edge_label =
1234           StrFormat(R"( headlabel="%d", labeldistance=2)", operand_num);
1235     } else if (control_edge) {
1236       edge_label = "style=\"dotted\" color=\"gray\" label=\"ctrl\"";
1237     }
1238 
1239     // We print "small" arrays using a hollow arrowhead and "large" arrays using
1240     // a filled arrowhead.
1241     constexpr char kEdgeFmt[] =
1242         R"(%s -> %s [arrowhead=%s tooltip="%s -> %s" %s];)";
1243     edges_.push_back(StrFormat(kEdgeFmt, InstructionId(from), InstructionId(to),
1244                                (IsSmall(from) ? "empty" : "normal"),
1245                                from->name(), to->name(), edge_label));
1246   };
1247 
1248   // Add edges from instr's operands to instr.  Parameters within fusion
1249   // expressions are handled specially -- we draw an edge from the corresponding
1250   // operand on the fusion node itself to the parameter.
1251   if (instr->opcode() == HloOpcode::kParameter && instr->IsFused()) {
1252     // Only add the edge if this is not the outermost computation; otherwise it
1253     // will lead from a node we're not drawing.
1254     if (instr->parent() != computation_) {
1255       const HloInstruction* fusion = instr->parent()->FusionInstruction();
1256       add_edge(fusion->operand(instr->parameter_number()), instr,
1257                /*operand_num=*/0);
1258     }
1259   } else {
1260     for (int64_t i = 0; i < instr->operand_count(); ++i) {
1261       add_edge(instr->operand(i), instr, i);
1262     }
1263     for (const HloInstruction* pred : instr->control_predecessors()) {
1264       add_edge(pred, instr, /*operand_num=*/0, /*control_edge=*/true);
1265     }
1266   }
1267 }
1268 
GetInstructionTrivialComputationStr(const HloInstruction * instr)1269 string HloDotDumper::GetInstructionTrivialComputationStr(
1270     const HloInstruction* instr) {
1271   // called_computations() on a fusion node "inherits" any called computations
1272   // of the fused root, which isn't what we want.  Just ignore fusion nodes
1273   // here; they're handled separately.
1274   if (instr->opcode() == HloOpcode::kFusion) {
1275     return "";
1276   }
1277 
1278   std::vector<string> lines;
1279   for (int64_t i = 0; i < instr->called_computations().size(); ++i) {
1280     optional<string> computation_type =
1281         MatchTrivialComputation(instr->called_computations()[i]);
1282     if (!computation_type) {
1283       continue;
1284     }
1285     if (instr->called_computations().size() == 1) {
1286       lines.push_back(StrFormat("Subcomputation: <b>%s</b>",
1287                                 HtmlLikeStringSanitize(*computation_type)));
1288     } else {
1289       lines.push_back(StrFormat("Subcomputation %d: <b>%s</b>", i,
1290                                 HtmlLikeStringSanitize(*computation_type)));
1291     }
1292   }
1293   return StrJoin(lines, "<br/>");
1294 }
1295 
GetNodeForEdge(const HloInstruction * instr)1296 const HloInstruction* HloDotDumper::GetNodeForEdge(
1297     const HloInstruction* instr) {
1298   while (instr->opcode() == HloOpcode::kFusion &&
1299          ShouldShowFusionSubcomputation(instr)) {
1300     instr = instr->fused_expression_root();
1301   }
1302   return instr;
1303 }
1304 
1305 // Gets a NodeFilter that includes roughly all instructions whose distance from
1306 // root is <= radius.
MakeNodeRadiusAroundFilter(const HloInstruction * root,int64_t radius,const absl::flat_hash_set<const HloInstruction * > & boundary)1307 NodeFilter MakeNodeRadiusAroundFilter(
1308     const HloInstruction* root, int64_t radius,
1309     const absl::flat_hash_set<const HloInstruction*>& boundary) {
1310   // First, find the neighborhood of nodes with distance from root <= radius.
1311   // These nodes are our initial set of "normal" nodes.
1312   absl::flat_hash_map<const HloInstruction*, NodeFilterResult> nodes;
1313   std::deque<std::pair<const HloInstruction*, /*depth*/ int64>> worklist;
1314   worklist.push_back({root, 0});
1315   while (!worklist.empty()) {
1316     const HloInstruction* instr;
1317     int64_t depth;
1318     std::tie(instr, depth) = worklist.front();
1319     worklist.pop_front();
1320 
1321     nodes[instr] = kNormalNode;
1322     if (depth == radius) {
1323       continue;
1324     }
1325     if (boundary.contains(instr)) {
1326       continue;
1327     }
1328 
1329     // Traverse into instr's operands.
1330     //
1331     // Don't traverse into tuples' operands unless the tuple is the root.
1332     // Usually a tuple is the bottommost node in the graph, and so its operands
1333     // are not interesting to the graph at hand.
1334     if (instr == root || instr->opcode() != HloOpcode::kTuple) {
1335       for (const HloInstruction* operand : instr->operands()) {
1336         if (!nodes.contains(operand)) {
1337           worklist.push_back({operand, depth + 1});
1338         }
1339       }
1340     }
1341 
1342     // Traverse into instr's nested computations.
1343     for (const HloComputation* computation : instr->called_computations()) {
1344       worklist.push_back({computation->root_instruction(), depth + 1});
1345     }
1346 
1347     // Traverse into instr's users, unless:
1348     //
1349     //  - there are a ton of them, in which case they're probably not
1350     //    interesting (and anyway, rendering them all would make the graph
1351     //    unreadable), or
1352     //  - instr is a constant, in which case its users are probably not
1353     //    interesting.
1354     if (instr->opcode() == HloOpcode::kConstant) {
1355       continue;
1356     }
1357     constexpr int kMaxUsersToRender = 16;
1358     if (instr->user_count() > kMaxUsersToRender) {
1359       // If we're going to skip this node's users, style it as such.
1360       nodes[instr] = kSomeUsersOmitted;
1361       continue;
1362     }
1363     for (const HloInstruction* user : instr->users()) {
1364       if (!nodes.contains(user)) {
1365         worklist.push_back({user, depth + 1});
1366       }
1367     }
1368   }
1369 
1370   auto is_displayed = [&](const HloInstruction* instr) {
1371     // Constants are displayed inline with their users; they're never omitted.
1372     // Nodes in subcomputations are always shown.
1373     return nodes.contains(instr) || instr->opcode() == HloOpcode::kConstant ||
1374            instr->parent() != root->parent();
1375   };
1376 
1377   // Make a second pass over 'nodes' to fix up the NodeFilterResults now that we
1378   // know which nodes will be included in the graph.
1379   for (auto& kv : nodes) {
1380     const HloInstruction* instr = kv.first;
1381     NodeFilterResult& filter_result = kv.second;
1382     const auto& operands = instr->operands();
1383 
1384     if (absl::c_any_of(operands, is_displayed) &&
1385         !absl::c_all_of(operands, is_displayed)) {
1386       // Mark nodes with some operands omitted appropriately.
1387       filter_result = kSomeOperandsOmitted;
1388     } else if (!operands.empty() && absl::c_none_of(operands, is_displayed)) {
1389       // Mark nodes with *all* operands omitted appropriately.
1390       filter_result = kOmitNodeOperands;
1391     }
1392 
1393     // Promote nodes with type kSomeUsersOmitted to kNormalNode if all of their
1394     // users made it into the graph.
1395     if (filter_result == kSomeUsersOmitted &&
1396         absl::c_all_of(instr->users(), is_displayed)) {
1397       filter_result = kNormalNode;
1398     }
1399   }
1400 
1401   // Highlight the root node.
1402   nodes[root] = kHighlightNode;
1403 
1404   return NodeFilter([=](const HloInstruction* instr) {
1405     auto it = nodes.find(instr);
1406     if (it != nodes.end()) {
1407       return it->second;
1408     }
1409     // Show all nodes in subcomputations.
1410     if (instr->parent() != root->parent()) {
1411       return kNormalNode;
1412     }
1413     return kHideNode;
1414   });
1415 }
1416 
1417 // Gets a node filter that includes nodes on all paths from `from` to `to`.  If
1418 // the all-paths set contains more than max_nodes elements, includes the nodes
1419 // on the shortest paths and sets hit_limit to true.
MakeNodeFromToFilter(const HloInstruction * from,const HloInstruction * to,int64_t max_nodes,bool * hit_limit)1420 NodeFilter MakeNodeFromToFilter(const HloInstruction* from,
1421                                 const HloInstruction* to, int64_t max_nodes,
1422                                 bool* hit_limit) {
1423   *hit_limit = false;
1424 
1425   // Elements in the queue are paths through the graph.
1426   std::deque<std::vector<const HloInstruction*>> queue;
1427   queue.push_front({from});
1428 
1429   // Compute the set of nodes we want to show using a slightly-modified
1430   // Djikstra's algorithm.  The only real difference is, rather than stopping
1431   // when we find a (shortest) path, we continue until we've found max_nodes
1432   // nodes on some path.
1433   std::unordered_set<const HloInstruction*> visited;
1434   std::unordered_set<const HloInstruction*> to_display = {from, to};
1435   while (!queue.empty() && to_display.size() < max_nodes) {
1436     std::vector<const HloInstruction*> path = std::move(queue.front());
1437     queue.pop_front();
1438     if (!visited.insert(path.back()).second) {
1439       continue;
1440     }
1441 
1442     for (const auto* user : path.back()->users()) {
1443       if (user == to) {
1444         auto it = path.begin();
1445         for (; it != path.end() && to_display.size() < max_nodes; ++it) {
1446           to_display.insert(*it);
1447         }
1448         if (it != path.end()) {
1449           *hit_limit = true;
1450         }
1451       } else if (!visited.count(user)) {
1452         auto new_path = path;
1453         new_path.push_back(user);
1454         queue.push_back(std::move(new_path));
1455       }
1456     }
1457   }
1458 
1459   return NodeFilter([=](const HloInstruction* instr) {
1460     if (instr == from || instr == to) {
1461       return kHighlightNode;
1462     }
1463     return to_display.count(instr) ? kNormalNode : kHideNode;
1464   });
1465 }
1466 
WrapDotInHtml(absl::string_view dot)1467 string WrapDotInHtml(absl::string_view dot) {
1468   static const char html_prefix[] = R"html(
1469 <!DOCTYPE html>
1470 <html>
1471 <head>
1472   <meta charset="utf-8">
1473   <style type="text/css">
1474     body {
1475       height: 100vh;
1476       margin: 0;
1477     }
1478   </style>
1479 </head>
1480 <body>
1481   <!-- Integrity hash is generated by https://www.srihash.org/ -->
1482   <script src="https://cdn.jsdelivr.net/npm/viz.js@2.1.1/viz.js"
1483      integrity="sha384-aD1MJYb0WKIUT+CtwJp5LTuV3U4pLAS6B/nUxL7ECimC2pN9N8vjlMr/yQCAkzxE"
1484      crossorigin="anonymous"></script>
1485   <script src="https://cdn.jsdelivr.net/npm/viz.js@2.1.1/full.render.js"
1486      integrity="sha384-bAixY275aIpCj6Te19y0MILZ4V+VEC8CVFujFEH+Lf7W+4XYYeYLwW5IBI6yQmMT"
1487      crossorigin="anonymous"></script>
1488   <script src="https://cdn.jsdelivr.net/npm/svg-pan-zoom@3.6.0/dist/svg-pan-zoom.min.js"
1489      integrity="sha384-3008WpYB2pOBvE7lwkrKf+qTmbTPGGPYxA9C1YVhvbPukns4ZFj7E98QPLkNW9dS"
1490      crossorigin="anonymous"></script>
1491   <div id="container" style="height:95vh; border:1px solid black; "></div>
1492   <script>
1493     var data = `
1494 )html";
1495 
1496   static const char html_suffix[] = R"html(
1497 `;
1498     var cssregex = new RegExp('stylesheet=<([^]*)\n>\n', 'gm');
1499     var results = cssregex.exec(data)
1500     // graphviz has problem dealing with large stylesheets.
1501     // https://github.com/tensorflow/tensorflow/issues/17220#issuecomment-369228492
1502     // In order to avoid the problem, remove the stylesheet from the dot and
1503     // insert it directly info the rendered SVG.
1504     var dot_data = data;
1505     var css_data = ''
1506     if (results !== null) {
1507         css_data = results[1].replace(/\s*data:.*\s*,/,''); // Strip content-type field.
1508         // CSS inside DOT is URL-escaped, so we must unescape it
1509         // before we can insert it into SVG.
1510         css_data = unescape(css_data);
1511         dot_data = data.replace(cssregex, ''); // Remove the stylesheet
1512     }
1513 
1514     var render_start = performance.now()
1515     function add_controls(svg) {
1516         var htmlblob = new Blob([document.documentElement.innerHTML],
1517                                 {type: 'text/html'});
1518         var savehtml = document.createElement('a');
1519         savehtml.setAttribute('href', URL.createObjectURL(htmlblob));
1520         savehtml.setAttribute('download', 'graph.html');
1521         savehtml.innerHTML = " [Save HTML+SVG] ";
1522         document.body.append(savehtml);
1523         var svgblob = new Blob([svg.outerHTML], {type: 'image/svg'});
1524         var savesvg = document.createElement('a');
1525         savesvg.setAttribute('href', URL.createObjectURL(svgblob));
1526         savesvg.setAttribute('download', 'graph.svg');
1527         savesvg.innerHTML = " [Save SVG] ";
1528         document.body.append(savesvg);
1529         var dotblob =  new Blob([data], {type: 'text/dot'});
1530         var savedot = document.createElement('a');
1531         savedot.setAttribute('href', URL.createObjectURL(dotblob));
1532         savedot.setAttribute('download', 'graph.dot');
1533         savedot.innerHTML = " [Save DOT] ";
1534         document.body.append(savedot);
1535         // Will get called after embed element was loaded
1536         var panzoom = svgPanZoom(svg, {
1537             zoomEnabled: true,
1538             controlIconsEnabled: true,
1539         });
1540         document.getElementsByTagName("BODY")[0].onresize = function() {
1541             panzoom.resize();
1542             panzoom.fit();
1543             panzoom.center();
1544         };
1545         var render_end = performance.now();
1546         var render_note = document.createElement('div')
1547         render_note.innerHTML = 'Rendering took '
1548                                 + (render_end - render_start).toFixed(2) + "ms."
1549         document.body.append(render_note);
1550     }
1551     var svg = document.getElementById('graph')
1552     if (svg == null) {
1553         // Need to render SVG first.
1554         var viz = new Viz();
1555         viz.renderSVGElement(dot_data)
1556             .then(function(svg){
1557                 var container = document.getElementById('container')
1558                 var style = document.createElementNS('http://www.w3.org/2000/svg', 'style');
1559                 var node = document.createTextNode(css_data);
1560                 style.appendChild(node);
1561                 svg.setAttribute('width', '100%');
1562                 svg.setAttribute('height', '100%');
1563                 svg.setAttribute('id', 'graph');
1564                 svg.appendChild(style);
1565                 container.appendChild(svg);
1566                 add_controls(svg);
1567             })
1568     } else {
1569         // HTML already has rendered SVG embedded, so we just need to add
1570         // controls.
1571         add_controls(svg);
1572     }
1573   </script>
1574 </body>
1575 </html>
1576 )html";
1577 
1578   return absl::StrCat(html_prefix, dot, html_suffix);
1579 }
1580 
1581 tensorflow::mutex url_renderer_mu(tensorflow::LINKER_INITIALIZED);
1582 std::function<StatusOr<string>(absl::string_view)>* url_renderer
1583     TF_GUARDED_BY(url_renderer_mu) = nullptr;
1584 
1585 // Storage for fusion visualization: (module_id, computation_id) -> sequence of
1586 // dot dumps.
1587 tensorflow::mutex fusion_visualizer_state_mu(tensorflow::LINKER_INITIALIZED);
1588 static auto& fusion_visualizer_state TF_GUARDED_BY(fusion_visualizer_state_mu) =
1589     *new absl::flat_hash_map<std::pair<int64, int64>,
1590                              std::vector<std::string>>();
1591 
1592 // Generates a key to the fusion visualizer state mapping.
1593 std::pair<int, int> FusionVisualizerStateKey(
1594     const HloComputation& computation) {
1595   return std::make_pair(computation.parent()->unique_id(),
1596                         computation.unique_id());
1597 }
1598 
1599 // Generates a fusion explorer for the given computation using the data in
1600 // fusion_visualizer_state and the URL renderer. Precondition: url_renderer !=
1601 // nullptr.
1602 StatusOr<std::string> WrapFusionExplorer(const HloComputation& computation)
1603     TF_EXCLUSIVE_LOCKS_REQUIRED(url_renderer_mu) {
1604   CHECK(url_renderer != nullptr);
1605   tensorflow::mutex_lock lock(fusion_visualizer_state_mu);
1606   const std::vector<std::string>& dot_graphs =
1607       fusion_visualizer_state[FusionVisualizerStateKey(computation)];
1608   std::vector<std::string> dot_urls;
1609   dot_urls.reserve(dot_graphs.size());
1610   for (const std::string& dot : dot_graphs) {
1611     TF_ASSIGN_OR_RETURN(std::string url, (*url_renderer)(dot));
1612     dot_urls.push_back(url);
1613   }
1614 
1615   return absl::StrReplaceAll(
1616       R"(
1617   <!doctype html>
1618   <style>
1619     html, body {height: 100%; text-align: center;}
1620     #display {height: 80%; width: 80%;}
1621   </style>
1622   <title>Fusion Explorer: $TITLE</title>
1623   <iframe id='display' width=80% height=80%></iframe>
1624   <p id='description'></p>
1625   <p>
1626     <a id='prev' href='#'>Prev Step</a>
1627     <a id='next' href='#'>Next Step</a>
1628   </p>
1629   <p>
1630     Use j/k for keyboard navigation.
1631   </p>
1632   <script>
1633   var currId = -1;
1634   var urls = [$URLS];
1635 
1636   var setIframe = function() {
1637     document.getElementById('display').src = urls[currId];
1638   };
1639 
1640   var update = function(delta)  {
1641     currId = (currId + delta + urls.length) % urls.length;
1642     document.getElementById('description').innerHTML = "Frame #"
1643       + (currId + 1) + " / " + urls.length;
1644     setIframe();
1645   };
1646 
1647   document.getElementById('prev').onclick = function() {
1648     update(-1);
1649     return false;
1650   };
1651 
1652   document.getElementById('next').onclick = function() {
1653     update(1);
1654     return false;
1655   };
1656 
1657   window.addEventListener("keydown", function (event) {
1658     if (event.defaultPrevented) {
1659       return;
1660     }
1661     if (event.key == "j") {
1662       update(1);
1663     } else if (event.key == "k") {
1664       update(-1);
1665     } else {
1666       return;
1667     }
1668     event.preventDefault();
1669   }, true);
1670 
1671   document.addEventListener("DOMContentLoaded", function() {
1672     update(1);
1673   });
1674 
1675   </script>
1676   )",
1677       {{"$URLS", absl::StrJoin(dot_urls, ", ",
1678                                [&](std::string* out, const std::string& url) {
1679                                  absl::StrAppend(out, "\"", url, "\"");
1680                                })},
1681        {"$TITLE",
1682         absl::StrCat(computation.parent()->name(), "_", computation.name())}});
1683 }
1684 
1685 // Precondition: (url_renderer != nullptr || (format != kUrl
1686 //   && format != kFusionVisualization)).
1687 //
1688 // (We specify this as a precondition rather than checking it in here and
1689 // returning an error because we want to fail quickly when there's no URL
1690 // renderer available, and this function runs only after we've done all the work
1691 // of producing dot for the graph.)
WrapDotInFormat(const HloComputation & computation,absl::string_view dot,RenderedGraphFormat format)1692 StatusOr<string> WrapDotInFormat(const HloComputation& computation,
1693                                  absl::string_view dot,
1694                                  RenderedGraphFormat format)
1695     TF_EXCLUSIVE_LOCKS_REQUIRED(url_renderer_mu) {
1696   switch (format) {
1697     case RenderedGraphFormat::kUrl:
1698       CHECK(url_renderer != nullptr)
1699           << "Should have checked url_renderer != null before calling.";
1700       return (*url_renderer)(dot);
1701     case RenderedGraphFormat::kHtml:
1702       return WrapDotInHtml(dot);
1703     case RenderedGraphFormat::kDot:
1704       return string(dot);
1705     case RenderedGraphFormat::kFusionVisualization:
1706       return WrapFusionExplorer(computation);
1707   }
1708 }
1709 
1710 }  // namespace
1711 
RegisterGraphToURLRenderer(std::function<StatusOr<string> (absl::string_view)> renderer)1712 void RegisterGraphToURLRenderer(
1713     std::function<StatusOr<string>(absl::string_view)> renderer) {
1714   tensorflow::mutex_lock lock(url_renderer_mu);
1715   if (url_renderer != nullptr) {
1716     LOG(WARNING) << "Multiple calls to RegisterGraphToURLRenderer.  Last call "
1717                     "wins, but because order of initialization in C++ is "
1718                     "nondeterministic, this may not be what you want.";
1719   }
1720   delete url_renderer;
1721   url_renderer = new std::function<StatusOr<string>(absl::string_view)>(
1722       std::move(renderer));
1723 }
1724 
RegisterFusionState(const HloComputation & computation,absl::string_view label)1725 Status RegisterFusionState(const HloComputation& computation,
1726                            absl::string_view label) {
1727   tensorflow::mutex_lock lock(fusion_visualizer_state_mu);
1728   TF_ASSIGN_OR_RETURN(
1729       string dot_graph,
1730       RenderGraph(computation,
1731                   absl::StrCat(computation.parent()->name(), ", ",
1732                                computation.name(), ", ", label),
1733                   /*debug_options=*/{}, xla::RenderedGraphFormat::kDot,
1734                   /*hlo_execution_profile=*/nullptr,
1735                   /*hlo_render_options=*/{}));
1736   std::vector<std::string>& fusion_states =
1737       fusion_visualizer_state[FusionVisualizerStateKey(computation)];
1738   if (fusion_states.empty() || fusion_states.back() != dot_graph) {
1739     fusion_states.push_back(dot_graph);
1740   }
1741   return Status::OK();
1742 }
1743 
RenderGraph(const HloComputation & computation,absl::string_view label,const DebugOptions & debug_options,RenderedGraphFormat format,const HloExecutionProfile * hlo_execution_profile,HloRenderOptions hlo_render_options)1744 StatusOr<string> RenderGraph(const HloComputation& computation,
1745                              absl::string_view label,
1746                              const DebugOptions& debug_options,
1747                              RenderedGraphFormat format,
1748                              const HloExecutionProfile* hlo_execution_profile,
1749                              HloRenderOptions hlo_render_options) {
1750   tensorflow::mutex_lock lock(url_renderer_mu);
1751   if (format == RenderedGraphFormat::kUrl && url_renderer == nullptr) {
1752     return Unavailable("Can't render as URL; no URL renderer was registered.");
1753   }
1754 
1755   string rendered_dot =
1756       HloDotDumper(&computation, label, debug_options, hlo_render_options,
1757                    hlo_execution_profile, NodeFilter())
1758           .Dump();
1759   return WrapDotInFormat(computation, rendered_dot, format);
1760 }
1761 
RenderNeighborhoodAround(const HloInstruction & node,int radius,RenderedGraphFormat format,HloRenderOptions hlo_render_options,const absl::flat_hash_set<const HloInstruction * > & boundary)1762 StatusOr<string> RenderNeighborhoodAround(
1763     const HloInstruction& node, int radius, RenderedGraphFormat format,
1764     HloRenderOptions hlo_render_options,
1765     const absl::flat_hash_set<const HloInstruction*>& boundary) {
1766   tensorflow::mutex_lock lock(url_renderer_mu);
1767   if (format == RenderedGraphFormat::kUrl && url_renderer == nullptr) {
1768     return FailedPrecondition(
1769         "Can't render as URL; no URL renderer was registered.");
1770   }
1771 
1772   string label =
1773       StrCat("Neighborhood of ", radius, " nodes around ", node.name());
1774   string rendered_dot =
1775       HloDotDumper(node.parent(), label,
1776                    node.GetModule()->config().debug_options(),
1777                    hlo_render_options, /*profile=*/nullptr,
1778                    MakeNodeRadiusAroundFilter(&node, radius, boundary))
1779           .Dump();
1780   return WrapDotInFormat(*node.parent(), rendered_dot, format);
1781 }
1782 
RenderAllPathsFromTo(const HloInstruction & from,const HloInstruction & to,int64_t max_nodes,RenderedGraphFormat format,HloRenderOptions hlo_render_options)1783 StatusOr<string> RenderAllPathsFromTo(const HloInstruction& from,
1784                                       const HloInstruction& to,
1785                                       int64_t max_nodes,
1786                                       RenderedGraphFormat format,
1787                                       HloRenderOptions hlo_render_options) {
1788   tensorflow::mutex_lock lock(url_renderer_mu);
1789   if (format == RenderedGraphFormat::kUrl && url_renderer == nullptr) {
1790     return FailedPrecondition(
1791         "Can't render as URL; no URL renderer was registered.");
1792   }
1793 
1794   CHECK_EQ(from.parent(), to.parent()) << "Nodes must be in same computation!";
1795   auto debug_options = from.GetModule()->config().debug_options();
1796 
1797   bool hit_limit = false;
1798   NodeFilter filter = MakeNodeFromToFilter(&from, &to, max_nodes, &hit_limit);
1799   string label;
1800   if (!hit_limit) {
1801     label = StrCat("All paths from ", from.name(), " to ", to.name());
1802   } else {
1803     label = StrCat(max_nodes, " nodes on the shortest paths from ", from.name(),
1804                    " to ", to.name(),
1805                    "<br/><br/>***SHOWING ONLY A SUBSET OF ALL PATHS BETWEEN "
1806                    "NODES***<br/><br/>");
1807   }
1808   string rendered_dot =
1809       HloDotDumper(from.parent(), label, debug_options, hlo_render_options,
1810                    /*profile=*/nullptr, filter)
1811           .Dump();
1812   return WrapDotInFormat(*from.parent(), rendered_dot, format);
1813 }
1814 
1815 }  // namespace xla
1816