1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // Usage: replay_computation some_binary_snapshot_proto*
17 //
18 // Where some_binary_snapshot_proto is [type_prefix:]file_path. Supported
19 // type_prefixes:
20 // * recordio_hlo_proto - for a Tensorflow recordio file containing serialized
21 // xla.HloProtos.
22 //
23 // If type_prefix is omitted, the program will make several guesses.
24 //
25 // Replays computations and shows the results on the command line.
26 //
27 // some_binary_snapshot_proto is obtained by serializing the HloSnapshot from
28 // ServiceInterface::SnapshotComputation to disk.
29 //
30 // Computations that require arguments can be replayed using fake data by
31 // passing --use_fake_data on the command line. If the real data is available
32 // in the proto and --use_fake_data is false, the real data is used.
33 //
34 // Input can be a binary HloSnapshot proto, a binary HloProto proto, or a
35 // textual HLO string.
36 //
37 // The output format is:
38 //
39 // file_path: computation_name :: type:literal_str
40 //
41 // Note: If you pass multiple modules, they will be compiled in parallel but run
42 // in series.
43
44 #define EIGEN_USE_THREADS
45
46 #include <stdio.h>
47
48 #include <algorithm>
49 #include <memory>
50 #include <string>
51 #include <utility>
52 #include <vector>
53
54 #include "absl/types/span.h"
55 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
56 #include "tensorflow/compiler/xla/client/client.h"
57 #include "tensorflow/compiler/xla/client/client_library.h"
58 #include "tensorflow/compiler/xla/client/global_data.h"
59 #include "tensorflow/compiler/xla/client/lib/testing.h"
60 #include "tensorflow/compiler/xla/client/local_client.h"
61 #include "tensorflow/compiler/xla/client/xla_computation.h"
62 #include "tensorflow/compiler/xla/debug_options_flags.h"
63 #include "tensorflow/compiler/xla/execution_options_util.h"
64 #include "tensorflow/compiler/xla/literal.h"
65 #include "tensorflow/compiler/xla/service/hlo.pb.h"
66 #include "tensorflow/compiler/xla/service/hlo_parser.h"
67 #include "tensorflow/compiler/xla/shape_util.h"
68 #include "tensorflow/compiler/xla/status_macros.h"
69 #include "tensorflow/compiler/xla/statusor.h"
70 #include "tensorflow/compiler/xla/tests/test_utils.h"
71 #include "tensorflow/compiler/xla/types.h"
72 #include "tensorflow/compiler/xla/xla_data.pb.h"
73 #include "tensorflow/core/lib/core/threadpool.h"
74 #include "tensorflow/core/lib/io/record_reader.h"
75 #include "tensorflow/core/platform/cpu_info.h"
76 #include "tensorflow/core/platform/env.h"
77 #include "tensorflow/core/platform/init_main.h"
78 #include "tensorflow/core/platform/logging.h"
79 #include "tensorflow/core/util/command_line_flags.h"
80
81 namespace xla {
82 namespace tools {
83 namespace {
84
85 // Command-line opts to this tool. See main() for descriptions of these
86 // fields.
87 struct Options {
Optionsxla::tools::__anon4b4bb92b0111::Options88 Options() {}
89
NeedsRealDataxla::tools::__anon4b4bb92b0111::Options90 bool NeedsRealData() const { return !use_fake_data && !compile_only; }
91
92 std::string fake_infeed_shape;
93 std::string fake_outfeed_shape;
94
95 // generate_fake_infeed == true is a safe default: If the model has 0 or 1
96 // infeeds, then it will work like normal. If the model has more than one
97 // infeed, it will be an error, but that wouldn't have worked anyway if you
98 // hadn't passed generate_fake_infeed.
99 //
100 // Same for generate_fake_outfeed.
101 bool generate_fake_infeed = true;
102 bool generate_fake_outfeed = true;
103
104 bool use_fake_data = false;
105 bool print_result = true;
106 int num_runs = 1;
107
108 int intra_op_thread_pool_size = -1;
109
110 bool compile_only = false;
111 };
112
CompileExecutable(const HloSnapshot & module,LocalClient * client,const Options & opts)113 StatusOr<std::unique_ptr<LocalExecutable>> CompileExecutable(
114 const HloSnapshot& module, LocalClient* client, const Options& opts) {
115 XlaComputation computation(module.hlo().hlo_module());
116 std::vector<Shape> argument_layouts;
117 argument_layouts.reserve(
118 computation.proto().host_program_shape().parameters_size());
119 std::vector<const Shape*> argument_layout_ptrs;
120 if (opts.use_fake_data) {
121 for (const ShapeProto& param :
122 computation.proto().host_program_shape().parameters()) {
123 argument_layouts.push_back(Shape(param));
124 argument_layout_ptrs.push_back(&argument_layouts.back());
125 }
126 } else {
127 for (const auto& proto : module.arguments()) {
128 if (!proto.has_shape()) {
129 return InvalidArgument("LiteralProto has no shape");
130 }
131 Shape shape(proto.shape());
132 argument_layouts.push_back(shape);
133 argument_layout_ptrs.push_back(&argument_layouts.back());
134 }
135 }
136 ExecutableBuildOptions exec_build_options;
137 *exec_build_options.mutable_debug_options() = GetDebugOptionsFromFlags();
138 TF_ASSIGN_OR_RETURN(
139 auto executables,
140 client->Compile(computation, argument_layout_ptrs, exec_build_options));
141 TF_RET_CHECK(executables.size() == 1);
142 return std::move(executables[0]);
143 }
144
GetXfeedShape(bool is_infeed,const HloModuleProto & module,const Options & opts)145 std::optional<Shape> GetXfeedShape(bool is_infeed, const HloModuleProto& module,
146 const Options& opts) {
147 std::vector<HloInstructionProto> xfeed_instrs;
148 for (const auto& comp : module.computations()) {
149 for (const auto& instruction : comp.instructions()) {
150 if (instruction.opcode() == HloOpcodeString(is_infeed
151 ? HloOpcode::kInfeed
152 : HloOpcode::kOutfeed)) {
153 xfeed_instrs.push_back(instruction);
154 }
155 }
156 }
157
158 auto log_xfeed_instrs = [&] {
159 for (const auto& infeed : xfeed_instrs) {
160 LOG(ERROR) << " " << ShapeUtil::HumanString(Shape(infeed.shape())) << " "
161 << infeed.name();
162 }
163 };
164
165 auto find_instruction_from_id_or_die = [&](int64_t id) {
166 for (const auto& comp : module.computations()) {
167 for (const auto& instruction : comp.instructions()) {
168 if (instruction.id() == id) {
169 return instruction;
170 }
171 }
172 }
173 LOG(FATAL) << "No instruction with id " << id;
174 };
175
176 std::optional<Shape> xfeed_shape;
177 std::string xfeed_name = is_infeed ? "infeed" : "outfeed";
178 std::string fake_xfeed_shape =
179 is_infeed ? opts.fake_infeed_shape : opts.fake_outfeed_shape;
180 bool generate_fake_xfeed =
181 is_infeed ? opts.generate_fake_infeed : opts.generate_fake_outfeed;
182 if (!fake_xfeed_shape.empty()) {
183 xfeed_shape = std::move(ParseShape(fake_xfeed_shape)).ValueOrDie();
184 } else if (generate_fake_xfeed) {
185 QCHECK_LT(xfeed_instrs.size(), 2)
186 << "--generate_fake_" << xfeed_name
187 << " only works if the model has 0 or 1 " << xfeed_name << " ops.";
188 if (xfeed_instrs.empty()) {
189 LOG(INFO) << "Not generating fake " << xfeed_name
190 << " shape; model has no " << xfeed_name << "s.";
191 } else if (xfeed_instrs.size() == 1) {
192 // kInfeed instructions should have a shape (buffer, token). kOutfeed
193 // instructions should have operand 0 of shape `buffer`. We want to xfeed
194 // just `buffer`.
195 xfeed_shape = is_infeed
196 ? Shape(xfeed_instrs.front().shape()).tuple_shapes(0)
197 : Shape(find_instruction_from_id_or_die(
198 xfeed_instrs.front().operand_ids(0))
199 .shape());
200 LOG(INFO) << "Generating fake " << xfeed_name << " with inferred shape: "
201 << ShapeUtil::HumanString(*xfeed_shape);
202 } else {
203 LOG(ERROR) << "--generate_fake_" << xfeed_name
204 << " only works if the model has 0 or 1 " << xfeed_name
205 << " ops, but this model has " << xfeed_instrs.size()
206 << " of them:";
207 log_xfeed_instrs();
208 LOG(QFATAL) << "Can't run model with --generate_fake_infeed.";
209 }
210 } else if (!xfeed_instrs.empty()) {
211 LOG(ERROR) << "Model contains " << xfeed_instrs.size() << " " << xfeed_name
212 << " instruction(s), but neither --generate_fake_" << xfeed_name
213 << " nor --fake_" << xfeed_name
214 << "_shape was specified. Execution will likely hang.";
215 log_xfeed_instrs();
216 }
217
218 return xfeed_shape;
219 }
220
221 // Invokes the given computation passing arbitrary data for every (unbound)
222 // parameter if use_fake_data, Otherwise use recorded data if available.
223 //
224 // Similarly, infeeds fake data of shape fake_infeed_shape if it is provided.
225 // If generate_fake_infeed is true, the required infeed shape is derived from
226 // the computation and then used to provide a fake infeed shape.
227 //
228 // If neither generate_fake_infeed is true nor a fake_infeed_shape is provided,
229 // no infeed is performed.
ReplayComputation(const HloSnapshot & module,LocalExecutable * executable,LocalClient * client,const Options & opts)230 StatusOr<Literal> ReplayComputation(const HloSnapshot& module,
231 LocalExecutable* executable,
232 LocalClient* client, const Options& opts) {
233 XlaComputation computation(module.hlo().hlo_module());
234
235 // Build the `argument_ptrs` vector, which contains ShapedBuffer*s to our
236 // arguments. This is a bit involved, because we may have to convert from
237 // GlobalData to ShapedBuffer*, and we have to manage the lifetime of all our
238 // objects.
239 std::vector<ScopedShapedBuffer> scoped_shaped_buffer_arguments;
240 std::vector<std::unique_ptr<GlobalData>> global_data_arguments;
241 std::vector<const ShapedBuffer*> argument_ptrs;
242 if (opts.use_fake_data) {
243 // Run fake computations with debug options ignoring XLA_FLAGS. Users very
244 // likely want XLA_FLAGS only to apply to the "real" computation being run,
245 // not to the fake computations we use for generating arguments. There is
246 // an exception. ptxas can be called during the generation of fake
247 // data. As it is cached in the process memory, the flag affecting this call
248 // should not be ignored.
249 auto debug_opts_flags = GetDebugOptionsFromFlags();
250 auto debug_opts = DefaultDebugOptionsIgnoringFlags();
251 debug_opts.set_xla_gpu_asm_extra_flags(
252 debug_opts_flags.xla_gpu_asm_extra_flags());
253
254 global_data_arguments =
255 MakeFakeArgumentsOrDie(computation, client, &debug_opts);
256 for (const auto& data : global_data_arguments) {
257 argument_ptrs.push_back(
258 client->GlobalDataToShapedBuffer(data->handle(), /*replica_number=*/0)
259 .ValueOrDie());
260 }
261 } else { // use recorded data if available
262 for (const auto& proto : module.arguments()) {
263 TF_ASSIGN_OR_RETURN(Literal literal, Literal::CreateFromProto(proto));
264 TF_ASSIGN_OR_RETURN(
265 ScopedShapedBuffer data,
266 client->LiteralToShapedBuffer(literal, /*device_ordinal=*/0));
267 scoped_shaped_buffer_arguments.push_back(std::move(data));
268 }
269 for (const auto& argument : scoped_shaped_buffer_arguments) {
270 argument_ptrs.push_back(&argument);
271 }
272 }
273
274 std::shared_ptr<Literal> infeed_data;
275 if (std::optional<Shape> infeed_shape = GetXfeedShape(
276 /*is_infeed=*/true, computation.proto(), opts)) {
277 infeed_data = std::make_shared<Literal>(
278 std::move(MakeFakeLiteral(*infeed_shape)).ValueOrDie());
279 }
280 std::optional<Shape> outfeed_shape =
281 GetXfeedShape(/*is_infeed=*/false, computation.proto(), opts);
282
283 // Do not attempt to run the executable if num_runs is less than 1.
284 if (opts.num_runs < 1) {
285 return Cancelled("Cancelled after compilation since --num_runs < 1.");
286 }
287
288 // Run the computation num_runs times, and return the result from the last
289 // execution.
290 const bool xla_hlo_profile = GetDebugOptionsFromFlags().xla_hlo_profile();
291 se::StreamExecutorMemoryAllocator allocator(
292 client->platform(),
293 {client->platform()->ExecutorForDevice(0).ValueOrDie()});
294 std::optional<ScopedShapedBuffer> final_result;
295 LOG(ERROR) << "Running " << opts.num_runs << " number of times\n";
296 for (int i = 0; i < opts.num_runs; ++i) {
297 // If xla_hlo_profile is enabled, print a noisy message before the last run,
298 // making it easier to separate this profile from the others in the logspam.
299 bool is_final_result = i == opts.num_runs - 1;
300 if (xla_hlo_profile && is_final_result) {
301 LOG(INFO) << "\n\n***** Final run below ******";
302 }
303 int thread_pool_size = opts.intra_op_thread_pool_size < 0
304 ? tensorflow::port::MaxParallelism()
305 : opts.intra_op_thread_pool_size;
306 tensorflow::thread::ThreadPool pool(tensorflow::Env::Default(), "XLAEigen",
307 thread_pool_size);
308 Eigen::ThreadPoolDevice thread_pool(pool.AsEigenThreadPool(),
309 pool.NumThreads());
310
311 ExecutionProfile profile;
312 ExecutableRunOptions run_options;
313 run_options.set_execution_profile(&profile);
314 run_options.set_allocator(&allocator);
315 run_options.set_intra_op_thread_pool(&thread_pool);
316
317 if (infeed_data) {
318 TF_CHECK_OK(client->TransferToInfeed(*infeed_data));
319 }
320 std::unique_ptr<tensorflow::Thread> outfeed_drain_thread;
321 if (outfeed_shape) {
322 // TransferFromOutfeedLocal blocks till the outfeed is available, so do
323 // it asynchronously separate thread.
324 outfeed_drain_thread.reset(tensorflow::Env::Default()->StartThread(
325 tensorflow::ThreadOptions(), "outfeed_drain_thread", [&] {
326 Literal outfeed(*outfeed_shape);
327 TF_CHECK_OK(client->TransferFromOutfeedLocal(/*device_ordinal=*/0,
328 &outfeed));
329 VLOG(1) << "Received outfeed data of shape "
330 << ShapeUtil::HumanStringWithLayout(*outfeed_shape);
331 }));
332 }
333
334 TF_ASSIGN_OR_RETURN(ScopedShapedBuffer result,
335 executable->Run(argument_ptrs, run_options));
336 LOG(INFO) << "Done executing in "
337 << static_cast<double>(profile.compute_time_ns()) / 1e9
338 << "s: " << module.hlo().hlo_module().name();
339
340 // Save the result if this is for the final iteration. Otherwise discard
341 // the result before rerunning the computation, so as to free up the
342 // relevant memory.
343 if (is_final_result) {
344 final_result = std::move(result);
345 }
346 }
347
348 TF_ASSIGN_OR_RETURN(Literal result_literal,
349 client->ShapedBufferToLiteral(*final_result));
350 return result_literal;
351 }
352
ParseRecordIoFile(absl::string_view filename,const Options & opts)353 StatusOr<std::vector<HloSnapshot>> ParseRecordIoFile(absl::string_view filename,
354 const Options& opts) {
355 tensorflow::Env* env = tensorflow::Env::Default();
356
357 std::unique_ptr<tensorflow::RandomAccessFile> file;
358 TF_RETURN_IF_ERROR(env->NewRandomAccessFile(
359 std::string(filename.begin(), filename.end()), &file));
360 tensorflow::io::RecordReader reader(
361 file.get(),
362 tensorflow::io::RecordReaderOptions::CreateRecordReaderOptions("ZLIB"));
363
364 std::vector<HloSnapshot> snapshots;
365 uint64_t offset = 0;
366 tensorflow::tstring record;
367 while (reader.ReadRecord(&offset, &record).ok()) {
368 HloSnapshot snapshot;
369 if (snapshot.mutable_hlo()->ParseFromString(record)) {
370 snapshots.push_back(std::move(snapshot));
371 } else {
372 LOG(ERROR) << "Encountered bad proto";
373 }
374 }
375 QCHECK(!snapshots.empty())
376 << "No proto is successfully parsed from the file - the file possibly "
377 "has a mismatched compression option, format, etc.";
378 QCHECK(!opts.NeedsRealData())
379 << "Without --use_fake_data or --compile_only, you must pass an "
380 "HloSnapshot -- HloProto and textual HLO don't carry real data.";
381 return snapshots;
382 }
383
ParseSingleHloFile(const std::string & filename,const Options & opts)384 StatusOr<std::vector<HloSnapshot>> ParseSingleHloFile(
385 const std::string& filename, const Options& opts) {
386 tensorflow::Env* env = tensorflow::Env::Default();
387
388 HloSnapshot snapshot;
389 auto s = tensorflow::ReadBinaryProto(env, filename, &snapshot);
390 if (s.ok()) {
391 return std::vector<HloSnapshot>{std::move(snapshot)};
392 }
393 if (s.code() == tensorflow::error::NOT_FOUND) {
394 return s;
395 }
396 QCHECK(!opts.NeedsRealData())
397 << "Without --use_fake_data or --compile_only, you must pass an "
398 "HloSnapshot -- HloProto and textual HLO don't carry real data.";
399 fprintf(stderr, "%s: is not HloSnapshot. Trying HloProto.\n",
400 filename.c_str());
401
402 if (tensorflow::ReadBinaryProto(env, filename, snapshot.mutable_hlo()).ok()) {
403 return std::vector<HloSnapshot>{std::move(snapshot)};
404 }
405 fprintf(stderr, "%s: is not HloProto. Trying HLO text.\n", filename.c_str());
406 std::string contents;
407 TF_RETURN_IF_ERROR(tensorflow::ReadFileToString(env, filename, &contents));
408 HloModuleConfig config;
409 config.set_debug_options(GetDebugOptionsFromFlags());
410 std::vector<std::string> hlo_module_texts =
411 absl::StrSplit(contents, "// -----");
412 std::vector<HloSnapshot> snapshots;
413 int start_line = 0;
414 for (const std::string& hlo_module_text : hlo_module_texts) {
415 StatusOr<std::unique_ptr<HloModule>> module =
416 ParseAndReturnUnverifiedModule(hlo_module_text, config);
417 if (module.ok()) {
418 HloSnapshot snapshot;
419 *snapshot.mutable_hlo()->mutable_hlo_module() =
420 module.ValueOrDie()->ToProto();
421 snapshots.push_back(snapshot);
422 } else {
423 LOG(ERROR) << module.status();
424 if (hlo_module_texts.size() > 1) {
425 LOG(ERROR)
426 << "The error below was done on the section starting at line "
427 << start_line;
428 }
429 }
430 start_line += absl::c_count(hlo_module_text, '\n');
431 }
432 if (!snapshots.empty()) {
433 return snapshots;
434 }
435 fprintf(stderr, "%s: is not HLO text. Nothing left to try.\n",
436 filename.c_str());
437 return InvalidArgument("Could not parse %s.", filename);
438 }
439
ParseInputFile(const std::string & filename,const Options & opts)440 StatusOr<std::vector<HloSnapshot>> ParseInputFile(const std::string& filename,
441 const Options& opts) {
442 std::vector<HloSnapshot> snapshots;
443 absl::string_view filename_view = filename;
444 if (absl::ConsumePrefix(&filename_view, "recordio_hlo_proto:")) {
445 return ParseRecordIoFile(filename_view, opts);
446 }
447 return ParseSingleHloFile(filename, opts);
448 }
449
RealMain(absl::Span<char * const> args,const Options & opts)450 int RealMain(absl::Span<char* const> args, const Options& opts) {
451 LocalClient* client = ClientLibrary::LocalClientOrDie();
452 int exit_status = EXIT_SUCCESS;
453
454 std::vector<HloSnapshot> snapshots;
455 for (char* arg : args) {
456 StatusOr<std::vector<HloSnapshot>> maybe_snapshot =
457 ParseInputFile(arg, opts);
458 if (maybe_snapshot.ok()) {
459 auto new_snapshots = std::move(maybe_snapshot).ValueOrDie();
460 snapshots.insert(snapshots.end(),
461 std::make_move_iterator(new_snapshots.begin()),
462 std::make_move_iterator(new_snapshots.end()));
463 } else {
464 LOG(ERROR) << maybe_snapshot.status();
465 }
466 }
467
468 // Compile all the modules in parallel.
469 LOG(INFO) << "Compiling " << snapshots.size() << " modules in parallel.";
470 std::vector<StatusOr<std::unique_ptr<LocalExecutable>>> executables;
471 {
472 constexpr size_t kThreadLimits = 100;
473 // ThreadPool CHECK-fails if we give it 0 threads.
474 tensorflow::thread::ThreadPool thread_pool(
475 tensorflow::Env::Default(), tensorflow::ThreadOptions(),
476 "compile_modules",
477 std::min<size_t>(std::max(kThreadLimits, snapshots.size()), 1),
478 /*low_latency_hint=*/false);
479 executables.resize(snapshots.size());
480 for (int64_t i = 0; i < snapshots.size(); ++i) {
481 thread_pool.Schedule([&snapshots, &executables, client, i, &opts] {
482 executables[i] = CompileExecutable(snapshots[i], client, opts);
483 });
484 }
485 }
486 LOG(INFO) << "Done compiling; now running the modules.";
487
488 for (int64_t i = 0; i < executables.size(); ++i) {
489 if (!executables[i].ok()) {
490 LOG(ERROR) << "Compilation failed: " << executables[i].status() << ": "
491 << snapshots[i].ShortDebugString();
492 exit_status = EXIT_FAILURE;
493 continue;
494 }
495
496 if (opts.compile_only) {
497 continue;
498 }
499
500 LocalExecutable* executable = executables[i].ValueOrDie().get();
501 LOG(ERROR) << "Running iteration " << i;
502 StatusOr<Literal> result_status =
503 ReplayComputation(snapshots[i], executable, client, opts);
504 LOG(ERROR) << "iteration complete.";
505 if (!result_status.ok()) {
506 fprintf(stderr, "%s: error: %s\n", args[i],
507 result_status.status().ToString().c_str());
508 exit_status = EXIT_FAILURE;
509 continue;
510 }
511
512 if (opts.print_result) {
513 Literal result = std::move(result_status).ValueOrDie();
514 fprintf(stdout, "%s: %s :: %s:%s\n", args[i],
515 executable->executable()->module().name().c_str(),
516 ShapeUtil::HumanString(result.shape()).c_str(),
517 result.ToString().c_str());
518 auto& snapshot = snapshots[i];
519 if (snapshot.has_result()) {
520 Literal literal = Literal::CreateFromProto(snapshot.result()).value();
521 fprintf(
522 stdout, "was %s:%s\n",
523 ShapeUtil::HumanString(Shape(snapshot.result().shape())).c_str(),
524 literal.ToString().c_str());
525 }
526 }
527 }
528
529 ClientLibrary::DestroyLocalInstances();
530 return exit_status;
531 }
532
533 } // namespace
534 } // namespace tools
535 } // namespace xla
536
main(int argc,char ** argv)537 int main(int argc, char** argv) {
538 xla::tools::Options opts;
539 std::vector<tensorflow::Flag> flag_list = {
540 tensorflow::Flag("use_fake_data", &opts.use_fake_data,
541 "Replay computation using fake data"),
542 tensorflow::Flag("print_result", &opts.print_result,
543 "Print the result of the computation to stdout"),
544 tensorflow::Flag("num_runs", &opts.num_runs,
545 "Number of times to run each computation"),
546 tensorflow::Flag("fake_infeed_shape", &opts.fake_infeed_shape,
547 "Shape of fake data to construct for (infinite) infeed"),
548 tensorflow::Flag("fake_outfeed_shape", &opts.fake_outfeed_shape,
549 "Shape of fake data to outfeed from computation"),
550 tensorflow::Flag("generate_fake_infeed", &opts.generate_fake_infeed,
551 "Whether a fake infeed shape should be derived "
552 "from the computation"),
553 tensorflow::Flag("generate_fake_outfeed", &opts.generate_fake_outfeed,
554 "Whether a fake outfeed shape should be derived "
555 "from the computation"),
556 tensorflow::Flag("intra_op_thread_pool_size",
557 &opts.intra_op_thread_pool_size,
558 "How many threads to use in the intra-op thread pool. "
559 "Defaults to the number of CPUs."),
560 tensorflow::Flag("compile_only", &opts.compile_only,
561 "Whether the input should only be compiled, as opposed "
562 "to compiled and executed."),
563 };
564 xla::AppendDebugOptionsFlags(&flag_list);
565 std::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
566 bool parse_ok = tensorflow::Flags::Parse(&argc, argv, flag_list);
567 tensorflow::port::InitMain(argv[0], &argc, &argv);
568 if (argc < 2 || !parse_ok) {
569 LOG(QFATAL) << usage;
570 }
571 absl::Span<char* const> args(argv, argc);
572 args.remove_prefix(1); // Pop off the binary name, argv[0]
573 if (opts.compile_only) {
574 opts.use_fake_data = true;
575 }
576 return xla::tools::RealMain(args, opts);
577 }
578