• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Usage: replay_computation some_binary_snapshot_proto*
17 //
18 // Where some_binary_snapshot_proto is [type_prefix:]file_path. Supported
19 // type_prefixes:
20 // * recordio_hlo_proto - for a Tensorflow recordio file containing serialized
21 // xla.HloProtos.
22 //
23 // If type_prefix is omitted, the program will make several guesses.
24 //
25 // Replays computations and shows the results on the command line.
26 //
27 // some_binary_snapshot_proto is obtained by serializing the HloSnapshot from
28 // ServiceInterface::SnapshotComputation to disk.
29 //
30 // Computations that require arguments can be replayed using fake data by
31 // passing --use_fake_data on the command line.  If the real data is available
32 // in the proto and --use_fake_data is false, the real data is used.
33 //
34 // Input can be a binary HloSnapshot proto, a binary HloProto proto, or a
35 // textual HLO string.
36 //
37 // The output format is:
38 //
39 // file_path: computation_name :: type:literal_str
40 //
41 // Note: If you pass multiple modules, they will be compiled in parallel but run
42 // in series.
43 
44 #define EIGEN_USE_THREADS
45 
46 #include <stdio.h>
47 
48 #include <algorithm>
49 #include <memory>
50 #include <string>
51 #include <utility>
52 #include <vector>
53 
54 #include "absl/types/span.h"
55 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
56 #include "tensorflow/compiler/xla/client/client.h"
57 #include "tensorflow/compiler/xla/client/client_library.h"
58 #include "tensorflow/compiler/xla/client/global_data.h"
59 #include "tensorflow/compiler/xla/client/lib/testing.h"
60 #include "tensorflow/compiler/xla/client/local_client.h"
61 #include "tensorflow/compiler/xla/client/xla_computation.h"
62 #include "tensorflow/compiler/xla/debug_options_flags.h"
63 #include "tensorflow/compiler/xla/execution_options_util.h"
64 #include "tensorflow/compiler/xla/literal.h"
65 #include "tensorflow/compiler/xla/service/hlo.pb.h"
66 #include "tensorflow/compiler/xla/service/hlo_parser.h"
67 #include "tensorflow/compiler/xla/shape_util.h"
68 #include "tensorflow/compiler/xla/status_macros.h"
69 #include "tensorflow/compiler/xla/statusor.h"
70 #include "tensorflow/compiler/xla/tests/test_utils.h"
71 #include "tensorflow/compiler/xla/types.h"
72 #include "tensorflow/compiler/xla/xla_data.pb.h"
73 #include "tensorflow/core/lib/core/threadpool.h"
74 #include "tensorflow/core/lib/io/record_reader.h"
75 #include "tensorflow/core/platform/cpu_info.h"
76 #include "tensorflow/core/platform/env.h"
77 #include "tensorflow/core/platform/init_main.h"
78 #include "tensorflow/core/platform/logging.h"
79 #include "tensorflow/core/util/command_line_flags.h"
80 
81 namespace xla {
82 namespace tools {
83 namespace {
84 
85 // Command-line opts to this tool.  See main() for descriptions of these
86 // fields.
87 struct Options {
Optionsxla::tools::__anon4b4bb92b0111::Options88   Options() {}
89 
NeedsRealDataxla::tools::__anon4b4bb92b0111::Options90   bool NeedsRealData() const { return !use_fake_data && !compile_only; }
91 
92   std::string fake_infeed_shape;
93   std::string fake_outfeed_shape;
94 
95   // generate_fake_infeed == true is a safe default: If the model has 0 or 1
96   // infeeds, then it will work like normal.  If the model has more than one
97   // infeed, it will be an error, but that wouldn't have worked anyway if you
98   // hadn't passed generate_fake_infeed.
99   //
100   // Same for generate_fake_outfeed.
101   bool generate_fake_infeed = true;
102   bool generate_fake_outfeed = true;
103 
104   bool use_fake_data = false;
105   bool print_result = true;
106   int num_runs = 1;
107 
108   int intra_op_thread_pool_size = -1;
109 
110   bool compile_only = false;
111 };
112 
CompileExecutable(const HloSnapshot & module,LocalClient * client,const Options & opts)113 StatusOr<std::unique_ptr<LocalExecutable>> CompileExecutable(
114     const HloSnapshot& module, LocalClient* client, const Options& opts) {
115   XlaComputation computation(module.hlo().hlo_module());
116   std::vector<Shape> argument_layouts;
117   argument_layouts.reserve(
118       computation.proto().host_program_shape().parameters_size());
119   std::vector<const Shape*> argument_layout_ptrs;
120   if (opts.use_fake_data) {
121     for (const ShapeProto& param :
122          computation.proto().host_program_shape().parameters()) {
123       argument_layouts.push_back(Shape(param));
124       argument_layout_ptrs.push_back(&argument_layouts.back());
125     }
126   } else {
127     for (const auto& proto : module.arguments()) {
128       if (!proto.has_shape()) {
129         return InvalidArgument("LiteralProto has no shape");
130       }
131       Shape shape(proto.shape());
132       argument_layouts.push_back(shape);
133       argument_layout_ptrs.push_back(&argument_layouts.back());
134     }
135   }
136   ExecutableBuildOptions exec_build_options;
137   *exec_build_options.mutable_debug_options() = GetDebugOptionsFromFlags();
138   TF_ASSIGN_OR_RETURN(
139       auto executables,
140       client->Compile(computation, argument_layout_ptrs, exec_build_options));
141   TF_RET_CHECK(executables.size() == 1);
142   return std::move(executables[0]);
143 }
144 
GetXfeedShape(bool is_infeed,const HloModuleProto & module,const Options & opts)145 std::optional<Shape> GetXfeedShape(bool is_infeed, const HloModuleProto& module,
146                                    const Options& opts) {
147   std::vector<HloInstructionProto> xfeed_instrs;
148   for (const auto& comp : module.computations()) {
149     for (const auto& instruction : comp.instructions()) {
150       if (instruction.opcode() == HloOpcodeString(is_infeed
151                                                       ? HloOpcode::kInfeed
152                                                       : HloOpcode::kOutfeed)) {
153         xfeed_instrs.push_back(instruction);
154       }
155     }
156   }
157 
158   auto log_xfeed_instrs = [&] {
159     for (const auto& infeed : xfeed_instrs) {
160       LOG(ERROR) << "  " << ShapeUtil::HumanString(Shape(infeed.shape())) << " "
161                  << infeed.name();
162     }
163   };
164 
165   auto find_instruction_from_id_or_die = [&](int64_t id) {
166     for (const auto& comp : module.computations()) {
167       for (const auto& instruction : comp.instructions()) {
168         if (instruction.id() == id) {
169           return instruction;
170         }
171       }
172     }
173     LOG(FATAL) << "No instruction with id " << id;
174   };
175 
176   std::optional<Shape> xfeed_shape;
177   std::string xfeed_name = is_infeed ? "infeed" : "outfeed";
178   std::string fake_xfeed_shape =
179       is_infeed ? opts.fake_infeed_shape : opts.fake_outfeed_shape;
180   bool generate_fake_xfeed =
181       is_infeed ? opts.generate_fake_infeed : opts.generate_fake_outfeed;
182   if (!fake_xfeed_shape.empty()) {
183     xfeed_shape = std::move(ParseShape(fake_xfeed_shape)).ValueOrDie();
184   } else if (generate_fake_xfeed) {
185     QCHECK_LT(xfeed_instrs.size(), 2)
186         << "--generate_fake_" << xfeed_name
187         << " only works if the model has 0 or 1 " << xfeed_name << " ops.";
188     if (xfeed_instrs.empty()) {
189       LOG(INFO) << "Not generating fake " << xfeed_name
190                 << " shape; model has no " << xfeed_name << "s.";
191     } else if (xfeed_instrs.size() == 1) {
192       // kInfeed instructions should have a shape (buffer, token).  kOutfeed
193       // instructions should have operand 0 of shape `buffer`. We want to xfeed
194       // just `buffer`.
195       xfeed_shape = is_infeed
196                         ? Shape(xfeed_instrs.front().shape()).tuple_shapes(0)
197                         : Shape(find_instruction_from_id_or_die(
198                                     xfeed_instrs.front().operand_ids(0))
199                                     .shape());
200       LOG(INFO) << "Generating fake " << xfeed_name << " with inferred shape: "
201                 << ShapeUtil::HumanString(*xfeed_shape);
202     } else {
203       LOG(ERROR) << "--generate_fake_" << xfeed_name
204                  << " only works if the model has 0 or 1 " << xfeed_name
205                  << " ops, but this model has " << xfeed_instrs.size()
206                  << " of them:";
207       log_xfeed_instrs();
208       LOG(QFATAL) << "Can't run model with --generate_fake_infeed.";
209     }
210   } else if (!xfeed_instrs.empty()) {
211     LOG(ERROR) << "Model contains " << xfeed_instrs.size() << " " << xfeed_name
212                << " instruction(s), but neither --generate_fake_" << xfeed_name
213                << " nor --fake_" << xfeed_name
214                << "_shape was specified.  Execution will likely hang.";
215     log_xfeed_instrs();
216   }
217 
218   return xfeed_shape;
219 }
220 
221 // Invokes the given computation passing arbitrary data for every (unbound)
222 // parameter if use_fake_data, Otherwise use recorded data if available.
223 //
224 // Similarly, infeeds fake data of shape fake_infeed_shape if it is provided.
225 // If generate_fake_infeed is true, the required infeed shape is derived from
226 // the computation and then used to provide a fake infeed shape.
227 //
228 // If neither generate_fake_infeed is true nor a fake_infeed_shape is provided,
229 // no infeed is performed.
ReplayComputation(const HloSnapshot & module,LocalExecutable * executable,LocalClient * client,const Options & opts)230 StatusOr<Literal> ReplayComputation(const HloSnapshot& module,
231                                     LocalExecutable* executable,
232                                     LocalClient* client, const Options& opts) {
233   XlaComputation computation(module.hlo().hlo_module());
234 
235   // Build the `argument_ptrs` vector, which contains ShapedBuffer*s to our
236   // arguments.  This is a bit involved, because we may have to convert from
237   // GlobalData to ShapedBuffer*, and we have to manage the lifetime of all our
238   // objects.
239   std::vector<ScopedShapedBuffer> scoped_shaped_buffer_arguments;
240   std::vector<std::unique_ptr<GlobalData>> global_data_arguments;
241   std::vector<const ShapedBuffer*> argument_ptrs;
242   if (opts.use_fake_data) {
243     // Run fake computations with debug options ignoring XLA_FLAGS.  Users very
244     // likely want XLA_FLAGS only to apply to the "real" computation being run,
245     // not to the fake computations we use for generating arguments. There is
246     // an exception. ptxas can be called during the generation of fake
247     // data. As it is cached in the process memory, the flag affecting this call
248     // should not be ignored.
249     auto debug_opts_flags = GetDebugOptionsFromFlags();
250     auto debug_opts = DefaultDebugOptionsIgnoringFlags();
251     debug_opts.set_xla_gpu_asm_extra_flags(
252         debug_opts_flags.xla_gpu_asm_extra_flags());
253 
254     global_data_arguments =
255         MakeFakeArgumentsOrDie(computation, client, &debug_opts);
256     for (const auto& data : global_data_arguments) {
257       argument_ptrs.push_back(
258           client->GlobalDataToShapedBuffer(data->handle(), /*replica_number=*/0)
259               .ValueOrDie());
260     }
261   } else {  // use recorded data if available
262     for (const auto& proto : module.arguments()) {
263       TF_ASSIGN_OR_RETURN(Literal literal, Literal::CreateFromProto(proto));
264       TF_ASSIGN_OR_RETURN(
265           ScopedShapedBuffer data,
266           client->LiteralToShapedBuffer(literal, /*device_ordinal=*/0));
267       scoped_shaped_buffer_arguments.push_back(std::move(data));
268     }
269     for (const auto& argument : scoped_shaped_buffer_arguments) {
270       argument_ptrs.push_back(&argument);
271     }
272   }
273 
274   std::shared_ptr<Literal> infeed_data;
275   if (std::optional<Shape> infeed_shape = GetXfeedShape(
276           /*is_infeed=*/true, computation.proto(), opts)) {
277     infeed_data = std::make_shared<Literal>(
278         std::move(MakeFakeLiteral(*infeed_shape)).ValueOrDie());
279   }
280   std::optional<Shape> outfeed_shape =
281       GetXfeedShape(/*is_infeed=*/false, computation.proto(), opts);
282 
283   // Do not attempt to run the executable if num_runs is less than 1.
284   if (opts.num_runs < 1) {
285     return Cancelled("Cancelled after compilation since --num_runs < 1.");
286   }
287 
288   // Run the computation num_runs times, and return the result from the last
289   // execution.
290   const bool xla_hlo_profile = GetDebugOptionsFromFlags().xla_hlo_profile();
291   se::StreamExecutorMemoryAllocator allocator(
292       client->platform(),
293       {client->platform()->ExecutorForDevice(0).ValueOrDie()});
294   std::optional<ScopedShapedBuffer> final_result;
295   LOG(ERROR) << "Running " << opts.num_runs << " number of times\n";
296   for (int i = 0; i < opts.num_runs; ++i) {
297     // If xla_hlo_profile is enabled, print a noisy message before the last run,
298     // making it easier to separate this profile from the others in the logspam.
299     bool is_final_result = i == opts.num_runs - 1;
300     if (xla_hlo_profile && is_final_result) {
301       LOG(INFO) << "\n\n***** Final run below ******";
302     }
303     int thread_pool_size = opts.intra_op_thread_pool_size < 0
304                                ? tensorflow::port::MaxParallelism()
305                                : opts.intra_op_thread_pool_size;
306     tensorflow::thread::ThreadPool pool(tensorflow::Env::Default(), "XLAEigen",
307                                         thread_pool_size);
308     Eigen::ThreadPoolDevice thread_pool(pool.AsEigenThreadPool(),
309                                         pool.NumThreads());
310 
311     ExecutionProfile profile;
312     ExecutableRunOptions run_options;
313     run_options.set_execution_profile(&profile);
314     run_options.set_allocator(&allocator);
315     run_options.set_intra_op_thread_pool(&thread_pool);
316 
317     if (infeed_data) {
318       TF_CHECK_OK(client->TransferToInfeed(*infeed_data));
319     }
320     std::unique_ptr<tensorflow::Thread> outfeed_drain_thread;
321     if (outfeed_shape) {
322       // TransferFromOutfeedLocal blocks till the outfeed is available, so do
323       // it asynchronously separate thread.
324       outfeed_drain_thread.reset(tensorflow::Env::Default()->StartThread(
325           tensorflow::ThreadOptions(), "outfeed_drain_thread", [&] {
326             Literal outfeed(*outfeed_shape);
327             TF_CHECK_OK(client->TransferFromOutfeedLocal(/*device_ordinal=*/0,
328                                                          &outfeed));
329             VLOG(1) << "Received outfeed data of shape "
330                     << ShapeUtil::HumanStringWithLayout(*outfeed_shape);
331           }));
332     }
333 
334     TF_ASSIGN_OR_RETURN(ScopedShapedBuffer result,
335                         executable->Run(argument_ptrs, run_options));
336     LOG(INFO) << "Done executing in "
337               << static_cast<double>(profile.compute_time_ns()) / 1e9
338               << "s: " << module.hlo().hlo_module().name();
339 
340     // Save the result if this is for the final iteration.  Otherwise discard
341     // the result before rerunning the computation, so as to free up the
342     // relevant memory.
343     if (is_final_result) {
344       final_result = std::move(result);
345     }
346   }
347 
348   TF_ASSIGN_OR_RETURN(Literal result_literal,
349                       client->ShapedBufferToLiteral(*final_result));
350   return result_literal;
351 }
352 
ParseRecordIoFile(absl::string_view filename,const Options & opts)353 StatusOr<std::vector<HloSnapshot>> ParseRecordIoFile(absl::string_view filename,
354                                                      const Options& opts) {
355   tensorflow::Env* env = tensorflow::Env::Default();
356 
357   std::unique_ptr<tensorflow::RandomAccessFile> file;
358   TF_RETURN_IF_ERROR(env->NewRandomAccessFile(
359       std::string(filename.begin(), filename.end()), &file));
360   tensorflow::io::RecordReader reader(
361       file.get(),
362       tensorflow::io::RecordReaderOptions::CreateRecordReaderOptions("ZLIB"));
363 
364   std::vector<HloSnapshot> snapshots;
365   uint64_t offset = 0;
366   tensorflow::tstring record;
367   while (reader.ReadRecord(&offset, &record).ok()) {
368     HloSnapshot snapshot;
369     if (snapshot.mutable_hlo()->ParseFromString(record)) {
370       snapshots.push_back(std::move(snapshot));
371     } else {
372       LOG(ERROR) << "Encountered bad proto";
373     }
374   }
375   QCHECK(!snapshots.empty())
376       << "No proto is successfully parsed from the file - the file possibly "
377          "has a mismatched compression option, format, etc.";
378   QCHECK(!opts.NeedsRealData())
379       << "Without --use_fake_data or --compile_only, you must pass an "
380          "HloSnapshot -- HloProto and textual HLO don't carry real data.";
381   return snapshots;
382 }
383 
ParseSingleHloFile(const std::string & filename,const Options & opts)384 StatusOr<std::vector<HloSnapshot>> ParseSingleHloFile(
385     const std::string& filename, const Options& opts) {
386   tensorflow::Env* env = tensorflow::Env::Default();
387 
388   HloSnapshot snapshot;
389   auto s = tensorflow::ReadBinaryProto(env, filename, &snapshot);
390   if (s.ok()) {
391     return std::vector<HloSnapshot>{std::move(snapshot)};
392   }
393   if (s.code() == tensorflow::error::NOT_FOUND) {
394     return s;
395   }
396   QCHECK(!opts.NeedsRealData())
397       << "Without --use_fake_data or --compile_only, you must pass an "
398          "HloSnapshot -- HloProto and textual HLO don't carry real data.";
399   fprintf(stderr, "%s: is not HloSnapshot. Trying HloProto.\n",
400           filename.c_str());
401 
402   if (tensorflow::ReadBinaryProto(env, filename, snapshot.mutable_hlo()).ok()) {
403     return std::vector<HloSnapshot>{std::move(snapshot)};
404   }
405   fprintf(stderr, "%s: is not HloProto. Trying HLO text.\n", filename.c_str());
406   std::string contents;
407   TF_RETURN_IF_ERROR(tensorflow::ReadFileToString(env, filename, &contents));
408   HloModuleConfig config;
409   config.set_debug_options(GetDebugOptionsFromFlags());
410   std::vector<std::string> hlo_module_texts =
411       absl::StrSplit(contents, "// -----");
412   std::vector<HloSnapshot> snapshots;
413   int start_line = 0;
414   for (const std::string& hlo_module_text : hlo_module_texts) {
415     StatusOr<std::unique_ptr<HloModule>> module =
416         ParseAndReturnUnverifiedModule(hlo_module_text, config);
417     if (module.ok()) {
418       HloSnapshot snapshot;
419       *snapshot.mutable_hlo()->mutable_hlo_module() =
420           module.ValueOrDie()->ToProto();
421       snapshots.push_back(snapshot);
422     } else {
423       LOG(ERROR) << module.status();
424       if (hlo_module_texts.size() > 1) {
425         LOG(ERROR)
426             << "The error below was done on the section starting at line "
427             << start_line;
428       }
429     }
430     start_line += absl::c_count(hlo_module_text, '\n');
431   }
432   if (!snapshots.empty()) {
433     return snapshots;
434   }
435   fprintf(stderr, "%s: is not HLO text.  Nothing left to try.\n",
436           filename.c_str());
437   return InvalidArgument("Could not parse %s.", filename);
438 }
439 
ParseInputFile(const std::string & filename,const Options & opts)440 StatusOr<std::vector<HloSnapshot>> ParseInputFile(const std::string& filename,
441                                                   const Options& opts) {
442   std::vector<HloSnapshot> snapshots;
443   absl::string_view filename_view = filename;
444   if (absl::ConsumePrefix(&filename_view, "recordio_hlo_proto:")) {
445     return ParseRecordIoFile(filename_view, opts);
446   }
447   return ParseSingleHloFile(filename, opts);
448 }
449 
RealMain(absl::Span<char * const> args,const Options & opts)450 int RealMain(absl::Span<char* const> args, const Options& opts) {
451   LocalClient* client = ClientLibrary::LocalClientOrDie();
452   int exit_status = EXIT_SUCCESS;
453 
454   std::vector<HloSnapshot> snapshots;
455   for (char* arg : args) {
456     StatusOr<std::vector<HloSnapshot>> maybe_snapshot =
457         ParseInputFile(arg, opts);
458     if (maybe_snapshot.ok()) {
459       auto new_snapshots = std::move(maybe_snapshot).ValueOrDie();
460       snapshots.insert(snapshots.end(),
461                        std::make_move_iterator(new_snapshots.begin()),
462                        std::make_move_iterator(new_snapshots.end()));
463     } else {
464       LOG(ERROR) << maybe_snapshot.status();
465     }
466   }
467 
468   // Compile all the modules in parallel.
469   LOG(INFO) << "Compiling " << snapshots.size() << " modules in parallel.";
470   std::vector<StatusOr<std::unique_ptr<LocalExecutable>>> executables;
471   {
472     constexpr size_t kThreadLimits = 100;
473     // ThreadPool CHECK-fails if we give it 0 threads.
474     tensorflow::thread::ThreadPool thread_pool(
475         tensorflow::Env::Default(), tensorflow::ThreadOptions(),
476         "compile_modules",
477         std::min<size_t>(std::max(kThreadLimits, snapshots.size()), 1),
478         /*low_latency_hint=*/false);
479     executables.resize(snapshots.size());
480     for (int64_t i = 0; i < snapshots.size(); ++i) {
481       thread_pool.Schedule([&snapshots, &executables, client, i, &opts] {
482         executables[i] = CompileExecutable(snapshots[i], client, opts);
483       });
484     }
485   }
486   LOG(INFO) << "Done compiling; now running the modules.";
487 
488   for (int64_t i = 0; i < executables.size(); ++i) {
489     if (!executables[i].ok()) {
490       LOG(ERROR) << "Compilation failed: " << executables[i].status() << ": "
491                  << snapshots[i].ShortDebugString();
492       exit_status = EXIT_FAILURE;
493       continue;
494     }
495 
496     if (opts.compile_only) {
497       continue;
498     }
499 
500     LocalExecutable* executable = executables[i].ValueOrDie().get();
501     LOG(ERROR) << "Running iteration " << i;
502     StatusOr<Literal> result_status =
503         ReplayComputation(snapshots[i], executable, client, opts);
504     LOG(ERROR) << "iteration complete.";
505     if (!result_status.ok()) {
506       fprintf(stderr, "%s: error: %s\n", args[i],
507               result_status.status().ToString().c_str());
508       exit_status = EXIT_FAILURE;
509       continue;
510     }
511 
512     if (opts.print_result) {
513       Literal result = std::move(result_status).ValueOrDie();
514       fprintf(stdout, "%s: %s :: %s:%s\n", args[i],
515               executable->executable()->module().name().c_str(),
516               ShapeUtil::HumanString(result.shape()).c_str(),
517               result.ToString().c_str());
518       auto& snapshot = snapshots[i];
519       if (snapshot.has_result()) {
520         Literal literal = Literal::CreateFromProto(snapshot.result()).value();
521         fprintf(
522             stdout, "was %s:%s\n",
523             ShapeUtil::HumanString(Shape(snapshot.result().shape())).c_str(),
524             literal.ToString().c_str());
525       }
526     }
527   }
528 
529   ClientLibrary::DestroyLocalInstances();
530   return exit_status;
531 }
532 
533 }  // namespace
534 }  // namespace tools
535 }  // namespace xla
536 
main(int argc,char ** argv)537 int main(int argc, char** argv) {
538   xla::tools::Options opts;
539   std::vector<tensorflow::Flag> flag_list = {
540       tensorflow::Flag("use_fake_data", &opts.use_fake_data,
541                        "Replay computation using fake data"),
542       tensorflow::Flag("print_result", &opts.print_result,
543                        "Print the result of the computation to stdout"),
544       tensorflow::Flag("num_runs", &opts.num_runs,
545                        "Number of times to run each computation"),
546       tensorflow::Flag("fake_infeed_shape", &opts.fake_infeed_shape,
547                        "Shape of fake data to construct for (infinite) infeed"),
548       tensorflow::Flag("fake_outfeed_shape", &opts.fake_outfeed_shape,
549                        "Shape of fake data to outfeed from computation"),
550       tensorflow::Flag("generate_fake_infeed", &opts.generate_fake_infeed,
551                        "Whether a fake infeed shape should be derived "
552                        "from the computation"),
553       tensorflow::Flag("generate_fake_outfeed", &opts.generate_fake_outfeed,
554                        "Whether a fake outfeed shape should be derived "
555                        "from the computation"),
556       tensorflow::Flag("intra_op_thread_pool_size",
557                        &opts.intra_op_thread_pool_size,
558                        "How many threads to use in the intra-op thread pool. "
559                        "Defaults to the number of CPUs."),
560       tensorflow::Flag("compile_only", &opts.compile_only,
561                        "Whether the input should only be compiled, as opposed "
562                        "to compiled and executed."),
563   };
564   xla::AppendDebugOptionsFlags(&flag_list);
565   std::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
566   bool parse_ok = tensorflow::Flags::Parse(&argc, argv, flag_list);
567   tensorflow::port::InitMain(argv[0], &argc, &argv);
568   if (argc < 2 || !parse_ok) {
569     LOG(QFATAL) << usage;
570   }
571   absl::Span<char* const> args(argv, argc);
572   args.remove_prefix(1);  // Pop off the binary name, argv[0]
573   if (opts.compile_only) {
574     opts.use_fake_data = true;
575   }
576   return xla::tools::RealMain(args, opts);
577 }
578