• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
17 
18 #include <unistd.h>
19 
20 #include <algorithm>
21 #include <atomic>
22 #include <deque>
23 #include <map>
24 #include <memory>
25 #include <queue>
26 #include <string>
27 #include <tuple>
28 #include <vector>
29 
30 #include "absl/container/flat_hash_map.h"
31 #include "absl/strings/match.h"
32 #include "absl/strings/str_cat.h"
33 #include "absl/strings/str_format.h"
34 #include "absl/strings/str_join.h"
35 #include "absl/strings/str_replace.h"
36 #include "absl/types/optional.h"
37 #include "tensorflow/compiler/xla/layout_util.h"
38 #include "tensorflow/compiler/xla/literal.h"
39 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
40 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
41 #include "tensorflow/compiler/xla/service/hlo_module.h"
42 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
43 #include "tensorflow/compiler/xla/shape_util.h"
44 #include "tensorflow/compiler/xla/types.h"
45 #include "tensorflow/compiler/xla/util.h"
46 #include "tensorflow/compiler/xla/window_util.h"
47 #include "tensorflow/core/lib/core/status.h"
48 #include "tensorflow/core/lib/gtl/map_util.h"
49 #include "tensorflow/core/lib/io/path.h"
50 #include "tensorflow/core/lib/strings/numbers.h"
51 #include "tensorflow/core/platform/env.h"
52 #include "tensorflow/core/platform/mutex.h"
53 #include "tensorflow/core/platform/protobuf.h"
54 #include "tensorflow/core/platform/regexp.h"
55 
56 namespace xla {
57 namespace {
58 
59 using absl::nullopt;
60 using absl::optional;
61 using absl::StrAppend;
62 using absl::StrCat;
63 using absl::StrFormat;
64 using absl::StrJoin;
65 
66 // Used to indicate how we should treat a given HLOInstruction in the graph.
67 // should we treat it like normal, hide it, and so on?
68 enum NodeFilterResult {
69   kNormalNode,
70   kHideNode,
71   // Make the node easy to find in the final graph.
72   kHighlightNode,
73   // "Gray out" the node to indicate that some of its operands have been
74   // omitted.
75   kSomeOperandsOmitted,
76   // Style the node the same as kSomeOperandsOmitted, but also don't connect it
77   // to its operands, even if they're present in the graph.
78   kOmitNodeOperands,
79   // Same style as kSomeOperandsOmitted, but used to indicate that some of the
80   // node's *users* have been omitted.
81   kSomeUsersOmitted,
82 };
83 
84 // NodeFilter is essentially a map from HloInstruction*s to NodeFilterResult.
85 // It lets callers tell the graph-drawing routines which nodes they want to be
86 // shown, hidden, or highlighted.
87 class NodeFilter {
88  public:
__anon627048d20202(const HloInstruction*) 89   NodeFilter() : filter_([](const HloInstruction*) { return kNormalNode; }) {}
90 
NodeFilter(std::function<NodeFilterResult (const HloInstruction * instr)> filter)91   explicit NodeFilter(
92       std::function<NodeFilterResult(const HloInstruction* instr)> filter)
93       : filter_(std::move(filter)) {}
94 
Show(const HloInstruction * instr) const95   bool Show(const HloInstruction* instr) const {
96     return filter_(instr) != kHideNode;
97   }
Highlight(const HloInstruction * instr) const98   bool Highlight(const HloInstruction* instr) const {
99     return filter_(instr) == kHighlightNode;
100   }
OmitOperands(const HloInstruction * instr) const101   bool OmitOperands(const HloInstruction* instr) const {
102     return filter_(instr) == kOmitNodeOperands;
103   }
SomeOrAllOperandsOmitted(const HloInstruction * instr) const104   bool SomeOrAllOperandsOmitted(const HloInstruction* instr) const {
105     auto result = filter_(instr);
106     return result == kOmitNodeOperands || result == kSomeOperandsOmitted;
107   }
Deemphasized(const HloInstruction * instr) const108   bool Deemphasized(const HloInstruction* instr) const {
109     auto result = filter_(instr);
110     return result == kOmitNodeOperands || result == kSomeOperandsOmitted ||
111            result == kSomeUsersOmitted;
112   }
113 
114  private:
115   std::function<NodeFilterResult(const HloInstruction* instr)> filter_;
116 };
117 
118 // We arbitrarily set this as the boundary between "large" and "small"
119 // instructions.
IsSmall(const HloInstruction * instr)120 bool IsSmall(const HloInstruction* instr) {
121   if (ShapeUtil::HasPrimitiveType(instr->shape(), OPAQUE_TYPE) ||
122       ShapeUtil::HasPrimitiveType(instr->shape(), TOKEN)) {
123     return true;
124   }
125   return ShapeUtil::ElementsInRecursive(instr->shape()) < 4096;
126 }
127 
128 // Node color schemes, used by NodeColorAttributes.
129 enum ColorScheme {
130   kBlue,
131   kBrown,
132   kDarkBlue,
133   kDarkGreen,
134   kDarkOrange,
135   kDarkRed,
136   kGray,
137   kGreen,
138   kOrange,
139   kPurple,
140   kRed,
141   kWhite,
142   kYellow,
143 
144   // Causes the node's border to be a dashed line, and its content to be gray
145   // text on a white background, suggesting that this is an "unimportant" node.
146   kDashedBorder,
147 };
148 
149 // Graphviz attributes/colors that make up a color scheme.
150 struct NodeColors {
151   const char* style;
152   const char* fill_color;
153   const char* stroke_color;
154   const char* font_color;
155 };
156 
NodeColorsForScheme(ColorScheme color)157 NodeColors NodeColorsForScheme(ColorScheme color) {
158   switch (color) {
159     case kBlue:
160       return NodeColors{"filled", "#bbdefb", "#8aacc8", "black"};
161     case kBrown:
162       return NodeColors{"filled", "#bcaaa4", "#8c7b75", "black"};
163     case kDarkBlue:
164       return NodeColors{"filled", "#1565c0", "#003c8f", "white"};
165     case kDarkGreen:
166       return NodeColors{"filled", "#2e7d32", "#005005", "white"};
167     case kDarkOrange:
168       // This is more of a "medium" orange, made to look close to kOrange;
169       // there's probably room for a darker weight if desired.
170       return NodeColors{"filled", "#ffb74d", "#c88719", "black"};
171     case kDarkRed:
172       return NodeColors{"filled", "#b71c1c", "#7f0000", "white"};
173     case kGray:
174       return NodeColors{"filled", "#cfd8dc", "#9ea7aa", "black"};
175     case kGreen:
176       return NodeColors{"filled", "#c8e6c9", "#97b498", "black"};
177     case kOrange:
178       return NodeColors{"filled", "#ffe0b2", "#cbae82", "black"};
179     case kPurple:
180       return NodeColors{"filled", "#e1bee7", "#af8eb5", "black"};
181     case kRed:
182       return NodeColors{"filled", "#ffcdd2", "#cb9ca1", "black"};
183     case kWhite:
184       return NodeColors{"filled", "white", "black", "black"};
185     case kYellow:
186       return NodeColors{"filled", "#fff9c4", "#cbc693", "black"};
187     case kDashedBorder:
188       // "filled,dashed" looks the same as "dashed", since we have a white
189       // background.  But we use "filled,dashed" so that when you hover over
190       // any part of the node (not just the text inside the node), our css
191       // :hover rule is triggered.
192       return NodeColors{"filled,dashed", "white", "#757575", "#757575"};
193   }
194 }
195 
196 // Given a ColorScheme, returns an attribute string for a node of that color.
197 // Sets the node's style and fill/stroke/text colors.
198 //
199 // Colors are from https://material.io/color.
NodeColorAttributes(ColorScheme color)200 string NodeColorAttributes(ColorScheme color) {
201   NodeColors node_colors = NodeColorsForScheme(color);
202 
203   return StrFormat(R"(style="%s", fontcolor="%s", color="%s", fillcolor="%s")",
204                    node_colors.style, node_colors.font_color,
205                    node_colors.stroke_color, node_colors.fill_color);
206 }
207 
208 // Replaces <> with &lt;&gt;, so that this string is safe(er) for use in a
209 // graphviz HTML-like string.
HtmlLikeStringSanitize(absl::string_view s)210 string HtmlLikeStringSanitize(absl::string_view s) {
211   return absl::StrReplaceAll(s, {{"<", "&lt;"}, {">", "&gt;"}});
212 }
213 
IsFusedBroadcastOfConstantEffectiveScalar(const HloInstruction * instr)214 bool IsFusedBroadcastOfConstantEffectiveScalar(const HloInstruction* instr) {
215   namespace m = match;
216   return instr->parent()->IsFusionComputation() &&
217          Match(instr, m::Broadcast(m::ConstantEffectiveScalar()));
218 }
219 
220 // Tries to generates a human-readable one-word description of the given
221 // computation.
222 //
223 // Currently we support:
224 //
225 //   "return param0 + param1;"      --> "add"
226 //   "return param0 * param1;"      --> "multiply"
227 //   "return min(param0, param1);"  --> "min"
228 //   "return max(param0, param1);"  --> "max"
229 //   "return param0 <= param1;"     --> "less-or-equal"
230 //   "return param0 >= param1;"     --> "greater-or-equal"
231 //   "return param0 >  param1;"     --> "greater-than"
232 //   "return param0 <  param1;"     --> "less-than"
233 //   "return param0 == param1;"     --> "equal-to"
234 //   "return param0 != param1;"     --> "not-equal-to"
235 //
236 // where param0 and param1 are effective scalars.  For the ops that are
237 // commutative, we also support them with param0 and param1 swapped.
238 //
239 // This is useful primarily for reduce and map nodes.  These take a
240 // subcomputation which is almost always one of the above, and pattern matching
241 // it to a short string lets us tell the user what the subcomputation is without
242 // drawing it as a graph.
MatchTrivialComputation(const HloComputation * computation)243 optional<string> MatchTrivialComputation(const HloComputation* computation) {
244   namespace m = match;
245 
246   if (computation->instruction_count() != 3) {
247     return nullopt;
248   }
249   HloInstruction* root = computation->root_instruction();
250   const HloInstruction *param0, *param1;
251   if (!Match(root, m::Op()
252                        .WithNumOperands(2)
253                        .WithShape(m::Shape().IsEffectiveScalar())
254                        .WithBinaryOperandsAnyOrder(
255                            m::Parameter(&param0, 0)
256                                .WithShape(m::Shape().IsEffectiveScalar()),
257                            m::Parameter(&param1, 1)
258                                .WithShape(m::Shape().IsEffectiveScalar())))) {
259     return nullopt;
260   }
261 
262   // If the params are reversed (i.e. operand0 is param1 and operand1 is
263   // param0), check that the operation being performed is commutative.
264   if (root->operand(0) == param1) {
265     CHECK_EQ(root->operand(1), param0);
266     if (root->opcode() == HloOpcode()) {
267       switch (root->comparison_direction()) {
268         case ComparisonDirection::kLe:
269         case ComparisonDirection::kGe:
270         case ComparisonDirection::kGt:
271         case ComparisonDirection::kLt:
272           return nullopt;
273         default:
274           break;
275       }
276     }
277   }
278 
279   // If we recognize the root's opcode, we've successfully pattern-matched!
280   switch (root->opcode()) {
281     case HloOpcode::kAdd:
282       return "add";
283     case HloOpcode::kMultiply:
284       return "multiply";
285     case HloOpcode::kMinimum:
286       return "min";
287     case HloOpcode::kMaximum:
288       return "max";
289     case HloOpcode::kCompare: {
290       switch (root->comparison_direction()) {
291         case ComparisonDirection::kLe:
292           return "less-or-equal";
293         case ComparisonDirection::kGe:
294           return "greater-or-equal";
295         case ComparisonDirection::kGt:
296           return "greater-than";
297         case ComparisonDirection::kLt:
298           return "less-than";
299         case ComparisonDirection::kEq:
300           return "equal-to";
301         case ComparisonDirection::kNe:
302           return "not-equal-to";
303       }
304     }
305     default:
306       return nullopt;
307   }
308 }
309 
310 // Encapsulates logic for dumping an HLO module to DOT (i.e. graphviz syntax).
311 class HloDotDumper {
312  public:
HloDotDumper(const HloComputation * computation,absl::string_view label,const DebugOptions & debug_options,bool show_backend_config,const HloExecutionProfile * profile,NodeFilter filter)313   HloDotDumper(const HloComputation* computation, absl::string_view label,
314                const DebugOptions& debug_options, bool show_backend_config,
315                const HloExecutionProfile* profile, NodeFilter filter)
316       : computation_(computation),
317         label_(label),
318         debug_options_(debug_options),
319         show_backend_config_(show_backend_config),
320         profile_(profile),
321         filter_(std::move(filter)) {}
322 
323   string Dump();
324 
325  private:
326   // Returns the dot graph identifier for the given instruction.
InstructionId(const HloInstruction * instruction)327   string InstructionId(const HloInstruction* instruction) {
328     return StrCat(reinterpret_cast<uint64>(instruction));
329   }
330 
331   // Returns the dot graph identifier for the given computation.
SubcomputationId(const HloComputation * computation)332   string SubcomputationId(const HloComputation* computation) {
333     return StrCat("cluster_", reinterpret_cast<uint64>(computation));
334   }
335 
336   // Generates graph header/footer.  These should be called *after* dumping all
337   // of the instructions and subcomputations for the graph, as they both use
338   // data generated while dumping the graph.
339   string Header();
340   string Footer();
341 
342   bool ShouldShowSubcomputation(const HloComputation* subcomp);
343   bool ShouldShowFusionSubcomputation(const HloInstruction* instr);
344 
345   // We omit some nodes from the graph, instead drawing them inlined into the
346   // nodes that use them.
347   bool ShouldMergeIntoUsers(const HloInstruction* instr) const;
348 
349   string DumpSubcomputation(const HloComputation* subcomp,
350                             const HloInstruction* parent_instr);
351   string DumpComputation(const HloComputation* comp);
352   string DumpRootTag();
353   string DumpInstruction(const HloInstruction* instr);
354   ColorScheme GetInstructionColor(const HloInstruction* instr);
355   string GetInstructionNodeShape(const HloInstruction* instr);
356   string GetInstructionNodeLabel(const HloInstruction* instr);
357   string GetInstructionNodeMetadata(const HloInstruction* instr);
358   string GetInstructionNodeBackendConfig(const HloInstruction* instr);
359   string GetInstructionNodeExtraInfo(const HloInstruction* instr);
360   string GetInstructionNodeInlinedOperands(const HloInstruction* instr);
361   void AddInstructionIncomingEdges(const HloInstruction* instr);
362 
363   // For most instructions, GetNodeForEdge(instr) returns instr.
364   //
365   // The exception is fusion nodes.  For these, we walk up the chain of nested
366   // fusion nodes starting at instr until we reach a node that either (a) isn't
367   // a fusion node, or (b) is a fusion node for which
368   // ShouldShowFusionSubcomputation is false.
369   //
370   // We do this because fusion nodes are expanded inline -- if
371   // ShouldShowFusionSubcomputation is true, the fusion node won't be present in
372   // the graph.
373   //
374   // In general when you want to draw an edge from A to B, you should actually
375   // draw an edge from GetNodeForEdge(A) to GetNodeForEdge(B).
376   const HloInstruction* GetNodeForEdge(const HloInstruction* instr);
377 
378   // If instr has just one computation and it's trivial (e.g. "return param0 +
379   // param1"), returns a string you can put into the node's body that names the
380   // subcomputation, e.g. "Subcomputation: <b>add</b>".
381   string GetInstructionTrivialComputationStr(const HloInstruction* instr);
382 
383   const HloComputation* computation_;  // never null
384   const string label_;                 // overall name for the graph
385   const DebugOptions& debug_options_;
386   const bool show_backend_config_;
387   const HloExecutionProfile* profile_;  // may be null
388   const NodeFilter filter_;
389 
390   // Each HloInstruction dumped gets a monotonically-increasing node ID.  This
391   // must start at 1, because that's where graphviz's accounting starts.
392   int64 next_node_id_ = 1;
393   absl::flat_hash_map<const HloInstruction*, int64> node_ids_;
394 
395   // The "root" tag doesn't have an associated HloInstruction pointer, so we
396   // need to store it outside the map.
397   int64 root_node_id_;
398 
399   // Each (from, to) edge gets a monotonically-increasing ID.  This is a
400   // multimap because it's possible for the same edge to appear multiple times
401   // in the graph (e.g. x^2 may be represented as mul(x, x)).
402   int64 next_edge_id_ = 1;
403   std::unordered_multimap<
404       std::pair<const HloInstruction*, const HloInstruction*>, int64,
405       tensorflow::hash<std::pair<const HloInstruction*, const HloInstruction*>>>
406       edge_ids_;
407 
408   // Each HloComputation that's emitted gets a monotonically-increasing ID.
409   int64 next_cluster_id_ = 1;
410   absl::flat_hash_map<const HloComputation*, int64> cluster_ids_;
411 
412   // Edges to print from Footer().  Edges come at the end because graphviz is
413   // unhappy if an edge from a subcomputation to a node in the outer computation
414   // appears before both the inner computation and the destination node are
415   // defined.
416   std::vector<string> edges_;
417 
418   // When coloring by sharding information, we track the sharding string
419   // representation to color association, by round-robin the color schemes.
420   absl::flat_hash_map<HloSharding, ColorScheme, HloSharding::Hasher>
421       sharding_colors_;
422   int64 next_shard_color_ = 0;
423 };
424 
Dump()425 string HloDotDumper::Dump() {
426   string body;
427   StrAppend(&body, DumpComputation(computation_));
428   StrAppend(&body, DumpRootTag());
429 
430   // By contract, Header() and Footer() have to be called after we've dumped all
431   // our instructions, because they use state generated during that process.
432   string g = Header();
433   StrAppend(&g, body);
434   StrAppend(&g, Footer());
435   return g;
436 }
437 
Header()438 string HloDotDumper::Header() {
439   constexpr char fmt[] = R"(digraph G {
440 rankdir = TB;
441 compound = true;
442 label = <<b>%s</b>>;
443 labelloc = t;
444 // Disable the tooltip.  Interestingly, "" doesn't work!
445 tooltip = " ";
446 // DOT graphs accept a stylesheet as a URI.  So naturally, an inline
447 // stylesheet is a data URI!
448 stylesheet=<
449   data:text/css,
450   @import url(https://fonts.googleapis.com/css?family=Roboto:400,700);
451   svg text {
452     font-family: 'Roboto';
453     font-size: 12px;
454   }
455 
456   %s
457 >
458 
459 )";
460 
461   VLOG(3) << "Generating Header";
462 
463   string graph_label =
464       StrCat(label_, "<br/>Computation ", computation_->name());
465   if (computation_->IsFusionComputation()) {
466     StrAppend(&graph_label, " (in fusion instruction ",
467               computation_->FusionInstruction()->name(), ")");
468   }
469   if (profile_ != nullptr) {
470     auto cycles = profile_->total_cycles_executed(*computation_);
471     absl::StrAppendFormat(&graph_label, "<br/>total cycles = %d (%s)", cycles,
472                           tensorflow::strings::HumanReadableNum(cycles));
473   }
474 
475   // Create CSS rules that say, when you hover over the given node or cluster,
476   // turn the given edge the given color.
477   //
478   // We rely on a few properties of how graphviz generates SVGs:
479   //
480   //  - Nodes are named "nodeN", where N corresponds to the 1-based index of
481   //    the node in our DOT (i.e. the first node in the DOT is "node1", etc.).
482   //    Edges are similarly named "edgeN", and clusters are named "clustN".
483   //  - Nodes come before their in- and out-edges in the SVG.  We need this
484   //    because the "X ~ Y" CSS selector finds a sibling of X that *comes
485   //    after X in the DOM* and matches Y.
486   std::vector<string> edge_css_rules;
487   const char* kBlue = "#1976d2";
488   const char* kRed = "#d32f2f";
489   for (const auto& kv : edge_ids_) {
490     const HloInstruction* from_node = kv.first.first;
491     const HloInstruction* to_node = kv.first.second;
492     int64 edge_id = kv.second;
493 
494     auto add_hover_css_rule = [&](string elem_type, int64 elem_id,
495                                   const char* color) {
496       // One could imagine other ways of writing this CSS rule that involve
497       // less duplication, but this way seems to be relatively performant.
498       edge_css_rules.push_back(
499           StrFormat("  #%s%d:hover ~ #edge%d text { fill: %s; }\n"
500                     "  #%s%d:hover ~ #edge%d path { "
501                     "stroke: %s; stroke-width: .2em; }\n"
502                     "  #%s%d:hover ~ #edge%d polygon { "
503                     "fill: %s; stroke: %s; stroke-width: .2em; }\n",
504                     elem_type, elem_id, edge_id, color,  //
505                     elem_type, elem_id, edge_id, color,  //
506                     elem_type, elem_id, edge_id, color, color));
507     };
508 
509     // The "to_node" value may be a NULL, indicating that this points to the
510     // "root" tag rather than a normal node.
511     int64 from_node_id =
512         tensorflow::gtl::FindWithDefault(node_ids_, from_node, -1);
513     if (from_node_id == -1) {
514       LOG(FATAL) << from_node->name() << " was added to edges but not to nodes";
515     }
516     int64 to_node_id =
517         to_node ? tensorflow::gtl::FindWithDefault(node_ids_, to_node, -1)
518                 : root_node_id_;
519     if (to_node != nullptr && to_node_id == -1) {
520       LOG(FATAL) << to_node->name() << " was added to edges but not to nodes";
521     }
522 
523     add_hover_css_rule("node", from_node_id, kBlue);
524     add_hover_css_rule("node", to_node_id, kRed);
525 
526     if (to_node) {
527       VLOG(3) << "Adding css for edge " << edge_id << " from node "
528               << from_node->name() << " to node " << to_node->name();
529     } else {
530       VLOG(3) << "Adding css for edge " << edge_id << " from node "
531               << from_node->name() << " to root tag";
532     }
533 
534     // If this edge crosses a fusion cluster boundary, highlight it when the
535     // cluster is hovered over.
536     if (to_node) {
537       if (from_node->IsFused() &&
538           from_node->parent()->root_instruction() == from_node) {
539         int64 cluster_id = cluster_ids_.at(from_node->parent());
540         add_hover_css_rule("clust", cluster_id, kBlue);
541       }
542       if (to_node->IsFused() && to_node->opcode() == HloOpcode::kParameter) {
543         int64 cluster_id = cluster_ids_.at(to_node->parent());
544         add_hover_css_rule("clust", cluster_id, kRed);
545       }
546     }
547   }
548 
549   // Browsers require that we URI-encode the contents of our data URI.  (It
550   // seems this was a relatively recent change?) In practice, this means that we
551   // need to escape '#'.
552   return StrFormat(
553       fmt, graph_label,
554       absl::StrReplaceAll(StrJoin(edge_css_rules, "\n"), {{"#", "%23"}}));
555 }
556 
Footer()557 string HloDotDumper::Footer() { return StrCat(StrJoin(edges_, "\n"), "\n}"); }
558 
ShouldShowFusionSubcomputation(const HloInstruction * instr)559 bool HloDotDumper::ShouldShowFusionSubcomputation(const HloInstruction* instr) {
560   CHECK_EQ(instr->opcode(), HloOpcode::kFusion);
561   return ShouldShowSubcomputation(instr->fused_instructions_computation());
562 }
563 
ShouldShowSubcomputation(const HloComputation * subcomp)564 bool HloDotDumper::ShouldShowSubcomputation(const HloComputation* subcomp) {
565   if (subcomp->IsFusionComputation()) {
566     const HloInstruction* fusion = subcomp->FusionInstruction();
567     if (!filter_.Show(fusion) || filter_.SomeOrAllOperandsOmitted(fusion)) {
568       return false;
569     }
570   }
571 
572   // Don't show trivial subcomputations on non-fusion nodes -- these are inlined
573   // into the graph.
574   if (!subcomp->IsFusionComputation() && MatchTrivialComputation(subcomp)) {
575     return false;
576   }
577 
578   // Show the subcomputation if we're showing any of its members.
579   return absl::c_any_of(
580       subcomp->instructions(),
581       [&](const HloInstruction* instr) { return filter_.Show(instr); });
582 }
583 
DumpSubcomputation(const HloComputation * subcomp,const HloInstruction * parent_instr)584 string HloDotDumper::DumpSubcomputation(const HloComputation* subcomp,
585                                         const HloInstruction* parent_instr) {
586   VLOG(2) << "Dumping subcomputation " << subcomp->name();
587   // Add an edge from the subcomputation to its parent node.  If subcomp
588   // belongs to a fusion node, it's drawn in place of the fusion instruction,
589   // so there's no need to link those.
590   if (parent_instr->opcode() != HloOpcode::kFusion) {
591     const HloInstruction* from = GetNodeForEdge(subcomp->root_instruction());
592     VLOG(2) << "Edge: from " << from->name() << " to " << parent_instr->name()
593             << " as " << next_edge_id_;
594     edge_ids_.insert({{from, parent_instr}, next_edge_id_++});
595     constexpr char edge_fmt[] =
596         R"(%s -> %s [ltail="%s", style="dashed" tooltip="%s -> %s"];)";
597     edges_.push_back(StrFormat(
598         edge_fmt, InstructionId(from), InstructionId(parent_instr),
599         SubcomputationId(subcomp), subcomp->name(), parent_instr->name()));
600   }
601 
602   // Have we already dumped this subcomputation?  If so, generating the edge
603   // linking it and parent_instr is all we want to do in this function.
604   if (cluster_ids_.find(subcomp) != cluster_ids_.end()) {
605     return "";
606   }
607 
608   cluster_ids_[subcomp] = next_cluster_id_++;
609 
610   string id = SubcomputationId(subcomp);
611 
612   string subcomp_label, style;
613   if (parent_instr->opcode() == HloOpcode::kFusion) {
614     subcomp_label =
615         StrFormat("Fused expression for <b>%s</b><br/>%s",
616                   HtmlLikeStringSanitize(parent_instr->name()),
617                   HtmlLikeStringSanitize(parent_instr->ToCategory()));
618     string extra_info = GetInstructionNodeExtraInfo(parent_instr);
619     if (!extra_info.empty()) {
620       StrAppend(&subcomp_label, "<br/>", extra_info);
621     }
622     string node_backend_config = GetInstructionNodeBackendConfig(parent_instr);
623     if (!node_backend_config.empty()) {
624       StrAppend(&subcomp_label, "<br/>", node_backend_config);
625     }
626 
627     bool highlight = filter_.Highlight(parent_instr);
628     const char* fillcolor;
629     const char* strokecolor;
630     if (debug_options_.xla_hlo_graph_sharding_color() && !highlight) {
631       // Use the sharding color, if the node isn't highlighted.
632       NodeColors node_colors =
633           NodeColorsForScheme(GetInstructionColor(parent_instr));
634       fillcolor = node_colors.fill_color;
635       strokecolor = node_colors.stroke_color;
636     } else {
637       // Subcomputation's fill/stroke color is light/dark red/gray, depending on
638       // whether or not the subcomputation's fusion node is highlighted.
639       fillcolor = highlight ? "#ffcdd2" : "#f5f5f5";
640       strokecolor = highlight ? "#b71c1c" : "#c2c2c2";
641     }
642     style =
643         StrFormat(R"(style="rounded,filled,bold"; fillcolor="%s"; color="%s;")",
644                   fillcolor, strokecolor);
645   } else {
646     subcomp_label = StrFormat("Subcomputation for <b>%s</b><br/>%s",
647                               HtmlLikeStringSanitize(parent_instr->name()),
648                               HtmlLikeStringSanitize(subcomp->name()));
649     style = "style=rounded; color=black;";
650   }
651 
652   string comp_body = DumpComputation(subcomp);
653 
654   constexpr char computation_fmt[] = R"(subgraph %s {
655 %s
656 label = <%s>;
657 labelloc = t;
658 tooltip = " ";
659 %s
660 }  // %s
661 
662 )";
663   return StrFormat(computation_fmt, id, style, subcomp_label, comp_body, id);
664 }
665 
DumpComputation(const HloComputation * comp)666 string HloDotDumper::DumpComputation(const HloComputation* comp) {
667   string g;
668   for (const auto* instr : comp->instructions()) {
669     if (!filter_.Show(instr)) {
670       continue;
671     }
672 
673     // Dump subcomputations within instr.
674     for (const HloComputation* subcomp : instr->called_computations()) {
675       if (ShouldShowSubcomputation(subcomp)) {
676         StrAppend(&g, DumpSubcomputation(subcomp, instr));
677       }
678     }
679 
680     StrAppend(&g, DumpInstruction(instr));
681   }
682   return g;
683 }
684 
DumpRootTag()685 string HloDotDumper::DumpRootTag() {
686   const HloInstruction* from = GetNodeForEdge(computation_->root_instruction());
687 
688   // We didn't display constants or broadcasts of effective scalars within
689   // fusions as separate nodes; so if the root is a constant/broadcast of
690   // scalar, we don't add root tag or edge for it.
691   if (!filter_.Show(from) || from->opcode() == HloOpcode::kConstant ||
692       IsFusedBroadcastOfConstantEffectiveScalar(from)) {
693     return "";
694   }
695 
696   auto from_id = InstructionId(from);
697 
698   // The ID of the root computation is otherwise unused, so it makes a good ID
699   // to use for the root-tag node.  However, the edge_ids_ map requires a
700   // HloInstruction* pointer for the 'to' value, so we use a NULL value there
701   // (rather than a pointer type-cast) to make it obvious if it is erroneously
702   // dereferenced.
703   HloInstruction* to = nullptr;
704   auto to_id = SubcomputationId(computation_);
705 
706   string node_body = "ROOT";
707   string node_shape = "circle";
708   ColorScheme color = kBrown;
709 
710   VLOG(2) << "Adding root tag as node " << next_node_id_;
711   root_node_id_ = next_node_id_++;
712 
713   VLOG(2) << "Adding edge from " << from->name() << " to root tag as "
714           << next_edge_id_;
715   edge_ids_.insert({{from, to}, next_edge_id_++});
716   edges_.push_back(StrFormat(R"(%s -> %s [tooltip=" "];)", from_id, to_id));
717 
718   return StrFormat(R"(%s [label=<%s>, shape=%s, tooltip=" ", %s];)"
719                    "\n",
720                    to_id, node_body, node_shape, NodeColorAttributes(color));
721 }
722 
TryGetFusionParameterConstant(const HloInstruction * instr)723 static const HloConstantInstruction* TryGetFusionParameterConstant(
724     const HloInstruction* instr) {
725   if (instr->opcode() != HloOpcode::kParameter || !instr->IsFused()) {
726     return nullptr;
727   }
728   const HloInstruction* fusion = instr->parent()->FusionInstruction();
729   const HloInstruction* operand = fusion->operand(instr->parameter_number());
730   return DynCast<HloConstantInstruction>(operand);
731 }
732 
ShouldMergeIntoUsers(const HloInstruction * instr) const733 bool HloDotDumper::ShouldMergeIntoUsers(const HloInstruction* instr) const {
734   // If a node:
735   //
736   //  - is a parameter of a fusion node which is bound to a constant,
737   //
738   // or
739   //
740   //  - is a tuple-shaped parameter, and
741   //  - is not a parameter to a fusion node, and
742   //  - has at least kMinUsersToOmit users shown, and
743   //  - all of the shown users are get-tuple-elements,
744   //
745   // then we omit it from the graph, merging it with its users.
746   //
747   // This helps us handle the common case where a while loop body has one big
748   // tuple-shaped parameter.
749   if (TryGetFusionParameterConstant(instr) != nullptr) {
750     return true;
751   }
752   const int kMinUsersToOmit = 3;
753   return instr->opcode() == HloOpcode::kParameter && instr->shape().IsTuple() &&
754          !instr->IsFused() &&
755          absl::c_count_if(instr->users(),
756                           [&](const HloInstruction* user) {
757                             return filter_.Show(user);
758                           }) > kMinUsersToOmit &&
759          absl::c_all_of(instr->users(), [&](const HloInstruction* user) {
760            return !filter_.Show(user) ||
761                   user->opcode() == HloOpcode::kGetTupleElement;
762          });
763 }
764 
DumpInstruction(const HloInstruction * instr)765 string HloDotDumper::DumpInstruction(const HloInstruction* instr) {
766   // We don't display constants or broadcasts of effective scalar constants
767   // within fusions as separate nodes; they're merged into their users.
768   if (instr->opcode() == HloOpcode::kConstant ||
769       IsFusedBroadcastOfConstantEffectiveScalar(instr)) {
770     return "";
771   }
772   // Skip this node if it's merged into its users.
773   if (ShouldMergeIntoUsers(instr)) {
774     return "";
775   }
776   // Omit the fusion node if its subcomputation is drawn, since the
777   // subcomputation will be drawn inline.
778   if (instr->opcode() == HloOpcode::kFusion &&
779       ShouldShowFusionSubcomputation(instr)) {
780     return "";
781   }
782 
783   VLOG(2) << "Adding node " << instr->name() << " as " << next_node_id_;
784   node_ids_[instr] = next_node_id_++;
785 
786   ColorScheme color = GetInstructionColor(instr);
787   string node_shape = GetInstructionNodeShape(instr);
788   string node_label = GetInstructionNodeLabel(instr);
789   string node_metadata = GetInstructionNodeMetadata(instr);
790   string node_backend_config = GetInstructionNodeBackendConfig(instr);
791   string extra_info = GetInstructionNodeExtraInfo(instr);
792   string inlined_constants = GetInstructionNodeInlinedOperands(instr);
793   string trivial_subcomputation = GetInstructionTrivialComputationStr(instr);
794   AddInstructionIncomingEdges(instr);
795 
796   if (!debug_options_.xla_hlo_graph_sharding_color()) {
797     // Override the node's styling if it should be (de-)emphasized.
798     if (filter_.Deemphasized(instr)) {
799       color = kDashedBorder;
800     }
801     if (filter_.Highlight(instr)) {
802       node_shape = "diamond";
803       color = kDarkRed;
804     }
805   }
806   // Build the text that will be displayed inside the node.
807   string node_body = node_label;
808   for (const string& s : {trivial_subcomputation, node_backend_config,
809                           extra_info, inlined_constants}) {
810     if (!s.empty()) {
811       StrAppend(&node_body, "<br/>", s);
812     }
813   }
814 
815   return StrFormat(R"(%s [label=<%s>, shape=%s, tooltip="%s", %s];)"
816                    "\n",
817                    InstructionId(instr), node_body, node_shape, node_metadata,
818                    NodeColorAttributes(color));
819 }
820 
GetInstructionNodeInlinedOperands(const HloInstruction * instr)821 string HloDotDumper::GetInstructionNodeInlinedOperands(
822     const HloInstruction* instr) {
823   // The constant's shape is a parameter because, in the case of a broadcasted
824   // scalar constant, we want to show the broadcasted shape, not the constant's
825   // scalar shape.
826   auto stringify_constant = [](const HloConstantInstruction* constant,
827                                const Shape& shape) {
828     // If the shape has a dimension of size zero, print it as e.g.
829     // "{} (f32[42, 0, 10])".  The alternative, calling Literal::ToString(),
830     // enumerates all of its empty dimensions (e.g.  "{ { {}, {} }, ..."), which
831     // is just noise.
832     if (ShapeUtil::IsZeroElementArray(shape)) {
833       return StrFormat("{} (%s)", ShapeUtil::HumanString(constant->shape()));
834     }
835 
836     // Print the literal value of constants with <= K elements.  Note that we
837     // use `constant->shape()` rather than `shape`, because if `constant` is a
838     // scalar that's broadcasted into `shape`, we want to print the constant.
839     optional<int64> elem_count;
840     if (shape.IsArray()) {
841       elem_count = ShapeUtil::ElementsIn(constant->shape());
842     }
843     // Allow HloDotDumper to print HloInstruction reconstructed from HloProto
844     // collected from profiling tools. Those constants may not have a valid
845     // literal.
846     if (elem_count.has_value() && *elem_count <= 8 && constant->HasLiteral()) {
847       return StrFormat("%s %s", shape.ToString(),
848                        constant->literal().ToStringWithoutShape());
849     }
850 
851     // Otherwise, print e.g. "%constant.42 (s32[100])".
852     string constant_name;
853     if (absl::StartsWith(constant->name(), "constant")) {
854       constant_name = constant->name();
855     } else {
856       constant_name = StrCat("constant ", constant->name());
857     }
858     return StrFormat("%s %s", constant_name, ShapeUtil::HumanString(shape));
859   };
860 
861   std::vector<string> lines;
862   for (int64 i = 0; i < instr->operand_count(); ++i) {
863     const HloInstruction* operand = instr->operand(i);
864     optional<string> operand_str;
865     if (const auto* constant_operand =
866             DynCast<HloConstantInstruction>(operand)) {
867       operand_str =
868           stringify_constant(constant_operand, constant_operand->shape());
869     } else if (IsFusedBroadcastOfConstantEffectiveScalar(operand)) {
870       operand_str = stringify_constant(
871           Cast<HloConstantInstruction>(operand->operand(0)), operand->shape());
872     } else if (ShouldMergeIntoUsers(operand)) {
873       // Special case: If the operand is a parameter to a fusion node and it
874       // always has a constant value, display it like a regular constant.
875       //
876       // For other parameters, use the parameter number rather than the proper
877       // name, because that's generally how people think of the node.
878       if (operand->opcode() == HloOpcode::kParameter) {
879         if (const HloConstantInstruction* constant =
880                 TryGetFusionParameterConstant(operand)) {
881           operand_str = stringify_constant(constant, constant->shape());
882         } else {
883           operand_str = StrFormat("Parameter %d", operand->parameter_number());
884         }
885       } else {
886         operand_str = operand->name();
887       }
888     }
889 
890     if (operand_str) {
891       if (instr->operand_count() > 1) {
892         lines.push_back(StrFormat("<b>operand %d</b> = %s", i, *operand_str));
893       } else {
894         lines.push_back(StrFormat("<b>operand</b> = %s", *operand_str));
895       }
896     }
897   }
898   return StrJoin(lines, "<br/>");
899 }
900 
GetInstructionColor(const HloInstruction * instr)901 ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) {
902   if (debug_options_.xla_hlo_graph_sharding_color()) {
903     if (!instr->has_sharding()) {
904       return kDashedBorder;
905     }
906     auto it = sharding_colors_.find(instr->sharding());
907     if (it != sharding_colors_.end()) {
908       return it->second;
909     }
910     ColorScheme color = static_cast<ColorScheme>(
911         kBlue + (next_shard_color_++ % (kDashedBorder - kBlue)));
912     sharding_colors_.emplace(instr->sharding(), color);
913     return color;
914   }
915 
916   // Choose different weights of orange for small vs large parameters.  This
917   // distinction is often important, especially in fusion nodes.
918   auto parameter_color = IsSmall(instr) ? kOrange : kDarkOrange;
919 
920   // Special case: If this instruction has a parameter merged into it, paint it
921   // the same color as a parameter.  Unless the merged-in parameter is a
922   // parameter to a fusion node that is bound to a constant -- these aren't
923   // "real" parameters from the user's perspective.
924   if (absl::c_any_of(instr->operands(), [&](const HloInstruction* operand) {
925         return operand->opcode() == HloOpcode::kParameter &&
926                ShouldMergeIntoUsers(operand) &&
927                TryGetFusionParameterConstant(operand) == nullptr;
928       })) {
929     return parameter_color;
930   }
931 
932   // Pick different colors or shapes for instructions which are particularly
933   // expensive (eg, dot) and those which are unusual in some way or unique
934   // (eg, parameter).
935   switch (instr->opcode()) {
936     case HloOpcode::kAbs:
937     case HloOpcode::kAdd:
938     case HloOpcode::kAnd:
939     case HloOpcode::kAtan2:
940     case HloOpcode::kBitcastConvert:
941     case HloOpcode::kCeil:
942     case HloOpcode::kClamp:
943     case HloOpcode::kClz:
944     case HloOpcode::kCompare:
945     case HloOpcode::kComplex:
946     case HloOpcode::kConvert:
947     case HloOpcode::kCos:
948     case HloOpcode::kDivide:
949     case HloOpcode::kExp:
950     case HloOpcode::kExpm1:
951     case HloOpcode::kFloor:
952     case HloOpcode::kImag:
953     case HloOpcode::kIota:
954     case HloOpcode::kIsFinite:
955     case HloOpcode::kLog:
956     case HloOpcode::kLog1p:
957     case HloOpcode::kMaximum:
958     case HloOpcode::kMinimum:
959     case HloOpcode::kMultiply:
960     case HloOpcode::kNegate:
961     case HloOpcode::kNot:
962     case HloOpcode::kPopulationCount:
963     case HloOpcode::kOr:
964     case HloOpcode::kXor:
965     case HloOpcode::kPower:
966     case HloOpcode::kReal:
967     case HloOpcode::kRemainder:
968     case HloOpcode::kRng:
969     case HloOpcode::kRngGetAndUpdateState:
970     case HloOpcode::kRoundNearestAfz:
971     case HloOpcode::kRsqrt:
972     case HloOpcode::kSelect:
973     case HloOpcode::kShiftLeft:
974     case HloOpcode::kShiftRightArithmetic:
975     case HloOpcode::kShiftRightLogical:
976     case HloOpcode::kSign:
977     case HloOpcode::kSin:
978     case HloOpcode::kSlice:
979     case HloOpcode::kSort:
980     case HloOpcode::kSqrt:
981     case HloOpcode::kSubtract:
982     case HloOpcode::kTanh:
983       // De-emphasize scalar-shaped elementwise ops -- they're generally
984       // uninteresting.
985       if (ShapeUtil::IsEffectiveScalar(instr->shape())) {
986         return kWhite;
987       }
988       return kYellow;
989     case HloOpcode::kBitcast:
990     case HloOpcode::kGetTupleElement:
991     case HloOpcode::kTrace:
992     case HloOpcode::kAfterAll:
993     case HloOpcode::kAddDependency:
994     case HloOpcode::kTuple:
995       return kWhite;
996     case HloOpcode::kBroadcast:
997       // De-emphasize nodes which broadcast a scalar within a fusion node --
998       // these are essentially free.
999       if (instr->IsFused() &&
1000           ShapeUtil::IsEffectiveScalar(instr->operand(0)->shape())) {
1001         return kWhite;
1002       }
1003       return kGreen;
1004     case HloOpcode::kConcatenate:
1005     case HloOpcode::kDynamicSlice:
1006     case HloOpcode::kGather:
1007     case HloOpcode::kPad:
1008     case HloOpcode::kReshape:
1009     case HloOpcode::kReverse:
1010     case HloOpcode::kTupleSelect:
1011     case HloOpcode::kTranspose:
1012       // De-emphasize scalar-shaped data movement ops and all data movement ops
1013       // inside fusion nodes, both of which are essentially free.
1014       if (ShapeUtil::IsEffectiveScalar(instr->shape()) || instr->IsFused()) {
1015         return kWhite;
1016       }
1017       return kGreen;
1018     case HloOpcode::kDynamicUpdateSlice:
1019       // Unlike the data-movement ops above, dynamic-update-slice is not ~free
1020       // inside of fusion nodes, so we de-emphasize it only if it's
1021       // scalar-shaped.
1022       if (ShapeUtil::IsEffectiveScalar(instr->shape())) {
1023         return kWhite;
1024       }
1025       return kGreen;
1026     case HloOpcode::kScatter:
1027       // Do not de-emphasize Scatter, since it involves significant work.
1028     case HloOpcode::kCopy:
1029     case HloOpcode::kCopyStart:
1030     case HloOpcode::kCopyDone:
1031       // Emphasize copy nodes, which are either physical transposes (and thus
1032       // significant), or copies of read-only buffers (and thus dead weight).
1033       return kGreen;
1034     case HloOpcode::kConvolution:
1035     case HloOpcode::kDot:
1036     case HloOpcode::kFft:
1037     case HloOpcode::kTriangularSolve:
1038     case HloOpcode::kCholesky:
1039       return kDarkBlue;
1040     case HloOpcode::kReducePrecision:
1041       return kRed;
1042     case HloOpcode::kParameter:
1043       return parameter_color;
1044     case HloOpcode::kBatchNormGrad:
1045     case HloOpcode::kBatchNormInference:
1046     case HloOpcode::kBatchNormTraining:
1047     case HloOpcode::kReduce:
1048     case HloOpcode::kReduceWindow:
1049     case HloOpcode::kSelectAndScatter:
1050       return kPurple;
1051     case HloOpcode::kDomain:
1052     case HloOpcode::kFusion:
1053     case HloOpcode::kMap:
1054     case HloOpcode::kGetDimensionSize:
1055     case HloOpcode::kSetDimensionSize:
1056       return kGray;
1057     case HloOpcode::kAllReduce:
1058     case HloOpcode::kAllToAll:
1059     case HloOpcode::kCollectivePermute:
1060     case HloOpcode::kInfeed:
1061     case HloOpcode::kOutfeed:
1062     case HloOpcode::kPartitionId:
1063     case HloOpcode::kRecv:
1064     case HloOpcode::kRecvDone:
1065     case HloOpcode::kSend:
1066     case HloOpcode::kSendDone:
1067     case HloOpcode::kReplicaId:
1068       return kBrown;
1069     case HloOpcode::kCall:
1070     case HloOpcode::kConditional:
1071     case HloOpcode::kCustomCall:
1072     case HloOpcode::kWhile:
1073       return kDarkGreen;
1074     case HloOpcode::kConstant:
1075       LOG(FATAL) << "Constants don't get their own nodes in the graph.";
1076   }
1077 }
1078 
GetInstructionNodeShape(const HloInstruction * instr)1079 string HloDotDumper::GetInstructionNodeShape(const HloInstruction* instr) {
1080   // Give while loops a different shape so they're easier to pick out.
1081   switch (instr->opcode()) {
1082     case HloOpcode::kWhile:
1083       return "ellipse";
1084     default:
1085       return "rect";
1086   }
1087 }
1088 
GetInstructionNodeLabel(const HloInstruction * instr)1089 string HloDotDumper::GetInstructionNodeLabel(const HloInstruction* instr) {
1090   // If we have a parameter, put the param number in the name.
1091   if (instr->opcode() == HloOpcode::kParameter) {
1092     return StrFormat("<b>Parameter %d</b>", instr->parameter_number());
1093   }
1094 
1095   // The HLO instruction name contains usually the opcode, e.g. "%add.42" is
1096   // an add instruction.  In this case we render just the name.
1097   if (absl::StartsWith(instr->name(), HloOpcodeString(instr->opcode()))) {
1098     return StrFormat("<b>%s</b>", HtmlLikeStringSanitize(instr->name()));
1099   }
1100   string extended_opcode =
1101       StrCat(HloOpcodeString(instr->opcode()),
1102              instr->opcode() != HloOpcode::kFusion
1103                  ? ""
1104                  : StrCat(":", xla::ToString(instr->fusion_kind())));
1105   // If the name does not contain the opcode, render both.
1106   return StrFormat("<b>%s</b><br/>%s", HtmlLikeStringSanitize(extended_opcode),
1107                    HtmlLikeStringSanitize(instr->name()));
1108 }
1109 
GetInstructionNodeMetadata(const HloInstruction * instr)1110 string HloDotDumper::GetInstructionNodeMetadata(const HloInstruction* instr) {
1111   std::vector<string> lines;
1112   if (!instr->metadata().op_name().empty()) {
1113     lines.push_back(HtmlLikeStringSanitize(instr->metadata().op_name()));
1114   }
1115   if (!instr->metadata().op_type().empty()) {
1116     lines.push_back(StrFormat(
1117         "op_type: %s", HtmlLikeStringSanitize(instr->metadata().op_type())));
1118   }
1119   if (!instr->metadata().source_file().empty() &&
1120       instr->metadata().source_line() != 0) {
1121     lines.push_back(StrFormat("op_type: %s:%d", instr->metadata().source_file(),
1122                               instr->metadata().source_line()));
1123   }
1124 
1125   return StrJoin(lines, "\n");
1126 }
1127 
GetInstructionNodeBackendConfig(const HloInstruction * instr)1128 string HloDotDumper::GetInstructionNodeBackendConfig(
1129     const HloInstruction* instr) {
1130   if (!show_backend_config_ || instr->raw_backend_config_string().empty()) {
1131     return "";
1132   }
1133 
1134   return StrCat("backend_config=\"", instr->raw_backend_config_string(), "\"");
1135 }
1136 
GetInstructionNodeExtraInfo(const HloInstruction * instr)1137 string HloDotDumper::GetInstructionNodeExtraInfo(const HloInstruction* instr) {
1138   std::vector<string> lines;
1139 
1140   // Get the instruction's extra attributes excluding the names of its
1141   // subcomputations, since those are drawn explicitly in the graph.
1142   for (const auto& line : instr->ExtraAttributesToString(
1143            HloPrintOptions().set_print_subcomputation_mode(
1144                HloPrintOptions::PrintSubcomputationMode::kOff))) {
1145     lines.push_back(HtmlLikeStringSanitize(line));
1146   }
1147 
1148   // Show the shape and layout of the instruction, unless it's an inlined fusion
1149   // node -- there the shape and layout is present in the output node.
1150   if (instr->opcode() != HloOpcode::kFusion ||
1151       !ShouldShowFusionSubcomputation(instr)) {
1152     // Show layout of instructions with more than one dimension.  Don't show
1153     // layout on tuples or tensors with just one dimension (which only have one
1154     // possible layout) to avoid visual noise.
1155     bool shape_is_multidim = false;
1156     ShapeUtil::ForEachSubshape(instr->shape(),
1157                                [&](const Shape& s, const ShapeIndex&) {
1158                                  shape_is_multidim |= s.dimensions_size() > 1;
1159                                });
1160     string instr_shape;
1161     if (instr->opcode() != HloOpcode::kTuple && shape_is_multidim) {
1162       instr_shape = ShapeUtil::HumanStringWithLayout(instr->shape());
1163     } else {
1164       instr_shape = ShapeUtil::HumanString(instr->shape());
1165     }
1166 
1167     // Some instructions have giant tuples as their shapes, so truncate the
1168     // HLO's shape to kMaxShapeLen characters.
1169     constexpr int kMaxShapeLen = 64;
1170     if (instr_shape.length() > kMaxShapeLen) {
1171       instr_shape = StrCat(
1172           absl::string_view(instr_shape).substr(0, kMaxShapeLen - 3), "...");
1173     }
1174     lines.push_back(instr_shape);
1175   }
1176   if (debug_options_.xla_hlo_graph_addresses()) {
1177     lines.push_back(StrFormat("[%p]", instr));
1178   }
1179   if (profile_ != nullptr) {
1180     double hlo_cycles_executed = profile_->GetCyclesTakenBy(*instr);
1181     double total_cycles_executed =
1182         profile_->total_cycles_executed(*instr->parent());
1183     if (hlo_cycles_executed > 0 && total_cycles_executed > 0) {
1184       lines.push_back(
1185           StrFormat("%% of cycles executed=%.2f",
1186                     100 * hlo_cycles_executed / total_cycles_executed));
1187     }
1188   }
1189   return StrJoin(lines, "<br/>");
1190 }
1191 
AddInstructionIncomingEdges(const HloInstruction * instr)1192 void HloDotDumper::AddInstructionIncomingEdges(const HloInstruction* instr) {
1193   auto add_edge = [&](const HloInstruction* from, const HloInstruction* to,
1194                       int64 operand_num, bool control_edge = false) {
1195     from = GetNodeForEdge(from);
1196 
1197     if (!filter_.Show(from) || from->opcode() == HloOpcode::kConstant ||
1198         IsFusedBroadcastOfConstantEffectiveScalar(from) ||
1199         ShouldMergeIntoUsers(from)) {
1200       return;
1201     }
1202     VLOG(2) << "Adding edge from " << from->name() << " to " << to->name()
1203             << " as " << next_edge_id_;
1204     edge_ids_.insert({{from, to}, next_edge_id_++});
1205 
1206     string edge_label;
1207     if (instr->operand_count() > 1 && !control_edge) {
1208       edge_label =
1209           StrFormat(R"( headlabel="%d", labeldistance=2)", operand_num);
1210     } else if (control_edge) {
1211       edge_label = "style=\"dotted\" color=\"gray\" label=\"ctrl\"";
1212     }
1213 
1214     // We print "small" arrays using a hollow arrowhead and "large" arrays using
1215     // a filled arrowhead.
1216     constexpr char kEdgeFmt[] =
1217         R"(%s -> %s [arrowhead=%s tooltip="%s -> %s" %s];)";
1218     edges_.push_back(StrFormat(kEdgeFmt, InstructionId(from), InstructionId(to),
1219                                (IsSmall(from) ? "empty" : "normal"),
1220                                from->name(), to->name(), edge_label));
1221   };
1222 
1223   // Add edges from instr's operands to instr.  Parameters within fusion
1224   // expressions are handled specially -- we draw an edge from the corresponding
1225   // operand on the fusion node itself to the parameter.
1226   if (instr->opcode() == HloOpcode::kParameter && instr->IsFused()) {
1227     // Only add the edge if this is not the outermost computation; otherwise it
1228     // will lead from a node we're not drawing.
1229     if (instr->parent() != computation_) {
1230       const HloInstruction* fusion = instr->parent()->FusionInstruction();
1231       add_edge(fusion->operand(instr->parameter_number()), instr,
1232                /*operand_num=*/0);
1233     }
1234   } else {
1235     for (int64 i = 0; i < instr->operand_count(); ++i) {
1236       add_edge(instr->operand(i), instr, i);
1237     }
1238     for (const HloInstruction* pred : instr->control_predecessors()) {
1239       add_edge(pred, instr, /*operand_num=*/0, /*control_edge=*/true);
1240     }
1241   }
1242 }
1243 
GetInstructionTrivialComputationStr(const HloInstruction * instr)1244 string HloDotDumper::GetInstructionTrivialComputationStr(
1245     const HloInstruction* instr) {
1246   // called_computations() on a fusion node "inherits" any called computations
1247   // of the fused root, which isn't what we want.  Just ignore fusion nodes
1248   // here; they're handled separately.
1249   if (instr->opcode() == HloOpcode::kFusion) {
1250     return "";
1251   }
1252 
1253   std::vector<string> lines;
1254   for (int64 i = 0; i < instr->called_computations().size(); ++i) {
1255     optional<string> computation_type =
1256         MatchTrivialComputation(instr->called_computations()[i]);
1257     if (!computation_type) {
1258       continue;
1259     }
1260     if (instr->called_computations().size() == 1) {
1261       lines.push_back(StrFormat("Subcomputation: <b>%s</b>",
1262                                 HtmlLikeStringSanitize(*computation_type)));
1263     } else {
1264       lines.push_back(StrFormat("Subcomputation %d: <b>%s</b>", i,
1265                                 HtmlLikeStringSanitize(*computation_type)));
1266     }
1267   }
1268   return StrJoin(lines, "<br/>");
1269 }
1270 
GetNodeForEdge(const HloInstruction * instr)1271 const HloInstruction* HloDotDumper::GetNodeForEdge(
1272     const HloInstruction* instr) {
1273   while (instr->opcode() == HloOpcode::kFusion &&
1274          ShouldShowFusionSubcomputation(instr)) {
1275     instr = instr->fused_expression_root();
1276   }
1277   return instr;
1278 }
1279 
1280 // Gets a NodeFilter that includes roughly all instructions whose distance from
1281 // root is <= radius.
MakeNodeRadiusAroundFilter(const HloInstruction * root,int64 radius,const absl::flat_hash_set<const HloInstruction * > & boundary)1282 NodeFilter MakeNodeRadiusAroundFilter(
1283     const HloInstruction* root, int64 radius,
1284     const absl::flat_hash_set<const HloInstruction*>& boundary) {
1285   // First, find the neighborhood of nodes with distance from root <= radius.
1286   // These nodes are our initial set of "normal" nodes.
1287   absl::flat_hash_map<const HloInstruction*, NodeFilterResult> nodes;
1288   std::deque<std::pair<const HloInstruction*, /*depth*/ int64>> worklist;
1289   worklist.push_back({root, 0});
1290   while (!worklist.empty()) {
1291     const HloInstruction* instr;
1292     int64 depth;
1293     std::tie(instr, depth) = worklist.front();
1294     worklist.pop_front();
1295 
1296     nodes[instr] = kNormalNode;
1297     if (depth == radius) {
1298       continue;
1299     }
1300     if (boundary.contains(instr)) {
1301       continue;
1302     }
1303 
1304     // Traverse into instr's operands.
1305     //
1306     // Don't traverse into tuples' operands unless the tuple is the root.
1307     // Usually a tuple is the bottommost node in the graph, and so its operands
1308     // are not interesting to the graph at hand.
1309     if (instr == root || instr->opcode() != HloOpcode::kTuple) {
1310       for (const HloInstruction* operand : instr->operands()) {
1311         if (!nodes.contains(operand)) {
1312           worklist.push_back({operand, depth + 1});
1313         }
1314       }
1315     }
1316 
1317     // Traverse into instr's nested computations.
1318     for (const HloComputation* computation : instr->called_computations()) {
1319       worklist.push_back({computation->root_instruction(), depth + 1});
1320     }
1321 
1322     // Traverse into instr's users, unless:
1323     //
1324     //  - there are a ton of them, in which case they're probably not
1325     //    interesting (and anyway, rendering them all would make the graph
1326     //    unreadable), or
1327     //  - instr is a constant, in which case its users are probably not
1328     //    interesting.
1329     if (instr->opcode() == HloOpcode::kConstant) {
1330       continue;
1331     }
1332     constexpr int kMaxUsersToRender = 16;
1333     if (instr->user_count() > kMaxUsersToRender) {
1334       // If we're going to skip this node's users, style it as such.
1335       nodes[instr] = kSomeUsersOmitted;
1336       continue;
1337     }
1338     for (const HloInstruction* user : instr->users()) {
1339       if (!nodes.contains(user)) {
1340         worklist.push_back({user, depth + 1});
1341       }
1342     }
1343   }
1344 
1345   auto is_displayed = [&](const HloInstruction* instr) {
1346     // Constants are displayed inline with their users; they're never omitted.
1347     // Nodes in subcomputations are always shown.
1348     return nodes.contains(instr) || instr->opcode() == HloOpcode::kConstant ||
1349            instr->parent() != root->parent();
1350   };
1351 
1352   // Make a second pass over 'nodes' to fix up the NodeFilterResults now that we
1353   // know which nodes will be included in the graph.
1354   for (auto& kv : nodes) {
1355     const HloInstruction* instr = kv.first;
1356     NodeFilterResult& filter_result = kv.second;
1357     const auto& operands = instr->operands();
1358 
1359     if (absl::c_any_of(operands, is_displayed) &&
1360         !absl::c_all_of(operands, is_displayed)) {
1361       // Mark nodes with some operands omitted appropriately.
1362       filter_result = kSomeOperandsOmitted;
1363     } else if (!operands.empty() && absl::c_none_of(operands, is_displayed)) {
1364       // Mark nodes with *all* operands omitted appropriately.
1365       filter_result = kOmitNodeOperands;
1366     }
1367 
1368     // Promote nodes with type kSomeUsersOmitted to kNormalNode if all of their
1369     // users made it into the graph.
1370     if (filter_result == kSomeUsersOmitted &&
1371         absl::c_all_of(instr->users(), is_displayed)) {
1372       filter_result = kNormalNode;
1373     }
1374   }
1375 
1376   // Highlight the root node.
1377   nodes[root] = kHighlightNode;
1378 
1379   return NodeFilter([=](const HloInstruction* instr) {
1380     auto it = nodes.find(instr);
1381     if (it != nodes.end()) {
1382       return it->second;
1383     }
1384     // Show all nodes in subcomputations.
1385     if (instr->parent() != root->parent()) {
1386       return kNormalNode;
1387     }
1388     return kHideNode;
1389   });
1390 }
1391 
1392 // Gets a node filter that includes nodes on all paths from `from` to `to`.  If
1393 // the all-paths set contains more than max_nodes elements, includes the nodes
1394 // on the shortest paths and sets hit_limit to true.
MakeNodeFromToFilter(const HloInstruction * from,const HloInstruction * to,int64 max_nodes,bool * hit_limit)1395 NodeFilter MakeNodeFromToFilter(const HloInstruction* from,
1396                                 const HloInstruction* to, int64 max_nodes,
1397                                 bool* hit_limit) {
1398   *hit_limit = false;
1399 
1400   // Elements in the queue are paths through the graph.
1401   std::deque<std::vector<const HloInstruction*>> queue;
1402   queue.push_front({from});
1403 
1404   // Compute the set of nodes we want to show using a slightly-modified
1405   // Djikstra's algorithm.  The only real difference is, rather than stopping
1406   // when we find a (shortest) path, we continue until we've found max_nodes
1407   // nodes on some path.
1408   std::unordered_set<const HloInstruction*> visited;
1409   std::unordered_set<const HloInstruction*> to_display = {from, to};
1410   while (!queue.empty() && to_display.size() < max_nodes) {
1411     std::vector<const HloInstruction*> path = std::move(queue.front());
1412     queue.pop_front();
1413     if (!visited.insert(path.back()).second) {
1414       continue;
1415     }
1416 
1417     for (const auto* user : path.back()->users()) {
1418       if (user == to) {
1419         auto it = path.begin();
1420         for (; it != path.end() && to_display.size() < max_nodes; ++it) {
1421           to_display.insert(*it);
1422         }
1423         if (it != path.end()) {
1424           *hit_limit = true;
1425         }
1426       } else if (!visited.count(user)) {
1427         auto new_path = path;
1428         new_path.push_back(user);
1429         queue.push_back(std::move(new_path));
1430       }
1431     }
1432   }
1433 
1434   return NodeFilter([=](const HloInstruction* instr) {
1435     if (instr == from || instr == to) {
1436       return kHighlightNode;
1437     }
1438     return to_display.count(instr) ? kNormalNode : kHideNode;
1439   });
1440 }
1441 
WrapDotInHtml(absl::string_view dot)1442 string WrapDotInHtml(absl::string_view dot) {
1443   static const char html_prefix[] = R"html(
1444 <!DOCTYPE html>
1445 <html>
1446 <head>
1447   <meta charset="utf-8">
1448   <style type="text/css">
1449     body {
1450       height: 100vh;
1451       margin: 0;
1452     }
1453   </style>
1454 </head>
1455 <body>
1456   <!-- Integrity hash is generated by https://www.srihash.org/ -->
1457   <script src="https://cdn.jsdelivr.net/npm/viz.js@2.1.1/viz.js"
1458      integrity="sha384-aD1MJYb0WKIUT+CtwJp5LTuV3U4pLAS6B/nUxL7ECimC2pN9N8vjlMr/yQCAkzxE"
1459      crossorigin="anonymous"></script>
1460   <script src="https://cdn.jsdelivr.net/npm/viz.js@2.1.1/full.render.js"
1461      integrity="sha384-bAixY275aIpCj6Te19y0MILZ4V+VEC8CVFujFEH+Lf7W+4XYYeYLwW5IBI6yQmMT"
1462      crossorigin="anonymous"></script>
1463   <script src="https://cdn.jsdelivr.net/npm/svg-pan-zoom@3.6.0/dist/svg-pan-zoom.min.js"
1464      integrity="sha384-3008WpYB2pOBvE7lwkrKf+qTmbTPGGPYxA9C1YVhvbPukns4ZFj7E98QPLkNW9dS"
1465      crossorigin="anonymous"></script>
1466   <div id="container" style="height:95vh; border:1px solid black; "></div>
1467   <script>
1468     var data = `
1469 )html";
1470 
1471   static const char html_suffix[] = R"html(
1472 `;
1473     var cssregex = new RegExp('stylesheet=<([^]*)\n>\n', 'gm');
1474     var results = cssregex.exec(data)
1475     // graphviz has problem dealing with large stylesheets.
1476     // https://github.com/tensorflow/tensorflow/issues/17220#issuecomment-369228492
1477     // In order to avoid the problem, remove the stylesheet from the dot and
1478     // insert it directly info the rendered SVG.
1479     var dot_data = data;
1480     var css_data = ''
1481     if (results !== null) {
1482         css_data = results[1].replace(/\s*data:.*\s*,/,''); // Strip content-type field.
1483         // CSS inside DOT is URL-escaped, so we must unescape it
1484         // before we can insert it into SVG.
1485         css_data = unescape(css_data);
1486         dot_data = data.replace(cssregex, ''); // Remove the stylesheet
1487     }
1488 
1489     var render_start = performance.now()
1490     function add_controls(svg) {
1491         var htmlblob = new Blob([document.documentElement.innerHTML],
1492                                 {type: 'text/html'});
1493         var savehtml = document.createElement('a');
1494         savehtml.setAttribute('href', URL.createObjectURL(htmlblob));
1495         savehtml.setAttribute('download', 'graph.html');
1496         savehtml.innerHTML = " [Save HTML+SVG] ";
1497         document.body.append(savehtml);
1498         var svgblob = new Blob([svg.outerHTML], {type: 'image/svg'});
1499         var savesvg = document.createElement('a');
1500         savesvg.setAttribute('href', URL.createObjectURL(svgblob));
1501         savesvg.setAttribute('download', 'graph.svg');
1502         savesvg.innerHTML = " [Save SVG] ";
1503         document.body.append(savesvg);
1504         var dotblob =  new Blob([data], {type: 'text/dot'});
1505         var savedot = document.createElement('a');
1506         savedot.setAttribute('href', URL.createObjectURL(dotblob));
1507         savedot.setAttribute('download', 'graph.dot');
1508         savedot.innerHTML = " [Save DOT] ";
1509         document.body.append(savedot);
1510         // Will get called after embed element was loaded
1511         var panzoom = svgPanZoom(svg, {
1512             zoomEnabled: true,
1513             controlIconsEnabled: true,
1514         });
1515         document.getElementsByTagName("BODY")[0].onresize = function() {
1516             panzoom.resize();
1517             panzoom.fit();
1518             panzoom.center();
1519         };
1520         var render_end = performance.now();
1521         var render_note = document.createElement('div')
1522         render_note.innerHTML = 'Rendering took '
1523                                 + (render_end - render_start).toFixed(2) + "ms."
1524         document.body.append(render_note);
1525     }
1526     var svg = document.getElementById('graph')
1527     if (svg == null) {
1528         // Need to render SVG first.
1529         var viz = new Viz();
1530         viz.renderSVGElement(dot_data)
1531             .then(function(svg){
1532                 var container = document.getElementById('container')
1533                 var style = document.createElementNS('http://www.w3.org/2000/svg', 'style');
1534                 var node = document.createTextNode(css_data);
1535                 style.appendChild(node);
1536                 svg.setAttribute('width', '100%');
1537                 svg.setAttribute('height', '100%');
1538                 svg.setAttribute('id', 'graph');
1539                 svg.appendChild(style);
1540                 container.appendChild(svg);
1541                 add_controls(svg);
1542             })
1543     } else {
1544         // HTML already has rendered SVG embedded, so we just need to add
1545         // controls.
1546         add_controls(svg);
1547     }
1548   </script>
1549 </body>
1550 </html>
1551 )html";
1552 
1553   return absl::StrCat(html_prefix, dot, html_suffix);
1554 }
1555 
1556 tensorflow::mutex url_renderer_mu(tensorflow::LINKER_INITIALIZED);
1557 std::function<StatusOr<string>(absl::string_view)>* url_renderer
1558     GUARDED_BY(url_renderer_mu) = nullptr;
1559 
1560 // Precondition: url_renderer != nullptr.
1561 //
1562 // (We specify this as a precondition rather than checking it in here and
1563 // returning an error because we want to fail quickly when there's no URL
1564 // renderer available, and this function runs only after we've done all the work
1565 // of producing dot for the graph.)
1566 StatusOr<string> WrapDotInFormat(absl::string_view dot,
1567                                  RenderedGraphFormat format)
1568     EXCLUSIVE_LOCKS_REQUIRED(url_renderer_mu) {
1569   switch (format) {
1570     case RenderedGraphFormat::kUrl:
1571       CHECK(url_renderer != nullptr)
1572           << "Should have checked url_renderer != null before calling.";
1573       return (*url_renderer)(dot);
1574     case RenderedGraphFormat::kHtml:
1575       return WrapDotInHtml(dot);
1576     case RenderedGraphFormat::kDot:
1577       return string(dot);
1578   }
1579 }
1580 
1581 }  // namespace
1582 
1583 void RegisterGraphToURLRenderer(
1584     std::function<StatusOr<string>(absl::string_view)> renderer) {
1585   tensorflow::mutex_lock lock(url_renderer_mu);
1586   if (url_renderer != nullptr) {
1587     LOG(WARNING) << "Multiple calls to RegisterGraphToURLRenderer.  Last call "
1588                     "wins, but because order of initialization in C++ is "
1589                     "nondeterministic, this may not be what you want.";
1590   }
1591   delete url_renderer;
1592   url_renderer = new std::function<StatusOr<string>(absl::string_view)>(
1593       std::move(renderer));
1594 }
1595 
1596 StatusOr<string> RenderGraph(const HloComputation& computation,
1597                              absl::string_view label,
1598                              const DebugOptions& debug_options,
1599                              RenderedGraphFormat format,
1600                              const HloExecutionProfile* hlo_execution_profile,
1601                              bool show_backend_config) {
1602   tensorflow::mutex_lock lock(url_renderer_mu);
1603   if (format == RenderedGraphFormat::kUrl && url_renderer == nullptr) {
1604     return Unavailable("Can't render as URL; no URL renderer was registered.");
1605   }
1606 
1607   string rendered_dot =
1608       HloDotDumper(&computation, label, debug_options, show_backend_config,
1609                    hlo_execution_profile, NodeFilter())
1610           .Dump();
1611   return WrapDotInFormat(rendered_dot, format);
1612 }
1613 
1614 StatusOr<string> RenderNeighborhoodAround(
1615     const HloInstruction& node, int radius, RenderedGraphFormat format,
1616     bool show_backend_config,
1617     const absl::flat_hash_set<const HloInstruction*>& boundary) {
1618   tensorflow::mutex_lock lock(url_renderer_mu);
1619   if (format == RenderedGraphFormat::kUrl && url_renderer == nullptr) {
1620     return FailedPrecondition(
1621         "Can't render as URL; no URL renderer was registered.");
1622   }
1623 
1624   string label =
1625       StrCat("Neighborhood of ", radius, " nodes around ", node.name());
1626   string rendered_dot =
1627       HloDotDumper(node.parent(), label,
1628                    node.GetModule()->config().debug_options(),
1629                    show_backend_config, /*profile=*/nullptr,
1630                    MakeNodeRadiusAroundFilter(&node, radius, boundary))
1631           .Dump();
1632   return WrapDotInFormat(rendered_dot, format);
1633 }
1634 
1635 StatusOr<string> RenderAllPathsFromTo(const HloInstruction& from,
1636                                       const HloInstruction& to, int64 max_nodes,
1637                                       RenderedGraphFormat format,
1638                                       bool show_backend_config) {
1639   tensorflow::mutex_lock lock(url_renderer_mu);
1640   if (format == RenderedGraphFormat::kUrl && url_renderer == nullptr) {
1641     return FailedPrecondition(
1642         "Can't render as URL; no URL renderer was registered.");
1643   }
1644 
1645   CHECK_EQ(from.parent(), to.parent()) << "Nodes must be in same computation!";
1646   auto debug_options = from.GetModule()->config().debug_options();
1647 
1648   bool hit_limit = false;
1649   NodeFilter filter = MakeNodeFromToFilter(&from, &to, max_nodes, &hit_limit);
1650   string label;
1651   if (!hit_limit) {
1652     label = StrCat("All paths from ", from.name(), " to ", to.name());
1653   } else {
1654     label = StrCat(max_nodes, " nodes on the shortest paths from ", from.name(),
1655                    " to ", to.name(),
1656                    "<br/><br/>***SHOWING ONLY A SUBSET OF ALL PATHS BETWEEN "
1657                    "NODES***<br/><br/>");
1658   }
1659   string rendered_dot =
1660       HloDotDumper(from.parent(), label, debug_options, show_backend_config,
1661                    /*profile=*/nullptr, filter)
1662           .Dump();
1663   return WrapDotInFormat(rendered_dot, format);
1664 }
1665 
1666 }  // namespace xla
1667