1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // Usage: replay_computation some_binary_snapshot_proto*
17 //
18 // Where some_binary_snapshot_proto is [type_prefix:]file_path. Supported
19 // type_prefixes:
20 // * recordio_hlo_proto - for a Tensorflow recordio file containing serialized
21 // xla.HloProtos.
22 //
23 // If type_prefix is omitted, the program will make several guesses.
24 //
25 // Replays computations and shows the results on the command line.
26 //
27 // some_binary_snapshot_proto is obtained by serializing the HloSnapshot from
28 // ServiceInterface::SnapshotComputation to disk.
29 //
30 // Computations that require arguments can be replayed using fake data by
31 // passing --use_fake_data on the command line. If the real data is available
32 // in the proto and --use_fake_data is false, the real data is used.
33 //
34 // Input can be a binary HloSnapshot proto, a binary HloProto proto, or a
35 // textual HLO string.
36 //
37 // The output format is:
38 //
39 // file_path: computation_name :: type:literal_str
40 //
41 // Note: If you pass multiple modules, they will be compiled in parallel but run
42 // in series.
43
44 #define EIGEN_USE_THREADS
45
46 #include <stdio.h>
47
48 #include <memory>
49 #include <string>
50 #include <utility>
51 #include <vector>
52
53 #include "absl/types/span.h"
54 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
55 #include "tensorflow/compiler/xla/client/client.h"
56 #include "tensorflow/compiler/xla/client/client_library.h"
57 #include "tensorflow/compiler/xla/client/global_data.h"
58 #include "tensorflow/compiler/xla/client/lib/testing.h"
59 #include "tensorflow/compiler/xla/client/local_client.h"
60 #include "tensorflow/compiler/xla/client/xla_computation.h"
61 #include "tensorflow/compiler/xla/debug_options_flags.h"
62 #include "tensorflow/compiler/xla/execution_options_util.h"
63 #include "tensorflow/compiler/xla/literal.h"
64 #include "tensorflow/compiler/xla/service/gpu/infeed_manager.h"
65 #include "tensorflow/compiler/xla/service/gpu/outfeed_manager.h"
66 #include "tensorflow/compiler/xla/service/hlo.pb.h"
67 #include "tensorflow/compiler/xla/service/hlo_parser.h"
68 #include "tensorflow/compiler/xla/shape_util.h"
69 #include "tensorflow/compiler/xla/status_macros.h"
70 #include "tensorflow/compiler/xla/statusor.h"
71 #include "tensorflow/compiler/xla/tests/test_utils.h"
72 #include "tensorflow/compiler/xla/types.h"
73 #include "tensorflow/compiler/xla/xla_data.pb.h"
74 #include "tensorflow/core/lib/core/threadpool.h"
75 #include "tensorflow/core/lib/io/record_reader.h"
76 #include "tensorflow/core/platform/cpu_info.h"
77 #include "tensorflow/core/platform/env.h"
78 #include "tensorflow/core/platform/init_main.h"
79 #include "tensorflow/core/platform/logging.h"
80 #include "tensorflow/core/util/command_line_flags.h"
81
82 namespace xla {
83 namespace tools {
84 namespace {
85
86 // Command-line opts to this tool. See main() for descriptions of these
87 // fields.
88 struct Options {
Optionsxla::tools::__anon38317c710111::Options89 Options() : intra_op_thread_pool_size(tensorflow::port::MaxParallelism()) {}
90
NeedsRealDataxla::tools::__anon38317c710111::Options91 bool NeedsRealData() const { return !use_fake_data && !compile_only; }
92
93 string fake_infeed_shape;
94 string fake_outfeed_shape;
95
96 // generate_fake_infeed == true is a safe default: If the model has 0 or 1
97 // infeeds, then it will work like normal. If the model has more than one
98 // infeed, it will be an error, but that wouldn't have worked anyway if you
99 // hadn't passed generate_fake_infeed.
100 //
101 // Same for generate_fake_outfeed.
102 bool generate_fake_infeed = true;
103 bool generate_fake_outfeed = true;
104
105 bool use_fake_data = false;
106 bool print_result = true;
107 int num_runs = 1;
108
109 int intra_op_thread_pool_size;
110
111 bool compile_only = false;
112 };
113
CompileExecutable(const HloSnapshot & module,LocalClient * client)114 StatusOr<std::unique_ptr<LocalExecutable>> CompileExecutable(
115 const HloSnapshot& module, LocalClient* client) {
116 XlaComputation computation(module.hlo().hlo_module());
117 std::vector<Shape> argument_layouts;
118 argument_layouts.reserve(
119 computation.proto().host_program_shape().parameters_size());
120 std::vector<const Shape*> argument_layout_ptrs;
121 for (const ShapeProto& param :
122 computation.proto().host_program_shape().parameters()) {
123 argument_layouts.push_back(Shape(param));
124 argument_layout_ptrs.push_back(&argument_layouts.back());
125 }
126 ExecutableBuildOptions exec_build_options;
127 *exec_build_options.mutable_debug_options() = GetDebugOptionsFromFlags();
128 TF_ASSIGN_OR_RETURN(
129 auto executables,
130 client->Compile(computation, argument_layout_ptrs, exec_build_options));
131 TF_RET_CHECK(executables.size() == 1);
132 return std::move(executables[0]);
133 }
134
GetXfeedShape(bool is_infeed,const HloModuleProto & module,const Options & opts)135 absl::optional<Shape> GetXfeedShape(bool is_infeed,
136 const HloModuleProto& module,
137 const Options& opts) {
138 std::vector<HloInstructionProto> xfeed_instrs;
139 for (const auto& comp : module.computations()) {
140 for (const auto& instruction : comp.instructions()) {
141 if (instruction.opcode() == HloOpcodeString(is_infeed
142 ? HloOpcode::kInfeed
143 : HloOpcode::kOutfeed)) {
144 xfeed_instrs.push_back(instruction);
145 }
146 }
147 }
148
149 auto log_xfeed_instrs = [&] {
150 for (const auto& infeed : xfeed_instrs) {
151 LOG(ERROR) << " " << ShapeUtil::HumanString(Shape(infeed.shape())) << " "
152 << infeed.name();
153 }
154 };
155
156 auto find_instruction_from_id_or_die = [&](int64 id) {
157 for (const auto& comp : module.computations()) {
158 for (const auto& instruction : comp.instructions()) {
159 if (instruction.id() == id) {
160 return instruction;
161 }
162 }
163 }
164 LOG(FATAL) << "No instruction with id " << id;
165 };
166
167 absl::optional<Shape> xfeed_shape;
168 string xfeed_name = is_infeed ? "infeed" : "outfeed";
169 string fake_xfeed_shape =
170 is_infeed ? opts.fake_infeed_shape : opts.fake_outfeed_shape;
171 bool generate_fake_xfeed =
172 is_infeed ? opts.generate_fake_infeed : opts.generate_fake_outfeed;
173 if (!fake_xfeed_shape.empty()) {
174 xfeed_shape = std::move(ParseShape(fake_xfeed_shape)).ValueOrDie();
175 } else if (generate_fake_xfeed) {
176 CHECK_LT(xfeed_instrs.size(), 2)
177 << "--generate_fake_" << xfeed_name
178 << " only works if the model has 0 or 1 " << xfeed_name << " ops.";
179 if (xfeed_instrs.empty()) {
180 LOG(INFO) << "Not generating fake " << xfeed_name
181 << " shape; model has no " << xfeed_name << "s.";
182 } else if (xfeed_instrs.size() == 1) {
183 // kInfeed instructions should have a shape (buffer, token). kOutfeed
184 // instructions should have operand 0 of shape `buffer`. We want to xfeed
185 // just `buffer`.
186 xfeed_shape = is_infeed
187 ? Shape(xfeed_instrs.front().shape()).tuple_shapes(0)
188 : Shape(find_instruction_from_id_or_die(
189 xfeed_instrs.front().operand_ids(0))
190 .shape());
191 LOG(INFO) << "Generating fake " << xfeed_name << " with inferred shape: "
192 << ShapeUtil::HumanString(*xfeed_shape);
193 } else {
194 LOG(ERROR) << "--generate_fake_" << xfeed_name
195 << " only works if the model has 0 or 1 " << xfeed_name
196 << " ops, but this model has " << xfeed_instrs.size()
197 << " of them:";
198 log_xfeed_instrs();
199 LOG(FATAL) << "Can't run model with --generate_fake_infeed.";
200 }
201 } else if (!xfeed_instrs.empty()) {
202 LOG(ERROR) << "Model contains " << xfeed_instrs.size() << " " << xfeed_name
203 << " instruction(s), but neither --generate_fake_" << xfeed_name
204 << " nor --fake_" << xfeed_name
205 << "_shape was specified. Execution will likely hang.";
206 log_xfeed_instrs();
207 }
208
209 return xfeed_shape;
210 }
211
212 // Invokes the given computation passing arbitrary data for every (unbound)
213 // parameter if use_fake_data, Otherwise use recorded data if available.
214 //
215 // Similarly, infeeds fake data of shape fake_infeed_shape if it is provided.
216 // If generate_fake_infeed is true, the required infeed shape is derived from
217 // the computation and then used to provide a fake infeed shape.
218 //
219 // If neither generate_fake_infeed is true nor a fake_infeed_shape is provided,
220 // no infeed is performed.
ReplayComputation(const HloSnapshot & module,LocalExecutable * executable,LocalClient * client,const Options & opts)221 StatusOr<Literal> ReplayComputation(const HloSnapshot& module,
222 LocalExecutable* executable,
223 LocalClient* client, const Options& opts) {
224 XlaComputation computation(module.hlo().hlo_module());
225
226 // Build the `argument_ptrs` vector, which contains ShapedBuffer*s to our
227 // arguments. This is a bit involved, because we may have to convert from
228 // GlobalData to ShapedBuffer*, and we have to manage the lifetime of all our
229 // objects.
230 std::vector<ScopedShapedBuffer> scoped_shaped_buffer_arguments;
231 std::vector<std::unique_ptr<GlobalData>> global_data_arguments;
232 std::vector<const ShapedBuffer*> argument_ptrs;
233 if (opts.use_fake_data) {
234 // Run fake computations with debug options ignoring XLA_FLAGS. Users very
235 // likely want XLA_FLAGS only to apply to the "real" computation being run,
236 // not to the fake computations we use for generating arguments.
237 auto debug_opts = DefaultDebugOptionsIgnoringFlags();
238 global_data_arguments =
239 MakeFakeArgumentsOrDie(computation, client, &debug_opts);
240 for (const auto& data : global_data_arguments) {
241 argument_ptrs.push_back(
242 client->GlobalDataToShapedBuffer(data->handle(), /*device_ordinal=*/0)
243 .ValueOrDie());
244 }
245 } else { // use recorded data if available
246 for (const auto& proto : module.arguments()) {
247 TF_ASSIGN_OR_RETURN(Literal literal, Literal::CreateFromProto(proto));
248 TF_ASSIGN_OR_RETURN(
249 ScopedShapedBuffer data,
250 client->LiteralToShapedBuffer(literal, /*device_ordinal=*/0));
251 scoped_shaped_buffer_arguments.push_back(std::move(data));
252 }
253 for (const auto& argument : scoped_shaped_buffer_arguments) {
254 argument_ptrs.push_back(&argument);
255 }
256 }
257
258 if (absl::optional<Shape> infeed_shape = GetXfeedShape(
259 /*is_infeed=*/true, computation.proto(), opts)) {
260 auto infeed_data = std::make_shared<Literal>(
261 std::move(MakeFakeLiteral(*infeed_shape)).ValueOrDie());
262 xla::gpu::GetOrCreateInfeedManager()
263 ->RegisterBeforeGetNextDestinationCallback([infeed_data, client] {
264 TF_CHECK_OK(client->TransferToInfeed(*infeed_data));
265 });
266 }
267
268 absl::optional<tensorflow::thread::ThreadPool> outfeed_thread_pool;
269 if (absl::optional<Shape> outfeed_shape = GetXfeedShape(
270 /*is_infeed=*/false, computation.proto(), opts)) {
271 // For each an outfeed that runs, enqueue a task that will consume it. We
272 // need a thread pool because the act of running an outfeed blocks on there
273 // being a destination available, and the act of making a destination
274 // available blocks on there being outfeed data available.
275 outfeed_thread_pool.emplace(tensorflow::Env::Default(), "infeed",
276 /*num_threads=*/1);
277 auto consume_outfeed = [client, outfeed_shape] {
278 TF_CHECK_OK(
279 client->TransferFromOutfeedLocal(*outfeed_shape, /*device_ordinal=*/0)
280 .status());
281 VLOG(1) << "Received outfeed data of shape "
282 << ShapeUtil::HumanStringWithLayout(*outfeed_shape);
283 };
284 xla::gpu::GetOrCreateOutfeedManager()
285 ->RegisterBeforeGetNextDestinationCallback(
286 [consume_outfeed, &outfeed_thread_pool] {
287 outfeed_thread_pool->Schedule(consume_outfeed);
288 });
289 }
290
291 // Do not attempt to run the executable if num_runs is less than 1.
292 if (opts.num_runs < 1) {
293 return Cancelled("Cancelled after compilation since --num_runs < 1.");
294 }
295
296 // Run the computation num_runs times, and return the result from the last
297 // execution.
298 const bool xla_hlo_profile = GetDebugOptionsFromFlags().xla_hlo_profile();
299 se::StreamExecutorMemoryAllocator allocator(
300 client->platform(),
301 {client->platform()->ExecutorForDevice(0).ValueOrDie()});
302 absl::optional<ScopedShapedBuffer> final_result;
303 for (int i = 0; i < opts.num_runs; ++i) {
304 // If xla_hlo_profile is enabled, print a noisy message before the last run,
305 // making it easier to separate this profile from the others in the logspam.
306 bool is_final_result = i == opts.num_runs - 1;
307 if (xla_hlo_profile && is_final_result) {
308 LOG(INFO) << "\n\n***** Final run below ******";
309 }
310 tensorflow::thread::ThreadPool pool(tensorflow::Env::Default(), "XLAEigen",
311 opts.intra_op_thread_pool_size);
312 Eigen::ThreadPoolDevice thread_pool(pool.AsEigenThreadPool(),
313 pool.NumThreads());
314
315 ExecutionProfile profile;
316 ExecutableRunOptions run_options;
317 run_options.set_execution_profile(&profile);
318 run_options.set_allocator(&allocator);
319 run_options.set_intra_op_thread_pool(&thread_pool);
320
321 TF_ASSIGN_OR_RETURN(ScopedShapedBuffer result,
322 executable->Run(argument_ptrs, run_options));
323 LOG(INFO) << "Done executing in "
324 << static_cast<double>(profile.compute_time_ns()) / 1e9
325 << "s: " << module.hlo().hlo_module().name();
326
327 // Save the result if this is for the final iteration. Otherwise discard
328 // the result before rerunning the computation, so as to free up the
329 // relevant memory.
330 if (is_final_result) {
331 final_result = std::move(result);
332 }
333 }
334
335 TF_ASSIGN_OR_RETURN(Literal result_literal,
336 client->ShapedBufferToLiteral(*final_result));
337 return result_literal;
338 }
339
ParseRecordIoFile(absl::string_view filename,const Options & opts)340 StatusOr<std::vector<HloSnapshot>> ParseRecordIoFile(absl::string_view filename,
341 const Options& opts) {
342 tensorflow::Env* env = tensorflow::Env::Default();
343
344 std::unique_ptr<tensorflow::RandomAccessFile> file;
345 TF_RETURN_IF_ERROR(env->NewRandomAccessFile(
346 string(filename.begin(), filename.end()), &file));
347 tensorflow::io::RecordReader reader(
348 file.get(),
349 tensorflow::io::RecordReaderOptions::CreateRecordReaderOptions("ZLIB"));
350
351 std::vector<HloSnapshot> snapshots;
352 uint64 offset = 0;
353 tensorflow::tstring record;
354 while (reader.ReadRecord(&offset, &record).ok()) {
355 HloSnapshot snapshot;
356 if (snapshot.mutable_hlo()->ParseFromString(record)) {
357 snapshots.push_back(std::move(snapshot));
358 } else {
359 LOG(ERROR) << "Encountered bad proto";
360 }
361 }
362 CHECK(!snapshots.empty())
363 << "No proto is successfully parsed from the file - the file possibly "
364 "has a mismatched compression option, format, etc.";
365 CHECK(!opts.NeedsRealData())
366 << "Without --use_fake_data or --compile_only, you must pass an "
367 "HloSnapshot -- HloProto and textual HLO don't carry real data.";
368 return snapshots;
369 }
370
ParseSingleHloFile(const string & filename,const Options & opts)371 StatusOr<HloSnapshot> ParseSingleHloFile(const string& filename,
372 const Options& opts) {
373 tensorflow::Env* env = tensorflow::Env::Default();
374
375 HloSnapshot snapshot;
376 auto s = tensorflow::ReadBinaryProto(env, filename, &snapshot);
377 if (s.ok()) {
378 return snapshot;
379 }
380 if (s.code() == tensorflow::error::NOT_FOUND) {
381 return s;
382 }
383 CHECK(!opts.NeedsRealData())
384 << "Without --use_fake_data or --compile_only, you must pass an "
385 "HloSnapshot -- HloProto and textual HLO don't carry real data.";
386 fprintf(stderr, "%s: is not HloSnapshot. Trying HloProto.\n",
387 filename.c_str());
388
389 if (tensorflow::ReadBinaryProto(env, filename, snapshot.mutable_hlo()).ok()) {
390 return snapshot;
391 }
392 fprintf(stderr, "%s: is not HloProto. Trying HLO text.\n", filename.c_str());
393 string contents;
394 TF_RETURN_IF_ERROR(tensorflow::ReadFileToString(env, filename, &contents));
395 HloModuleConfig config;
396 config.set_debug_options(GetDebugOptionsFromFlags());
397 StatusOr<std::unique_ptr<HloModule>> module =
398 ParseAndReturnUnverifiedModule(contents, config);
399 if (module.ok()) {
400 *snapshot.mutable_hlo()->mutable_hlo_module() =
401 module.ValueOrDie()->ToProto();
402 return snapshot;
403 } else {
404 LOG(ERROR) << module.status();
405 }
406 fprintf(stderr, "%s: is not HLO text. Nothing left to try.\n",
407 filename.c_str());
408 return InvalidArgument("Could not parse %s.", filename);
409 }
410
ParseInputFile(const string & filename,const Options & opts)411 StatusOr<std::vector<HloSnapshot>> ParseInputFile(const string& filename,
412 const Options& opts) {
413 std::vector<HloSnapshot> snapshots;
414 absl::string_view filename_view = filename;
415 if (absl::ConsumePrefix(&filename_view, "recordio_hlo_proto:")) {
416 return ParseRecordIoFile(filename_view, opts);
417 }
418 TF_ASSIGN_OR_RETURN(auto snapshot, ParseSingleHloFile(filename, opts));
419 return std::vector<HloSnapshot>{std::move(snapshot)};
420 }
421
RealMain(absl::Span<char * const> args,const Options & opts)422 int RealMain(absl::Span<char* const> args, const Options& opts) {
423 LocalClient* client = ClientLibrary::LocalClientOrDie();
424 int exit_status = EXIT_SUCCESS;
425
426 std::vector<HloSnapshot> snapshots;
427 for (char* arg : args) {
428 StatusOr<std::vector<HloSnapshot>> maybe_snapshot =
429 ParseInputFile(arg, opts);
430 if (maybe_snapshot.ok()) {
431 auto new_snapshots = std::move(maybe_snapshot).ValueOrDie();
432 snapshots.insert(snapshots.end(),
433 std::make_move_iterator(new_snapshots.begin()),
434 std::make_move_iterator(new_snapshots.end()));
435 } else {
436 LOG(ERROR) << maybe_snapshot.status();
437 }
438 }
439
440 // Compile all the modules in parallel.
441 LOG(INFO) << "Compiling " << snapshots.size() << " modules in parallel.";
442 std::vector<StatusOr<std::unique_ptr<LocalExecutable>>> executables;
443 {
444 constexpr size_t kThreadLimits = 100;
445 // ThreadPool CHECK-fails if we give it 0 threads.
446 tensorflow::thread::ThreadPool thread_pool(
447 tensorflow::Env::Default(), tensorflow::ThreadOptions(),
448 "compile_modules",
449 std::min<size_t>(std::max(kThreadLimits, snapshots.size()), 1),
450 /*low_latency_hint=*/false);
451 executables.resize(snapshots.size());
452 for (int64 i = 0; i < snapshots.size(); ++i) {
453 thread_pool.Schedule([&snapshots, &executables, client, i] {
454 executables[i] = CompileExecutable(snapshots[i], client);
455 });
456 }
457 }
458 LOG(INFO) << "Done compiling; now running the modules.";
459
460 for (int64 i = 0; i < executables.size(); ++i) {
461 if (!executables[i].ok()) {
462 LOG(ERROR) << "Compilation failed: " << executables[i].status() << ": "
463 << snapshots[i].ShortDebugString();
464 exit_status = EXIT_FAILURE;
465 continue;
466 }
467
468 if (opts.compile_only) {
469 continue;
470 }
471
472 LocalExecutable* executable = executables[i].ValueOrDie().get();
473 LOG(ERROR) << "Running iteration " << i;
474 StatusOr<Literal> result_status =
475 ReplayComputation(snapshots[i], executable, client, opts);
476 LOG(ERROR) << "iteration complete.";
477 if (!result_status.ok()) {
478 fprintf(stderr, "%s: error: %s\n", args[i],
479 result_status.status().ToString().c_str());
480 exit_status = EXIT_FAILURE;
481 continue;
482 }
483
484 if (opts.print_result) {
485 Literal result = std::move(result_status).ValueOrDie();
486 fprintf(stdout, "%s: %s :: %s:%s\n", args[i],
487 executable->executable()->module().name().c_str(),
488 ShapeUtil::HumanString(result.shape()).c_str(),
489 result.ToString().c_str());
490 auto& snapshot = snapshots[i];
491 if (snapshot.has_result()) {
492 Literal literal =
493 Literal::CreateFromProto(snapshot.result()).ConsumeValueOrDie();
494 fprintf(
495 stdout, "was %s:%s\n",
496 ShapeUtil::HumanString(Shape(snapshot.result().shape())).c_str(),
497 literal.ToString().c_str());
498 }
499 }
500 }
501
502 ClientLibrary::DestroyLocalInstances();
503 return exit_status;
504 }
505
506 } // namespace
507 } // namespace tools
508 } // namespace xla
509
main(int argc,char ** argv)510 int main(int argc, char** argv) {
511 xla::tools::Options opts;
512 const std::vector<tensorflow::Flag> flag_list = {
513 tensorflow::Flag("use_fake_data", &opts.use_fake_data,
514 "Replay computation using fake data"),
515 tensorflow::Flag("print_result", &opts.print_result,
516 "Print the result of the computation to stdout"),
517 tensorflow::Flag("num_runs", &opts.num_runs,
518 "Number of times to run each computation"),
519 tensorflow::Flag("fake_infeed_shape", &opts.fake_infeed_shape,
520 "Shape of fake data to construct for (infinite) infeed"),
521 tensorflow::Flag("fake_outfeed_shape", &opts.fake_outfeed_shape,
522 "Shape of fake data to outfeed from computation"),
523 tensorflow::Flag("generate_fake_infeed", &opts.generate_fake_infeed,
524 "Whether a fake infeed shape should be derived "
525 "from the computation"),
526 tensorflow::Flag("generate_fake_outfeed", &opts.generate_fake_outfeed,
527 "Whether a fake outfeed shape should be derived "
528 "from the computation"),
529 tensorflow::Flag("intra_op_thread_pool_size",
530 &opts.intra_op_thread_pool_size,
531 "How many threads to use in the intra-op thread pool. "
532 "Defaults to the number of CPUs."),
533 tensorflow::Flag("compile_only", &opts.compile_only,
534 "Whether the input should only be compiled, as opposed "
535 "to compiled and executed."),
536 };
537 xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
538 bool parse_ok = tensorflow::Flags::Parse(&argc, argv, flag_list);
539 tensorflow::port::InitMain(argv[0], &argc, &argv);
540 if (argc < 2 || !parse_ok) {
541 LOG(QFATAL) << usage;
542 }
543
544 absl::Span<char* const> args(argv, argc);
545 args.remove_prefix(1); // Pop off the binary name, argv[0]
546 return xla::tools::RealMain(args, opts);
547 }
548