1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // A tool for interactively exploring graphviz dumps of HLO graphs.
17 //
18 // Input can be a binary HloSnapshot proto, a binary HloProto proto, or a
19 // textual HLO string.
20 //
21 // Generated visualization is opened in a new default browser window using
22 // /usr/bin/sensible-browser.
23
24 #include <stdio.h>
25 #include <unistd.h>
26
27 #include <functional>
28 #include <string>
29 #include <utility>
30
31 #include "absl/algorithm/container.h"
32 #include "absl/container/flat_hash_set.h"
33 #include "absl/strings/match.h"
34 #include "absl/strings/numbers.h"
35 #include "absl/strings/str_cat.h"
36 #include "absl/strings/str_join.h"
37 #include "absl/strings/str_split.h"
38 #include "tensorflow/compiler/xla/client/client_library.h"
39 #include "tensorflow/compiler/xla/client/local_client.h"
40 #include "tensorflow/compiler/xla/service/compiler.h"
41 #include "tensorflow/compiler/xla/service/hlo.pb.h"
42 #include "tensorflow/compiler/xla/service/hlo_runner.h"
43 #include "tensorflow/compiler/xla/service/local_service.h"
44 #include "tensorflow/compiler/xla/service/platform_util.h"
45 #include "tensorflow/compiler/xla/tools/hlo_extractor.h"
46 #include "tensorflow/core/lib/io/path.h"
47 #include "tensorflow/core/platform/init_main.h"
48 #include "tensorflow/core/platform/logging.h"
49 #include "tensorflow/core/platform/subprocess.h"
50 #include "tensorflow/core/protobuf/error_codes.pb.h"
51 #include "tensorflow/core/util/command_line_flags.h"
52 #if defined(PLATFORM_GOOGLE)
53 #include "util/readline/readline.h"
54 #endif
55
56 #if defined(PLATFORM_WINDOWS)
57 #include <io.h>
58 #define isatty _isatty
59 #endif
60
61 namespace xla {
62 namespace tools {
63 namespace {
64
ReadLine(const char * prompt,std::string * line)65 bool ReadLine(const char* prompt, std::string* line) {
66 #if defined(PLATFORM_GOOGLE)
67 return util::ReadLine(prompt, line);
68 #else
69 std::cout << prompt;
70 std::getline(std::cin, *line);
71 return std::cin.good();
72 #endif
73 }
74
75 // Command-line opts to this tool. See main() for descriptions of these
76 // fields.
77 struct Options {
78 std::string hlo_snapshot;
79 std::string hlo_proto;
80 std::string hlo_module_proto;
81 std::string hlo_text;
82 std::string platform;
83 std::string browser;
84 };
85
86 const char* const kUsage = R"(
87 This tool lets you load an XLA dump and then interactively explore its graphical
88 representation.
89
90 Most models are too large to visualize in their entirety using graphviz, but
91 it's still useful to be able to look at the nodes "near" a particular node of
92 interest.
93
94 If you pass --platform, this tool will compile the HloModule for the given
95 platform. This means that if you acquired your proto from a binary running at a
96 particular CL, the HLO graph it ran isn't necessarily the same as the one shown
97 here, unless this program was built at the same CL (and our compiler is
98 deterministic :).
99
100 Be patient when starting this program if you give it a large input; it has to
101 compile the whole thing.
102
103 Usage:
104
105 interactive_graphviz -- \
106 --{hlo_snapshot,hlo_proto,hlo_text}=path/to/binary_proto
107 --platform={CUDA,CPU,...}
108 )";
109
110 // Unless an explicit width is specified, we will render a neighborhood of
111 // kDefaultWidth nodes around the requested instruction.
112 constexpr int64_t kDefaultWidth = 2;
113
114 // When printing all paths between two nodes, we print out only this many nodes
115 // by default, truncating the graph if there are more nodes than this in the
116 // all-paths set.
117 constexpr int64_t kDefaultMaxNumNodesInAllPaths = 100;
118
119 using absl::EqualsIgnoreCase;
120
121 HloRenderOptions hlo_render_options;
122
FindInstruction(const HloModule & module,std::string node_name)123 HloInstruction* FindInstruction(const HloModule& module,
124 std::string node_name) {
125 if (absl::StartsWith(node_name, "%")) {
126 node_name.erase(node_name.begin());
127 }
128 for (const auto& computation : module.computations()) {
129 auto instrs = computation->instructions();
130 auto it = absl::c_find_if(instrs, [&](const HloInstruction* instr) {
131 // Try with and without "%" at the beginning of the node name.
132 return EqualsIgnoreCase(instr->name(), node_name) ||
133 EqualsIgnoreCase(instr->name(), absl::StrCat("%", node_name));
134 });
135 if (it != instrs.end()) {
136 return *it;
137 }
138 }
139 return nullptr;
140 }
141
FindComputation(const HloModule & module,const std::string & comp_name)142 HloComputation* FindComputation(const HloModule& module,
143 const std::string& comp_name) {
144 for (auto* computation : module.computations()) {
145 if (EqualsIgnoreCase(computation->name(), comp_name)) {
146 return computation;
147 }
148 }
149 return nullptr;
150 }
151
152 // Print a help message describing the various available commands.
DoHelpCommand()153 void DoHelpCommand() {
154 std::cout << R"(Commands:
155 <instruction> [<width>] [/ <boundary_instruction>+]
156 Renders a neighborhood of <width> nodes around <instruction>, without going
157 beyond the optional boundary instructions. If <width> is not provided,
158 the default value is )"
159 << kDefaultWidth << R"(.
160 allpaths <instruction> <instruction> [<n>]
161 Renders a subset of all paths from one instruction to the other. Either
162 order of nodes is accepted. Shows the <n> nodes in the all-paths set on the
163 shortest paths; default is )"
164 << kDefaultMaxNumNodesInAllPaths << R"(.
165 <computation>
166 Renders all nodes in <computation>.
167 backend_config [on|off]
168 Controls whether backend operation configuration information is printed.
169 show_fusion_subcomputations [on|off]
170 Controls whether fusion subcomputations are shown.
171 list [name|op_name|op_type] <pattern>
172 Lists all instructions whose name, metadata op_name, or metadata op_type
173 contains <pattern> as a substring.
174 list computations
175 Lists all computations in the module.
176 info <instruction>
177 info <computation>
178 Prints information about <instruction> or <computation>.
179 extract <instruction> <height>
180 Creates a new HLO module with <instruction> as entry computation root. If
181 <height> is specified, the new computation contains nodes up to <height>
182 nodes above the root.
183 help
184 Prints this usage information.
185 quit
186 Exit the application.)"
187 << std::endl;
188 }
189
190 // Turn metadata-printing on or off.
DoBackendConfigCommand(const std::vector<std::string> & tokens)191 void DoBackendConfigCommand(const std::vector<std::string>& tokens) {
192 if (tokens.size() == 2 && tokens[1] == "on") {
193 hlo_render_options.show_backend_config = true;
194 } else if (tokens.size() == 2 && tokens[1] == "off") {
195 hlo_render_options.show_backend_config = false;
196 } else if (tokens.size() != 1) {
197 std::cerr << "(Illegal backend_config value. Use either 'on' or 'off'.)"
198 << std::endl;
199 }
200 std::cout << "Backend configuration display "
201 << (hlo_render_options.show_backend_config ? "ON" : "OFF")
202 << std::endl;
203 }
204
205 // Turn fusion computation display on or off.
DoShowFusionSubcomputationsCommand(const std::vector<std::string> & tokens)206 void DoShowFusionSubcomputationsCommand(
207 const std::vector<std::string>& tokens) {
208 if (tokens.size() == 2 && tokens[1] == "on") {
209 hlo_render_options.show_fusion_subcomputations = true;
210 } else if (tokens.size() == 2 && tokens[1] == "off") {
211 hlo_render_options.show_fusion_subcomputations = false;
212 } else if (tokens.size() != 1) {
213 std::cerr << "(Illegal show_fusion_subcomputations value. Use either "
214 "'on' or 'off'.)"
215 << std::endl;
216 }
217 std::cout << "Fusion subcomputations display "
218 << (hlo_render_options.show_fusion_subcomputations ? "ON" : "OFF")
219 << std::endl;
220 }
221
222 // List all computations in the module.
DoListComputationsCommand(const HloModule & module,const std::vector<std::string> & tokens)223 void DoListComputationsCommand(const HloModule& module,
224 const std::vector<std::string>& tokens) {
225 if (tokens.size() > 2) {
226 std::cout << R"(Illegal syntax; "list computations" takes no arguments.)";
227 return;
228 }
229 if (module.entry_computation() != nullptr) {
230 std::cout << "Entry computation:" << std::endl;
231 std::cout << " " << module.entry_computation()->name() << std::endl
232 << std::endl;
233 }
234 std::cout << "Subcomputations:" << std::endl;
235 std::vector<std::string> names;
236 for (const auto& computation : module.computations()) {
237 if (computation == module.entry_computation()) {
238 continue;
239 }
240 std::cout << " " << computation->name() << std::endl;
241 }
242 }
243
244 // List all instructions matching a pattern.
DoListCommand(const HloModule & module,const std::vector<std::string> & tokens)245 void DoListCommand(const HloModule& module,
246 const std::vector<std::string>& tokens) {
247 std::string pattern = "";
248 std::string type = "name";
249 if (tokens.size() == 2) {
250 pattern = tokens[1];
251 } else if (tokens.size() == 3) {
252 type = tokens[1];
253 pattern = tokens[2];
254 } else {
255 std::cout << "Illegal list query syntax. Use "
256 << R"("list [name|op_name|op_type] pattern".)" << std::endl;
257 return;
258 }
259
260 std::cout << "Query results:" << std::endl;
261 for (const auto& computation : module.computations()) {
262 for (const auto& instr : computation->instructions()) {
263 if ((type == "name" &&
264 instr->name().find(pattern) != std::string::npos) ||
265 (type == "op_name" &&
266 instr->metadata().op_name().find(pattern) != std::string::npos) ||
267 (type == "op_type" &&
268 instr->metadata().op_type().find(pattern) != std::string::npos)) {
269 std::cout << " " << instr->name();
270 std::cout << ", op_name '" << instr->metadata().op_name() << "'";
271 std::cout << ", op_type '" << instr->metadata().op_type() << "'";
272 std::cout << std::endl;
273 }
274 }
275 }
276 }
277
278 // Print info about an instruction or computation.
DoInfoCommand(const HloModule & module,const std::vector<std::string> & tokens)279 void DoInfoCommand(const HloModule& module,
280 const std::vector<std::string>& tokens) {
281 if (tokens.size() != 2) {
282 std::cerr << "Illegal info query syntax. Use "
283 << R"("info name".)";
284 return;
285 }
286 std::string node_name = tokens[1];
287
288 const HloInstruction* instr = FindInstruction(module, node_name);
289 const HloComputation* comp = FindComputation(module, node_name);
290 if (!instr && !comp) {
291 std::cerr << "Couldn't find HloInstruction or HloComputation named "
292 << node_name << std::endl;
293 return;
294 }
295
296 if (comp != nullptr) {
297 std::cout << "HloComputation " << comp->name() << std::endl;
298 if (comp->IsFusionComputation()) {
299 std::cout << " Fusion instruction: " << comp->FusionInstruction()->name()
300 << std::endl;
301 }
302 std::cout << " Parameters:" << std::endl;
303 for (const auto& param : comp->parameter_instructions()) {
304 std::cout << " " << param->name() << " ("
305 << ShapeUtil::HumanStringWithLayout(param->shape()) << ")"
306 << std::endl;
307 }
308 HloInstruction* root = comp->root_instruction();
309 std::cout << " Root instruction: " << root->name() << " ("
310 << ShapeUtil::HumanStringWithLayout(root->shape()) << ")"
311 << std::endl;
312
313 auto embedded_computations = comp->MakeEmbeddedComputationsList();
314 std::cout << " " << embedded_computations.size() << " embedded computation"
315 << (embedded_computations.size() != 1 ? "s" : "")
316 << (!embedded_computations.empty() ? ":" : ".") << std::endl;
317 for (const HloComputation* c : embedded_computations) {
318 std::cout << " " << c->name() << std::endl;
319 }
320
321 // Find which computations reference comp as an embedded computation.
322 std::vector<const HloComputation*> users;
323 for (const HloComputation* c : module.computations()) {
324 if (absl::c_linear_search(c->MakeEmbeddedComputationsList(), comp)) {
325 users.push_back(c);
326 }
327 }
328 std::cout << " Used by " << users.size() << " computation"
329 << (users.size() != 1 ? "s" : "") << (!users.empty() ? ":" : ".");
330 for (const HloComputation* c : users) {
331 std::cout << " " << c->name() << std::endl;
332 }
333 } else {
334 std::cout << "HloInstruction " << instr->name() << std::endl;
335 std::cout << " Parent computation: " << instr->parent()->name()
336 << std::endl;
337 std::cout << " Opcode: " << HloOpcodeString(instr->opcode()) << std::endl;
338 std::cout << " Shape: " << ShapeUtil::HumanStringWithLayout(instr->shape())
339 << std::endl;
340 std::cout << " Metadata:" << std::endl;
341 if (!instr->metadata().op_name().empty()) {
342 std::cout << " Name: " << instr->metadata().op_name() << std::endl;
343 }
344 if (!instr->metadata().op_type().empty()) {
345 std::cout << " Type: " << instr->metadata().op_type() << std::endl;
346 }
347 if (!instr->raw_backend_config_string().empty()) {
348 std::cout << " Backend configuration: "
349 << instr->raw_backend_config_string() << std::endl;
350 }
351 if (instr->opcode() == HloOpcode::kFusion) {
352 std::cout << " Fusion kind: " << xla::ToString(instr->fusion_kind())
353 << std::endl;
354 std::cout << " Fusion computation: "
355 << instr->fused_instructions_computation()->name() << std::endl;
356 std::cout << " Fused computation root: "
357 << instr->fused_expression_root()->name() << std::endl;
358 }
359 std::cout << " Operands:" << std::endl;
360 for (HloInstruction* operand : instr->operands()) {
361 std::cout << " " << operand->name() << " ("
362 << ShapeUtil::HumanStringWithLayout(operand->shape()) << ")"
363 << std::endl;
364 }
365 std::cout << " Users:" << std::endl;
366 for (HloInstruction* user : instr->users()) {
367 std::cout << " " << user->name() << std::endl;
368 }
369 if (instr->parent()->root_instruction() == instr) {
370 std::cout << " Root instruction of " << instr->parent()->name()
371 << std::endl;
372 }
373 }
374 }
375
DoExtractCommand(const HloModule & module,absl::Span<const std::string> tokens)376 void DoExtractCommand(const HloModule& module,
377 absl::Span<const std::string> tokens) {
378 if (tokens.size() > 3) {
379 std::cerr << R"(Illegal input. Enter e.g. "extract %fusion.1 2")"
380 << std::endl;
381 return;
382 }
383
384 // Find the node with the given name.
385 std::string node_name = tokens[1];
386 HloInstruction* instr = FindInstruction(module, node_name);
387 if (!instr) {
388 std::cerr << "Couldn't find HloInstruction named " << node_name << "."
389 << std::endl;
390 return;
391 }
392
393 int64_t height = -1;
394 if (tokens.size() == 3) {
395 if (!absl::SimpleAtoi(tokens[2], &height)) {
396 std::cerr << "Can't parse '" << tokens[2] << "' as an integer."
397 << std::endl;
398 return;
399 }
400 }
401
402 auto extracted_module = ExtractModule(instr, height);
403 std::cout << extracted_module->ToString(
404 HloPrintOptions::ShortParsable().set_print_backend_config(
405 hlo_render_options.show_backend_config))
406 << std::endl;
407 }
408
409 // Checks if there is a use-def path from `from` to `to`.
ExistsPathFromTo(const HloInstruction * from,const HloInstruction * to)410 bool ExistsPathFromTo(const HloInstruction* from, const HloInstruction* to) {
411 absl::flat_hash_set<const HloInstruction*> visited;
412 std::vector<const HloInstruction*> to_visit = {from};
413 while (!to_visit.empty()) {
414 auto* n = to_visit.back();
415 if (n == to) {
416 return true;
417 }
418 to_visit.pop_back();
419 visited.insert(n);
420 for (auto* user : n->users()) {
421 if (!visited.count(user)) {
422 to_visit.push_back(user);
423 }
424 }
425 }
426 return false;
427 }
428
OpenUrl(const Options & opts,absl::string_view url)429 void OpenUrl(const Options& opts, absl::string_view url) {
430 std::cout << url << std::endl;
431
432 // If it is a url, try to open it up in the user's browser too.
433 if (absl::StartsWithIgnoreCase(url, "http://") ||
434 absl::StartsWithIgnoreCase(url, "https://") ||
435 absl::StartsWithIgnoreCase(url, "file://")) {
436 const char* browser_bin = opts.browser.empty() ? "/usr/bin/sensible-browser"
437 : opts.browser.c_str();
438 tensorflow::SubProcess p;
439 p.SetProgram(browser_bin, {browser_bin, std::string(url)});
440 p.Start();
441 } else {
442 std::cerr << "\nExpected a URL, but got strange graph result (dumped "
443 "above). If this isn't what you expected, maybe file a bug?"
444 << std::endl;
445 }
446 }
447
448 // Renders a graph by calling `renderer`, and then tries to open it.
449 //
450 // `renderer` is a callback so we can try various formats. In particular, the
451 // URL format doesn't work out of the box; it requires you to register a plugin.
RenderAndDisplayGraph(const Options & opts,const std::function<StatusOr<std::string> (RenderedGraphFormat)> & renderer)452 void RenderAndDisplayGraph(
453 const Options& opts,
454 const std::function<StatusOr<std::string>(RenderedGraphFormat)>& renderer) {
455 StatusOr<std::string> url_result = renderer(RenderedGraphFormat::kUrl);
456 if (url_result.ok()) {
457 std::string url = url_result.ValueOrDie();
458 OpenUrl(opts, url);
459 return;
460 }
461
462 // Ignore UNAVAILABLE errors; these are expected when there's no URL renderer
463 // plugin registered.
464 if (url_result.status().code() != tensorflow::error::UNAVAILABLE) {
465 std::cerr << "Unable to render graph as URL: " << url_result.status()
466 << std::endl;
467 std::cerr << "Trying as HTML..." << std::endl;
468 }
469
470 auto* env = tensorflow::Env::Default();
471 StatusOr<std::string> html_result = renderer(RenderedGraphFormat::kHtml);
472 if (!html_result.ok()) {
473 std::cerr << "Failed to render graph as HTML: " << html_result.status()
474 << std::endl;
475 return;
476 }
477
478 std::vector<std::string> temp_dirs;
479 env->GetLocalTempDirectories(&temp_dirs);
480 if (temp_dirs.empty()) {
481 std::cerr << "Can't render graph as HTML because we can't find a suitable "
482 "temp directory. Try setting $TMPDIR?"
483 << std::endl;
484 return;
485 }
486
487 // Try to create a unique file inside of temp_dirs.front(). Notably, this
488 // file's name must end with ".html", otherwise web browsers will treat it as
489 // plain text, so we can't use Env::CreateUniqueFileName().
490 std::string temp_file_path = tensorflow::io::JoinPath(
491 temp_dirs.front(),
492 absl::StrFormat("interactive_graphviz.%d.html", env->NowMicros()));
493 auto status = tensorflow::WriteStringToFile(
494 env, temp_file_path, std::move(html_result).ValueOrDie());
495 if (status.ok()) {
496 OpenUrl(opts, absl::StrCat("file://", temp_file_path));
497 return;
498 }
499
500 std::cerr << "Failed to write rendered HTML graph to " << temp_file_path
501 << ": " << status;
502
503 // We don't bother trying kDot, because kHTML should always work (or if it
504 // doesn't, we don't have any reason to believe kDot will work better).
505 }
506
DoAllPathsCommand(const Options & opts,const HloModule & module,const std::vector<std::string> & tokens)507 void DoAllPathsCommand(const Options& opts, const HloModule& module,
508 const std::vector<std::string>& tokens) {
509 if (tokens.size() > 4) {
510 std::cerr << R"(Illegal input. Enter e.g. "allpaths %add.4 %subtract.2" or
511 "allpaths add.4 subtract.2 42.)"
512 << std::endl;
513 return;
514 }
515
516 int64_t max_nodes = kDefaultMaxNumNodesInAllPaths;
517 if (tokens.size() == 4 && !absl::SimpleAtoi(tokens[3], &max_nodes)) {
518 std::cerr << "Can't parse '" << tokens[3] << "' as an integer."
519 << std::endl;
520 return;
521 }
522
523 const HloInstruction* n1 = FindInstruction(module, tokens[1]);
524 if (!n1) {
525 std::cerr << "Couldn't find HloInstruction named " << tokens[1];
526 return;
527 }
528 const HloInstruction* n2 = FindInstruction(module, tokens[2]);
529 if (!n2) {
530 std::cerr << "Couldn't find HloInstruction named " << tokens[2];
531 return;
532 }
533
534 // Is there a path from n1 to n2, or vice versa?
535 const HloInstruction* from;
536 const HloInstruction* to;
537 if (ExistsPathFromTo(n1, n2)) {
538 from = n1;
539 to = n2;
540 } else if (ExistsPathFromTo(n2, n1)) {
541 from = n2;
542 to = n1;
543 } else {
544 std::cerr << "No path from/to " << tokens[1] << " to/from " << tokens[2];
545 return;
546 }
547 RenderAndDisplayGraph(opts, [&](RenderedGraphFormat format) {
548 return RenderAllPathsFromTo(*from, *to, max_nodes, format,
549 hlo_render_options);
550 });
551 }
552
553 // Plot a given instruction neighborhood or computation with graphviz.
554 void DoPlotCommand(const Options& opts, const HloModule& module,
555 const std::vector<std::string>& tokens) {
556 std::string node_name = tokens[0];
557
558 // Find the node with the given name.
559 const HloInstruction* instr = FindInstruction(module, node_name);
560 const HloComputation* comp = FindComputation(module, node_name);
561 if (!instr && !comp) {
562 std::cerr << "Couldn't find HloInstruction or HloComputation named "
563 << node_name << "." << std::endl;
564 return;
565 }
566
567 uint64_t graph_width = kDefaultWidth;
568 absl::flat_hash_set<const HloInstruction*> boundary;
569 if (tokens.size() >= 2) {
570 if (comp) {
571 std::cerr << "Can only use graph-size parameter with instructions, but "
572 << node_name << " is a computation." << std::endl;
573 return;
574 }
575
576 int bound_index = 1;
577 // Get the <width> if present.
578 if (absl::SimpleAtoi(tokens[bound_index], &graph_width)) {
579 bound_index++;
580 } else {
581 // <width> not found, need to reset graph_width.
582 graph_width = kDefaultWidth;
583 }
584 // Get the '/'.
585 if (bound_index < tokens.size()) {
586 // This token must be a '/'.
587 if (tokens[bound_index] != "/") {
588 std::cerr << "Expect a /, but get a '" << tokens[bound_index] << "'."
589 << std::endl;
590 return;
591 }
592 bound_index++;
593 }
594 // Get the boundary nodes.
595 while (bound_index < tokens.size()) {
596 std::string bnode_name = tokens[bound_index];
597 const HloInstruction* binstr = FindInstruction(module, bnode_name);
598 if (!binstr) {
599 std::cerr << "Couldn't find HloInstruction named " << bnode_name << "."
600 << std::endl;
601 return;
602 }
603 boundary.insert(binstr);
604 bound_index++;
605 }
606 }
607
608 // Generate the graph and print the resulting string, which should be a
609 // graphviz url.
610 if (comp) {
611 RenderAndDisplayGraph(opts, [&](RenderedGraphFormat format) {
612 return RenderGraph(*comp, /*label=*/"",
613 comp->parent()->config().debug_options(), format,
614 /*hlo_execution_profile=*/nullptr, hlo_render_options);
615 });
616 } else {
617 RenderAndDisplayGraph(opts, [&](RenderedGraphFormat format) {
618 return RenderNeighborhoodAround(*instr, graph_width, format,
619 hlo_render_options,
620 /*boundary=*/boundary);
621 });
622 }
623 }
624
625 // Run the main event loop, reading user commands and processing them.
626 void InteractiveDumpGraphs(const Options& opts, const HloModule& module) {
627 // This is an interactive tool, but some may use `extract` in non-tty
628 // environment anyway. Give them a clean hlo dump.
629 if (isatty(fileno(stdin))) {
630 std::cout << "\n\nLoaded module " << module.name() << "." << std::endl;
631 DoHelpCommand();
632 }
633 for (std::string line; ReadLine("\ncommand: ", &line);) {
634 if (line.empty()) {
635 std::cout << R"(Enter e.g. "fusion.1 3" or "add.8".)" << std::endl
636 << R"(Enter "help" for help; ^D, "quit", or "exit" to exit.)"
637 << std::endl;
638 continue;
639 }
640 std::vector<std::string> tokens =
641 absl::StrSplit(line, ' ', absl::SkipEmpty());
642 if (tokens[0] == "quit" || tokens[0] == "exit") {
643 break;
644 } else if (tokens[0] == "help") {
645 DoHelpCommand();
646 } else if (tokens[0] == "backend_config") {
647 DoBackendConfigCommand(tokens);
648 } else if (tokens[0] == "show_fusion_subcomputations") {
649 DoShowFusionSubcomputationsCommand(tokens);
650 } else if (tokens[0] == "list") {
651 if (tokens.size() > 1 && tokens[1] == "computations") {
652 DoListComputationsCommand(module, tokens);
653 } else {
654 DoListCommand(module, tokens);
655 }
656 } else if (tokens[0] == "info") {
657 DoInfoCommand(module, tokens);
658 } else if (tokens[0] == "extract") {
659 DoExtractCommand(module, tokens);
660 } else if (tokens[0] == "allpaths") {
661 DoAllPathsCommand(opts, module, tokens);
662 } else {
663 DoPlotCommand(opts, module, tokens);
664 }
665 }
666 }
667
668 void CheckFlags(const Options& opts) {
669 int nonempty_flags_amount = 0;
670 if (!opts.hlo_proto.empty()) {
671 ++nonempty_flags_amount;
672 }
673 if (!opts.hlo_snapshot.empty()) {
674 ++nonempty_flags_amount;
675 }
676 if (!opts.hlo_text.empty()) {
677 ++nonempty_flags_amount;
678 }
679 if (!opts.hlo_module_proto.empty()) {
680 ++nonempty_flags_amount;
681 }
682 if (nonempty_flags_amount == 1) {
683 return;
684 }
685 LOG(FATAL) << "Can only specify one and only one of '--hlo_proto', "
686 "'--hlo_snapshot', '--hlo_text', '--hlo_module_proto' flags.";
687 }
688
689 void RealMain(const Options& opts) {
690 if (!isatty(fileno(stdin))) {
691 LOG(ERROR) << "\n\n*****************************************\n"
692 << "This is an interactive tool, but stdin is not a tty.\n"
693 << "*****************************************\n\n";
694 }
695
696 CheckFlags(opts);
697
698 std::unique_ptr<HloModule> module;
699 if (!opts.hlo_snapshot.empty()) {
700 HloSnapshot snapshot;
701 TF_CHECK_OK(tensorflow::ReadBinaryProto(tensorflow::Env::Default(),
702 opts.hlo_snapshot, &snapshot))
703 << "Can't open, read, or parse HloSnapshot proto at "
704 << opts.hlo_snapshot;
705 auto config =
706 HloModule::CreateModuleConfigFromProto(snapshot.hlo().hlo_module(),
707 xla::GetDebugOptionsFromFlags())
708 .ValueOrDie();
709 module = HloModule::CreateFromProto(snapshot.hlo().hlo_module(), config)
710 .ValueOrDie();
711 } else if (!opts.hlo_proto.empty()) {
712 module = HloRunner::ReadModuleFromBinaryProtoFile(
713 opts.hlo_proto, xla::GetDebugOptionsFromFlags())
714 .ValueOrDie();
715 } else if (!opts.hlo_text.empty()) {
716 module = HloRunner::ReadModuleFromHloTextFile(
717 opts.hlo_text, xla::GetDebugOptionsFromFlags())
718 .ValueOrDie();
719 } else if (!opts.hlo_module_proto.empty()) {
720 module = HloRunner::ReadModuleFromModuleBinaryProtofile(
721 opts.hlo_module_proto, xla::GetDebugOptionsFromFlags())
722 .ValueOrDie();
723 }
724
725 // If a platform was specified, compile the module for that platform.
726 if (!opts.platform.empty()) {
727 se::Platform* platform =
728 PlatformUtil::GetPlatform(opts.platform).ValueOrDie();
729 LOG(INFO) << "Compiling module for " << platform->Name();
730
731 se::StreamExecutor* executor =
732 platform->ExecutorForDevice(/*ordinal=*/0).ValueOrDie();
733 auto compiler = Compiler::GetForPlatform(platform).ValueOrDie();
734 module = compiler
735 ->RunHloPasses(std::move(module), executor,
736 /*device_allocator=*/nullptr)
737 .ValueOrDie();
738 auto executable = compiler
739 ->RunBackend(std::move(module), executor,
740 /*device_allocator=*/nullptr)
741 .ValueOrDie();
742 InteractiveDumpGraphs(opts, executable->module());
743 } else {
744 InteractiveDumpGraphs(opts, *module);
745 }
746 }
747
748 } // namespace
749 } // namespace tools
750 } // namespace xla
751
752 int main(int argc, char** argv) {
753 xla::tools::Options opts;
754 opts.browser = "/usr/bin/sensible-browser";
755 bool need_help = false;
756 const std::vector<tensorflow::Flag> flag_list = {
757 tensorflow::Flag("hlo_snapshot", &opts.hlo_snapshot,
758 "HloSnapshot proto to interactively dump to graphviz"),
759 tensorflow::Flag("hlo_proto", &opts.hlo_proto,
760 "XLA hlo proto to interactively dump to graphviz"),
761 tensorflow::Flag("hlo_module_proto", &opts.hlo_module_proto,
762 "XLA hlomodule proto to interactively dump to graphviz"),
763 tensorflow::Flag("hlo_text", &opts.hlo_text,
764 "XLA hlo proto to interactively dump to graphviz"),
765 tensorflow::Flag("platform", &opts.platform,
766 "Platform to compile for: CPU, CUDA, etc"),
767 tensorflow::Flag("browser", &opts.browser,
768 "Path to web browser used to display produced graphs."),
769 tensorflow::Flag("help", &need_help, "Prints this help message"),
770 };
771 std::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
772 bool parse_ok = tensorflow::Flags::Parse(&argc, argv, flag_list);
773 tensorflow::port::InitMain(argv[0], &argc, &argv);
774 if (argc != 1 || !parse_ok || need_help) {
775 LOG(QFATAL) << usage;
776 }
777 xla::tools::RealMain(opts);
778 return 0;
779 }
780