• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
17 
18 #include <unistd.h>
19 #include <algorithm>
20 #include <atomic>
21 #include <deque>
22 #include <map>
23 #include <memory>
24 #include <queue>
25 #include <string>
26 #include <tuple>
27 #include <vector>
28 
29 #include "absl/container/flat_hash_map.h"
30 #include "absl/strings/match.h"
31 #include "absl/strings/str_cat.h"
32 #include "absl/strings/str_format.h"
33 #include "absl/strings/str_join.h"
34 #include "absl/strings/str_replace.h"
35 #include "absl/types/optional.h"
36 #include "tensorflow/compiler/xla/layout_util.h"
37 #include "tensorflow/compiler/xla/literal.h"
38 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
39 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
40 #include "tensorflow/compiler/xla/service/hlo_module.h"
41 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
42 #include "tensorflow/compiler/xla/shape_util.h"
43 #include "tensorflow/compiler/xla/types.h"
44 #include "tensorflow/compiler/xla/util.h"
45 #include "tensorflow/compiler/xla/window_util.h"
46 #include "tensorflow/core/lib/core/status.h"
47 #include "tensorflow/core/lib/gtl/map_util.h"
48 #include "tensorflow/core/lib/io/path.h"
49 #include "tensorflow/core/lib/strings/numbers.h"
50 #include "tensorflow/core/platform/env.h"
51 #include "tensorflow/core/platform/mutex.h"
52 #include "tensorflow/core/platform/protobuf.h"
53 #include "tensorflow/core/platform/regexp.h"
54 
55 namespace xla {
56 namespace {
57 
58 using absl::nullopt;
59 using absl::optional;
60 using absl::StrAppend;
61 using absl::StrCat;
62 using absl::StrFormat;
63 using absl::StrJoin;
64 using tensorflow::Env;
65 using tensorflow::WriteStringToFile;
66 using tensorflow::io::JoinPath;
67 
68 // Used to indicate how we should treat a given HLOInstruction in the graph.
69 // should we treat it like normal, hide it, and so on?
70 enum NodeFilterResult {
71   kNormalNode,
72   kHideNode,
73   // Make the node easy to find in the final graph.
74   kHighlightNode,
75   // "Gray out" the node to indicate that some of its operands have been
76   // omitted.
77   kSomeOperandsOmitted,
78   // Style the node the same as kSomeOperandsOmitted, but also don't connect it
79   // to its operands, even if they're present in the graph.
80   kOmitNodeOperands,
81   // Same style as kSomeOperandsOmitted, but used to indicate that some of the
82   // node's *users* have been omitted.
83   kSomeUsersOmitted,
84 };
85 
86 // NodeFilter is essentially a map from HloInstruction*s to NodeFilterResult.
87 // It lets callers tell the graph-drawing routines which nodes they want to be
88 // shown, hidden, or highlighted.
89 class NodeFilter {
90  public:
__anon7102e5d00202(const HloInstruction*) 91   NodeFilter() : filter_([](const HloInstruction*) { return kNormalNode; }) {}
92 
NodeFilter(std::function<NodeFilterResult (const HloInstruction * instr)> filter)93   explicit NodeFilter(
94       std::function<NodeFilterResult(const HloInstruction* instr)> filter)
95       : filter_(std::move(filter)) {}
96 
Show(const HloInstruction * instr) const97   bool Show(const HloInstruction* instr) const {
98     return filter_(instr) != kHideNode;
99   }
Highlight(const HloInstruction * instr) const100   bool Highlight(const HloInstruction* instr) const {
101     return filter_(instr) == kHighlightNode;
102   }
OmitOperands(const HloInstruction * instr) const103   bool OmitOperands(const HloInstruction* instr) const {
104     return filter_(instr) == kOmitNodeOperands;
105   }
SomeOrAllOperandsOmitted(const HloInstruction * instr) const106   bool SomeOrAllOperandsOmitted(const HloInstruction* instr) const {
107     auto result = filter_(instr);
108     return result == kOmitNodeOperands || result == kSomeOperandsOmitted;
109   }
Deemphasized(const HloInstruction * instr) const110   bool Deemphasized(const HloInstruction* instr) const {
111     auto result = filter_(instr);
112     return result == kOmitNodeOperands || result == kSomeOperandsOmitted ||
113            result == kSomeUsersOmitted;
114   }
115 
116  private:
117   std::function<NodeFilterResult(const HloInstruction* instr)> filter_;
118 };
119 
120 // We arbitrarily set this as the boundary between "large" and "small"
121 // instructions.
IsSmall(const HloInstruction * instr)122 bool IsSmall(const HloInstruction* instr) {
123   if (ShapeUtil::HasPrimitiveType(instr->shape(), OPAQUE) ||
124       ShapeUtil::HasPrimitiveType(instr->shape(), TOKEN)) {
125     return true;
126   }
127   return ShapeUtil::ElementsInRecursive(instr->shape()) < 4096;
128 }
129 
130 // Node color schemes, used by NodeColorAttributes.
131 enum ColorScheme {
132   kBlue,
133   kBrown,
134   kDarkBlue,
135   kDarkGreen,
136   kDarkOrange,
137   kDarkRed,
138   kGray,
139   kGreen,
140   kOrange,
141   kPurple,
142   kRed,
143   kWhite,
144   kYellow,
145 
146   // Causes the node's border to be a dashed line, and its content to be gray
147   // text on a white background, suggesting that this is an "unimportant" node.
148   kDashedBorder,
149 };
150 
151 // Graphviz attributes/colors that make up a color scheme.
152 struct NodeColors {
153   const char* style;
154   const char* fill_color;
155   const char* stroke_color;
156   const char* font_color;
157 };
158 
NodeColorsForScheme(ColorScheme color)159 NodeColors NodeColorsForScheme(ColorScheme color) {
160   switch (color) {
161     case kBlue:
162       return NodeColors{"filled", "#bbdefb", "#8aacc8", "black"};
163     case kBrown:
164       return NodeColors{"filled", "#bcaaa4", "#8c7b75", "black"};
165     case kDarkBlue:
166       return NodeColors{"filled", "#1565c0", "#003c8f", "white"};
167     case kDarkGreen:
168       return NodeColors{"filled", "#2e7d32", "#005005", "white"};
169     case kDarkOrange:
170       // This is more of a "medium" orange, made to look close to kOrange;
171       // there's probably room for a darker weight if desired.
172       return NodeColors{"filled", "#ffb74d", "#c88719", "black"};
173     case kDarkRed:
174       return NodeColors{"filled", "#b71c1c", "#7f0000", "white"};
175     case kGray:
176       return NodeColors{"filled", "#cfd8dc", "#9ea7aa", "black"};
177     case kGreen:
178       return NodeColors{"filled", "#c8e6c9", "#97b498", "black"};
179     case kOrange:
180       return NodeColors{"filled", "#ffe0b2", "#cbae82", "black"};
181     case kPurple:
182       return NodeColors{"filled", "#e1bee7", "#af8eb5", "black"};
183     case kRed:
184       return NodeColors{"filled", "#ffcdd2", "#cb9ca1", "black"};
185     case kWhite:
186       return NodeColors{"filled", "white", "black", "black"};
187     case kYellow:
188       return NodeColors{"filled", "#fff9c4", "#cbc693", "black"};
189     case kDashedBorder:
190       // "filled,dashed" looks the same as "dashed", since we have a white
191       // background.  But we use "filled,dashed" so that when you hover over
192       // any part of the node (not just the text inside the node), our css
193       // :hover rule is triggered.
194       return NodeColors{"filled,dashed", "white", "#757575", "#757575"};
195   }
196 }
197 
198 // Given a ColorScheme, returns an attribute string for a node of that color.
199 // Sets the node's style and fill/stroke/text colors.
200 //
201 // Colors are from https://material.io/color.
NodeColorAttributes(ColorScheme color)202 string NodeColorAttributes(ColorScheme color) {
203   NodeColors node_colors = NodeColorsForScheme(color);
204 
205   return StrFormat(R"(style="%s", fontcolor="%s", color="%s", fillcolor="%s")",
206                    node_colors.style, node_colors.font_color,
207                    node_colors.stroke_color, node_colors.fill_color);
208 }
209 
210 // Replaces <> with &lt;&gt;, so that this string is safe(er) for use in a
211 // graphviz HTML-like string.
HtmlLikeStringSanitize(absl::string_view s)212 string HtmlLikeStringSanitize(absl::string_view s) {
213   return absl::StrReplaceAll(s, {{"<", "&lt;"}, {">", "&gt;"}});
214 }
215 
216 // Tries to generates a human-readable one-word description of the given
217 // computation.
218 //
219 // Currently we support:
220 //
221 //   "return param0 + param1;"      --> "add"
222 //   "return param0 * param1;"      --> "multiply"
223 //   "return min(param0, param1);"  --> "min"
224 //   "return max(param0, param1);"  --> "max"
225 //   "return param0 <= param1;"     --> "less-or-equal"
226 //   "return param0 >= param1;"     --> "greater-or-equal"
227 //   "return param0 >  param1;"     --> "greater-than"
228 //   "return param0 <  param1;"     --> "less-than"
229 //   "return param0 == param1;"     --> "equal-to"
230 //   "return param0 != param1;"     --> "not-equal-to"
231 //
232 // where param0 and param1 are effective scalars.  For the ops that are
233 // commutative, we also support them with param0 and param1 swapped.
234 //
235 // This is useful primarily for reduce and map nodes.  These take a
236 // subcomputation which is almost always one of the above, and pattern matching
237 // it to a short string lets us tell the user what the subcomputation is without
238 // drawing it as a graph.
MatchTrivialComputation(const HloComputation * computation)239 optional<string> MatchTrivialComputation(const HloComputation* computation) {
240   namespace m = match;
241 
242   if (computation->instruction_count() != 3) {
243     return nullopt;
244   }
245   HloInstruction* root = computation->root_instruction();
246   const HloInstruction *param0, *param1;
247   if (!Match(root, m::Op()
248                        .WithNumOperands(2)
249                        .WithShape(m::Shape().IsEffectiveScalar())
250                        .WithBinaryOperandsAnyOrder(
251                            m::Parameter(&param0, 0)
252                                .WithShape(m::Shape().IsEffectiveScalar()),
253                            m::Parameter(&param1, 1)
254                                .WithShape(m::Shape().IsEffectiveScalar())))) {
255     return nullopt;
256   }
257 
258   // If the params are reversed (i.e. operand0 is param1 and operand1 is
259   // param0), check that the operation being performed is commutative.
260   if (root->operand(0) == param1) {
261     CHECK_EQ(root->operand(1), param0);
262     if (root->opcode() == HloOpcode()) {
263       switch (root->comparison_direction()) {
264         case ComparisonDirection::kLe:
265         case ComparisonDirection::kGe:
266         case ComparisonDirection::kGt:
267         case ComparisonDirection::kLt:
268           return nullopt;
269         default:
270           break;
271       }
272     }
273   }
274 
275   // If we recognize the root's opcode, we've successfully pattern-matched!
276   switch (root->opcode()) {
277     case HloOpcode::kAdd:
278       return "add";
279     case HloOpcode::kMultiply:
280       return "multiply";
281     case HloOpcode::kMinimum:
282       return "min";
283     case HloOpcode::kMaximum:
284       return "max";
285     case HloOpcode::kCompare: {
286       switch (root->comparison_direction()) {
287         case ComparisonDirection::kLe:
288           return "less-or-equal";
289         case ComparisonDirection::kGe:
290           return "greater-or-equal";
291         case ComparisonDirection::kGt:
292           return "greater-than";
293         case ComparisonDirection::kLt:
294           return "less-than";
295         case ComparisonDirection::kEq:
296           return "equal-to";
297         case ComparisonDirection::kNe:
298           return "not-equal-to";
299       }
300     }
301     default:
302       return nullopt;
303   }
304 }
305 
306 // Encapsulates logic for dumping an HLO module to DOT (i.e. graphviz syntax).
307 class HloDotDumper {
308  public:
HloDotDumper(const HloComputation * computation,absl::string_view label,const DebugOptions & debug_options,bool show_backend_config,const HloExecutionProfile * profile,NodeFilter filter)309   HloDotDumper(const HloComputation* computation, absl::string_view label,
310                const DebugOptions& debug_options, bool show_backend_config,
311                const HloExecutionProfile* profile, NodeFilter filter)
312       : computation_(computation),
313         label_(label),
314         debug_options_(debug_options),
315         show_backend_config_(show_backend_config),
316         profile_(profile),
317         filter_(std::move(filter)) {}
318 
319   string Dump();
320 
321  private:
322   // Returns the dot graph identifier for the given instruction.
InstructionId(const HloInstruction * instruction)323   string InstructionId(const HloInstruction* instruction) {
324     return StrCat(reinterpret_cast<uint64>(instruction));
325   }
326 
327   // Returns the dot graph identifier for the given computation.
SubcomputationId(const HloComputation * computation)328   string SubcomputationId(const HloComputation* computation) {
329     return StrCat("cluster_", reinterpret_cast<uint64>(computation));
330   }
331 
332   // Generates graph header/footer.  These should be called *after* dumping all
333   // of the instructions and subcomputations for the graph, as they both use
334   // data generated while dumping the graph.
335   string Header();
336   string Footer();
337 
338   bool ShouldShowSubcomputation(const HloComputation* subcomp);
339   bool ShouldShowFusionSubcomputation(const HloInstruction* instr);
340 
341   // We omit some nodes from the graph, instead drawing them inlined into the
342   // nodes that use them.
343   bool ShouldMergeIntoUsers(const HloInstruction* instr) const;
344 
345   string DumpSubcomputation(const HloComputation* subcomp,
346                             const HloInstruction* parent_instr);
347   string DumpComputation(const HloComputation* comp);
348   string DumpRootTag();
349   string DumpInstruction(const HloInstruction* instr);
350   ColorScheme GetInstructionColor(const HloInstruction* instr);
351   string GetInstructionNodeShape(const HloInstruction* instr);
352   string GetInstructionNodeLabel(const HloInstruction* instr);
353   string GetInstructionNodeMetadata(const HloInstruction* instr);
354   string GetInstructionNodeBackendConfig(const HloInstruction* instr);
355   string GetInstructionNodeExtraInfo(const HloInstruction* instr);
356   string GetInstructionNodeInlinedOperands(const HloInstruction* instr);
357   void AddInstructionIncomingEdges(const HloInstruction* instr);
358 
359   // For most instructions, GetNodeForEdge(instr) returns instr.
360   //
361   // The exception is fusion nodes.  For these, we walk up the chain of nested
362   // fusion nodes starting at instr until we reach a node that either (a) isn't
363   // a fusion node, or (b) is a fusion node for which
364   // ShouldShowFusionSubcomputation is false.
365   //
366   // We do this because fusion nodes are expanded inline -- if
367   // ShouldShowFusionSubcomputation is true, the fusion node won't be present in
368   // the graph.
369   //
370   // In general when you want to draw an edge from A to B, you should actually
371   // draw an edge from GetNodeForEdge(A) to GetNodeForEdge(B).
372   const HloInstruction* GetNodeForEdge(const HloInstruction* instr);
373 
374   // If instr has just one computation and it's trivial (e.g. "return param0 +
375   // param1"), returns a string you can put into the node's body that names the
376   // subcomputation, e.g. "Subcomputation: <b>add</b>".
377   string GetInstructionTrivialComputationStr(const HloInstruction* instr);
378 
379   const HloComputation* computation_;  // never null
380   const string label_;                 // overall name for the graph
381   const DebugOptions& debug_options_;
382   const bool show_backend_config_;
383   const HloExecutionProfile* profile_;  // may be null
384   const NodeFilter filter_;
385 
386   // Each HloInstruction dumped gets a monotically-increasing node ID.  This
387   // must start at 1, because that's where graphviz's accounting starts.
388   int64 next_node_id_ = 1;
389   absl::flat_hash_map<const HloInstruction*, int64> node_ids_;
390 
391   // The "root" tag doesn't have an associated HloInstruction pointer, so we
392   // need to store it outside the map.
393   int64 root_node_id_;
394 
395   // Each (from, to) edge gets a monotonically-increasing ID.  This is a
396   // multimap because it's possible for the same edge to appear multiple times
397   // in the graph (e.g. x^2 may be represented as mul(x, x)).
398   int64 next_edge_id_ = 1;
399   std::unordered_multimap<
400       std::pair<const HloInstruction*, const HloInstruction*>, int64,
401       tensorflow::hash<std::pair<const HloInstruction*, const HloInstruction*>>>
402       edge_ids_;
403 
404   // Each HloComputation that's emitted gets a monotonically-increasing ID.
405   int64 next_cluster_id_ = 1;
406   absl::flat_hash_map<const HloComputation*, int64> cluster_ids_;
407 
408   // Edges to print from Footer().  Edges come at the end because graphviz is
409   // unhappy if an edge from a subcomputation to a node in the outer computation
410   // appears before both the inner computation and the destination node are
411   // defined.
412   std::vector<string> edges_;
413 
414   // When coloring by sharding information, we track the sharding string
415   // representation to color association, by round-robin the color schemes.
416   absl::flat_hash_map<HloSharding, ColorScheme, HloSharding::Hasher>
417       sharding_colors_;
418   int64 next_shard_color_ = 0;
419 };
420 
Dump()421 string HloDotDumper::Dump() {
422   string body;
423   StrAppend(&body, DumpComputation(computation_));
424   StrAppend(&body, DumpRootTag());
425 
426   // By contract, Header() and Footer() have to be called after we've dumped all
427   // our instructions, because they use state generated during that process.
428   string g = Header();
429   StrAppend(&g, body);
430   StrAppend(&g, Footer());
431   return g;
432 }
433 
Header()434 string HloDotDumper::Header() {
435   constexpr char fmt[] = R"(digraph G {
436 rankdir = TB;
437 compound = true;
438 label = <<b>%s</b>>;
439 labelloc = t;
440 // Disable the tooltip.  Interestingly, "" doesn't work!
441 tooltip = " ";
442 // DOT graphs accept a stylesheet as a URI.  So naturally, an inline
443 // stylesheet is a data URI!
444 stylesheet=<
445   data:text/css,
446   @import url(https://fonts.googleapis.com/css?family=Roboto:400,700);
447   svg text {
448     font-family: 'Roboto';
449     font-size: 12px;
450   }
451 
452   %s
453 >
454 
455 )";
456 
457   VLOG(3) << "Generating Header";
458 
459   string graph_label =
460       StrCat(label_, "<br/>Computation ", computation_->name());
461   if (computation_->IsFusionComputation()) {
462     StrAppend(&graph_label, " (in fusion instruction ",
463               computation_->FusionInstruction()->name(), ")");
464   }
465   if (profile_ != nullptr) {
466     auto cycles = profile_->total_cycles_executed(*computation_);
467     absl::StrAppendFormat(&graph_label, "<br/>total cycles = %d (%s)", cycles,
468                           tensorflow::strings::HumanReadableNum(cycles));
469   }
470 
471   // Create CSS rules that say, when you hover over the given node or cluster,
472   // turn the given edge the given color.
473   //
474   // We rely on a few properties of how graphviz generates SVGs:
475   //
476   //  - Nodes are named "nodeN", where N corresponds to the 1-based index of
477   //    the node in our DOT (i.e. the first node in the DOT is "node1", etc.).
478   //    Edges are similarly named "edgeN", and clusters are named "clustN".
479   //  - Nodes come before their in- and out-edges in the SVG.  We need this
480   //    because the "X ~ Y" CSS selector finds a sibling of X that *comes
481   //    after X in the DOM* and matches Y.
482   std::vector<string> edge_css_rules;
483   const char* kBlue = "#1976d2";
484   const char* kRed = "#d32f2f";
485   for (const auto& kv : edge_ids_) {
486     const HloInstruction* from_node = kv.first.first;
487     const HloInstruction* to_node = kv.first.second;
488     int64 edge_id = kv.second;
489 
490     auto add_hover_css_rule = [&](string elem_type, int64 elem_id,
491                                   const char* color) {
492       // One could imagine other ways of writing this CSS rule that involve
493       // less duplication, but this way seems to be relatively performant.
494       edge_css_rules.push_back(
495           StrFormat("  #%s%d:hover ~ #edge%d text { fill: %s; }\n"
496                     "  #%s%d:hover ~ #edge%d path { "
497                     "stroke: %s; stroke-width: .2em; }\n"
498                     "  #%s%d:hover ~ #edge%d polygon { "
499                     "fill: %s; stroke: %s; stroke-width: .2em; }\n",
500                     elem_type, elem_id, edge_id, color,  //
501                     elem_type, elem_id, edge_id, color,  //
502                     elem_type, elem_id, edge_id, color, color));
503     };
504 
505     // The "to_node" value may be a NULL, indicating that this points to the
506     // "root" tag rather than a normal node.
507     int64 from_node_id =
508         tensorflow::gtl::FindWithDefault(node_ids_, from_node, -1);
509     if (from_node_id == -1) {
510       LOG(FATAL) << from_node->name() << " was added to edges but not to nodes";
511     }
512     int64 to_node_id =
513         to_node ? tensorflow::gtl::FindWithDefault(node_ids_, to_node, -1)
514                 : root_node_id_;
515     if (to_node != nullptr && to_node_id == -1) {
516       LOG(FATAL) << to_node->name() << " was added to edges but not to nodes";
517     }
518 
519     add_hover_css_rule("node", from_node_id, kBlue);
520     add_hover_css_rule("node", to_node_id, kRed);
521 
522     if (to_node) {
523       VLOG(3) << "Adding css for edge " << edge_id << " from node "
524               << from_node->name() << " to node " << to_node->name();
525     } else {
526       VLOG(3) << "Adding css for edge " << edge_id << " from node "
527               << from_node->name() << " to root tag";
528     }
529 
530     // If this edge crosses a fusion cluster boundary, highlight it when the
531     // cluster is hovered over.
532     if (to_node) {
533       if (from_node->IsFused() &&
534           from_node->parent()->root_instruction() == from_node) {
535         int64 cluster_id = cluster_ids_.at(from_node->parent());
536         add_hover_css_rule("clust", cluster_id, kBlue);
537       }
538       if (to_node->IsFused() && to_node->opcode() == HloOpcode::kParameter) {
539         int64 cluster_id = cluster_ids_.at(to_node->parent());
540         add_hover_css_rule("clust", cluster_id, kRed);
541       }
542     }
543   }
544 
545   // Browsers require that we URI-encode the contents of our data URI.  (It
546   // seems this was a relatively recent change?) In practice, this means that we
547   // need to escape '#'.
548   return StrFormat(
549       fmt, graph_label,
550       absl::StrReplaceAll(StrJoin(edge_css_rules, "\n"), {{"#", "%23"}}));
551 }
552 
Footer()553 string HloDotDumper::Footer() { return StrCat(StrJoin(edges_, "\n"), "\n}"); }
554 
ShouldShowFusionSubcomputation(const HloInstruction * instr)555 bool HloDotDumper::ShouldShowFusionSubcomputation(const HloInstruction* instr) {
556   CHECK_EQ(instr->opcode(), HloOpcode::kFusion);
557   return ShouldShowSubcomputation(instr->fused_instructions_computation());
558 }
559 
ShouldShowSubcomputation(const HloComputation * subcomp)560 bool HloDotDumper::ShouldShowSubcomputation(const HloComputation* subcomp) {
561   if (subcomp->IsFusionComputation()) {
562     const HloInstruction* fusion = subcomp->FusionInstruction();
563     if (!filter_.Show(fusion) || filter_.SomeOrAllOperandsOmitted(fusion)) {
564       return false;
565     }
566   }
567 
568   // Don't show trivial subcomputations on non-fusion nodes -- these are inlined
569   // into the graph.
570   if (!subcomp->IsFusionComputation() && MatchTrivialComputation(subcomp)) {
571     return false;
572   }
573 
574   // Show the subcomputation if we're showing any of its members.
575   return absl::c_any_of(
576       subcomp->instructions(),
577       [&](const HloInstruction* instr) { return filter_.Show(instr); });
578 }
579 
DumpSubcomputation(const HloComputation * subcomp,const HloInstruction * parent_instr)580 string HloDotDumper::DumpSubcomputation(const HloComputation* subcomp,
581                                         const HloInstruction* parent_instr) {
582   VLOG(2) << "Dumping subcomputation " << subcomp->name();
583   // Add an edge from the subcomputation to its parent node.  If subcomp
584   // belongs to a fusion node, it's drawn in place of the fusion instruction,
585   // so there's no need to link those.
586   if (parent_instr->opcode() != HloOpcode::kFusion) {
587     const HloInstruction* from = GetNodeForEdge(subcomp->root_instruction());
588     VLOG(2) << "Edge: from " << from->name() << " to " << parent_instr->name()
589             << " as " << next_edge_id_;
590     edge_ids_.insert({{from, parent_instr}, next_edge_id_++});
591     constexpr char edge_fmt[] =
592         R"(%s -> %s [ltail="%s", style="dashed" tooltip="%s -> %s"];)";
593     edges_.push_back(StrFormat(
594         edge_fmt, InstructionId(from), InstructionId(parent_instr),
595         SubcomputationId(subcomp), subcomp->name(), parent_instr->name()));
596   }
597 
598   // Have we already dumped this subcomputation?  If so, generating the edge
599   // linking it and parent_instr is all we want to do in this function.
600   if (cluster_ids_.find(subcomp) != cluster_ids_.end()) {
601     return "";
602   }
603 
604   cluster_ids_[subcomp] = next_cluster_id_++;
605 
606   string id = SubcomputationId(subcomp);
607 
608   string subcomp_label, style;
609   if (parent_instr->opcode() == HloOpcode::kFusion) {
610     subcomp_label =
611         StrFormat("Fused expression for <b>%s</b><br/>%s",
612                   HtmlLikeStringSanitize(parent_instr->name()),
613                   HtmlLikeStringSanitize(parent_instr->ToCategory()));
614     string extra_info = GetInstructionNodeExtraInfo(parent_instr);
615     if (!extra_info.empty()) {
616       StrAppend(&subcomp_label, "<br/>", extra_info);
617     }
618     string node_backend_config = GetInstructionNodeBackendConfig(parent_instr);
619     if (!node_backend_config.empty()) {
620       StrAppend(&subcomp_label, "<br/>", node_backend_config);
621     }
622 
623     bool highlight = filter_.Highlight(parent_instr);
624     const char* fillcolor;
625     const char* strokecolor;
626     if (debug_options_.xla_hlo_graph_sharding_color() && !highlight) {
627       // Use the sharding color, if the node isn't highlighted.
628       NodeColors node_colors =
629           NodeColorsForScheme(GetInstructionColor(parent_instr));
630       fillcolor = node_colors.fill_color;
631       strokecolor = node_colors.stroke_color;
632     } else {
633       // Subcomputation's fill/stroke color is light/dark red/gray, depending on
634       // whether or not the subcomputation's fusion node is highlighted.
635       fillcolor = highlight ? "#ffcdd2" : "#f5f5f5";
636       strokecolor = highlight ? "#b71c1c" : "#c2c2c2";
637     }
638     style =
639         StrFormat(R"(style="rounded,filled,bold"; fillcolor="%s"; color="%s;")",
640                   fillcolor, strokecolor);
641   } else {
642     subcomp_label = StrFormat("Subcomputation for <b>%s</b><br/>%s",
643                               HtmlLikeStringSanitize(parent_instr->name()),
644                               HtmlLikeStringSanitize(subcomp->name()));
645     style = "style=rounded; color=black;";
646   }
647 
648   string comp_body = DumpComputation(subcomp);
649 
650   constexpr char computation_fmt[] = R"(subgraph %s {
651 %s
652 label = <%s>;
653 labelloc = t;
654 tooltip = " ";
655 %s
656 }  // %s
657 
658 )";
659   return StrFormat(computation_fmt, id, style, subcomp_label, comp_body, id);
660 }
661 
DumpComputation(const HloComputation * comp)662 string HloDotDumper::DumpComputation(const HloComputation* comp) {
663   string g;
664   for (const auto* instr : comp->instructions()) {
665     if (!filter_.Show(instr)) {
666       continue;
667     }
668 
669     // Dump subcomputations within instr.
670     for (const HloComputation* subcomp : instr->called_computations()) {
671       if (ShouldShowSubcomputation(subcomp)) {
672         StrAppend(&g, DumpSubcomputation(subcomp, instr));
673       }
674     }
675 
676     StrAppend(&g, DumpInstruction(instr));
677   }
678   return g;
679 }
680 
DumpRootTag()681 string HloDotDumper::DumpRootTag() {
682   const HloInstruction* from = GetNodeForEdge(computation_->root_instruction());
683 
684   // We didn't display constants as separate nodes; so if the root is a
685   // constant, we don't add root tag or edge for it.
686   if (!filter_.Show(from) || from->opcode() == HloOpcode::kConstant) {
687     return "";
688   }
689 
690   auto from_id = InstructionId(from);
691 
692   // The ID of the root computation is otherwise unused, so it makes a good ID
693   // to use for the root-tag node.  However, the edge_ids_ map requires a
694   // HloInstruction* pointer for the 'to' value, so we use a NULL value there
695   // (rather than a pointer type-cast) to make it obvious if it is erroneously
696   // dereferenced.
697   HloInstruction* to = nullptr;
698   auto to_id = SubcomputationId(computation_);
699 
700   string node_body = "ROOT";
701   string node_shape = "circle";
702   ColorScheme color = kBrown;
703 
704   VLOG(2) << "Adding root tag as node " << next_node_id_;
705   root_node_id_ = next_node_id_++;
706 
707   VLOG(2) << "Adding edge from " << from->name() << " to root tag as "
708           << next_edge_id_;
709   edge_ids_.insert({{from, to}, next_edge_id_++});
710   edges_.push_back(StrFormat(R"(%s -> %s [tooltip=" "];)", from_id, to_id));
711 
712   return StrFormat(R"(%s [label=<%s>, shape=%s, tooltip=" ", %s];)"
713                    "\n",
714                    to_id, node_body, node_shape, NodeColorAttributes(color));
715 }
716 
TryGetFusionParameterConstant(const HloInstruction * instr)717 static const HloConstantInstruction* TryGetFusionParameterConstant(
718     const HloInstruction* instr) {
719   if (instr->opcode() != HloOpcode::kParameter || !instr->IsFused()) {
720     return nullptr;
721   }
722   const HloInstruction* fusion = instr->parent()->FusionInstruction();
723   const HloInstruction* operand = fusion->operand(instr->parameter_number());
724   return DynCast<HloConstantInstruction>(operand);
725 }
726 
ShouldMergeIntoUsers(const HloInstruction * instr) const727 bool HloDotDumper::ShouldMergeIntoUsers(const HloInstruction* instr) const {
728   // If a node:
729   //
730   //  - is a parameter of a fusion node which is bound to a constant,
731   //
732   // or
733   //
734   //  - is a tuple-shaped parameter, and
735   //  - is not a parameter to a fusion node, and
736   //  - has at least kMinUsersToOmit users shown, and
737   //  - all of the shown users are get-tuple-elements,
738   //
739   // then we omit it from the graph, merging it with its users.
740   //
741   // This helps us handle the common case where a while loop body has one big
742   // tuple-shaped parameter.
743   if (TryGetFusionParameterConstant(instr) != nullptr) {
744     return true;
745   }
746   const int kMinUsersToOmit = 3;
747   return instr->opcode() == HloOpcode::kParameter && instr->shape().IsTuple() &&
748          !instr->IsFused() &&
749          absl::c_count_if(instr->users(),
750                           [&](const HloInstruction* user) {
751                             return filter_.Show(user);
752                           }) > kMinUsersToOmit &&
753          absl::c_all_of(instr->users(), [&](const HloInstruction* user) {
754            return !filter_.Show(user) ||
755                   user->opcode() == HloOpcode::kGetTupleElement;
756          });
757 }
758 
DumpInstruction(const HloInstruction * instr)759 string HloDotDumper::DumpInstruction(const HloInstruction* instr) {
760   // We don't display constants as separate nodes; they're merged into their
761   // users.
762   if (instr->opcode() == HloOpcode::kConstant) {
763     return "";
764   }
765   // Skip this node if it's merged into its users.
766   if (ShouldMergeIntoUsers(instr)) {
767     return "";
768   }
769   // Omit the fusion node if its subcomputation is drawn, since the
770   // subcomputation will be drawn inline.
771   if (instr->opcode() == HloOpcode::kFusion &&
772       ShouldShowFusionSubcomputation(instr)) {
773     return "";
774   }
775 
776   VLOG(2) << "Adding node " << instr->name() << " as " << next_node_id_;
777   node_ids_[instr] = next_node_id_++;
778 
779   ColorScheme color = GetInstructionColor(instr);
780   string node_shape = GetInstructionNodeShape(instr);
781   string node_label = GetInstructionNodeLabel(instr);
782   string node_metadata = GetInstructionNodeMetadata(instr);
783   string node_backend_config = GetInstructionNodeBackendConfig(instr);
784   string extra_info = GetInstructionNodeExtraInfo(instr);
785   string inlined_constants = GetInstructionNodeInlinedOperands(instr);
786   string trivial_subcomputation = GetInstructionTrivialComputationStr(instr);
787   AddInstructionIncomingEdges(instr);
788 
789   if (!debug_options_.xla_hlo_graph_sharding_color()) {
790     // Override the node's styling if it should be (de-)emphasized.
791     if (filter_.Deemphasized(instr)) {
792       color = kDashedBorder;
793     }
794     if (filter_.Highlight(instr)) {
795       node_shape = "diamond";
796       color = kDarkRed;
797     }
798   }
799   // Build the text that will be displayed inside the node.
800   string node_body = node_label;
801   for (const string& s : {trivial_subcomputation, node_backend_config,
802                           extra_info, inlined_constants}) {
803     if (!s.empty()) {
804       StrAppend(&node_body, "<br/>", s);
805     }
806   }
807 
808   return StrFormat(R"(%s [label=<%s>, shape=%s, tooltip="%s", %s];)"
809                    "\n",
810                    InstructionId(instr), node_body, node_shape, node_metadata,
811                    NodeColorAttributes(color));
812 }
813 
GetInstructionNodeInlinedOperands(const HloInstruction * instr)814 string HloDotDumper::GetInstructionNodeInlinedOperands(
815     const HloInstruction* instr) {
816   auto stringify_constant = [](const HloConstantInstruction* constant) {
817     const auto& shape = constant->shape();
818 
819     // If the shape has a dimension of size zero, print it as e.g.
820     // "{} (f32[42, 0, 10])".  The alternative, calling Literal::ToString(),
821     // enumerates all of its empty dimensions (e.g.  "{ { {}, {} }, ..."), which
822     // is just noise.
823     if (ShapeUtil::IsZeroElementArray(shape)) {
824       return StrFormat("{} (%s)", ShapeUtil::HumanString(constant->shape()));
825     }
826 
827     // Print the literal value of constants with <= K elements.
828     optional<int64> elem_count;
829     if (shape.IsArray()) {
830       elem_count = 1;
831       for (int64 dim : shape.dimensions()) {
832         *elem_count *= dim;
833       }
834     }
835     // Allow HloDotDumper to print HloInstruction reconstructed from HloProto
836     // collected from profiling tools. Those constants may not have a valid
837     // literal.
838     if (elem_count.has_value() && *elem_count <= 8 && constant->HasLiteral()) {
839       return constant->literal().ToString();
840     }
841 
842     // Otherwise, print e.g. "%constant.42 (s32[100])".
843     string constant_name;
844     if (absl::StartsWith(constant->name(), "constant")) {
845       constant_name = constant->name();
846     } else {
847       constant_name = StrCat("constant ", constant->name());
848     }
849     return StrFormat("%s %s", constant_name,
850                      ShapeUtil::HumanString(constant->shape()));
851   };
852 
853   std::vector<string> lines;
854   for (int64 i = 0; i < instr->operand_count(); ++i) {
855     const HloInstruction* operand = instr->operand(i);
856     const auto* constant_operand = DynCast<HloConstantInstruction>(operand);
857     optional<string> operand_str;
858     if (constant_operand != nullptr) {
859       operand_str = stringify_constant(constant_operand);
860     } else if (ShouldMergeIntoUsers(operand)) {
861       // Special case: If the operand is a parameter to a fusion node and it
862       // always has a constant value, display it like a regular constant.
863       //
864       // For other parameters, use the parameter number rather than the proper
865       // name, because that's generally how people think of the node.
866       if (operand->opcode() == HloOpcode::kParameter) {
867         if (const HloConstantInstruction* constant =
868                 TryGetFusionParameterConstant(operand)) {
869           operand_str = stringify_constant(constant);
870         } else {
871           operand_str = StrFormat("Parameter %d", operand->parameter_number());
872         }
873       } else {
874         operand_str = operand->name();
875       }
876     }
877 
878     if (operand_str) {
879       if (instr->operand_count() > 1) {
880         lines.push_back(StrFormat("<b>operand %d</b> = %s", i, *operand_str));
881       } else {
882         lines.push_back(StrFormat("<b>operand</b> = %s", *operand_str));
883       }
884     }
885   }
886   return StrJoin(lines, "<br/>");
887 }
888 
GetInstructionColor(const HloInstruction * instr)889 ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) {
890   if (debug_options_.xla_hlo_graph_sharding_color()) {
891     if (!instr->has_sharding()) {
892       return kDashedBorder;
893     }
894     auto it = sharding_colors_.find(instr->sharding());
895     if (it != sharding_colors_.end()) {
896       return it->second;
897     }
898     ColorScheme color = static_cast<ColorScheme>(
899         kBlue + (next_shard_color_++ % (kDashedBorder - kBlue)));
900     sharding_colors_.emplace(instr->sharding(), color);
901     return color;
902   }
903 
904   // Choose different weights of orange for small vs large parameters.  This
905   // distinction is often important, especially in fusion nodes.
906   auto parameter_color = IsSmall(instr) ? kOrange : kDarkOrange;
907 
908   // Special case: If this instruction has a parameter merged into it, paint it
909   // the same color as a parameter.  Unless the merged-in parameter is a
910   // parameter to a fusion node that is bound to a constant -- these aren't
911   // "real" parameters from the user's perspective.
912   if (absl::c_any_of(instr->operands(), [&](const HloInstruction* operand) {
913         return operand->opcode() == HloOpcode::kParameter &&
914                ShouldMergeIntoUsers(operand) &&
915                TryGetFusionParameterConstant(operand) == nullptr;
916       })) {
917     return parameter_color;
918   }
919 
920   // Pick different colors or shapes for instructions which are particularly
921   // expensive (eg, dot) and those which are unusual in some way or unique
922   // (eg, parameter).
923   switch (instr->opcode()) {
924     case HloOpcode::kAbs:
925     case HloOpcode::kAdd:
926     case HloOpcode::kAnd:
927     case HloOpcode::kAtan2:
928     case HloOpcode::kBitcastConvert:
929     case HloOpcode::kCeil:
930     case HloOpcode::kClamp:
931     case HloOpcode::kClz:
932     case HloOpcode::kCompare:
933     case HloOpcode::kComplex:
934     case HloOpcode::kConvert:
935     case HloOpcode::kCos:
936     case HloOpcode::kDivide:
937     case HloOpcode::kExp:
938     case HloOpcode::kExpm1:
939     case HloOpcode::kFloor:
940     case HloOpcode::kImag:
941     case HloOpcode::kIota:
942     case HloOpcode::kIsFinite:
943     case HloOpcode::kLog:
944     case HloOpcode::kLog1p:
945     case HloOpcode::kMaximum:
946     case HloOpcode::kMinimum:
947     case HloOpcode::kMultiply:
948     case HloOpcode::kNegate:
949     case HloOpcode::kNot:
950     case HloOpcode::kOr:
951     case HloOpcode::kXor:
952     case HloOpcode::kPower:
953     case HloOpcode::kReal:
954     case HloOpcode::kRemainder:
955     case HloOpcode::kRng:
956     case HloOpcode::kRoundNearestAfz:
957     case HloOpcode::kRsqrt:
958     case HloOpcode::kSelect:
959     case HloOpcode::kShiftLeft:
960     case HloOpcode::kShiftRightArithmetic:
961     case HloOpcode::kShiftRightLogical:
962     case HloOpcode::kSign:
963     case HloOpcode::kSin:
964     case HloOpcode::kSlice:
965     case HloOpcode::kSort:
966     case HloOpcode::kSqrt:
967     case HloOpcode::kSubtract:
968     case HloOpcode::kTanh:
969       // De-emphasize scalar-shaped elementwise ops -- they're generally
970       // uninteresting.
971       if (ShapeUtil::IsEffectiveScalar(instr->shape())) {
972         return kWhite;
973       }
974       return kYellow;
975     case HloOpcode::kBitcast:
976     case HloOpcode::kGetTupleElement:
977     case HloOpcode::kTrace:
978     case HloOpcode::kAfterAll:
979     case HloOpcode::kAddDependency:
980     case HloOpcode::kTuple:
981       return kWhite;
982     case HloOpcode::kBroadcast:
983       // De-emphasize nodes which broadcast a scalar within a fusion node --
984       // these are essentially free.
985       if (instr->IsFused() &&
986           ShapeUtil::IsEffectiveScalar(instr->operand(0)->shape())) {
987         return kWhite;
988       }
989       return kGreen;
990     case HloOpcode::kConcatenate:
991     case HloOpcode::kDynamicSlice:
992     case HloOpcode::kGather:
993     case HloOpcode::kPad:
994     case HloOpcode::kReshape:
995     case HloOpcode::kReverse:
996     case HloOpcode::kTupleSelect:
997     case HloOpcode::kTranspose:
998       // De-emphasize scalar-shaped data movement ops and all data movement ops
999       // inside fusion nodes, both of which are essentially free.
1000       if (ShapeUtil::IsEffectiveScalar(instr->shape()) || instr->IsFused()) {
1001         return kWhite;
1002       }
1003       return kGreen;
1004     case HloOpcode::kDynamicUpdateSlice:
1005       // Unlike the data-movement ops above, dynamic-update-slice is not ~free
1006       // inside of fusion nodes, so we de-emphasize it only if it's
1007       // scalar-shaped.
1008       if (ShapeUtil::IsEffectiveScalar(instr->shape())) {
1009         return kWhite;
1010       }
1011       return kGreen;
1012     case HloOpcode::kScatter:
1013       // Do not de-emphasize Scatter, since it involves significant work.
1014     case HloOpcode::kCopy:
1015       // Emphasize copy nodes, which are either physical transposes (and thus
1016       // significant), or copies of read-only buffers (and thus dead weight).
1017       return kGreen;
1018     case HloOpcode::kConvolution:
1019     case HloOpcode::kDot:
1020     case HloOpcode::kFft:
1021     case HloOpcode::kTriangularSolve:
1022     case HloOpcode::kCholesky:
1023       return kDarkBlue;
1024     case HloOpcode::kReducePrecision:
1025       return kRed;
1026     case HloOpcode::kParameter:
1027       return parameter_color;
1028     case HloOpcode::kBatchNormGrad:
1029     case HloOpcode::kBatchNormInference:
1030     case HloOpcode::kBatchNormTraining:
1031     case HloOpcode::kReduce:
1032     case HloOpcode::kReduceWindow:
1033     case HloOpcode::kSelectAndScatter:
1034       return kPurple;
1035     case HloOpcode::kDomain:
1036     case HloOpcode::kFusion:
1037     case HloOpcode::kMap:
1038     case HloOpcode::kGetDimensionSize:
1039       return kGray;
1040     case HloOpcode::kAllReduce:
1041     case HloOpcode::kAllToAll:
1042     case HloOpcode::kCollectivePermute:
1043     case HloOpcode::kInfeed:
1044     case HloOpcode::kOutfeed:
1045     case HloOpcode::kRecv:
1046     case HloOpcode::kRecvDone:
1047     case HloOpcode::kSend:
1048     case HloOpcode::kSendDone:
1049     case HloOpcode::kReplicaId:
1050       return kBrown;
1051     case HloOpcode::kCall:
1052     case HloOpcode::kConditional:
1053     case HloOpcode::kCustomCall:
1054     case HloOpcode::kWhile:
1055       return kDarkGreen;
1056     case HloOpcode::kConstant:
1057       LOG(FATAL) << "Constants don't get their own nodes in the graph.";
1058   }
1059 }
1060 
GetInstructionNodeShape(const HloInstruction * instr)1061 string HloDotDumper::GetInstructionNodeShape(const HloInstruction* instr) {
1062   // Give while loops a different shape so they're easier to pick out.
1063   switch (instr->opcode()) {
1064     case HloOpcode::kWhile:
1065       return "ellipse";
1066     default:
1067       return "rect";
1068   }
1069 }
1070 
GetInstructionNodeLabel(const HloInstruction * instr)1071 string HloDotDumper::GetInstructionNodeLabel(const HloInstruction* instr) {
1072   // If we have a parameter, put the param number in the name.
1073   if (instr->opcode() == HloOpcode::kParameter) {
1074     return StrFormat("<b>Parameter %d</b>", instr->parameter_number());
1075   }
1076 
1077   // The HLO instruction name contains usually the opcode, e.g. "%add.42" is
1078   // an add instruction.  In this case we render just the name.
1079   if (absl::StartsWith(instr->name(), HloOpcodeString(instr->opcode()))) {
1080     return StrFormat("<b>%s</b>", HtmlLikeStringSanitize(instr->name()));
1081   }
1082   string extended_opcode =
1083       StrCat(HloOpcodeString(instr->opcode()),
1084              instr->opcode() != HloOpcode::kFusion
1085                  ? ""
1086                  : StrCat(":", xla::ToString(instr->fusion_kind())));
1087   // If the name does not contain the opcode, render both.
1088   return StrFormat("<b>%s</b><br/>%s", HtmlLikeStringSanitize(extended_opcode),
1089                    HtmlLikeStringSanitize(instr->name()));
1090 }
1091 
GetInstructionNodeMetadata(const HloInstruction * instr)1092 string HloDotDumper::GetInstructionNodeMetadata(const HloInstruction* instr) {
1093   std::vector<string> lines;
1094   if (!instr->metadata().op_name().empty()) {
1095     lines.push_back(HtmlLikeStringSanitize(instr->metadata().op_name()));
1096   }
1097   if (!instr->metadata().op_type().empty()) {
1098     lines.push_back(StrFormat(
1099         "op_type: %s", HtmlLikeStringSanitize(instr->metadata().op_type())));
1100   }
1101   if (!instr->metadata().source_file().empty() &&
1102       instr->metadata().source_line() != 0) {
1103     lines.push_back(StrFormat("op_type: %s:%d", instr->metadata().source_file(),
1104                               instr->metadata().source_line()));
1105   }
1106 
1107   return StrJoin(lines, "\n");
1108 }
1109 
GetInstructionNodeBackendConfig(const HloInstruction * instr)1110 string HloDotDumper::GetInstructionNodeBackendConfig(
1111     const HloInstruction* instr) {
1112   if (!show_backend_config_ || instr->raw_backend_config_string().empty()) {
1113     return "";
1114   }
1115 
1116   return StrCat("backend_config=\"", instr->raw_backend_config_string(), "\"");
1117 }
1118 
GetInstructionNodeExtraInfo(const HloInstruction * instr)1119 string HloDotDumper::GetInstructionNodeExtraInfo(const HloInstruction* instr) {
1120   std::vector<string> lines;
1121 
1122   // Get the instruction's extra attributes excluding the names of its
1123   // subcomputations, since those are drawn explicitly in the graph.
1124   for (const auto& line : instr->ExtraAttributesToString(
1125            HloPrintOptions().set_print_subcomputation_mode(
1126                HloPrintOptions::PrintSubcomputationMode::kOff))) {
1127     lines.push_back(HtmlLikeStringSanitize(line));
1128   }
1129 
1130   // Show the shape and layout of the instruction, unless it's an inlined fusion
1131   // node -- there the shape and layout is present in the output node.
1132   if (instr->opcode() != HloOpcode::kFusion ||
1133       !ShouldShowFusionSubcomputation(instr)) {
1134     // Show layout of instructions with more than one dimension.  Don't show
1135     // layout on tuples or tensors with just one dimension (which only have one
1136     // possible layout) to avoid visual noise.
1137     bool shape_is_multidim = false;
1138     ShapeUtil::ForEachSubshape(instr->shape(),
1139                                [&](const Shape& s, const ShapeIndex&) {
1140                                  shape_is_multidim |= s.dimensions_size() > 1;
1141                                });
1142     string instr_shape;
1143     if (instr->opcode() != HloOpcode::kTuple && shape_is_multidim) {
1144       instr_shape = ShapeUtil::HumanStringWithLayout(instr->shape());
1145     } else {
1146       instr_shape = ShapeUtil::HumanString(instr->shape());
1147     }
1148 
1149     // Some instructions have giant tuples as their shapes, so truncate the
1150     // HLO's shape to kMaxShapeLen characters.
1151     constexpr int kMaxShapeLen = 64;
1152     if (instr_shape.length() > kMaxShapeLen) {
1153       instr_shape = StrCat(
1154           absl::string_view(instr_shape).substr(0, kMaxShapeLen - 3), "...");
1155     }
1156     lines.push_back(instr_shape);
1157   }
1158   if (debug_options_.xla_hlo_graph_addresses()) {
1159     lines.push_back(StrFormat("[%p]", instr));
1160   }
1161   if (profile_ != nullptr) {
1162     double hlo_cycles_executed = profile_->GetCyclesTakenBy(*instr);
1163     double total_cycles_executed =
1164         profile_->total_cycles_executed(*instr->parent());
1165     if (hlo_cycles_executed > 0 && total_cycles_executed > 0) {
1166       lines.push_back(
1167           StrFormat("%% of cycles executed=%.2f",
1168                     100 * hlo_cycles_executed / total_cycles_executed));
1169     }
1170   }
1171   return StrJoin(lines, "<br/>");
1172 }
1173 
AddInstructionIncomingEdges(const HloInstruction * instr)1174 void HloDotDumper::AddInstructionIncomingEdges(const HloInstruction* instr) {
1175   auto add_edge = [&](const HloInstruction* from, const HloInstruction* to,
1176                       int64 operand_num, bool control_edge = false) {
1177     from = GetNodeForEdge(from);
1178 
1179     if (!filter_.Show(from) || from->opcode() == HloOpcode::kConstant ||
1180         ShouldMergeIntoUsers(from)) {
1181       return;
1182     }
1183     VLOG(2) << "Adding edge from " << from->name() << " to " << to->name()
1184             << " as " << next_edge_id_;
1185     edge_ids_.insert({{from, to}, next_edge_id_++});
1186 
1187     string edge_label;
1188     if (instr->operand_count() > 1 && !control_edge) {
1189       edge_label =
1190           StrFormat(R"( headlabel="%d", labeldistance=2)", operand_num);
1191     } else if (control_edge) {
1192       edge_label = "style=\"dotted\" color=\"gray\" label=\"ctrl\"";
1193     }
1194 
1195     // We print "small" arrays using a hollow arrowhead and "large" arrays using
1196     // a filled arrowhead.
1197     constexpr char kEdgeFmt[] =
1198         R"(%s -> %s [arrowhead=%s tooltip="%s -> %s" %s];)";
1199     edges_.push_back(StrFormat(kEdgeFmt, InstructionId(from), InstructionId(to),
1200                                (IsSmall(from) ? "empty" : "normal"),
1201                                from->name(), to->name(), edge_label));
1202   };
1203 
1204   // Add edges from instr's operands to instr.  Parameters within fusion
1205   // expressions are handled specially -- we draw an edge from the corresponding
1206   // operand on the fusion node itself to the parameter.
1207   if (instr->opcode() == HloOpcode::kParameter && instr->IsFused()) {
1208     // Only add the edge if this is not the outermost computation; otherwise it
1209     // will lead from a node we're not drawing.
1210     if (instr->parent() != computation_) {
1211       const HloInstruction* fusion = instr->parent()->FusionInstruction();
1212       add_edge(fusion->operand(instr->parameter_number()), instr,
1213                /*operand_num=*/0);
1214     }
1215   } else {
1216     for (int64 i = 0; i < instr->operand_count(); ++i) {
1217       add_edge(instr->operand(i), instr, i);
1218     }
1219     for (const HloInstruction* pred : instr->control_predecessors()) {
1220       add_edge(pred, instr, /*operand_num=*/0, /*control_edge=*/true);
1221     }
1222   }
1223 }
1224 
GetInstructionTrivialComputationStr(const HloInstruction * instr)1225 string HloDotDumper::GetInstructionTrivialComputationStr(
1226     const HloInstruction* instr) {
1227   // called_computations() on a fusion node "inherits" any called computations
1228   // of the fused root, which isn't what we want.  Just ignore fusion nodes
1229   // here; they're handled separately.
1230   if (instr->opcode() == HloOpcode::kFusion) {
1231     return "";
1232   }
1233 
1234   std::vector<string> lines;
1235   for (int64 i = 0; i < instr->called_computations().size(); ++i) {
1236     optional<string> computation_type =
1237         MatchTrivialComputation(instr->called_computations()[i]);
1238     if (!computation_type) {
1239       continue;
1240     }
1241     if (instr->called_computations().size() == 1) {
1242       lines.push_back(StrFormat("Subcomputation: <b>%s</b>",
1243                                 HtmlLikeStringSanitize(*computation_type)));
1244     } else {
1245       lines.push_back(StrFormat("Subcomputation %d: <b>%s</b>", i,
1246                                 HtmlLikeStringSanitize(*computation_type)));
1247     }
1248   }
1249   return StrJoin(lines, "<br/>");
1250 }
1251 
GetNodeForEdge(const HloInstruction * instr)1252 const HloInstruction* HloDotDumper::GetNodeForEdge(
1253     const HloInstruction* instr) {
1254   while (instr->opcode() == HloOpcode::kFusion &&
1255          ShouldShowFusionSubcomputation(instr)) {
1256     instr = instr->fused_expression_root();
1257   }
1258   return instr;
1259 }
1260 
1261 // Gets a NodeFilter that includes roughly all instructions whose distance from
1262 // root is <= radius.
MakeNodeRadiusAroundFilter(const HloInstruction * root,int64 radius,const absl::flat_hash_set<const HloInstruction * > & boundary)1263 NodeFilter MakeNodeRadiusAroundFilter(
1264     const HloInstruction* root, int64 radius,
1265     const absl::flat_hash_set<const HloInstruction*>& boundary) {
1266   // First, find the neighborhood of nodes with distance from root <= radius.
1267   // These nodes are our initial set of "normal" nodes.
1268   absl::flat_hash_map<const HloInstruction*, NodeFilterResult> nodes;
1269   std::deque<std::pair<const HloInstruction*, /*depth*/ int64>> worklist;
1270   worklist.push_back({root, 0});
1271   while (!worklist.empty()) {
1272     const HloInstruction* instr;
1273     int64 depth;
1274     std::tie(instr, depth) = worklist.front();
1275     worklist.pop_front();
1276 
1277     nodes[instr] = kNormalNode;
1278     if (depth == radius) {
1279       continue;
1280     }
1281     if (boundary.contains(instr)) {
1282       continue;
1283     }
1284 
1285     // Traverse into instr's operands.
1286     //
1287     // Don't traverse into tuples' operands unless the tuple is the root.
1288     // Usually a tuple is the bottommost node in the graph, and so its operands
1289     // are not interesting to the graph at hand.
1290     if (instr == root || instr->opcode() != HloOpcode::kTuple) {
1291       for (const HloInstruction* operand : instr->operands()) {
1292         if (!nodes.contains(operand)) {
1293           worklist.push_back({operand, depth + 1});
1294         }
1295       }
1296     }
1297 
1298     // Traverse into instr's nested computations.
1299     for (const HloComputation* computation : instr->called_computations()) {
1300       worklist.push_back({computation->root_instruction(), depth + 1});
1301     }
1302 
1303     // Traverse into instr's users, unless:
1304     //
1305     //  - there are a ton of them, in which case they're probably not
1306     //    interesting (and anyway, rendering them all would make the graph
1307     //    unreadable), or
1308     //  - instr is a constant, in which case its users are probably not
1309     //    interesting.
1310     if (instr->opcode() == HloOpcode::kConstant) {
1311       continue;
1312     }
1313     constexpr int kMaxUsersToRender = 16;
1314     if (instr->user_count() > kMaxUsersToRender) {
1315       // If we're going to skip this node's users, style it as such.
1316       nodes[instr] = kSomeUsersOmitted;
1317       continue;
1318     }
1319     for (const HloInstruction* user : instr->users()) {
1320       if (!nodes.contains(user)) {
1321         worklist.push_back({user, depth + 1});
1322       }
1323     }
1324   }
1325 
1326   auto is_displayed = [&](const HloInstruction* instr) {
1327     // Constants are displayed inline with their users; they're never omitted.
1328     // Nodes in subcomputations are always shown.
1329     return nodes.contains(instr) || instr->opcode() == HloOpcode::kConstant ||
1330            instr->parent() != root->parent();
1331   };
1332 
1333   // Make a second pass over 'nodes' to fix up the NodeFilterResults now that we
1334   // know which nodes will be included in the graph.
1335   for (auto& kv : nodes) {
1336     const HloInstruction* instr = kv.first;
1337     NodeFilterResult& filter_result = kv.second;
1338     const auto& operands = instr->operands();
1339 
1340     if (absl::c_any_of(operands, is_displayed) &&
1341         !absl::c_all_of(operands, is_displayed)) {
1342       // Mark nodes with some operands omitted appropriately.
1343       filter_result = kSomeOperandsOmitted;
1344     } else if (!operands.empty() && absl::c_none_of(operands, is_displayed)) {
1345       // Mark nodes with *all* operands omitted appropriately.
1346       filter_result = kOmitNodeOperands;
1347     }
1348 
1349     // Promote nodes with type kSomeUsersOmitted to kNormalNode if all of their
1350     // users made it into the graph.
1351     if (filter_result == kSomeUsersOmitted &&
1352         absl::c_all_of(instr->users(), is_displayed)) {
1353       filter_result = kNormalNode;
1354     }
1355   }
1356 
1357   // Highlight the root node.
1358   nodes[root] = kHighlightNode;
1359 
1360   return NodeFilter([=](const HloInstruction* instr) {
1361     auto it = nodes.find(instr);
1362     if (it != nodes.end()) {
1363       return it->second;
1364     }
1365     // Show all nodes in subcomputations.
1366     if (instr->parent() != root->parent()) {
1367       return kNormalNode;
1368     }
1369     return kHideNode;
1370   });
1371 }
1372 
1373 // Gets a node filter that includes nodes on all paths from `from` to `to`.  If
1374 // the all-paths set contains more than max_nodes elements, includes the nodes
1375 // on the shortest paths and sets hit_limit to true.
MakeNodeFromToFilter(const HloInstruction * from,const HloInstruction * to,int64 max_nodes,bool * hit_limit)1376 NodeFilter MakeNodeFromToFilter(const HloInstruction* from,
1377                                 const HloInstruction* to, int64 max_nodes,
1378                                 bool* hit_limit) {
1379   *hit_limit = false;
1380 
1381   // Elements in the queue are paths through the graph.
1382   std::deque<std::vector<const HloInstruction*>> queue;
1383   queue.push_front({from});
1384 
1385   // Compute the set of nodes we want to show using a slightly-modified
1386   // Djikstra's algorithm.  The only real difference is, rather than stopping
1387   // when we find a (shortest) path, we continue until we've found max_nodes
1388   // nodes on some path.
1389   std::unordered_set<const HloInstruction*> visited;
1390   std::unordered_set<const HloInstruction*> to_display = {from, to};
1391   while (!queue.empty() && to_display.size() < max_nodes) {
1392     std::vector<const HloInstruction*> path = std::move(queue.front());
1393     queue.pop_front();
1394     if (!visited.insert(path.back()).second) {
1395       continue;
1396     }
1397 
1398     for (const auto* user : path.back()->users()) {
1399       if (user == to) {
1400         auto it = path.begin();
1401         for (; it != path.end() && to_display.size() < max_nodes; ++it) {
1402           to_display.insert(*it);
1403         }
1404         if (it != path.end()) {
1405           *hit_limit = true;
1406         }
1407       } else if (!visited.count(user)) {
1408         auto new_path = path;
1409         new_path.push_back(user);
1410         queue.push_back(std::move(new_path));
1411       }
1412     }
1413   }
1414 
1415   return NodeFilter([=](const HloInstruction* instr) {
1416     if (instr == from || instr == to) {
1417       return kHighlightNode;
1418     }
1419     return to_display.count(instr) ? kNormalNode : kHideNode;
1420   });
1421 }
1422 
WrapDotInHtml(absl::string_view dot)1423 string WrapDotInHtml(absl::string_view dot) {
1424   static const char html_prefix[] = R"html(
1425 <!DOCTYPE html>
1426 <html>
1427 <head>
1428   <meta charset="utf-8">
1429   <style type="text/css">
1430     body {
1431       height: 100vh;
1432       margin: 0;
1433     }
1434   </style>
1435 </head>
1436 <body>
1437   <!-- Integrity hash is generated by https://www.srihash.org/ -->
1438   <script src="https://cdn.jsdelivr.net/npm/viz.js@2.1.1/viz.js"
1439      integrity="sha384-aD1MJYb0WKIUT+CtwJp5LTuV3U4pLAS6B/nUxL7ECimC2pN9N8vjlMr/yQCAkzxE"
1440      crossorigin="anonymous"></script>
1441   <script src="https://cdn.jsdelivr.net/npm/viz.js@2.1.1/full.render.js"
1442      integrity="sha384-bAixY275aIpCj6Te19y0MILZ4V+VEC8CVFujFEH+Lf7W+4XYYeYLwW5IBI6yQmMT"
1443      crossorigin="anonymous"></script>
1444   <script src="https://cdn.jsdelivr.net/npm/svg-pan-zoom@3.6.0/dist/svg-pan-zoom.min.js"
1445      integrity="sha384-3008WpYB2pOBvE7lwkrKf+qTmbTPGGPYxA9C1YVhvbPukns4ZFj7E98QPLkNW9dS"
1446      crossorigin="anonymous"></script>
1447   <div id="container" style="height:95vh; border:1px solid black; "></div>
1448   <script>
1449     var data = `
1450 )html";
1451 
1452   static const char html_suffix[] = R"html(
1453 `;
1454     var cssregex = new RegExp('stylesheet=<([^]*)\n>\n', 'gm');
1455     var results = cssregex.exec(data)
1456     // graphviz has problem dealing with large stylesheets.
1457     // https://github.com/tensorflow/tensorflow/issues/17220#issuecomment-369228492
1458     // In order to avoid the problem, remove the stylesheet from the dot and
1459     // insert it directly info the rendered SVG.
1460     var dot_data = data;
1461     var css_data = ''
1462     if (results !== null) {
1463         css_data = results[1].replace(/\s*data:.*\s*,/,''); // Strip content-type field.
1464         // CSS inside DOT is URL-escaped, so we must unescape it
1465         // before we can insert it into SVG.
1466         css_data = unescape(css_data);
1467         dot_data = data.replace(cssregex, ''); // Remove the stylesheet
1468     }
1469 
1470     var render_start = performance.now()
1471     function add_controls(svg) {
1472         var htmlblob = new Blob([document.documentElement.innerHTML],
1473                                 {type: 'text/html'});
1474         var savehtml = document.createElement('a');
1475         savehtml.setAttribute('href', URL.createObjectURL(htmlblob));
1476         savehtml.setAttribute('download', 'graph.html');
1477         savehtml.innerHTML = " [Save HTML+SVG] ";
1478         document.body.append(savehtml);
1479         var svgblob = new Blob([svg.outerHTML], {type: 'image/svg'});
1480         var savesvg = document.createElement('a');
1481         savesvg.setAttribute('href', URL.createObjectURL(svgblob));
1482         savesvg.setAttribute('download', 'graph.svg');
1483         savesvg.innerHTML = " [Save SVG] ";
1484         document.body.append(savesvg);
1485         var dotblob =  new Blob([data], {type: 'text/dot'});
1486         var savedot = document.createElement('a');
1487         savedot.setAttribute('href', URL.createObjectURL(dotblob));
1488         savedot.setAttribute('download', 'graph.dot');
1489         savedot.innerHTML = " [Save DOT] ";
1490         document.body.append(savedot);
1491         // Will get called after embed element was loaded
1492         var panzoom = svgPanZoom(svg, {
1493             zoomEnabled: true,
1494             controlIconsEnabled: true,
1495         });
1496         document.getElementsByTagName("BODY")[0].onresize = function() {
1497             panzoom.resize();
1498             panzoom.fit();
1499             panzoom.center();
1500         };
1501         var render_end = performance.now();
1502         var render_note = document.createElement('div')
1503         render_note.innerHTML = 'Rendering took '
1504                                 + (render_end - render_start).toFixed(2) + "ms."
1505         document.body.append(render_note);
1506     }
1507     var svg = document.getElementById('graph')
1508     if (svg == null) {
1509         // Need to render SVG first.
1510         var viz = new Viz();
1511         viz.renderSVGElement(dot_data)
1512             .then(function(svg){
1513                 var container = document.getElementById('container')
1514                 var style = document.createElementNS('http://www.w3.org/2000/svg', 'style');
1515                 var node = document.createTextNode(css_data);
1516                 style.appendChild(node);
1517                 svg.setAttribute('width', '100%');
1518                 svg.setAttribute('height', '100%');
1519                 svg.setAttribute('id', 'graph');
1520                 svg.appendChild(style);
1521                 container.appendChild(svg);
1522                 add_controls(svg);
1523             })
1524     } else {
1525         // HTML already has rendered SVG embedded, so we just need to add
1526         // controls.
1527         add_controls(svg);
1528     }
1529   </script>
1530 </body>
1531 </html>
1532 )html";
1533 
1534   return absl::StrCat(html_prefix, dot, html_suffix);
1535 }
1536 
1537 tensorflow::mutex url_renderer_mu(tensorflow::LINKER_INITIALIZED);
1538 std::function<StatusOr<string>(absl::string_view)>* url_renderer
1539     GUARDED_BY(url_renderer_mu) = nullptr;
1540 
1541 // Precondition: url_renderer != nullptr.
1542 //
1543 // (We specify this as a precondition rather than checking it in here and
1544 // returning an error because we want to fail quickly when there's no URL
1545 // renderer available, and this function runs only after we've done all the work
1546 // of producing dot for the graph.)
1547 StatusOr<string> WrapDotInFormat(absl::string_view dot,
1548                                  RenderedGraphFormat format)
1549     EXCLUSIVE_LOCKS_REQUIRED(url_renderer_mu) {
1550   switch (format) {
1551     case RenderedGraphFormat::kUrl:
1552       CHECK(url_renderer != nullptr)
1553           << "Should have checked url_renderer != null before calling.";
1554       return (*url_renderer)(dot);
1555     case RenderedGraphFormat::kHtml:
1556       return WrapDotInHtml(dot);
1557     case RenderedGraphFormat::kDot:
1558       return string(dot);
1559   }
1560 }
1561 
1562 }  // namespace
1563 
1564 void RegisterGraphToURLRenderer(
1565     std::function<StatusOr<string>(absl::string_view)> renderer) {
1566   tensorflow::mutex_lock lock(url_renderer_mu);
1567   if (url_renderer != nullptr) {
1568     LOG(WARNING) << "Multiple calls to RegisterGraphToURLRenderer.  Last call "
1569                     "wins, but because order of initialization in C++ is "
1570                     "nondeterministic, this may not be what you want.";
1571   }
1572   delete url_renderer;
1573   url_renderer = new std::function<StatusOr<string>(absl::string_view)>(
1574       std::move(renderer));
1575 }
1576 
1577 StatusOr<string> RenderGraph(const HloComputation& computation,
1578                              absl::string_view label,
1579                              const DebugOptions& debug_options,
1580                              RenderedGraphFormat format,
1581                              const HloExecutionProfile* hlo_execution_profile,
1582                              bool show_backend_config) {
1583   tensorflow::mutex_lock lock(url_renderer_mu);
1584   if (format == RenderedGraphFormat::kUrl && url_renderer == nullptr) {
1585     return Unavailable("Can't render as URL; no URL renderer was registered.");
1586   }
1587 
1588   string rendered_dot =
1589       HloDotDumper(&computation, label, debug_options, show_backend_config,
1590                    hlo_execution_profile, NodeFilter())
1591           .Dump();
1592   return WrapDotInFormat(rendered_dot, format);
1593 }
1594 
1595 StatusOr<string> RenderNeighborhoodAround(
1596     const HloInstruction& node, int radius, RenderedGraphFormat format,
1597     bool show_backend_config,
1598     const absl::flat_hash_set<const HloInstruction*>& boundary) {
1599   tensorflow::mutex_lock lock(url_renderer_mu);
1600   if (format == RenderedGraphFormat::kUrl && url_renderer == nullptr) {
1601     return FailedPrecondition(
1602         "Can't render as URL; no URL renderer was registered.");
1603   }
1604 
1605   string label =
1606       StrCat("Neighborhood of ", radius, " nodes around ", node.name());
1607   string rendered_dot =
1608       HloDotDumper(node.parent(), label,
1609                    node.GetModule()->config().debug_options(),
1610                    show_backend_config, /*profile=*/nullptr,
1611                    MakeNodeRadiusAroundFilter(&node, radius, boundary))
1612           .Dump();
1613   return WrapDotInFormat(rendered_dot, format);
1614 }
1615 
1616 StatusOr<string> RenderAllPathsFromTo(const HloInstruction& from,
1617                                       const HloInstruction& to, int64 max_nodes,
1618                                       RenderedGraphFormat format,
1619                                       bool show_backend_config) {
1620   tensorflow::mutex_lock lock(url_renderer_mu);
1621   if (format == RenderedGraphFormat::kUrl && url_renderer == nullptr) {
1622     return FailedPrecondition(
1623         "Can't render as URL; no URL renderer was registered.");
1624   }
1625 
1626   CHECK_EQ(from.parent(), to.parent()) << "Nodes must be in same computation!";
1627   auto debug_options = from.GetModule()->config().debug_options();
1628 
1629   bool hit_limit = false;
1630   NodeFilter filter = MakeNodeFromToFilter(&from, &to, max_nodes, &hit_limit);
1631   string label;
1632   if (!hit_limit) {
1633     label = StrCat("All paths from ", from.name(), " to ", to.name());
1634   } else {
1635     label = StrCat(max_nodes, " nodes on the shortest paths from ", from.name(),
1636                    " to ", to.name(),
1637                    "<br/><br/>***SHOWING ONLY A SUBSET OF ALL PATHS BETWEEN "
1638                    "NODES***<br/><br/>");
1639   }
1640   string rendered_dot =
1641       HloDotDumper(from.parent(), label, debug_options, show_backend_config,
1642                    /*profile=*/nullptr, filter)
1643           .Dump();
1644   return WrapDotInFormat(rendered_dot, format);
1645 }
1646 
1647 }  // namespace xla
1648