1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // A tool for interactively exploring graphviz dumps of HLO graphs.
17 //
18 // Input can be a binary HloSnapshot proto, a binary HloProto proto, or a
19 // textual HLO string.
20 //
21 // Generated visualization is opened in a new default browser window using
22 // /usr/bin/sensible-browser.
23 
24 #include <stdio.h>
25 #include <unistd.h>
26 
27 #include <functional>
28 #include <string>
29 #include <utility>
30 
31 #include "absl/algorithm/container.h"
32 #include "absl/container/flat_hash_set.h"
33 #include "absl/strings/match.h"
34 #include "absl/strings/numbers.h"
35 #include "absl/strings/str_cat.h"
36 #include "absl/strings/str_join.h"
37 #include "absl/strings/str_split.h"
38 #include "tensorflow/compiler/xla/client/client_library.h"
39 #include "tensorflow/compiler/xla/client/local_client.h"
40 #include "tensorflow/compiler/xla/service/compiler.h"
41 #include "tensorflow/compiler/xla/service/hlo.pb.h"
42 #include "tensorflow/compiler/xla/service/hlo_runner.h"
43 #include "tensorflow/compiler/xla/service/local_service.h"
44 #include "tensorflow/compiler/xla/service/platform_util.h"
45 #include "tensorflow/compiler/xla/tools/hlo_extractor.h"
46 #include "tensorflow/core/lib/io/path.h"
47 #include "tensorflow/core/platform/init_main.h"
48 #include "tensorflow/core/platform/logging.h"
49 #include "tensorflow/core/platform/subprocess.h"
50 #include "tensorflow/core/protobuf/error_codes.pb.h"
51 #include "tensorflow/core/util/command_line_flags.h"
52 #if defined(PLATFORM_GOOGLE)
53 #include "util/readline/readline.h"
54 #endif
55 
56 #if defined(PLATFORM_WINDOWS)
57 #include <io.h>
58 #define isatty _isatty
59 #endif
60 
61 namespace xla {
62 namespace tools {
63 namespace {
64 
ReadLine(const char * prompt,std::string * line)65 bool ReadLine(const char* prompt, std::string* line) {
66 #if defined(PLATFORM_GOOGLE)
67   return util::ReadLine(prompt, line);
68 #else
69   std::cout << prompt;
70   std::getline(std::cin, *line);
71   return std::cin.good();
72 #endif
73 }
74 
75 // Command-line opts to this tool.  See main() for descriptions of these
76 // fields.
77 struct Options {
78   std::string hlo_snapshot;
79   std::string hlo_proto;
80   std::string hlo_module_proto;
81   std::string hlo_text;
82   std::string platform;
83   std::string browser;
84 };
85 
86 const char* const kUsage = R"(
87 This tool lets you load an XLA dump and then interactively explore its graphical
88 representation.
89 
90 Most models are too large to visualize in their entirety using graphviz, but
91 it's still useful to be able to look at the nodes "near" a particular node of
92 interest.
93 
94 If you pass --platform, this tool will compile the HloModule for the given
95 platform.  This means that if you acquired your proto from a binary running at a
96 particular CL, the HLO graph it ran isn't necessarily the same as the one shown
97 here, unless this program was built at the same CL (and our compiler is
98 deterministic :).
99 
100 Be patient when starting this program if you give it a large input; it has to
101 compile the whole thing.
102 
103 Usage:
104 
105   interactive_graphviz -- \
106     --{hlo_snapshot,hlo_proto,hlo_text}=path/to/binary_proto
107     --platform={CUDA,CPU,...}
108 )";
109 
110 // Unless an explicit width is specified, we will render a neighborhood of
111 // kDefaultWidth nodes around the requested instruction.
112 constexpr int64_t kDefaultWidth = 2;
113 
114 // When printing all paths between two nodes, we print out only this many nodes
115 // by default, truncating the graph if there are more nodes than this in the
116 // all-paths set.
117 constexpr int64_t kDefaultMaxNumNodesInAllPaths = 100;
118 
119 using absl::EqualsIgnoreCase;
120 
121 HloRenderOptions hlo_render_options;
122 
FindInstruction(const HloModule & module,std::string node_name)123 HloInstruction* FindInstruction(const HloModule& module,
124                                 std::string node_name) {
125   if (absl::StartsWith(node_name, "%")) {
126     node_name.erase(node_name.begin());
127   }
128   for (const auto& computation : module.computations()) {
129     auto instrs = computation->instructions();
130     auto it = absl::c_find_if(instrs, [&](const HloInstruction* instr) {
131       // Try with and without "%" at the beginning of the node name.
132       return EqualsIgnoreCase(instr->name(), node_name) ||
133              EqualsIgnoreCase(instr->name(), absl::StrCat("%", node_name));
134     });
135     if (it != instrs.end()) {
136       return *it;
137     }
138   }
139   return nullptr;
140 }
141 
FindComputation(const HloModule & module,const std::string & comp_name)142 HloComputation* FindComputation(const HloModule& module,
143                                 const std::string& comp_name) {
144   for (auto* computation : module.computations()) {
145     if (EqualsIgnoreCase(computation->name(), comp_name)) {
146       return computation;
147     }
148   }
149   return nullptr;
150 }
151 
152 // Print a help message describing the various available commands.
DoHelpCommand()153 void DoHelpCommand() {
154   std::cout << R"(Commands:
155   <instruction> [<width>] [/ <boundary_instruction>+]
156     Renders a neighborhood of <width> nodes around <instruction>, without going
157     beyond the optional boundary instructions.  If <width> is not provided,
158     the default value is )"
159             << kDefaultWidth << R"(.
160   allpaths <instruction> <instruction> [<n>]
161     Renders a subset of all paths from one instruction to the other.  Either
162     order of nodes is accepted.  Shows the <n> nodes in the all-paths set on the
163     shortest paths; default is )"
164             << kDefaultMaxNumNodesInAllPaths << R"(.
165   <computation>
166     Renders all nodes in <computation>.
167   backend_config [on|off]
168     Controls whether backend operation configuration information is printed.
169   show_fusion_subcomputations [on|off]
170     Controls whether fusion subcomputations are shown.
171   list [name|op_name|op_type] <pattern>
172     Lists all instructions whose name, metadata op_name, or metadata op_type
173     contains <pattern> as a substring.
174   list computations
175     Lists all computations in the module.
176   info <instruction>
177   info <computation>
178     Prints information about <instruction> or <computation>.
179   extract <instruction> <height>
180     Creates a new HLO module with <instruction> as entry computation root. If
181     <height> is specified, the new computation contains nodes up to <height>
182     nodes above the root.
183   help
184     Prints this usage information.
185   quit
186     Exit the application.)"
187             << std::endl;
188 }
189 
190 // Turn metadata-printing on or off.
DoBackendConfigCommand(const std::vector<std::string> & tokens)191 void DoBackendConfigCommand(const std::vector<std::string>& tokens) {
192   if (tokens.size() == 2 && tokens[1] == "on") {
193     hlo_render_options.show_backend_config = true;
194   } else if (tokens.size() == 2 && tokens[1] == "off") {
195     hlo_render_options.show_backend_config = false;
196   } else if (tokens.size() != 1) {
197     std::cerr << "(Illegal backend_config value.  Use either 'on' or 'off'.)"
198               << std::endl;
199   }
200   std::cout << "Backend configuration display "
201             << (hlo_render_options.show_backend_config ? "ON" : "OFF")
202             << std::endl;
203 }
204 
205 // Turn fusion computation display on or off.
DoShowFusionSubcomputationsCommand(const std::vector<std::string> & tokens)206 void DoShowFusionSubcomputationsCommand(
207     const std::vector<std::string>& tokens) {
208   if (tokens.size() == 2 && tokens[1] == "on") {
209     hlo_render_options.show_fusion_subcomputations = true;
210   } else if (tokens.size() == 2 && tokens[1] == "off") {
211     hlo_render_options.show_fusion_subcomputations = false;
212   } else if (tokens.size() != 1) {
213     std::cerr << "(Illegal show_fusion_subcomputations value.  Use either "
214                  "'on' or 'off'.)"
215               << std::endl;
216   }
217   std::cout << "Fusion subcomputations display "
218             << (hlo_render_options.show_fusion_subcomputations ? "ON" : "OFF")
219             << std::endl;
220 }
221 
222 // List all computations in the module.
DoListComputationsCommand(const HloModule & module,const std::vector<std::string> & tokens)223 void DoListComputationsCommand(const HloModule& module,
224                                const std::vector<std::string>& tokens) {
225   if (tokens.size() > 2) {
226     std::cout << R"(Illegal syntax; "list computations" takes no arguments.)";
227     return;
228   }
229   if (module.entry_computation() != nullptr) {
230     std::cout << "Entry computation:" << std::endl;
231     std::cout << "  " << module.entry_computation()->name() << std::endl
232               << std::endl;
233   }
234   std::cout << "Subcomputations:" << std::endl;
235   std::vector<std::string> names;
236   for (const auto& computation : module.computations()) {
237     if (computation == module.entry_computation()) {
238       continue;
239     }
240     std::cout << "  " << computation->name() << std::endl;
241   }
242 }
243 
244 // List all instructions matching a pattern.
DoListCommand(const HloModule & module,const std::vector<std::string> & tokens)245 void DoListCommand(const HloModule& module,
246                    const std::vector<std::string>& tokens) {
247   std::string pattern = "";
248   std::string type = "name";
249   if (tokens.size() == 2) {
250     pattern = tokens[1];
251   } else if (tokens.size() == 3) {
252     type = tokens[1];
253     pattern = tokens[2];
254   } else {
255     std::cout << "Illegal list query syntax. Use "
256               << R"("list [name|op_name|op_type] pattern".)" << std::endl;
257     return;
258   }
259 
260   std::cout << "Query results:" << std::endl;
261   for (const auto& computation : module.computations()) {
262     for (const auto& instr : computation->instructions()) {
263       if ((type == "name" &&
264            instr->name().find(pattern) != std::string::npos) ||
265           (type == "op_name" &&
266            instr->metadata().op_name().find(pattern) != std::string::npos) ||
267           (type == "op_type" &&
268            instr->metadata().op_type().find(pattern) != std::string::npos)) {
269         std::cout << "  " << instr->name();
270         std::cout << ", op_name '" << instr->metadata().op_name() << "'";
271         std::cout << ", op_type '" << instr->metadata().op_type() << "'";
272         std::cout << std::endl;
273       }
274     }
275   }
276 }
277 
278 // Print info about an instruction or computation.
DoInfoCommand(const HloModule & module,const std::vector<std::string> & tokens)279 void DoInfoCommand(const HloModule& module,
280                    const std::vector<std::string>& tokens) {
281   if (tokens.size() != 2) {
282     std::cerr << "Illegal info query syntax. Use "
283               << R"("info name".)";
284     return;
285   }
286   std::string node_name = tokens[1];
287 
288   const HloInstruction* instr = FindInstruction(module, node_name);
289   const HloComputation* comp = FindComputation(module, node_name);
290   if (!instr && !comp) {
291     std::cerr << "Couldn't find HloInstruction or HloComputation named "
292               << node_name << std::endl;
293     return;
294   }
295 
296   if (comp != nullptr) {
297     std::cout << "HloComputation " << comp->name() << std::endl;
298     if (comp->IsFusionComputation()) {
299       std::cout << "  Fusion instruction: " << comp->FusionInstruction()->name()
300                 << std::endl;
301     }
302     std::cout << "  Parameters:" << std::endl;
303     for (const auto& param : comp->parameter_instructions()) {
304       std::cout << "    " << param->name() << " ("
305                 << ShapeUtil::HumanStringWithLayout(param->shape()) << ")"
306                 << std::endl;
307     }
308     HloInstruction* root = comp->root_instruction();
309     std::cout << "  Root instruction: " << root->name() << " ("
310               << ShapeUtil::HumanStringWithLayout(root->shape()) << ")"
311               << std::endl;
312 
313     auto embedded_computations = comp->MakeEmbeddedComputationsList();
314     std::cout << "  " << embedded_computations.size() << " embedded computation"
315               << (embedded_computations.size() != 1 ? "s" : "")
316               << (!embedded_computations.empty() ? ":" : ".") << std::endl;
317     for (const HloComputation* c : embedded_computations) {
318       std::cout << "    " << c->name() << std::endl;
319     }
320 
321     // Find which computations reference comp as an embedded computation.
322     std::vector<const HloComputation*> users;
323     for (const HloComputation* c : module.computations()) {
324       if (absl::c_linear_search(c->MakeEmbeddedComputationsList(), comp)) {
325         users.push_back(c);
326       }
327     }
328     std::cout << "  Used by " << users.size() << " computation"
329               << (users.size() != 1 ? "s" : "") << (!users.empty() ? ":" : ".");
330     for (const HloComputation* c : users) {
331       std::cout << "    " << c->name() << std::endl;
332     }
333   } else {
334     std::cout << "HloInstruction " << instr->name() << std::endl;
335     std::cout << "  Parent computation: " << instr->parent()->name()
336               << std::endl;
337     std::cout << "  Opcode: " << HloOpcodeString(instr->opcode()) << std::endl;
338     std::cout << "  Shape: " << ShapeUtil::HumanStringWithLayout(instr->shape())
339               << std::endl;
340     std::cout << "  Metadata:" << std::endl;
341     if (!instr->metadata().op_name().empty()) {
342       std::cout << "    Name: " << instr->metadata().op_name() << std::endl;
343     }
344     if (!instr->metadata().op_type().empty()) {
345       std::cout << "    Type: " << instr->metadata().op_type() << std::endl;
346     }
347     if (!instr->raw_backend_config_string().empty()) {
348       std::cout << "  Backend configuration: "
349                 << instr->raw_backend_config_string() << std::endl;
350     }
351     if (instr->opcode() == HloOpcode::kFusion) {
352       std::cout << "  Fusion kind: " << xla::ToString(instr->fusion_kind())
353                 << std::endl;
354       std::cout << "  Fusion computation: "
355                 << instr->fused_instructions_computation()->name() << std::endl;
356       std::cout << "  Fused computation root: "
357                 << instr->fused_expression_root()->name() << std::endl;
358     }
359     std::cout << "  Operands:" << std::endl;
360     for (HloInstruction* operand : instr->operands()) {
361       std::cout << "    " << operand->name() << " ("
362                 << ShapeUtil::HumanStringWithLayout(operand->shape()) << ")"
363                 << std::endl;
364     }
365     std::cout << "  Users:" << std::endl;
366     for (HloInstruction* user : instr->users()) {
367       std::cout << "    " << user->name() << std::endl;
368     }
369     if (instr->parent()->root_instruction() == instr) {
370       std::cout << "  Root instruction of " << instr->parent()->name()
371                 << std::endl;
372     }
373   }
374 }
375 
DoExtractCommand(const HloModule & module,absl::Span<const std::string> tokens)376 void DoExtractCommand(const HloModule& module,
377                       absl::Span<const std::string> tokens) {
378   if (tokens.size() > 3) {
379     std::cerr << R"(Illegal input.  Enter e.g. "extract %fusion.1 2")"
380               << std::endl;
381     return;
382   }
383 
384   // Find the node with the given name.
385   std::string node_name = tokens[1];
386   HloInstruction* instr = FindInstruction(module, node_name);
387   if (!instr) {
388     std::cerr << "Couldn't find HloInstruction named " << node_name << "."
389               << std::endl;
390     return;
391   }
392 
393   int64_t height = -1;
394   if (tokens.size() == 3) {
395     if (!absl::SimpleAtoi(tokens[2], &height)) {
396       std::cerr << "Can't parse '" << tokens[2] << "' as an integer."
397                 << std::endl;
398       return;
399     }
400   }
401 
402   auto extracted_module = ExtractModule(instr, height);
403   std::cout << extracted_module->ToString(
404                    HloPrintOptions::ShortParsable().set_print_backend_config(
405                        hlo_render_options.show_backend_config))
406             << std::endl;
407 }
408 
409 // Checks if there is a use-def path from `from` to `to`.
ExistsPathFromTo(const HloInstruction * from,const HloInstruction * to)410 bool ExistsPathFromTo(const HloInstruction* from, const HloInstruction* to) {
411   absl::flat_hash_set<const HloInstruction*> visited;
412   std::vector<const HloInstruction*> to_visit = {from};
413   while (!to_visit.empty()) {
414     auto* n = to_visit.back();
415     if (n == to) {
416       return true;
417     }
418     to_visit.pop_back();
419     visited.insert(n);
420     for (auto* user : n->users()) {
421       if (!visited.count(user)) {
422         to_visit.push_back(user);
423       }
424     }
425   }
426   return false;
427 }
428 
OpenUrl(const Options & opts,absl::string_view url)429 void OpenUrl(const Options& opts, absl::string_view url) {
430   std::cout << url << std::endl;
431 
432   // If it is a url, try to open it up in the user's browser too.
433   if (absl::StartsWithIgnoreCase(url, "http://") ||
434       absl::StartsWithIgnoreCase(url, "https://") ||
435       absl::StartsWithIgnoreCase(url, "file://")) {
436     const char* browser_bin = opts.browser.empty() ? "/usr/bin/sensible-browser"
437                                                    : opts.browser.c_str();
438     tensorflow::SubProcess p;
439     p.SetProgram(browser_bin, {browser_bin, std::string(url)});
440     p.Start();
441   } else {
442     std::cerr << "\nExpected a URL, but got strange graph result (dumped "
443                  "above).  If this isn't what you expected, maybe file a bug?"
444               << std::endl;
445   }
446 }
447 
448 // Renders a graph by calling `renderer`, and then tries to open it.
449 //
450 // `renderer` is a callback so we can try various formats.  In particular, the
451 // URL format doesn't work out of the box; it requires you to register a plugin.
RenderAndDisplayGraph(const Options & opts,const std::function<StatusOr<std::string> (RenderedGraphFormat)> & renderer)452 void RenderAndDisplayGraph(
453     const Options& opts,
454     const std::function<StatusOr<std::string>(RenderedGraphFormat)>& renderer) {
455   StatusOr<std::string> url_result = renderer(RenderedGraphFormat::kUrl);
456   if (url_result.ok()) {
457     std::string url = url_result.ValueOrDie();
458     OpenUrl(opts, url);
459     return;
460   }
461 
462   // Ignore UNAVAILABLE errors; these are expected when there's no URL renderer
463   // plugin registered.
464   if (url_result.status().code() != tensorflow::error::UNAVAILABLE) {
465     std::cerr << "Unable to render graph as URL: " << url_result.status()
466               << std::endl;
467     std::cerr << "Trying as HTML..." << std::endl;
468   }
469 
470   auto* env = tensorflow::Env::Default();
471   StatusOr<std::string> html_result = renderer(RenderedGraphFormat::kHtml);
472   if (!html_result.ok()) {
473     std::cerr << "Failed to render graph as HTML: " << html_result.status()
474               << std::endl;
475     return;
476   }
477 
478   std::vector<std::string> temp_dirs;
479   env->GetLocalTempDirectories(&temp_dirs);
480   if (temp_dirs.empty()) {
481     std::cerr << "Can't render graph as HTML because we can't find a suitable "
482                  "temp directory.  Try setting $TMPDIR?"
483               << std::endl;
484     return;
485   }
486 
487   // Try to create a unique file inside of temp_dirs.front().  Notably, this
488   // file's name must end with ".html", otherwise web browsers will treat it as
489   // plain text, so we can't use Env::CreateUniqueFileName().
490   std::string temp_file_path = tensorflow::io::JoinPath(
491       temp_dirs.front(),
492       absl::StrFormat("interactive_graphviz.%d.html", env->NowMicros()));
493   auto status = tensorflow::WriteStringToFile(
494       env, temp_file_path, std::move(html_result).ValueOrDie());
495   if (status.ok()) {
496     OpenUrl(opts, absl::StrCat("file://", temp_file_path));
497     return;
498   }
499 
500   std::cerr << "Failed to write rendered HTML graph to " << temp_file_path
501             << ": " << status;
502 
503   // We don't bother trying kDot, because kHTML should always work (or if it
504   // doesn't, we don't have any reason to believe kDot will work better).
505 }
506 
DoAllPathsCommand(const Options & opts,const HloModule & module,const std::vector<std::string> & tokens)507 void DoAllPathsCommand(const Options& opts, const HloModule& module,
508                        const std::vector<std::string>& tokens) {
509   if (tokens.size() > 4) {
510     std::cerr << R"(Illegal input.  Enter e.g. "allpaths %add.4 %subtract.2" or
511 "allpaths add.4 subtract.2 42.)"
512               << std::endl;
513     return;
514   }
515 
516   int64_t max_nodes = kDefaultMaxNumNodesInAllPaths;
517   if (tokens.size() == 4 && !absl::SimpleAtoi(tokens[3], &max_nodes)) {
518     std::cerr << "Can't parse '" << tokens[3] << "' as an integer."
519               << std::endl;
520     return;
521   }
522 
523   const HloInstruction* n1 = FindInstruction(module, tokens[1]);
524   if (!n1) {
525     std::cerr << "Couldn't find HloInstruction named " << tokens[1];
526     return;
527   }
528   const HloInstruction* n2 = FindInstruction(module, tokens[2]);
529   if (!n2) {
530     std::cerr << "Couldn't find HloInstruction named " << tokens[2];
531     return;
532   }
533 
534   // Is there a path from n1 to n2, or vice versa?
535   const HloInstruction* from;
536   const HloInstruction* to;
537   if (ExistsPathFromTo(n1, n2)) {
538     from = n1;
539     to = n2;
540   } else if (ExistsPathFromTo(n2, n1)) {
541     from = n2;
542     to = n1;
543   } else {
544     std::cerr << "No path from/to " << tokens[1] << " to/from " << tokens[2];
545     return;
546   }
547   RenderAndDisplayGraph(opts, [&](RenderedGraphFormat format) {
548     return RenderAllPathsFromTo(*from, *to, max_nodes, format,
549                                 hlo_render_options);
550   });
551 }
552 
553 // Plot a given instruction neighborhood or computation with graphviz.
554 void DoPlotCommand(const Options& opts, const HloModule& module,
555                    const std::vector<std::string>& tokens) {
556   std::string node_name = tokens[0];
557 
558   // Find the node with the given name.
559   const HloInstruction* instr = FindInstruction(module, node_name);
560   const HloComputation* comp = FindComputation(module, node_name);
561   if (!instr && !comp) {
562     std::cerr << "Couldn't find HloInstruction or HloComputation named "
563               << node_name << "." << std::endl;
564     return;
565   }
566 
567   uint64_t graph_width = kDefaultWidth;
568   absl::flat_hash_set<const HloInstruction*> boundary;
569   if (tokens.size() >= 2) {
570     if (comp) {
571       std::cerr << "Can only use graph-size parameter with instructions, but "
572                 << node_name << " is a computation." << std::endl;
573       return;
574     }
575 
576     int bound_index = 1;
577     // Get the <width> if present.
578     if (absl::SimpleAtoi(tokens[bound_index], &graph_width)) {
579       bound_index++;
580     } else {
581       // <width> not found, need to reset graph_width.
582       graph_width = kDefaultWidth;
583     }
584     // Get the '/'.
585     if (bound_index < tokens.size()) {
586       // This token must be a '/'.
587       if (tokens[bound_index] != "/") {
588         std::cerr << "Expect a /, but get a '" << tokens[bound_index] << "'."
589                   << std::endl;
590         return;
591       }
592       bound_index++;
593     }
594     // Get the boundary nodes.
595     while (bound_index < tokens.size()) {
596       std::string bnode_name = tokens[bound_index];
597       const HloInstruction* binstr = FindInstruction(module, bnode_name);
598       if (!binstr) {
599         std::cerr << "Couldn't find HloInstruction named " << bnode_name << "."
600                   << std::endl;
601         return;
602       }
603       boundary.insert(binstr);
604       bound_index++;
605     }
606   }
607 
608   // Generate the graph and print the resulting string, which should be a
609   // graphviz url.
610   if (comp) {
611     RenderAndDisplayGraph(opts, [&](RenderedGraphFormat format) {
612       return RenderGraph(*comp, /*label=*/"",
613                          comp->parent()->config().debug_options(), format,
614                          /*hlo_execution_profile=*/nullptr, hlo_render_options);
615     });
616   } else {
617     RenderAndDisplayGraph(opts, [&](RenderedGraphFormat format) {
618       return RenderNeighborhoodAround(*instr, graph_width, format,
619                                       hlo_render_options,
620                                       /*boundary=*/boundary);
621     });
622   }
623 }
624 
625 // Run the main event loop, reading user commands and processing them.
626 void InteractiveDumpGraphs(const Options& opts, const HloModule& module) {
627   // This is an interactive tool, but some may use `extract` in non-tty
628   // environment anyway. Give them a clean hlo dump.
629   if (isatty(fileno(stdin))) {
630     std::cout << "\n\nLoaded module " << module.name() << "." << std::endl;
631     DoHelpCommand();
632   }
633   for (std::string line; ReadLine("\ncommand: ", &line);) {
634     if (line.empty()) {
635       std::cout << R"(Enter e.g. "fusion.1 3" or "add.8".)" << std::endl
636                 << R"(Enter "help" for help; ^D, "quit", or "exit" to exit.)"
637                 << std::endl;
638       continue;
639     }
640     std::vector<std::string> tokens =
641         absl::StrSplit(line, ' ', absl::SkipEmpty());
642     if (tokens[0] == "quit" || tokens[0] == "exit") {
643       break;
644     } else if (tokens[0] == "help") {
645       DoHelpCommand();
646     } else if (tokens[0] == "backend_config") {
647       DoBackendConfigCommand(tokens);
648     } else if (tokens[0] == "show_fusion_subcomputations") {
649       DoShowFusionSubcomputationsCommand(tokens);
650     } else if (tokens[0] == "list") {
651       if (tokens.size() > 1 && tokens[1] == "computations") {
652         DoListComputationsCommand(module, tokens);
653       } else {
654         DoListCommand(module, tokens);
655       }
656     } else if (tokens[0] == "info") {
657       DoInfoCommand(module, tokens);
658     } else if (tokens[0] == "extract") {
659       DoExtractCommand(module, tokens);
660     } else if (tokens[0] == "allpaths") {
661       DoAllPathsCommand(opts, module, tokens);
662     } else {
663       DoPlotCommand(opts, module, tokens);
664     }
665   }
666 }
667 
668 void CheckFlags(const Options& opts) {
669   int nonempty_flags_amount = 0;
670   if (!opts.hlo_proto.empty()) {
671     ++nonempty_flags_amount;
672   }
673   if (!opts.hlo_snapshot.empty()) {
674     ++nonempty_flags_amount;
675   }
676   if (!opts.hlo_text.empty()) {
677     ++nonempty_flags_amount;
678   }
679   if (!opts.hlo_module_proto.empty()) {
680     ++nonempty_flags_amount;
681   }
682   if (nonempty_flags_amount == 1) {
683     return;
684   }
685   LOG(FATAL) << "Can only specify one and only one of '--hlo_proto', "
686                 "'--hlo_snapshot', '--hlo_text', '--hlo_module_proto' flags.";
687 }
688 
689 void RealMain(const Options& opts) {
690   if (!isatty(fileno(stdin))) {
691     LOG(ERROR) << "\n\n*****************************************\n"
692                << "This is an interactive tool, but stdin is not a tty.\n"
693                << "*****************************************\n\n";
694   }
695 
696   CheckFlags(opts);
697 
698   std::unique_ptr<HloModule> module;
699   if (!opts.hlo_snapshot.empty()) {
700     HloSnapshot snapshot;
701     TF_CHECK_OK(tensorflow::ReadBinaryProto(tensorflow::Env::Default(),
702                                             opts.hlo_snapshot, &snapshot))
703         << "Can't open, read, or parse HloSnapshot proto at "
704         << opts.hlo_snapshot;
705     auto config =
706         HloModule::CreateModuleConfigFromProto(snapshot.hlo().hlo_module(),
707                                                xla::GetDebugOptionsFromFlags())
708             .ValueOrDie();
709     module = HloModule::CreateFromProto(snapshot.hlo().hlo_module(), config)
710                  .ValueOrDie();
711   } else if (!opts.hlo_proto.empty()) {
712     module = HloRunner::ReadModuleFromBinaryProtoFile(
713                  opts.hlo_proto, xla::GetDebugOptionsFromFlags())
714                  .ValueOrDie();
715   } else if (!opts.hlo_text.empty()) {
716     module = HloRunner::ReadModuleFromHloTextFile(
717                  opts.hlo_text, xla::GetDebugOptionsFromFlags())
718                  .ValueOrDie();
719   } else if (!opts.hlo_module_proto.empty()) {
720     module = HloRunner::ReadModuleFromModuleBinaryProtofile(
721                  opts.hlo_module_proto, xla::GetDebugOptionsFromFlags())
722                  .ValueOrDie();
723   }
724 
725   // If a platform was specified, compile the module for that platform.
726   if (!opts.platform.empty()) {
727     se::Platform* platform =
728         PlatformUtil::GetPlatform(opts.platform).ValueOrDie();
729     LOG(INFO) << "Compiling module for " << platform->Name();
730 
731     se::StreamExecutor* executor =
732         platform->ExecutorForDevice(/*ordinal=*/0).ValueOrDie();
733     auto compiler = Compiler::GetForPlatform(platform).ValueOrDie();
734     module = compiler
735                  ->RunHloPasses(std::move(module), executor,
736                                 /*device_allocator=*/nullptr)
737                  .ValueOrDie();
738     auto executable = compiler
739                           ->RunBackend(std::move(module), executor,
740                                        /*device_allocator=*/nullptr)
741                           .ValueOrDie();
742     InteractiveDumpGraphs(opts, executable->module());
743   } else {
744     InteractiveDumpGraphs(opts, *module);
745   }
746 }
747 
748 }  // namespace
749 }  // namespace tools
750 }  // namespace xla
751 
752 int main(int argc, char** argv) {
753   xla::tools::Options opts;
754   opts.browser = "/usr/bin/sensible-browser";
755   bool need_help = false;
756   const std::vector<tensorflow::Flag> flag_list = {
757       tensorflow::Flag("hlo_snapshot", &opts.hlo_snapshot,
758                        "HloSnapshot proto to interactively dump to graphviz"),
759       tensorflow::Flag("hlo_proto", &opts.hlo_proto,
760                        "XLA hlo proto to interactively dump to graphviz"),
761       tensorflow::Flag("hlo_module_proto", &opts.hlo_module_proto,
762                        "XLA hlomodule proto to interactively dump to graphviz"),
763       tensorflow::Flag("hlo_text", &opts.hlo_text,
764                        "XLA hlo proto to interactively dump to graphviz"),
765       tensorflow::Flag("platform", &opts.platform,
766                        "Platform to compile for: CPU, CUDA, etc"),
767       tensorflow::Flag("browser", &opts.browser,
768                        "Path to web browser used to display produced graphs."),
769       tensorflow::Flag("help", &need_help, "Prints this help message"),
770   };
771   std::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
772   bool parse_ok = tensorflow::Flags::Parse(&argc, argv, flag_list);
773   tensorflow::port::InitMain(argv[0], &argc, &argv);
774   if (argc != 1 || !parse_ok || need_help) {
775     LOG(QFATAL) << usage;
776   }
777   xla::tools::RealMain(opts);
778   return 0;
779 }
780