1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/debug_options_flags.h"
17
18 #include <vector>
19
20 #include "absl/base/call_once.h"
21 #include "absl/container/flat_hash_map.h"
22 #include "absl/container/node_hash_map.h"
23 #include "absl/strings/str_format.h"
24 #include "absl/strings/str_split.h"
25 #include "tensorflow/compiler/xla/debug_options_parsers.h"
26 #include "tensorflow/compiler/xla/parse_flags_from_env.h"
27
28 namespace xla {
29
DefaultDebugOptionsIgnoringFlags()30 DebugOptions DefaultDebugOptionsIgnoringFlags() {
31 DebugOptions opts;
32 opts.set_xla_llvm_enable_alias_scope_metadata(true);
33 opts.set_xla_llvm_enable_noalias_metadata(true);
34 opts.set_xla_llvm_enable_invariant_load_metadata(true);
35 opts.set_xla_llvm_disable_expensive_passes(false);
36 opts.set_xla_backend_optimization_level(3);
37 opts.set_xla_gpu_autotune_level(4);
38 opts.set_xla_cpu_multi_thread_eigen(true);
39 opts.set_xla_gpu_cuda_data_dir("./cuda_sdk_lib");
40 opts.set_xla_eliminate_hlo_implicit_broadcast(true);
41 opts.set_xla_dump_hlo_as_html(false);
42 opts.set_xla_dump_include_timestamp(true);
43 opts.set_xla_dump_max_hlo_modules(-1);
44 #ifdef INTEL_MKL
45 opts.set_xla_cpu_use_mkl_dnn(true);
46 #endif // INTEL_MKL
47 opts.set_xla_gpu_max_kernel_unroll_factor(4);
48 // Set cudnn batchnorm off by default; it does not provide a performance win
49 // on average.
50 opts.set_xla_gpu_use_cudnn_batchnorm(false);
51
52 // Run all GPU work on one stream by default. Using multiple streams
53 // increases memory usage and we lack strong motivating benchmarks for tuning
54 // the heuristics needed to decide when to run on multiple streams. See
55 // b/77879207.
56 opts.set_xla_gpu_disable_multi_streaming(true);
57
58 // TODO(jlebar): Disable fastmath once doing so is not a performance
59 // regression.
60 opts.set_xla_cpu_enable_fast_math(true);
61 opts.set_xla_gpu_enable_fast_min_max(true);
62
63 opts.set_xla_allow_excess_precision(true);
64 opts.set_xla_force_host_platform_device_count(1);
65 opts.set_xla_gpu_deterministic_reductions(false);
66 return opts;
67 }
68
69 static absl::once_flag flags_init;
70 static DebugOptions* flag_values;
71 static std::vector<tensorflow::Flag>* flag_objects;
72
73 // Maps pass -> initial fuel values (parsed when AllocateFlags was run).
74 static absl::flat_hash_map<string, int64>* initial_fuel;
75
76 // Maps pass -> whether fuel was ever consumed for that pass.
77 static absl::node_hash_map<string, std::atomic<bool>>* fuel_ever_consumed;
78
79 // Maps pass -> remaining fuel.
80 //
81 // All threads start off using this global fuel pool, but ResetThreadLocalFuel()
82 // switches them to a thread-local fuel pool.
83 static absl::node_hash_map<string, std::atomic<int64>>* global_fuel;
84
85 // If we're using thread-local fuel, this stores it.
86 static thread_local std::unique_ptr<
87 absl::node_hash_map<string, std::atomic<int64>>>
88 thread_fuel; // NOLINT (global variable with nontrivial destructor)
89
90 // Logs a warning if a pass's fuel was never consumed, on the theory that this
91 // may be a typo in the flag value. Called atexit.
WarnIfFuelWasNeverConsumed()92 static void WarnIfFuelWasNeverConsumed() {
93 CHECK(fuel_ever_consumed != nullptr);
94 for (const auto& kv : *fuel_ever_consumed) {
95 absl::string_view pass = kv.first;
96 bool was_consumed = kv.second;
97 if (!was_consumed) {
98 LOG(ERROR) << absl::StreamFormat(
99 "Compiler fuel for \"%s\" was never consumed. This may be a typo in "
100 "the --xla_fuel flag you passed.",
101 pass);
102 }
103 }
104 }
105
106 // Allocates flag_values and flag_objects; this function must not be called more
107 // than once - its call done via call_once.
AllocateFlags()108 static void AllocateFlags() {
109 flag_values = new DebugOptions(DefaultDebugOptionsIgnoringFlags());
110
111 // Returns a lambda that calls "member_setter" on "flag_values" with the
112 // argument passed in to the lambda.
113 auto bool_setter_for = [](void (DebugOptions::*member_setter)(bool)) {
114 return [member_setter](bool value) {
115 (flag_values->*member_setter)(value);
116 return true;
117 };
118 };
119
120 // Returns a lambda that calls "member_setter" on "flag_values" with the
121 // argument passed in to the lambda.
122 auto int32_setter_for = [](void (DebugOptions::*member_setter)(int32)) {
123 return [member_setter](int32 value) {
124 (flag_values->*member_setter)(value);
125 return true;
126 };
127 };
128
129 auto string_setter_for =
130 [](void (DebugOptions::*member_setter)(const string& value)) {
131 return [member_setter](const string& value) {
132 (flag_values->*member_setter)(value);
133 return true;
134 };
135 };
136
137 // Custom "sub-parser" lambda for xla_disable_hlo_passes.
138 auto setter_for_xla_disable_hlo_passes = [](string comma_separated_values) {
139 for (const auto& passname :
140 std::vector<string>(absl::StrSplit(comma_separated_values, ','))) {
141 flag_values->add_xla_disable_hlo_passes(passname);
142 }
143 return true;
144 };
145
146 // Custom "sub-parser" lambda for xla_enable_hlo_passes_only.
147 auto setter_for_xla_enable_hlo_passes_only =
148 [](string comma_separated_values) {
149 for (const auto& passname :
150 std::vector<string>(absl::StrSplit(comma_separated_values, ','))) {
151 flag_values->add_xla_enable_hlo_passes_only(passname);
152 }
153 return true;
154 };
155
156 // Custom "sub-parser" lambda for xla_gpu_ptx_file.
157 auto setter_for_xla_gpu_ptx_file = [](string value) {
158 flag_values->add_xla_gpu_ptx_file(value);
159 return true;
160 };
161
162 // Custom "sub-parser" lambda for xla_backend_extra_options.
163 auto setter_for_xla_backend_extra_options =
164 [](string comma_separated_values) {
165 auto* extra_options_map =
166 flag_values->mutable_xla_backend_extra_options();
167 parse_xla_backend_extra_options(extra_options_map,
168 comma_separated_values);
169 return true;
170 };
171
172 // Custom "sub-parser" for xla_fuel. Note that ConsumeFuel does not do any
173 // locking on the fuel global variables. This means that it's
174 // illegal/undefined behavior to modify this flag value while the compiler is
175 // running.
176 initial_fuel = new absl::flat_hash_map<string, int64>();
177 fuel_ever_consumed = new absl::node_hash_map<string, std::atomic<bool>>();
178 global_fuel = new absl::node_hash_map<string, std::atomic<int64>>();
179 auto setter_for_xla_fuel = [](string xla_fuel_value) {
180 initial_fuel->clear();
181 global_fuel->clear();
182 fuel_ever_consumed->clear();
183
184 for (const auto& kv : absl::StrSplit(xla_fuel_value, ',')) {
185 std::vector<string> pass_and_fuel = absl::StrSplit(kv, '=');
186 if (pass_and_fuel.size() != 2) {
187 LOG(ERROR) << absl::StreamFormat(
188 "Illegal value for --xla_fuel. Saw %s, but expected token %s to "
189 "have format X=INTEGER.",
190 xla_fuel_value, kv);
191 return false;
192 }
193 const auto& pass = pass_and_fuel[0];
194 const auto& fuel_str = pass_and_fuel[1];
195 int64 fuel;
196 if (!absl::SimpleAtoi(fuel_str, &fuel)) {
197 LOG(ERROR) << absl::StreamFormat(
198 "Illegal value for --xla_fuel. Saw %s, but expected token %s to be "
199 "an integer.",
200 xla_fuel_value, fuel_str);
201 return false;
202 }
203 initial_fuel->emplace(pass, fuel);
204 global_fuel->emplace(pass, fuel);
205 fuel_ever_consumed->emplace(pass, false);
206 }
207
208 // If --xla_fuel was specified, register an atexit handler which logs a
209 // warning if a pass was specified but never consumed any fuel, on the
210 // theory that this is may be a typo.
211 if (!initial_fuel->empty()) {
212 static absl::once_flag register_atexit_once;
213 absl::call_once(
214 register_atexit_once,
215 +[] { std::atexit(WarnIfFuelWasNeverConsumed); });
216 }
217 return true;
218 };
219
220 flag_objects = new std::vector<tensorflow::Flag>({
221 tensorflow::Flag(
222 "xla_cpu_enable_fast_math",
223 bool_setter_for(&DebugOptions::set_xla_cpu_enable_fast_math),
224 flag_values->xla_cpu_enable_fast_math(),
225 "Enable unsafe fast-math optimizations in the CPU compiler; "
226 "this may produce faster code at the expense of some accuracy."),
227 tensorflow::Flag(
228 "xla_cpu_fast_math_honor_nans",
229 bool_setter_for(&DebugOptions::set_xla_cpu_fast_math_honor_nans),
230 flag_values->xla_cpu_fast_math_honor_nans(),
231 "When xla_cpu_enable_fast_math is true then this controls whether we "
232 "allow operations to produce NaNs. Ignored when "
233 "xla_cpu_enable_fast_math is false."),
234 tensorflow::Flag(
235 "xla_cpu_fast_math_honor_infs",
236 bool_setter_for(&DebugOptions::set_xla_cpu_fast_math_honor_infs),
237 flag_values->xla_cpu_fast_math_honor_infs(),
238 "When xla_cpu_enable_fast_math is true then this controls whether we "
239 "allow operations to produce infinites. Ignored when "
240 "xla_cpu_enable_fast_math is false."),
241 tensorflow::Flag(
242 "xla_cpu_fast_math_honor_division",
243 bool_setter_for(&DebugOptions::set_xla_cpu_fast_math_honor_division),
244 flag_values->xla_cpu_fast_math_honor_division(),
245 "When xla_cpu_enable_fast_math is true then this controls whether "
246 "we forbid to use multiplication by the reciprocal instead of "
247 "division. Ignored when xla_cpu_enable_fast_math is false."),
248 tensorflow::Flag(
249 "xla_cpu_fast_math_honor_functions",
250 bool_setter_for(&DebugOptions::set_xla_cpu_fast_math_honor_functions),
251 flag_values->xla_cpu_fast_math_honor_functions(),
252 "When xla_cpu_enable_fast_math is true then this controls whether "
253 "we forbid to approximate calculations for functions. Ignored when "
254 "xla_cpu_enable_fast_math is false."),
255 tensorflow::Flag(
256 "xla_gpu_enable_fast_min_max",
257 bool_setter_for(&DebugOptions::set_xla_gpu_enable_fast_min_max),
258 flag_values->xla_gpu_enable_fast_min_max(),
259 "Enable fast floating point min/max lowering that does not propagate "
260 "NaNs."),
261 tensorflow::Flag(
262 "xla_llvm_enable_alias_scope_metadata",
263 bool_setter_for(
264 &DebugOptions::set_xla_llvm_enable_alias_scope_metadata),
265 flag_values->xla_llvm_enable_alias_scope_metadata(),
266 "In LLVM-based backends, enable the emission of "
267 "!alias.scope metadata in the generated IR."),
268 tensorflow::Flag(
269 "xla_llvm_enable_noalias_metadata",
270 bool_setter_for(&DebugOptions::set_xla_llvm_enable_noalias_metadata),
271 flag_values->xla_llvm_enable_noalias_metadata(),
272 "In LLVM-based backends, enable the emission of "
273 "!noalias metadata in the generated IR."),
274 tensorflow::Flag(
275 "xla_llvm_enable_invariant_load_metadata",
276 bool_setter_for(
277 &DebugOptions::set_xla_llvm_enable_invariant_load_metadata),
278 flag_values->xla_llvm_enable_invariant_load_metadata(),
279 "In LLVM-based backends, enable the emission of "
280 "!invariant.load metadata in "
281 "the generated IR."),
282 tensorflow::Flag(
283 "xla_llvm_disable_expensive_passes",
284 bool_setter_for(&DebugOptions::set_xla_llvm_disable_expensive_passes),
285 flag_values->xla_llvm_disable_expensive_passes(),
286 "In LLVM-based backends, disable a custom set of "
287 "expensive optimization passes."),
288 tensorflow::Flag(
289 "xla_backend_optimization_level",
290 int32_setter_for(&DebugOptions::set_xla_backend_optimization_level),
291 flag_values->xla_backend_optimization_level(),
292 "Numerical optimization level for the XLA compiler backend."),
293 tensorflow::Flag(
294 "xla_disable_hlo_passes", setter_for_xla_disable_hlo_passes, "",
295 "Comma-separated list of hlo passes to be disabled. These names "
296 "must exactly match the passes' names; no whitespace around "
297 "commas."),
298 tensorflow::Flag(
299 "xla_enable_hlo_passes_only", setter_for_xla_enable_hlo_passes_only,
300 "",
301 "Comma-separated list of hlo passes to be enabled. These names "
302 "must exactly match the passes' names; no whitespace around "
303 "commas. The unspecified passes are all disabled."),
304 tensorflow::Flag(
305 "xla_disable_all_hlo_passes",
306 bool_setter_for(&DebugOptions::set_xla_disable_all_hlo_passes), false,
307 "Disables all HLO passes. Notes that some passes are necessary for "
308 "correctness and the invariants that must be satisfied by 'fully "
309 "optimized' HLO are different for different devices and may change "
310 "over time. The only 'guarantee', such as it is, is that if you "
311 "compile XLA and dump the optimized HLO for some graph, you should "
312 "be able to run it again on the same device with the same build of "
313 "XLA."),
314 tensorflow::Flag(
315 "xla_embed_ir_in_executable",
316 bool_setter_for(&DebugOptions::set_xla_embed_ir_in_executable),
317 flag_values->xla_embed_ir_in_executable(),
318 "Embed the compiler IR as a string in the executable."),
319 tensorflow::Flag(
320 "xla_eliminate_hlo_implicit_broadcast",
321 bool_setter_for(
322 &DebugOptions::set_xla_eliminate_hlo_implicit_broadcast),
323 flag_values->xla_eliminate_hlo_implicit_broadcast(),
324 "Eliminate implicit broadcasts when lowering user "
325 "computations to HLO instructions; use explicit "
326 "broadcast instead."),
327 tensorflow::Flag(
328 "xla_cpu_multi_thread_eigen",
329 bool_setter_for(&DebugOptions::set_xla_cpu_multi_thread_eigen),
330 flag_values->xla_cpu_multi_thread_eigen(),
331 "When generating calls to Eigen in the CPU backend, "
332 "use multi-threaded Eigen mode."),
333 tensorflow::Flag("xla_gpu_cuda_data_dir",
334 flag_values->mutable_xla_gpu_cuda_data_dir(),
335 "If non-empty, specifies a local directory containing "
336 "ptxas and nvvm libdevice files; otherwise we use "
337 "those from runfile directories."),
338 tensorflow::Flag("xla_gpu_ftz",
339 bool_setter_for(&DebugOptions::set_xla_gpu_ftz),
340 flag_values->xla_gpu_ftz(),
341 "If true, flush-to-zero semantics are enabled in the "
342 "code generated for GPUs."),
343 tensorflow::Flag(
344 "xla_gpu_disable_multi_streaming",
345 bool_setter_for(&DebugOptions::set_xla_gpu_disable_multi_streaming),
346 flag_values->xla_gpu_disable_multi_streaming(),
347 "If true, multi-streaming in the GPU backend is disabled."),
348 tensorflow::Flag(
349 "xla_gpu_max_kernel_unroll_factor",
350 int32_setter_for(&DebugOptions::set_xla_gpu_max_kernel_unroll_factor),
351 flag_values->xla_gpu_max_kernel_unroll_factor(),
352 "Specify the maximum kernel unroll factor for the GPU backend."),
353 tensorflow::Flag("xla_gpu_ptx_file", setter_for_xla_gpu_ptx_file, "",
354 "If non-empty, specifies a file containing ptx to use. "
355 "The filename prefix must have the same pattern as PTX "
356 "dumped by XLA. This allows to match one specific "
357 "module. General workflow. Get the generated module "
358 "ptx from XLA. Modify it. Then pass it back via this "
359 "option."),
360 tensorflow::Flag(
361 "xla_test_all_output_layouts",
362 bool_setter_for(&DebugOptions::set_xla_test_all_output_layouts),
363 flag_values->xla_test_all_output_layouts(),
364 "Let ClientLibraryTestBase::ComputeAndCompare* test "
365 "all permutations of output layouts. For example, with "
366 "a 3D shape, all permutations of the set {0, 1, 2} are "
367 "tried."),
368 tensorflow::Flag(
369 "xla_test_all_input_layouts",
370 bool_setter_for(&DebugOptions::set_xla_test_all_input_layouts),
371 flag_values->xla_test_all_input_layouts(),
372 "Let ClientLibraryTestBase::ComputeAndCompare* test "
373 "all permutations of *input* layouts. For example, for "
374 "2 input arguments with 2D shape and 4D shape, the "
375 "computation will run 2! * 4! times for every possible "
376 "layouts"),
377 tensorflow::Flag(
378 "xla_hlo_profile",
379 bool_setter_for(&DebugOptions::set_xla_hlo_profile),
380 flag_values->xla_hlo_profile(),
381 "Instrument the computation to collect per-HLO cycle counts"),
382 tensorflow::Flag("xla_backend_extra_options",
383 setter_for_xla_backend_extra_options, "",
384 "Extra options to pass to a backend; "
385 "comma-separated list of 'key=val' strings (=val "
386 "may be omitted); no whitespace around commas."),
387 tensorflow::Flag(
388 "xla_gpu_use_cudnn_batchnorm",
389 bool_setter_for(&DebugOptions::set_xla_gpu_use_cudnn_batchnorm),
390 flag_values->xla_gpu_use_cudnn_batchnorm(),
391 "Allows the GPU backend to implement batchnorm HLOs using cudnn, "
392 "rather than expanding them to a soup of HLOs."),
393 tensorflow::Flag("xla_cpu_use_mkl_dnn",
394 bool_setter_for(&DebugOptions::set_xla_cpu_use_mkl_dnn),
395 flag_values->xla_cpu_use_mkl_dnn(),
396 "Generate calls to MKL-DNN in the CPU backend."),
397 tensorflow::Flag(
398 "xla_gpu_crash_on_verification_failures",
399 bool_setter_for(
400 &DebugOptions::set_xla_gpu_crash_on_verification_failures),
401 flag_values->xla_gpu_crash_on_verification_failures(),
402 "Crashes the program on extra verification failures, e.g. cuDNN "
403 "cross checking failures"),
404 tensorflow::Flag(
405 "xla_gpu_autotune_level",
406 int32_setter_for(&DebugOptions::set_xla_gpu_autotune_level),
407 flag_values->xla_gpu_autotune_level(),
408 "Set GEMM and Convolution auto-tuning level."
409 "0 = off; 1 = on; 2 = on+init; 3 = on+init+reinit; 4 = "
410 "on+init+reinit+check."),
411 tensorflow::Flag(
412 "xla_force_host_platform_device_count",
413 int32_setter_for(
414 &DebugOptions::set_xla_force_host_platform_device_count),
415 flag_values->xla_force_host_platform_device_count(),
416 "Force the host platform to pretend that there are these many "
417 "host \"devices\". All of these host devices are backed by the same"
418 "threadpool. Setting this to anything other than 1 can increase "
419 "overhead from context switching but we let the user override this "
420 "behavior to help run tests on the host that run models in parallel "
421 "across multiple devices."),
422 tensorflow::Flag(
423 "xla_gpu_disable_gpuasm_optimizations",
424 bool_setter_for(
425 &DebugOptions::set_xla_gpu_disable_gpuasm_optimizations),
426 flag_values->xla_gpu_disable_gpuasm_optimizations(),
427 "In XLA:GPU run ptxas in -O0 (default is -O3)."),
428 tensorflow::Flag(
429 "xla_fuel", setter_for_xla_fuel, /*default_value_for_display=*/"",
430 "Sets compiler fuel, useful for bisecting bugs in passes. Format "
431 "--xla_fuel=PASS1=NUM1,PASS2=NUM2,..."),
432
433 tensorflow::Flag(
434 "xla_dump_to", string_setter_for(&DebugOptions::set_xla_dump_to),
435 flag_values->xla_dump_to(),
436 "Directory into which debugging data is written. If not specified "
437 "but another dumping flag is passed, data will be written to stdout. "
438 " To explicitly write to stdout, set this to \"-\". The values "
439 "\"sponge\" and \"test_undeclared_outputs_dir\" have a special "
440 "meaning: They cause us to dump into the directory specified by the "
441 "environment variable TEST_UNDECLARED_OUTPUTS_DIR."),
442 tensorflow::Flag(
443 "xla_dump_hlo_as_text",
444 bool_setter_for(&DebugOptions::set_xla_dump_hlo_as_text),
445 flag_values->xla_dump_hlo_as_text(),
446 "Dumps HLO modules as text before and after optimizations. Results "
447 "are written to the --xla_dump_to dir, or, if no dir is specified, "
448 "to stdout."),
449 tensorflow::Flag(
450 "xla_dump_hlo_as_proto",
451 bool_setter_for(&DebugOptions::set_xla_dump_hlo_as_proto),
452 flag_values->xla_dump_hlo_as_proto(),
453 "Dumps HLO modules as HloProtos to the directory specified by "
454 "--xla_dump_to."),
455 tensorflow::Flag(
456 "xla_dump_hlo_as_dot",
457 bool_setter_for(&DebugOptions::set_xla_dump_hlo_as_dot),
458 flag_values->xla_dump_hlo_as_dot(),
459 "Dumps HLO modules rendered as dot files to the directory "
460 "specified by --xla_dump_to."),
461 tensorflow::Flag("xla_dump_hlo_as_html",
462 bool_setter_for(&DebugOptions::set_xla_dump_hlo_as_html),
463 flag_values->xla_dump_hlo_as_html(),
464 "Dumps HLO modules rendered as HTML files to the "
465 "directory specified by --xla_dump_to."),
466 tensorflow::Flag(
467 "xla_dump_hlo_as_url",
468 bool_setter_for(&DebugOptions::set_xla_dump_hlo_as_url),
469 flag_values->xla_dump_hlo_as_url(),
470 "Tries to dump HLO modules rendered as URLs to stdout (and also to "
471 "the directory specified by --xla_dump_to). This is not implemented "
472 "by default; you need to add a plugin which calls "
473 "RegisterGraphToURLRenderer()."),
474 tensorflow::Flag(
475 "xla_dump_hlo_snapshots",
476 bool_setter_for(&DebugOptions::set_xla_dump_hlo_snapshots),
477 flag_values->xla_dump_hlo_snapshots(),
478 "Every time an HLO module is run, dumps an HloSnapshot to the "
479 "directory specified by --xla_dump_to."),
480 tensorflow::Flag(
481 "xla_dump_hlo_module_re",
482 string_setter_for(&DebugOptions::set_xla_dump_hlo_module_re),
483 flag_values->xla_dump_hlo_module_re(),
484 "Limits dumping only to modules which match this regular expression. "
485 " Default is to dump all modules."),
486 tensorflow::Flag(
487 "xla_dump_hlo_pass_re",
488 string_setter_for(&DebugOptions::set_xla_dump_hlo_pass_re),
489 flag_values->xla_dump_hlo_pass_re(),
490 "If specified, dumps HLO before and after optimization passes which "
491 "match this regular expression, in addition to dumping at the very "
492 "beginning and end of compilation."),
493 tensorflow::Flag(
494 "xla_dump_include_timestamp",
495 bool_setter_for(&DebugOptions::set_xla_dump_include_timestamp),
496 flag_values->xla_dump_include_timestamp(),
497 "If specified, includes a timestamp in the dumped filenames."),
498 tensorflow::Flag(
499 "xla_dump_max_hlo_modules",
500 int32_setter_for(&DebugOptions::set_xla_dump_max_hlo_modules),
501 flag_values->xla_dump_max_hlo_modules(),
502 "Max number of hlo module dumps in a directory. Set to < 0 for "
503 "unbounded."),
504 tensorflow::Flag(
505 "xla_hlo_graph_addresses",
506 bool_setter_for(&DebugOptions::set_xla_hlo_graph_addresses),
507 flag_values->xla_hlo_graph_addresses(),
508 "When rendering graphs (--xla_dump_hlo_as_{dot,html,url}), displays "
509 "the address in memory of each HloInstruction object."),
510 tensorflow::Flag(
511 "xla_hlo_graph_sharding_color",
512 bool_setter_for(&DebugOptions::set_xla_hlo_graph_sharding_color),
513 flag_values->xla_hlo_graph_sharding_color(),
514 "Assign colors based on sharding assignments when generating the "
515 "HLO graphs."),
516 tensorflow::Flag(
517 "xla_allow_excess_precision",
518 bool_setter_for(&DebugOptions::set_xla_allow_excess_precision),
519 flag_values->xla_allow_excess_precision(),
520 "Allow xla to increase the output precision of an instruction."),
521 tensorflow::Flag(
522 "xla_gpu_force_conv_nchw",
523 bool_setter_for(&DebugOptions::set_xla_gpu_force_conv_nchw),
524 flag_values->xla_gpu_force_conv_nchw(),
525 "For cuDNN convolutions, always NCHW layouts."),
526 tensorflow::Flag("xla_gpu_algorithm_blacklist_path",
527 string_setter_for(
528 &DebugOptions::set_xla_gpu_algorithm_blacklist_path),
529 flag_values->xla_gpu_algorithm_blacklist_path(),
530 "An AlgorithmBlacklist text proto file as a blacklist "
531 "of convolutions to avoid to use."),
532
533 tensorflow::Flag(
534 "xla_gpu_deterministic_reductions",
535 bool_setter_for(&DebugOptions::set_xla_gpu_deterministic_reductions),
536 flag_values->xla_gpu_deterministic_reductions(),
537 "Always run deterministic reductions on GPU"),
538 });
539 ParseFlagsFromEnvAndDieIfUnknown("XLA_FLAGS", *flag_objects);
540 }
541
AppendDebugOptionsFlags(std::vector<tensorflow::Flag> * flag_list)542 void AppendDebugOptionsFlags(std::vector<tensorflow::Flag>* flag_list) {
543 absl::call_once(flags_init, &AllocateFlags);
544 flag_list->insert(flag_list->end(), flag_objects->begin(),
545 flag_objects->end());
546 }
547
GetDebugOptionsFromFlags()548 xla::DebugOptions GetDebugOptionsFromFlags() {
549 absl::call_once(flags_init, &AllocateFlags);
550 return *flag_values;
551 }
552
ResetThreadLocalFuel()553 void ResetThreadLocalFuel() {
554 absl::call_once(flags_init, &AllocateFlags);
555
556 thread_fuel.reset(new absl::node_hash_map<string, std::atomic<int64>>());
557 CHECK(initial_fuel != nullptr);
558 for (const auto& kv : *initial_fuel) {
559 thread_fuel->emplace(kv.first, kv.second);
560 }
561 }
562
ConsumeFuel(absl::string_view pass,bool * just_ran_out)563 bool ConsumeFuel(absl::string_view pass, bool* just_ran_out) {
564 absl::call_once(flags_init, &AllocateFlags);
565 if (just_ran_out != nullptr) {
566 *just_ran_out = false;
567 }
568 auto* fuel_pool = thread_fuel ? thread_fuel.get() : global_fuel;
569 if (fuel_pool->empty()) {
570 return true;
571 }
572 auto it = fuel_pool->find(pass);
573 if (it == fuel_pool->end()) {
574 return true;
575 }
576 std::atomic<int64>& remaining_fuel = it->second;
577 std::atomic<bool>& fuel_has_been_consumed = fuel_ever_consumed->at(pass);
578 fuel_has_been_consumed = true;
579
580 int64 remaining = remaining_fuel.fetch_sub(1);
581 if (just_ran_out != nullptr) {
582 *just_ran_out = remaining == 0;
583 }
584 return remaining > 0;
585 }
586
587 } // namespace xla
588