• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // A tool for interactively exploring graphviz dumps of HLO graphs.
17 //
18 // Input can be a binary HloSnapshot proto, a binary HloProto proto, or a
19 // textual HLO string.
20 //
21 // Generated visualization is opened in a new default browser window using
22 // /usr/bin/sensible-browser.
23 
24 #include <stdio.h>
25 #include <unistd.h>
26 
27 #include "absl/algorithm/container.h"
28 #include "absl/strings/match.h"
29 #include "absl/strings/numbers.h"
30 #include "absl/strings/str_cat.h"
31 #include "absl/strings/str_join.h"
32 #include "absl/strings/str_split.h"
33 #include "tensorflow/compiler/xla/client/client_library.h"
34 #include "tensorflow/compiler/xla/client/local_client.h"
35 #include "tensorflow/compiler/xla/service/compiler.h"
36 #include "tensorflow/compiler/xla/service/hlo.pb.h"
37 #include "tensorflow/compiler/xla/service/hlo_runner.h"
38 #include "tensorflow/compiler/xla/service/local_service.h"
39 #include "tensorflow/compiler/xla/service/platform_util.h"
40 #include "tensorflow/compiler/xla/tools/hlo_extractor.h"
41 #include "tensorflow/core/lib/io/path.h"
42 #include "tensorflow/core/platform/init_main.h"
43 #include "tensorflow/core/platform/logging.h"
44 #include "tensorflow/core/platform/subprocess.h"
45 #include "tensorflow/core/protobuf/error_codes.pb.h"
46 #include "tensorflow/core/util/command_line_flags.h"
47 #if defined(PLATFORM_GOOGLE)
48 #include "util/readline/readline.h"
49 #endif
50 
51 #if defined(PLATFORM_WINDOWS)
52 #include <io.h>
53 #define isatty _isatty
54 #endif
55 
56 namespace xla {
57 namespace tools {
58 namespace {
59 
ReadLine(const char * prompt,string * line)60 bool ReadLine(const char* prompt, string* line) {
61 #if defined(PLATFORM_GOOGLE)
62   return util::ReadLine(prompt, line);
63 #else
64   std::cout << prompt;
65   std::getline(std::cin, *line);
66   return std::cin.good();
67 #endif
68 }
69 
70 // Command-line opts to this tool.  See main() for descriptions of these
71 // fields.
72 struct Options {
73   string hlo_snapshot;
74   string hlo_proto;
75   string hlo_module_proto;
76   string hlo_text;
77   string platform;
78   string browser;
79 };
80 
81 const char* const kUsage = R"(
82 This tool lets you load an XLA dump and then interactively explore its graphical
83 representation.
84 
85 Most models are too large to visualize in their entirety using graphviz, but
86 it's still useful to be able to look at the nodes "near" a particular node of
87 interest.
88 
89 If you pass --platform, this tool will compile the HloModule for the given
90 platform.  This means that if you acquired your proto from a binary running at a
91 particular CL, the HLO graph it ran isn't necessarily the same as the one shown
92 here, unless this program was built at the same CL (and our compiler is
93 deterministic :).
94 
95 Be patient when starting this program if you give it a large input; it has to
96 compile the whole thing.
97 
98 Usage:
99 
100   interactive_graphviz -- \
101     --{hlo_snapshot,hlo_proto,hlo_text}=path/to/binary_proto
102     --platform={CUDA,CPU,...}
103 )";
104 
105 // Unless an explicit width is specified, we will render a neighborhood of
106 // kDefaultWidth nodes around the requested instruction.
107 constexpr int64 kDefaultWidth = 2;
108 
109 // When printing all paths between two nodes, we print out only this many nodes
110 // by default, truncating the graph if there are more nodes than this in the
111 // all-paths set.
112 constexpr int64 kDefaultMaxNumNodesInAllPaths = 100;
113 
114 using absl::EqualsIgnoreCase;
115 
116 HloRenderOptions hlo_render_options;
117 
FindInstruction(const HloModule & module,string node_name)118 HloInstruction* FindInstruction(const HloModule& module, string node_name) {
119   if (absl::StartsWith(node_name, "%")) {
120     node_name.erase(node_name.begin());
121   }
122   for (const auto& computation : module.computations()) {
123     auto instrs = computation->instructions();
124     auto it = absl::c_find_if(instrs, [&](const HloInstruction* instr) {
125       // Try with and without "%" at the beginning of the node name.
126       return EqualsIgnoreCase(instr->name(), node_name) ||
127              EqualsIgnoreCase(instr->name(), absl::StrCat("%", node_name));
128     });
129     if (it != instrs.end()) {
130       return *it;
131     }
132   }
133   return nullptr;
134 }
135 
FindComputation(const HloModule & module,const string & comp_name)136 HloComputation* FindComputation(const HloModule& module,
137                                 const string& comp_name) {
138   for (auto* computation : module.computations()) {
139     if (EqualsIgnoreCase(computation->name(), comp_name)) {
140       return computation;
141     }
142   }
143   return nullptr;
144 }
145 
146 // Print a help message describing the various available commands.
DoHelpCommand()147 void DoHelpCommand() {
148   std::cout << R"(Commands:
149   <instruction> [<width>] [/ <boundary_instruction>+]
150     Renders a neighborhood of <width> nodes around <instruction>, without going
151     beyond the optional boundary instructions.  If <width> is not provided,
152     the default value is )"
153             << kDefaultWidth << R"(.
154   allpaths <instruction> <instruction> [<n>]
155     Renders a subset of all paths from one instruction to the other.  Either
156     order of nodes is accepted.  Shows the <n> nodes in the all-paths set on the
157     shortest paths; default is )"
158             << kDefaultMaxNumNodesInAllPaths << R"(.
159   <computation>
160     Renders all nodes in <computation>.
161   backend_config [on|off]
162     Controls whether backend operation configuration information is printed.
163   show_fusion_subcomputations [on|off]
164     Controls whether fusion subcomputations are shown.
165   list [name|op_name|op_type] <pattern>
166     Lists all instructions whose name, metadata op_name, or metadata op_type
167     contains <pattern> as a substring.
168   list computations
169     Lists all computations in the module.
170   info <instruction>
171   info <computation>
172     Prints information about <instruction> or <computation>.
173   extract <instruction> <height>
174     Creates a new HLO module with <instruction> as entry computation root. If
175     <height> is specified, the new computation contains nodes up to <height>
176     nodes above the root.
177   help
178     Prints this usage information.
179   quit
180     Exit the application.)"
181             << std::endl;
182 }
183 
184 // Turn metadata-printing on or off.
DoBackendConfigCommand(const std::vector<string> & tokens)185 void DoBackendConfigCommand(const std::vector<string>& tokens) {
186   if (tokens.size() == 2 && tokens[1] == "on") {
187     hlo_render_options.show_backend_config = true;
188   } else if (tokens.size() == 2 && tokens[1] == "off") {
189     hlo_render_options.show_backend_config = false;
190   } else if (tokens.size() != 1) {
191     std::cerr << "(Illegal backend_config value.  Use either 'on' or 'off'.)"
192               << std::endl;
193   }
194   std::cout << "Backend configuration display "
195             << (hlo_render_options.show_backend_config ? "ON" : "OFF")
196             << std::endl;
197 }
198 
199 // Turn fusion computation display on or off.
DoShowFusionSubcomputationsCommand(const std::vector<string> & tokens)200 void DoShowFusionSubcomputationsCommand(const std::vector<string>& tokens) {
201   if (tokens.size() == 2 && tokens[1] == "on") {
202     hlo_render_options.show_fusion_subcomputations = true;
203   } else if (tokens.size() == 2 && tokens[1] == "off") {
204     hlo_render_options.show_fusion_subcomputations = false;
205   } else if (tokens.size() != 1) {
206     std::cerr << "(Illegal show_fusion_subcomputations value.  Use either "
207                  "'on' or 'off'.)"
208               << std::endl;
209   }
210   std::cout << "Fusion subcomputations display "
211             << (hlo_render_options.show_fusion_subcomputations ? "ON" : "OFF")
212             << std::endl;
213 }
214 
215 // List all computations in the module.
DoListComputationsCommand(const HloModule & module,const std::vector<string> & tokens)216 void DoListComputationsCommand(const HloModule& module,
217                                const std::vector<string>& tokens) {
218   if (tokens.size() > 2) {
219     std::cout << R"(Illegal syntax; "list computations" takes no arguments.)";
220     return;
221   }
222   if (module.entry_computation() != nullptr) {
223     std::cout << "Entry computation:" << std::endl;
224     std::cout << "  " << module.entry_computation()->name() << std::endl
225               << std::endl;
226   }
227   std::cout << "Subcomputations:" << std::endl;
228   std::vector<string> names;
229   for (const auto& computation : module.computations()) {
230     if (computation == module.entry_computation()) {
231       continue;
232     }
233     std::cout << "  " << computation->name() << std::endl;
234   }
235 }
236 
237 // List all instructions matching a pattern.
DoListCommand(const HloModule & module,const std::vector<string> & tokens)238 void DoListCommand(const HloModule& module, const std::vector<string>& tokens) {
239   string pattern = "";
240   string type = "name";
241   if (tokens.size() == 2) {
242     pattern = tokens[1];
243   } else if (tokens.size() == 3) {
244     type = tokens[1];
245     pattern = tokens[2];
246   } else {
247     std::cout << "Illegal list query syntax. Use "
248               << R"("list [name|op_name|op_type] pattern".)" << std::endl;
249     return;
250   }
251 
252   std::cout << "Query results:" << std::endl;
253   for (const auto& computation : module.computations()) {
254     for (const auto& instr : computation->instructions()) {
255       if ((type == "name" && instr->name().find(pattern) != string::npos) ||
256           (type == "op_name" &&
257            instr->metadata().op_name().find(pattern) != string::npos) ||
258           (type == "op_type" &&
259            instr->metadata().op_type().find(pattern) != string::npos)) {
260         std::cout << "  " << instr->name();
261         std::cout << ", op_name '" << instr->metadata().op_name() << "'";
262         std::cout << ", op_type '" << instr->metadata().op_type() << "'";
263         std::cout << std::endl;
264       }
265     }
266   }
267 }
268 
269 // Print info about an instruction or computation.
DoInfoCommand(const HloModule & module,const std::vector<string> & tokens)270 void DoInfoCommand(const HloModule& module, const std::vector<string>& tokens) {
271   if (tokens.size() != 2) {
272     std::cerr << "Illegal info query syntax. Use "
273               << R"("info name".)";
274     return;
275   }
276   string node_name = tokens[1];
277 
278   const HloInstruction* instr = FindInstruction(module, node_name);
279   const HloComputation* comp = FindComputation(module, node_name);
280   if (!instr && !comp) {
281     std::cerr << "Couldn't find HloInstruction or HloComputation named "
282               << node_name << std::endl;
283     return;
284   }
285 
286   if (comp != nullptr) {
287     std::cout << "HloComputation " << comp->name() << std::endl;
288     if (comp->IsFusionComputation()) {
289       std::cout << "  Fusion instruction: " << comp->FusionInstruction()->name()
290                 << std::endl;
291     }
292     std::cout << "  Parameters:" << std::endl;
293     for (const auto& param : comp->parameter_instructions()) {
294       std::cout << "    " << param->name() << " ("
295                 << ShapeUtil::HumanStringWithLayout(param->shape()) << ")"
296                 << std::endl;
297     }
298     HloInstruction* root = comp->root_instruction();
299     std::cout << "  Root instruction: " << root->name() << " ("
300               << ShapeUtil::HumanStringWithLayout(root->shape()) << ")"
301               << std::endl;
302 
303     auto embedded_computations = comp->MakeEmbeddedComputationsList();
304     std::cout << "  " << embedded_computations.size() << " embedded computation"
305               << (embedded_computations.size() != 1 ? "s" : "")
306               << (!embedded_computations.empty() ? ":" : ".") << std::endl;
307     for (const HloComputation* c : embedded_computations) {
308       std::cout << "    " << c->name() << std::endl;
309     }
310 
311     // Find which computations reference comp as an embedded computation.
312     std::vector<const HloComputation*> users;
313     for (const HloComputation* c : module.computations()) {
314       if (absl::c_linear_search(c->MakeEmbeddedComputationsList(), comp)) {
315         users.push_back(c);
316       }
317     }
318     std::cout << "  Used by " << users.size() << " computation"
319               << (users.size() != 1 ? "s" : "") << (!users.empty() ? ":" : ".");
320     for (const HloComputation* c : users) {
321       std::cout << "    " << c->name() << std::endl;
322     }
323   } else {
324     std::cout << "HloInstruction " << instr->name() << std::endl;
325     std::cout << "  Parent computation: " << instr->parent()->name()
326               << std::endl;
327     std::cout << "  Opcode: " << HloOpcodeString(instr->opcode()) << std::endl;
328     std::cout << "  Shape: " << ShapeUtil::HumanStringWithLayout(instr->shape())
329               << std::endl;
330     std::cout << "  Metadata:" << std::endl;
331     if (!instr->metadata().op_name().empty()) {
332       std::cout << "    Name: " << instr->metadata().op_name() << std::endl;
333     }
334     if (!instr->metadata().op_type().empty()) {
335       std::cout << "    Type: " << instr->metadata().op_type() << std::endl;
336     }
337     if (!instr->raw_backend_config_string().empty()) {
338       std::cout << "  Backend configuration: "
339                 << instr->raw_backend_config_string() << std::endl;
340     }
341     if (instr->opcode() == HloOpcode::kFusion) {
342       std::cout << "  Fusion kind: " << xla::ToString(instr->fusion_kind())
343                 << std::endl;
344       std::cout << "  Fusion computation: "
345                 << instr->fused_instructions_computation()->name() << std::endl;
346       std::cout << "  Fused computation root: "
347                 << instr->fused_expression_root()->name() << std::endl;
348     }
349     std::cout << "  Operands:" << std::endl;
350     for (HloInstruction* operand : instr->operands()) {
351       std::cout << "    " << operand->name() << " ("
352                 << ShapeUtil::HumanStringWithLayout(operand->shape()) << ")"
353                 << std::endl;
354     }
355     std::cout << "  Users:" << std::endl;
356     for (HloInstruction* user : instr->users()) {
357       std::cout << "    " << user->name() << std::endl;
358     }
359     if (instr->parent()->root_instruction() == instr) {
360       std::cout << "  Root instruction of " << instr->parent()->name()
361                 << std::endl;
362     }
363   }
364 }
365 
DoExtractCommand(const HloModule & module,absl::Span<const string> tokens)366 void DoExtractCommand(const HloModule& module,
367                       absl::Span<const string> tokens) {
368   if (tokens.size() > 3) {
369     std::cerr << R"(Illegal input.  Enter e.g. "extract %fusion.1 2")"
370               << std::endl;
371     return;
372   }
373 
374   // Find the node with the given name.
375   string node_name = tokens[1];
376   HloInstruction* instr = FindInstruction(module, node_name);
377   if (!instr) {
378     std::cerr << "Couldn't find HloInstruction named " << node_name << "."
379               << std::endl;
380     return;
381   }
382 
383   int64 height = -1;
384   if (tokens.size() == 3) {
385     if (!absl::SimpleAtoi(tokens[2], &height)) {
386       std::cerr << "Can't parse '" << tokens[2] << "' as an integer."
387                 << std::endl;
388       return;
389     }
390   }
391 
392   auto extracted_module = ExtractModule(instr, height);
393   std::cout << extracted_module->ToString(
394                    HloPrintOptions::ShortParsable().set_print_backend_config(
395                        hlo_render_options.show_backend_config))
396             << std::endl;
397 }
398 
399 // Checks if there is a use-def path from `from` to `to`.
ExistsPathFromTo(const HloInstruction * from,const HloInstruction * to)400 bool ExistsPathFromTo(const HloInstruction* from, const HloInstruction* to) {
401   std::unordered_set<const HloInstruction*> visited;
402   std::vector<const HloInstruction*> to_visit = {from};
403   while (!to_visit.empty()) {
404     auto* n = to_visit.back();
405     if (n == to) {
406       return true;
407     }
408     to_visit.pop_back();
409     visited.insert(n);
410     for (auto* user : n->users()) {
411       if (!visited.count(user)) {
412         to_visit.push_back(user);
413       }
414     }
415   }
416   return false;
417 }
418 
OpenUrl(const Options & opts,absl::string_view url)419 void OpenUrl(const Options& opts, absl::string_view url) {
420   std::cout << url << std::endl;
421 
422   // If it is a url, try to open it up in the user's browser too.
423   if (absl::StartsWithIgnoreCase(url, "http://") ||
424       absl::StartsWithIgnoreCase(url, "https://") ||
425       absl::StartsWithIgnoreCase(url, "file://")) {
426     const char* browser_bin = opts.browser.empty() ? "/usr/bin/sensible-browser"
427                                                    : opts.browser.c_str();
428     tensorflow::SubProcess p;
429     p.SetProgram(browser_bin, {browser_bin, string(url)});
430     p.Start();
431   } else {
432     std::cerr << "\nExpected a URL, but got strange graph result (dumped "
433                  "above).  If this isn't what you expected, maybe file a bug?"
434               << std::endl;
435   }
436 }
437 
438 // Renders a graph by calling `renderer`, and then tries to open it.
439 //
440 // `renderer` is a callback so we can try various formats.  In particular, the
441 // URL format doesn't work out of the box; it requires you to register a plugin.
RenderAndDisplayGraph(const Options & opts,const std::function<StatusOr<string> (RenderedGraphFormat)> & renderer)442 void RenderAndDisplayGraph(
443     const Options& opts,
444     const std::function<StatusOr<string>(RenderedGraphFormat)>& renderer) {
445   StatusOr<string> url_result = renderer(RenderedGraphFormat::kUrl);
446   if (url_result.ok()) {
447     string url = url_result.ValueOrDie();
448     OpenUrl(opts, url);
449     return;
450   }
451 
452   // Ignore UNAVAILABLE errors; these are expected when there's no URL renderer
453   // plugin registered.
454   if (url_result.status().code() != tensorflow::error::UNAVAILABLE) {
455     std::cerr << "Unable to render graph as URL: " << url_result.status()
456               << std::endl;
457     std::cerr << "Trying as HTML..." << std::endl;
458   }
459 
460   auto* env = tensorflow::Env::Default();
461   StatusOr<string> html_result = renderer(RenderedGraphFormat::kHtml);
462   if (!html_result.ok()) {
463     std::cerr << "Failed to render graph as HTML: " << html_result.status()
464               << std::endl;
465     return;
466   }
467 
468   std::vector<string> temp_dirs;
469   env->GetLocalTempDirectories(&temp_dirs);
470   if (temp_dirs.empty()) {
471     std::cerr << "Can't render graph as HTML because we can't find a suitable "
472                  "temp directory.  Try setting $TMPDIR?"
473               << std::endl;
474     return;
475   }
476 
477   // Try to create a unique file inside of temp_dirs.front().  Notably, this
478   // file's name must end with ".html", otherwise web browsers will treat it as
479   // plain text, so we can't use Env::CreateUniqueFileName().
480   string temp_file_path = tensorflow::io::JoinPath(
481       temp_dirs.front(),
482       absl::StrFormat("interactive_graphviz.%d.html", env->NowMicros()));
483   auto status = tensorflow::WriteStringToFile(
484       env, temp_file_path, std::move(html_result).ValueOrDie());
485   if (status.ok()) {
486     OpenUrl(opts, absl::StrCat("file://", temp_file_path));
487     return;
488   }
489 
490   std::cerr << "Failed to write rendered HTML graph to " << temp_file_path
491             << ": " << status;
492 
493   // We don't bother trying kDot, because kHTML should always work (or if it
494   // doesn't, we don't have any reason to believe kDot will work better).
495 }
496 
DoAllPathsCommand(const Options & opts,const HloModule & module,const std::vector<string> & tokens)497 void DoAllPathsCommand(const Options& opts, const HloModule& module,
498                        const std::vector<string>& tokens) {
499   if (tokens.size() > 4) {
500     std::cerr << R"(Illegal input.  Enter e.g. "allpaths %add.4 %subtract.2" or
501 "allpaths add.4 subtract.2 42.)"
502               << std::endl;
503     return;
504   }
505 
506   int64 max_nodes = kDefaultMaxNumNodesInAllPaths;
507   if (tokens.size() == 4 && !absl::SimpleAtoi(tokens[3], &max_nodes)) {
508     std::cerr << "Can't parse '" << tokens[3] << "' as an integer."
509               << std::endl;
510     return;
511   }
512 
513   const HloInstruction* n1 = FindInstruction(module, tokens[1]);
514   if (!n1) {
515     std::cerr << "Couldn't find HloInstruction named " << tokens[1];
516     return;
517   }
518   const HloInstruction* n2 = FindInstruction(module, tokens[2]);
519   if (!n2) {
520     std::cerr << "Couldn't find HloInstruction named " << tokens[2];
521     return;
522   }
523 
524   // Is there a path from n1 to n2, or vice versa?
525   const HloInstruction* from;
526   const HloInstruction* to;
527   if (ExistsPathFromTo(n1, n2)) {
528     from = n1;
529     to = n2;
530   } else if (ExistsPathFromTo(n2, n1)) {
531     from = n2;
532     to = n1;
533   } else {
534     std::cerr << "No path from/to " << tokens[1] << " to/from " << tokens[2];
535     return;
536   }
537   RenderAndDisplayGraph(opts, [&](RenderedGraphFormat format) {
538     return RenderAllPathsFromTo(*from, *to, max_nodes, format,
539                                 hlo_render_options);
540   });
541 }
542 
543 // Plot a given instruction neighborhood or computation with graphviz.
544 void DoPlotCommand(const Options& opts, const HloModule& module,
545                    const std::vector<string>& tokens) {
546   string node_name = tokens[0];
547 
548   // Find the node with the given name.
549   const HloInstruction* instr = FindInstruction(module, node_name);
550   const HloComputation* comp = FindComputation(module, node_name);
551   if (!instr && !comp) {
552     std::cerr << "Couldn't find HloInstruction or HloComputation named "
553               << node_name << "." << std::endl;
554     return;
555   }
556 
557   uint64 graph_width = kDefaultWidth;
558   absl::flat_hash_set<const HloInstruction*> boundary;
559   if (tokens.size() >= 2) {
560     if (comp) {
561       std::cerr << "Can only use graph-size parameter with instructions, but "
562                 << node_name << " is a computation." << std::endl;
563       return;
564     }
565 
566     int bound_index = 1;
567     // Get the <width> if present.
568     if (absl::SimpleAtoi(tokens[bound_index], &graph_width)) {
569       bound_index++;
570     } else {
571       // <width> not found, need to reset graph_width.
572       graph_width = kDefaultWidth;
573     }
574     // Get the '/'.
575     if (bound_index < tokens.size()) {
576       // This token must be a '/'.
577       if (tokens[bound_index] != "/") {
578         std::cerr << "Expect a /, but get a '" << tokens[bound_index] << "'."
579                   << std::endl;
580         return;
581       }
582       bound_index++;
583     }
584     // Get the boundary nodes.
585     while (bound_index < tokens.size()) {
586       string bnode_name = tokens[bound_index];
587       const HloInstruction* binstr = FindInstruction(module, bnode_name);
588       if (!binstr) {
589         std::cerr << "Couldn't find HloInstruction named " << bnode_name << "."
590                   << std::endl;
591         return;
592       }
593       boundary.insert(binstr);
594       bound_index++;
595     }
596   }
597 
598   // Generate the graph and print the resulting string, which should be a
599   // graphviz url.
600   if (comp) {
601     RenderAndDisplayGraph(opts, [&](RenderedGraphFormat format) {
602       return RenderGraph(*comp, /*label=*/"",
603                          comp->parent()->config().debug_options(), format,
604                          /*hlo_execution_profile=*/nullptr, hlo_render_options);
605     });
606   } else {
607     RenderAndDisplayGraph(opts, [&](RenderedGraphFormat format) {
608       return RenderNeighborhoodAround(*instr, graph_width, format,
609                                       hlo_render_options,
610                                       /*boundary=*/boundary);
611     });
612   }
613 }
614 
615 // Run the main event loop, reading user commands and processing them.
616 void InteractiveDumpGraphs(const Options& opts, const HloModule& module) {
617   // This is an interactive tool, but some may use `extract` in non-tty
618   // environment anyway. Give them a clean hlo dump.
619   if (isatty(fileno(stdin))) {
620     std::cout << "\n\nLoaded module " << module.name() << "." << std::endl;
621     DoHelpCommand();
622   }
623   for (string line; ReadLine("\ncommand: ", &line);) {
624     if (line.empty()) {
625       std::cout << R"(Enter e.g. "fusion.1 3" or "add.8".)" << std::endl
626                 << R"(Enter "help" for help; ^D, "quit", or "exit" to exit.)"
627                 << std::endl;
628       continue;
629     }
630     std::vector<string> tokens = absl::StrSplit(line, ' ', absl::SkipEmpty());
631     if (tokens[0] == "quit" || tokens[0] == "exit") {
632       break;
633     } else if (tokens[0] == "help") {
634       DoHelpCommand();
635     } else if (tokens[0] == "backend_config") {
636       DoBackendConfigCommand(tokens);
637     } else if (tokens[0] == "show_fusion_subcomputations") {
638       DoShowFusionSubcomputationsCommand(tokens);
639     } else if (tokens[0] == "list") {
640       if (tokens.size() > 1 && tokens[1] == "computations") {
641         DoListComputationsCommand(module, tokens);
642       } else {
643         DoListCommand(module, tokens);
644       }
645     } else if (tokens[0] == "info") {
646       DoInfoCommand(module, tokens);
647     } else if (tokens[0] == "extract") {
648       DoExtractCommand(module, tokens);
649     } else if (tokens[0] == "allpaths") {
650       DoAllPathsCommand(opts, module, tokens);
651     } else {
652       DoPlotCommand(opts, module, tokens);
653     }
654   }
655 }
656 
657 void CheckFlags(const Options& opts) {
658   int nonempty_flags_amount = 0;
659   if (!opts.hlo_proto.empty()) {
660     ++nonempty_flags_amount;
661   }
662   if (!opts.hlo_snapshot.empty()) {
663     ++nonempty_flags_amount;
664   }
665   if (!opts.hlo_text.empty()) {
666     ++nonempty_flags_amount;
667   }
668   if (!opts.hlo_module_proto.empty()) {
669     ++nonempty_flags_amount;
670   }
671   if (nonempty_flags_amount == 1) {
672     return;
673   }
674   LOG(FATAL) << "Can only specify one and only one of '--hlo_proto', "
675                 "'--hlo_snapshot', '--hlo_text', '--hlo_module_proto' flags.";
676 }
677 
678 void RealMain(const Options& opts) {
679   if (!isatty(fileno(stdin))) {
680     LOG(ERROR) << "\n\n*****************************************\n"
681                << "This is an interactive tool, but stdin is not a tty.\n"
682                << "*****************************************\n\n";
683   }
684 
685   CheckFlags(opts);
686 
687   std::unique_ptr<HloModule> module;
688   if (!opts.hlo_snapshot.empty()) {
689     HloSnapshot snapshot;
690     TF_CHECK_OK(tensorflow::ReadBinaryProto(tensorflow::Env::Default(),
691                                             opts.hlo_snapshot, &snapshot))
692         << "Can't open, read, or parse HloSnapshot proto at "
693         << opts.hlo_snapshot;
694     auto config =
695         HloModule::CreateModuleConfigFromProto(snapshot.hlo().hlo_module(),
696                                                xla::GetDebugOptionsFromFlags())
697             .ValueOrDie();
698     module = HloModule::CreateFromProto(snapshot.hlo().hlo_module(), config)
699                  .ValueOrDie();
700   } else if (!opts.hlo_proto.empty()) {
701     module = HloRunner::ReadModuleFromBinaryProtoFile(
702                  opts.hlo_proto, xla::GetDebugOptionsFromFlags())
703                  .ValueOrDie();
704   } else if (!opts.hlo_text.empty()) {
705     module = HloRunner::ReadModuleFromHloTextFile(
706                  opts.hlo_text, xla::GetDebugOptionsFromFlags())
707                  .ValueOrDie();
708   } else if (!opts.hlo_module_proto.empty()) {
709     module = HloRunner::ReadModuleFromModuleBinaryProtofile(
710                  opts.hlo_module_proto, xla::GetDebugOptionsFromFlags())
711                  .ValueOrDie();
712   }
713 
714   // If a platform was specified, compile the module for that platform.
715   if (!opts.platform.empty()) {
716     se::Platform* platform =
717         PlatformUtil::GetPlatform(opts.platform).ValueOrDie();
718     LOG(INFO) << "Compiling module for " << platform->Name();
719 
720     se::StreamExecutor* executor =
721         platform->ExecutorForDevice(/*ordinal=*/0).ValueOrDie();
722     auto compiler = Compiler::GetForPlatform(platform).ValueOrDie();
723     module = compiler
724                  ->RunHloPasses(std::move(module), executor,
725                                 /*device_allocator=*/nullptr)
726                  .ValueOrDie();
727     auto executable = compiler
728                           ->RunBackend(std::move(module), executor,
729                                        /*device_allocator=*/nullptr)
730                           .ValueOrDie();
731     InteractiveDumpGraphs(opts, executable->module());
732   } else {
733     InteractiveDumpGraphs(opts, *module);
734   }
735 }
736 
737 }  // namespace
738 }  // namespace tools
739 }  // namespace xla
740 
741 int main(int argc, char** argv) {
742   xla::tools::Options opts;
743   opts.browser = "/usr/bin/sensible-browser";
744   bool need_help = false;
745   const std::vector<tensorflow::Flag> flag_list = {
746       tensorflow::Flag("hlo_snapshot", &opts.hlo_snapshot,
747                        "HloSnapshot proto to interactively dump to graphviz"),
748       tensorflow::Flag("hlo_proto", &opts.hlo_proto,
749                        "XLA hlo proto to interactively dump to graphviz"),
750       tensorflow::Flag("hlo_module_proto", &opts.hlo_module_proto,
751                        "XLA hlomodule proto to interactively dump to graphviz"),
752       tensorflow::Flag("hlo_text", &opts.hlo_text,
753                        "XLA hlo proto to interactively dump to graphviz"),
754       tensorflow::Flag("platform", &opts.platform,
755                        "Platform to compile for: CPU, CUDA, etc"),
756       tensorflow::Flag("browser", &opts.browser,
757                        "Path to web browser used to display produced graphs."),
758       tensorflow::Flag("help", &need_help, "Prints this help message"),
759   };
760   xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
761   bool parse_ok = tensorflow::Flags::Parse(&argc, argv, flag_list);
762   tensorflow::port::InitMain(argv[0], &argc, &argv);
763   if (argc != 1 || !parse_ok || need_help) {
764     LOG(QFATAL) << usage;
765   }
766   xla::tools::RealMain(opts);
767   return 0;
768 }
769