1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // A tool for interactively exploring graphviz dumps of HLO graphs.
17 //
18 // Input can be a binary HloSnapshot proto, a binary HloProto proto, or a
19 // textual HLO string.
20 //
21 // Generated visualization is opened in a new default browser window using
22 // /usr/bin/sensible-browser.
23
24 #include <stdio.h>
25 #include <unistd.h>
26
27 #include "absl/algorithm/container.h"
28 #include "absl/strings/match.h"
29 #include "absl/strings/numbers.h"
30 #include "absl/strings/str_cat.h"
31 #include "absl/strings/str_join.h"
32 #include "absl/strings/str_split.h"
33 #include "tensorflow/compiler/xla/client/client_library.h"
34 #include "tensorflow/compiler/xla/client/local_client.h"
35 #include "tensorflow/compiler/xla/service/compiler.h"
36 #include "tensorflow/compiler/xla/service/hlo.pb.h"
37 #include "tensorflow/compiler/xla/service/hlo_runner.h"
38 #include "tensorflow/compiler/xla/service/local_service.h"
39 #include "tensorflow/compiler/xla/service/platform_util.h"
40 #include "tensorflow/compiler/xla/tools/hlo_extractor.h"
41 #include "tensorflow/core/lib/io/path.h"
42 #include "tensorflow/core/platform/init_main.h"
43 #include "tensorflow/core/platform/logging.h"
44 #include "tensorflow/core/platform/subprocess.h"
45 #include "tensorflow/core/protobuf/error_codes.pb.h"
46 #include "tensorflow/core/util/command_line_flags.h"
47 #if defined(PLATFORM_GOOGLE)
48 #include "util/readline/readline.h"
49 #endif
50
51 #if defined(PLATFORM_WINDOWS)
52 #include <io.h>
53 #define isatty _isatty
54 #endif
55
56 namespace xla {
57 namespace tools {
58 namespace {
59
ReadLine(const char * prompt,string * line)60 bool ReadLine(const char* prompt, string* line) {
61 #if defined(PLATFORM_GOOGLE)
62 return util::ReadLine(prompt, line);
63 #else
64 std::cout << prompt;
65 std::getline(std::cin, *line);
66 return std::cin.good();
67 #endif
68 }
69
70 // Command-line opts to this tool. See main() for descriptions of these
71 // fields.
72 struct Options {
73 string hlo_snapshot;
74 string hlo_proto;
75 string hlo_module_proto;
76 string hlo_text;
77 string platform;
78 string browser;
79 };
80
81 const char* const kUsage = R"(
82 This tool lets you load an XLA dump and then interactively explore its graphical
83 representation.
84
85 Most models are too large to visualize in their entirety using graphviz, but
86 it's still useful to be able to look at the nodes "near" a particular node of
87 interest.
88
89 If you pass --platform, this tool will compile the HloModule for the given
90 platform. This means that if you acquired your proto from a binary running at a
91 particular CL, the HLO graph it ran isn't necessarily the same as the one shown
92 here, unless this program was built at the same CL (and our compiler is
93 deterministic :).
94
95 Be patient when starting this program if you give it a large input; it has to
96 compile the whole thing.
97
98 Usage:
99
100 interactive_graphviz -- \
101 --{hlo_snapshot,hlo_proto,hlo_text}=path/to/binary_proto
102 --platform={CUDA,CPU,...}
103 )";
104
105 // Unless an explicit width is specified, we will render a neighborhood of
106 // kDefaultWidth nodes around the requested instruction.
107 constexpr int64 kDefaultWidth = 2;
108
109 // When printing all paths between two nodes, we print out only this many nodes
110 // by default, truncating the graph if there are more nodes than this in the
111 // all-paths set.
112 constexpr int64 kDefaultMaxNumNodesInAllPaths = 100;
113
114 using absl::EqualsIgnoreCase;
115
116 HloRenderOptions hlo_render_options;
117
FindInstruction(const HloModule & module,string node_name)118 HloInstruction* FindInstruction(const HloModule& module, string node_name) {
119 if (absl::StartsWith(node_name, "%")) {
120 node_name.erase(node_name.begin());
121 }
122 for (const auto& computation : module.computations()) {
123 auto instrs = computation->instructions();
124 auto it = absl::c_find_if(instrs, [&](const HloInstruction* instr) {
125 // Try with and without "%" at the beginning of the node name.
126 return EqualsIgnoreCase(instr->name(), node_name) ||
127 EqualsIgnoreCase(instr->name(), absl::StrCat("%", node_name));
128 });
129 if (it != instrs.end()) {
130 return *it;
131 }
132 }
133 return nullptr;
134 }
135
FindComputation(const HloModule & module,const string & comp_name)136 HloComputation* FindComputation(const HloModule& module,
137 const string& comp_name) {
138 for (auto* computation : module.computations()) {
139 if (EqualsIgnoreCase(computation->name(), comp_name)) {
140 return computation;
141 }
142 }
143 return nullptr;
144 }
145
146 // Print a help message describing the various available commands.
DoHelpCommand()147 void DoHelpCommand() {
148 std::cout << R"(Commands:
149 <instruction> [<width>] [/ <boundary_instruction>+]
150 Renders a neighborhood of <width> nodes around <instruction>, without going
151 beyond the optional boundary instructions. If <width> is not provided,
152 the default value is )"
153 << kDefaultWidth << R"(.
154 allpaths <instruction> <instruction> [<n>]
155 Renders a subset of all paths from one instruction to the other. Either
156 order of nodes is accepted. Shows the <n> nodes in the all-paths set on the
157 shortest paths; default is )"
158 << kDefaultMaxNumNodesInAllPaths << R"(.
159 <computation>
160 Renders all nodes in <computation>.
161 backend_config [on|off]
162 Controls whether backend operation configuration information is printed.
163 show_fusion_subcomputations [on|off]
164 Controls whether fusion subcomputations are shown.
165 list [name|op_name|op_type] <pattern>
166 Lists all instructions whose name, metadata op_name, or metadata op_type
167 contains <pattern> as a substring.
168 list computations
169 Lists all computations in the module.
170 info <instruction>
171 info <computation>
172 Prints information about <instruction> or <computation>.
173 extract <instruction> <height>
174 Creates a new HLO module with <instruction> as entry computation root. If
175 <height> is specified, the new computation contains nodes up to <height>
176 nodes above the root.
177 help
178 Prints this usage information.
179 quit
180 Exit the application.)"
181 << std::endl;
182 }
183
184 // Turn metadata-printing on or off.
DoBackendConfigCommand(const std::vector<string> & tokens)185 void DoBackendConfigCommand(const std::vector<string>& tokens) {
186 if (tokens.size() == 2 && tokens[1] == "on") {
187 hlo_render_options.show_backend_config = true;
188 } else if (tokens.size() == 2 && tokens[1] == "off") {
189 hlo_render_options.show_backend_config = false;
190 } else if (tokens.size() != 1) {
191 std::cerr << "(Illegal backend_config value. Use either 'on' or 'off'.)"
192 << std::endl;
193 }
194 std::cout << "Backend configuration display "
195 << (hlo_render_options.show_backend_config ? "ON" : "OFF")
196 << std::endl;
197 }
198
199 // Turn fusion computation display on or off.
DoShowFusionSubcomputationsCommand(const std::vector<string> & tokens)200 void DoShowFusionSubcomputationsCommand(const std::vector<string>& tokens) {
201 if (tokens.size() == 2 && tokens[1] == "on") {
202 hlo_render_options.show_fusion_subcomputations = true;
203 } else if (tokens.size() == 2 && tokens[1] == "off") {
204 hlo_render_options.show_fusion_subcomputations = false;
205 } else if (tokens.size() != 1) {
206 std::cerr << "(Illegal show_fusion_subcomputations value. Use either "
207 "'on' or 'off'.)"
208 << std::endl;
209 }
210 std::cout << "Fusion subcomputations display "
211 << (hlo_render_options.show_fusion_subcomputations ? "ON" : "OFF")
212 << std::endl;
213 }
214
215 // List all computations in the module.
DoListComputationsCommand(const HloModule & module,const std::vector<string> & tokens)216 void DoListComputationsCommand(const HloModule& module,
217 const std::vector<string>& tokens) {
218 if (tokens.size() > 2) {
219 std::cout << R"(Illegal syntax; "list computations" takes no arguments.)";
220 return;
221 }
222 if (module.entry_computation() != nullptr) {
223 std::cout << "Entry computation:" << std::endl;
224 std::cout << " " << module.entry_computation()->name() << std::endl
225 << std::endl;
226 }
227 std::cout << "Subcomputations:" << std::endl;
228 std::vector<string> names;
229 for (const auto& computation : module.computations()) {
230 if (computation == module.entry_computation()) {
231 continue;
232 }
233 std::cout << " " << computation->name() << std::endl;
234 }
235 }
236
237 // List all instructions matching a pattern.
DoListCommand(const HloModule & module,const std::vector<string> & tokens)238 void DoListCommand(const HloModule& module, const std::vector<string>& tokens) {
239 string pattern = "";
240 string type = "name";
241 if (tokens.size() == 2) {
242 pattern = tokens[1];
243 } else if (tokens.size() == 3) {
244 type = tokens[1];
245 pattern = tokens[2];
246 } else {
247 std::cout << "Illegal list query syntax. Use "
248 << R"("list [name|op_name|op_type] pattern".)" << std::endl;
249 return;
250 }
251
252 std::cout << "Query results:" << std::endl;
253 for (const auto& computation : module.computations()) {
254 for (const auto& instr : computation->instructions()) {
255 if ((type == "name" && instr->name().find(pattern) != string::npos) ||
256 (type == "op_name" &&
257 instr->metadata().op_name().find(pattern) != string::npos) ||
258 (type == "op_type" &&
259 instr->metadata().op_type().find(pattern) != string::npos)) {
260 std::cout << " " << instr->name();
261 std::cout << ", op_name '" << instr->metadata().op_name() << "'";
262 std::cout << ", op_type '" << instr->metadata().op_type() << "'";
263 std::cout << std::endl;
264 }
265 }
266 }
267 }
268
269 // Print info about an instruction or computation.
DoInfoCommand(const HloModule & module,const std::vector<string> & tokens)270 void DoInfoCommand(const HloModule& module, const std::vector<string>& tokens) {
271 if (tokens.size() != 2) {
272 std::cerr << "Illegal info query syntax. Use "
273 << R"("info name".)";
274 return;
275 }
276 string node_name = tokens[1];
277
278 const HloInstruction* instr = FindInstruction(module, node_name);
279 const HloComputation* comp = FindComputation(module, node_name);
280 if (!instr && !comp) {
281 std::cerr << "Couldn't find HloInstruction or HloComputation named "
282 << node_name << std::endl;
283 return;
284 }
285
286 if (comp != nullptr) {
287 std::cout << "HloComputation " << comp->name() << std::endl;
288 if (comp->IsFusionComputation()) {
289 std::cout << " Fusion instruction: " << comp->FusionInstruction()->name()
290 << std::endl;
291 }
292 std::cout << " Parameters:" << std::endl;
293 for (const auto& param : comp->parameter_instructions()) {
294 std::cout << " " << param->name() << " ("
295 << ShapeUtil::HumanStringWithLayout(param->shape()) << ")"
296 << std::endl;
297 }
298 HloInstruction* root = comp->root_instruction();
299 std::cout << " Root instruction: " << root->name() << " ("
300 << ShapeUtil::HumanStringWithLayout(root->shape()) << ")"
301 << std::endl;
302
303 auto embedded_computations = comp->MakeEmbeddedComputationsList();
304 std::cout << " " << embedded_computations.size() << " embedded computation"
305 << (embedded_computations.size() != 1 ? "s" : "")
306 << (!embedded_computations.empty() ? ":" : ".") << std::endl;
307 for (const HloComputation* c : embedded_computations) {
308 std::cout << " " << c->name() << std::endl;
309 }
310
311 // Find which computations reference comp as an embedded computation.
312 std::vector<const HloComputation*> users;
313 for (const HloComputation* c : module.computations()) {
314 if (absl::c_linear_search(c->MakeEmbeddedComputationsList(), comp)) {
315 users.push_back(c);
316 }
317 }
318 std::cout << " Used by " << users.size() << " computation"
319 << (users.size() != 1 ? "s" : "") << (!users.empty() ? ":" : ".");
320 for (const HloComputation* c : users) {
321 std::cout << " " << c->name() << std::endl;
322 }
323 } else {
324 std::cout << "HloInstruction " << instr->name() << std::endl;
325 std::cout << " Parent computation: " << instr->parent()->name()
326 << std::endl;
327 std::cout << " Opcode: " << HloOpcodeString(instr->opcode()) << std::endl;
328 std::cout << " Shape: " << ShapeUtil::HumanStringWithLayout(instr->shape())
329 << std::endl;
330 std::cout << " Metadata:" << std::endl;
331 if (!instr->metadata().op_name().empty()) {
332 std::cout << " Name: " << instr->metadata().op_name() << std::endl;
333 }
334 if (!instr->metadata().op_type().empty()) {
335 std::cout << " Type: " << instr->metadata().op_type() << std::endl;
336 }
337 if (!instr->raw_backend_config_string().empty()) {
338 std::cout << " Backend configuration: "
339 << instr->raw_backend_config_string() << std::endl;
340 }
341 if (instr->opcode() == HloOpcode::kFusion) {
342 std::cout << " Fusion kind: " << xla::ToString(instr->fusion_kind())
343 << std::endl;
344 std::cout << " Fusion computation: "
345 << instr->fused_instructions_computation()->name() << std::endl;
346 std::cout << " Fused computation root: "
347 << instr->fused_expression_root()->name() << std::endl;
348 }
349 std::cout << " Operands:" << std::endl;
350 for (HloInstruction* operand : instr->operands()) {
351 std::cout << " " << operand->name() << " ("
352 << ShapeUtil::HumanStringWithLayout(operand->shape()) << ")"
353 << std::endl;
354 }
355 std::cout << " Users:" << std::endl;
356 for (HloInstruction* user : instr->users()) {
357 std::cout << " " << user->name() << std::endl;
358 }
359 if (instr->parent()->root_instruction() == instr) {
360 std::cout << " Root instruction of " << instr->parent()->name()
361 << std::endl;
362 }
363 }
364 }
365
DoExtractCommand(const HloModule & module,absl::Span<const string> tokens)366 void DoExtractCommand(const HloModule& module,
367 absl::Span<const string> tokens) {
368 if (tokens.size() > 3) {
369 std::cerr << R"(Illegal input. Enter e.g. "extract %fusion.1 2")"
370 << std::endl;
371 return;
372 }
373
374 // Find the node with the given name.
375 string node_name = tokens[1];
376 HloInstruction* instr = FindInstruction(module, node_name);
377 if (!instr) {
378 std::cerr << "Couldn't find HloInstruction named " << node_name << "."
379 << std::endl;
380 return;
381 }
382
383 int64 height = -1;
384 if (tokens.size() == 3) {
385 if (!absl::SimpleAtoi(tokens[2], &height)) {
386 std::cerr << "Can't parse '" << tokens[2] << "' as an integer."
387 << std::endl;
388 return;
389 }
390 }
391
392 auto extracted_module = ExtractModule(instr, height);
393 std::cout << extracted_module->ToString(
394 HloPrintOptions::ShortParsable().set_print_backend_config(
395 hlo_render_options.show_backend_config))
396 << std::endl;
397 }
398
399 // Checks if there is a use-def path from `from` to `to`.
ExistsPathFromTo(const HloInstruction * from,const HloInstruction * to)400 bool ExistsPathFromTo(const HloInstruction* from, const HloInstruction* to) {
401 std::unordered_set<const HloInstruction*> visited;
402 std::vector<const HloInstruction*> to_visit = {from};
403 while (!to_visit.empty()) {
404 auto* n = to_visit.back();
405 if (n == to) {
406 return true;
407 }
408 to_visit.pop_back();
409 visited.insert(n);
410 for (auto* user : n->users()) {
411 if (!visited.count(user)) {
412 to_visit.push_back(user);
413 }
414 }
415 }
416 return false;
417 }
418
OpenUrl(const Options & opts,absl::string_view url)419 void OpenUrl(const Options& opts, absl::string_view url) {
420 std::cout << url << std::endl;
421
422 // If it is a url, try to open it up in the user's browser too.
423 if (absl::StartsWithIgnoreCase(url, "http://") ||
424 absl::StartsWithIgnoreCase(url, "https://") ||
425 absl::StartsWithIgnoreCase(url, "file://")) {
426 const char* browser_bin = opts.browser.empty() ? "/usr/bin/sensible-browser"
427 : opts.browser.c_str();
428 tensorflow::SubProcess p;
429 p.SetProgram(browser_bin, {browser_bin, string(url)});
430 p.Start();
431 } else {
432 std::cerr << "\nExpected a URL, but got strange graph result (dumped "
433 "above). If this isn't what you expected, maybe file a bug?"
434 << std::endl;
435 }
436 }
437
438 // Renders a graph by calling `renderer`, and then tries to open it.
439 //
440 // `renderer` is a callback so we can try various formats. In particular, the
441 // URL format doesn't work out of the box; it requires you to register a plugin.
RenderAndDisplayGraph(const Options & opts,const std::function<StatusOr<string> (RenderedGraphFormat)> & renderer)442 void RenderAndDisplayGraph(
443 const Options& opts,
444 const std::function<StatusOr<string>(RenderedGraphFormat)>& renderer) {
445 StatusOr<string> url_result = renderer(RenderedGraphFormat::kUrl);
446 if (url_result.ok()) {
447 string url = url_result.ValueOrDie();
448 OpenUrl(opts, url);
449 return;
450 }
451
452 // Ignore UNAVAILABLE errors; these are expected when there's no URL renderer
453 // plugin registered.
454 if (url_result.status().code() != tensorflow::error::UNAVAILABLE) {
455 std::cerr << "Unable to render graph as URL: " << url_result.status()
456 << std::endl;
457 std::cerr << "Trying as HTML..." << std::endl;
458 }
459
460 auto* env = tensorflow::Env::Default();
461 StatusOr<string> html_result = renderer(RenderedGraphFormat::kHtml);
462 if (!html_result.ok()) {
463 std::cerr << "Failed to render graph as HTML: " << html_result.status()
464 << std::endl;
465 return;
466 }
467
468 std::vector<string> temp_dirs;
469 env->GetLocalTempDirectories(&temp_dirs);
470 if (temp_dirs.empty()) {
471 std::cerr << "Can't render graph as HTML because we can't find a suitable "
472 "temp directory. Try setting $TMPDIR?"
473 << std::endl;
474 return;
475 }
476
477 // Try to create a unique file inside of temp_dirs.front(). Notably, this
478 // file's name must end with ".html", otherwise web browsers will treat it as
479 // plain text, so we can't use Env::CreateUniqueFileName().
480 string temp_file_path = tensorflow::io::JoinPath(
481 temp_dirs.front(),
482 absl::StrFormat("interactive_graphviz.%d.html", env->NowMicros()));
483 auto status = tensorflow::WriteStringToFile(
484 env, temp_file_path, std::move(html_result).ValueOrDie());
485 if (status.ok()) {
486 OpenUrl(opts, absl::StrCat("file://", temp_file_path));
487 return;
488 }
489
490 std::cerr << "Failed to write rendered HTML graph to " << temp_file_path
491 << ": " << status;
492
493 // We don't bother trying kDot, because kHTML should always work (or if it
494 // doesn't, we don't have any reason to believe kDot will work better).
495 }
496
DoAllPathsCommand(const Options & opts,const HloModule & module,const std::vector<string> & tokens)497 void DoAllPathsCommand(const Options& opts, const HloModule& module,
498 const std::vector<string>& tokens) {
499 if (tokens.size() > 4) {
500 std::cerr << R"(Illegal input. Enter e.g. "allpaths %add.4 %subtract.2" or
501 "allpaths add.4 subtract.2 42.)"
502 << std::endl;
503 return;
504 }
505
506 int64 max_nodes = kDefaultMaxNumNodesInAllPaths;
507 if (tokens.size() == 4 && !absl::SimpleAtoi(tokens[3], &max_nodes)) {
508 std::cerr << "Can't parse '" << tokens[3] << "' as an integer."
509 << std::endl;
510 return;
511 }
512
513 const HloInstruction* n1 = FindInstruction(module, tokens[1]);
514 if (!n1) {
515 std::cerr << "Couldn't find HloInstruction named " << tokens[1];
516 return;
517 }
518 const HloInstruction* n2 = FindInstruction(module, tokens[2]);
519 if (!n2) {
520 std::cerr << "Couldn't find HloInstruction named " << tokens[2];
521 return;
522 }
523
524 // Is there a path from n1 to n2, or vice versa?
525 const HloInstruction* from;
526 const HloInstruction* to;
527 if (ExistsPathFromTo(n1, n2)) {
528 from = n1;
529 to = n2;
530 } else if (ExistsPathFromTo(n2, n1)) {
531 from = n2;
532 to = n1;
533 } else {
534 std::cerr << "No path from/to " << tokens[1] << " to/from " << tokens[2];
535 return;
536 }
537 RenderAndDisplayGraph(opts, [&](RenderedGraphFormat format) {
538 return RenderAllPathsFromTo(*from, *to, max_nodes, format,
539 hlo_render_options);
540 });
541 }
542
543 // Plot a given instruction neighborhood or computation with graphviz.
544 void DoPlotCommand(const Options& opts, const HloModule& module,
545 const std::vector<string>& tokens) {
546 string node_name = tokens[0];
547
548 // Find the node with the given name.
549 const HloInstruction* instr = FindInstruction(module, node_name);
550 const HloComputation* comp = FindComputation(module, node_name);
551 if (!instr && !comp) {
552 std::cerr << "Couldn't find HloInstruction or HloComputation named "
553 << node_name << "." << std::endl;
554 return;
555 }
556
557 uint64 graph_width = kDefaultWidth;
558 absl::flat_hash_set<const HloInstruction*> boundary;
559 if (tokens.size() >= 2) {
560 if (comp) {
561 std::cerr << "Can only use graph-size parameter with instructions, but "
562 << node_name << " is a computation." << std::endl;
563 return;
564 }
565
566 int bound_index = 1;
567 // Get the <width> if present.
568 if (absl::SimpleAtoi(tokens[bound_index], &graph_width)) {
569 bound_index++;
570 } else {
571 // <width> not found, need to reset graph_width.
572 graph_width = kDefaultWidth;
573 }
574 // Get the '/'.
575 if (bound_index < tokens.size()) {
576 // This token must be a '/'.
577 if (tokens[bound_index] != "/") {
578 std::cerr << "Expect a /, but get a '" << tokens[bound_index] << "'."
579 << std::endl;
580 return;
581 }
582 bound_index++;
583 }
584 // Get the boundary nodes.
585 while (bound_index < tokens.size()) {
586 string bnode_name = tokens[bound_index];
587 const HloInstruction* binstr = FindInstruction(module, bnode_name);
588 if (!binstr) {
589 std::cerr << "Couldn't find HloInstruction named " << bnode_name << "."
590 << std::endl;
591 return;
592 }
593 boundary.insert(binstr);
594 bound_index++;
595 }
596 }
597
598 // Generate the graph and print the resulting string, which should be a
599 // graphviz url.
600 if (comp) {
601 RenderAndDisplayGraph(opts, [&](RenderedGraphFormat format) {
602 return RenderGraph(*comp, /*label=*/"",
603 comp->parent()->config().debug_options(), format,
604 /*hlo_execution_profile=*/nullptr, hlo_render_options);
605 });
606 } else {
607 RenderAndDisplayGraph(opts, [&](RenderedGraphFormat format) {
608 return RenderNeighborhoodAround(*instr, graph_width, format,
609 hlo_render_options,
610 /*boundary=*/boundary);
611 });
612 }
613 }
614
615 // Run the main event loop, reading user commands and processing them.
616 void InteractiveDumpGraphs(const Options& opts, const HloModule& module) {
617 // This is an interactive tool, but some may use `extract` in non-tty
618 // environment anyway. Give them a clean hlo dump.
619 if (isatty(fileno(stdin))) {
620 std::cout << "\n\nLoaded module " << module.name() << "." << std::endl;
621 DoHelpCommand();
622 }
623 for (string line; ReadLine("\ncommand: ", &line);) {
624 if (line.empty()) {
625 std::cout << R"(Enter e.g. "fusion.1 3" or "add.8".)" << std::endl
626 << R"(Enter "help" for help; ^D, "quit", or "exit" to exit.)"
627 << std::endl;
628 continue;
629 }
630 std::vector<string> tokens = absl::StrSplit(line, ' ', absl::SkipEmpty());
631 if (tokens[0] == "quit" || tokens[0] == "exit") {
632 break;
633 } else if (tokens[0] == "help") {
634 DoHelpCommand();
635 } else if (tokens[0] == "backend_config") {
636 DoBackendConfigCommand(tokens);
637 } else if (tokens[0] == "show_fusion_subcomputations") {
638 DoShowFusionSubcomputationsCommand(tokens);
639 } else if (tokens[0] == "list") {
640 if (tokens.size() > 1 && tokens[1] == "computations") {
641 DoListComputationsCommand(module, tokens);
642 } else {
643 DoListCommand(module, tokens);
644 }
645 } else if (tokens[0] == "info") {
646 DoInfoCommand(module, tokens);
647 } else if (tokens[0] == "extract") {
648 DoExtractCommand(module, tokens);
649 } else if (tokens[0] == "allpaths") {
650 DoAllPathsCommand(opts, module, tokens);
651 } else {
652 DoPlotCommand(opts, module, tokens);
653 }
654 }
655 }
656
657 void CheckFlags(const Options& opts) {
658 int nonempty_flags_amount = 0;
659 if (!opts.hlo_proto.empty()) {
660 ++nonempty_flags_amount;
661 }
662 if (!opts.hlo_snapshot.empty()) {
663 ++nonempty_flags_amount;
664 }
665 if (!opts.hlo_text.empty()) {
666 ++nonempty_flags_amount;
667 }
668 if (!opts.hlo_module_proto.empty()) {
669 ++nonempty_flags_amount;
670 }
671 if (nonempty_flags_amount == 1) {
672 return;
673 }
674 LOG(FATAL) << "Can only specify one and only one of '--hlo_proto', "
675 "'--hlo_snapshot', '--hlo_text', '--hlo_module_proto' flags.";
676 }
677
678 void RealMain(const Options& opts) {
679 if (!isatty(fileno(stdin))) {
680 LOG(ERROR) << "\n\n*****************************************\n"
681 << "This is an interactive tool, but stdin is not a tty.\n"
682 << "*****************************************\n\n";
683 }
684
685 CheckFlags(opts);
686
687 std::unique_ptr<HloModule> module;
688 if (!opts.hlo_snapshot.empty()) {
689 HloSnapshot snapshot;
690 TF_CHECK_OK(tensorflow::ReadBinaryProto(tensorflow::Env::Default(),
691 opts.hlo_snapshot, &snapshot))
692 << "Can't open, read, or parse HloSnapshot proto at "
693 << opts.hlo_snapshot;
694 auto config =
695 HloModule::CreateModuleConfigFromProto(snapshot.hlo().hlo_module(),
696 xla::GetDebugOptionsFromFlags())
697 .ValueOrDie();
698 module = HloModule::CreateFromProto(snapshot.hlo().hlo_module(), config)
699 .ValueOrDie();
700 } else if (!opts.hlo_proto.empty()) {
701 module = HloRunner::ReadModuleFromBinaryProtoFile(
702 opts.hlo_proto, xla::GetDebugOptionsFromFlags())
703 .ValueOrDie();
704 } else if (!opts.hlo_text.empty()) {
705 module = HloRunner::ReadModuleFromHloTextFile(
706 opts.hlo_text, xla::GetDebugOptionsFromFlags())
707 .ValueOrDie();
708 } else if (!opts.hlo_module_proto.empty()) {
709 module = HloRunner::ReadModuleFromModuleBinaryProtofile(
710 opts.hlo_module_proto, xla::GetDebugOptionsFromFlags())
711 .ValueOrDie();
712 }
713
714 // If a platform was specified, compile the module for that platform.
715 if (!opts.platform.empty()) {
716 se::Platform* platform =
717 PlatformUtil::GetPlatform(opts.platform).ValueOrDie();
718 LOG(INFO) << "Compiling module for " << platform->Name();
719
720 se::StreamExecutor* executor =
721 platform->ExecutorForDevice(/*ordinal=*/0).ValueOrDie();
722 auto compiler = Compiler::GetForPlatform(platform).ValueOrDie();
723 module = compiler
724 ->RunHloPasses(std::move(module), executor,
725 /*device_allocator=*/nullptr)
726 .ValueOrDie();
727 auto executable = compiler
728 ->RunBackend(std::move(module), executor,
729 /*device_allocator=*/nullptr)
730 .ValueOrDie();
731 InteractiveDumpGraphs(opts, executable->module());
732 } else {
733 InteractiveDumpGraphs(opts, *module);
734 }
735 }
736
737 } // namespace
738 } // namespace tools
739 } // namespace xla
740
741 int main(int argc, char** argv) {
742 xla::tools::Options opts;
743 opts.browser = "/usr/bin/sensible-browser";
744 bool need_help = false;
745 const std::vector<tensorflow::Flag> flag_list = {
746 tensorflow::Flag("hlo_snapshot", &opts.hlo_snapshot,
747 "HloSnapshot proto to interactively dump to graphviz"),
748 tensorflow::Flag("hlo_proto", &opts.hlo_proto,
749 "XLA hlo proto to interactively dump to graphviz"),
750 tensorflow::Flag("hlo_module_proto", &opts.hlo_module_proto,
751 "XLA hlomodule proto to interactively dump to graphviz"),
752 tensorflow::Flag("hlo_text", &opts.hlo_text,
753 "XLA hlo proto to interactively dump to graphviz"),
754 tensorflow::Flag("platform", &opts.platform,
755 "Platform to compile for: CPU, CUDA, etc"),
756 tensorflow::Flag("browser", &opts.browser,
757 "Path to web browser used to display produced graphs."),
758 tensorflow::Flag("help", &need_help, "Prints this help message"),
759 };
760 xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
761 bool parse_ok = tensorflow::Flags::Parse(&argc, argv, flag_list);
762 tensorflow::port::InitMain(argv[0], &argc, &argv);
763 if (argc != 1 || !parse_ok || need_help) {
764 LOG(QFATAL) << usage;
765 }
766 xla::tools::RealMain(opts);
767 return 0;
768 }
769