1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/debug_options_flags.h"
17
18 #include <vector>
19
20 #include "absl/base/call_once.h"
21 #include "absl/container/flat_hash_map.h"
22 #include "absl/container/node_hash_map.h"
23 #include "absl/strings/str_format.h"
24 #include "absl/strings/str_split.h"
25 #include "tensorflow/compiler/xla/debug_options_parsers.h"
26 #include "tensorflow/compiler/xla/parse_flags_from_env.h"
27
28 namespace xla {
29
DefaultDebugOptionsIgnoringFlags()30 DebugOptions DefaultDebugOptionsIgnoringFlags() {
31 DebugOptions opts;
32 opts.set_xla_llvm_enable_alias_scope_metadata(true);
33 opts.set_xla_llvm_enable_noalias_metadata(true);
34 opts.set_xla_llvm_enable_invariant_load_metadata(true);
35 opts.set_xla_llvm_disable_expensive_passes(false);
36 opts.set_xla_backend_optimization_level(3);
37 opts.set_xla_gpu_autotune_level(4);
38 opts.set_xla_cpu_multi_thread_eigen(true);
39 opts.set_xla_gpu_cuda_data_dir("./cuda_sdk_lib");
40 opts.set_xla_gpu_asm_extra_flags("");
41 opts.set_xla_eliminate_hlo_implicit_broadcast(true);
42 opts.set_xla_dump_hlo_as_html(false);
43 opts.set_xla_dump_fusion_visualization(false);
44 opts.set_xla_dump_include_timestamp(true);
45 opts.set_xla_dump_max_hlo_modules(-1);
46 opts.set_xla_dump_module_metadata(false);
47 #ifdef ENABLE_MKL
48 opts.set_xla_cpu_use_mkl_dnn(true);
49 #endif // ENABLE_MKL
50 opts.set_xla_gpu_max_kernel_unroll_factor(4);
51 // Set cudnn batchnorm off by default; it does not provide a performance win
52 // on average.
53 opts.set_xla_gpu_use_cudnn_batchnorm(false);
54
55 // Run all GPU work on one stream by default. Using multiple streams
56 // increases memory usage and we lack strong motivating benchmarks for tuning
57 // the heuristics needed to decide when to run on multiple streams. See
58 // b/77879207.
59 opts.set_xla_gpu_disable_multi_streaming(true);
60
61 // Disable forms of fast math that have caused users problems in the past.
62 opts.set_xla_cpu_enable_fast_math(true);
63 opts.set_xla_cpu_fast_math_honor_nans(true);
64 opts.set_xla_cpu_fast_math_honor_infs(true);
65 opts.set_xla_cpu_fast_math_honor_functions(true);
66 opts.set_xla_cpu_fast_math_honor_division(true);
67
68 // By default, copy TF's Eigen style min_max behavior with nans.
69 opts.set_xla_cpu_enable_fast_min_max(true);
70
71 opts.set_xla_gpu_enable_fast_min_max(true);
72
73 opts.set_xla_allow_excess_precision(true);
74 opts.set_xla_force_host_platform_device_count(1);
75 opts.set_xla_gpu_deterministic_reductions(false);
76 opts.set_xla_cpu_enable_xprof_traceme(false);
77 opts.set_xla_gpu_unsafe_fallback_to_driver_on_ptxas_not_found(false);
78 opts.set_xla_multiheap_size_constraint_per_heap(-1);
79 opts.set_xla_detailed_logging_and_dumping(true);
80 return opts;
81 }
82
83 static absl::once_flag flags_init;
84 static DebugOptions* flag_values;
85 static std::vector<tensorflow::Flag>* flag_objects;
86
87 // Maps pass -> initial fuel values (parsed when AllocateFlags was run).
88 static absl::flat_hash_map<string, int64>* initial_fuel;
89
90 // Maps pass -> whether fuel was ever consumed for that pass.
91 static absl::node_hash_map<string, std::atomic<bool>>* fuel_ever_consumed;
92
93 // Maps pass -> remaining fuel.
94 //
95 // All threads start off using this global fuel pool, but ResetThreadLocalFuel()
96 // switches them to a thread-local fuel pool.
97 static absl::node_hash_map<string, std::atomic<int64>>* global_fuel;
98
99 // If we're using thread-local fuel, this stores it.
100 static thread_local std::unique_ptr<
101 absl::node_hash_map<string, std::atomic<int64>>>
102 thread_fuel; // NOLINT (global variable with nontrivial destructor)
103
104 // Logs a warning if a pass's fuel was never consumed, on the theory that this
105 // may be a typo in the flag value. Called atexit.
WarnIfFuelWasNeverConsumed()106 static void WarnIfFuelWasNeverConsumed() {
107 CHECK(fuel_ever_consumed != nullptr);
108 for (const auto& kv : *fuel_ever_consumed) {
109 absl::string_view pass = kv.first;
110 bool was_consumed = kv.second;
111 if (!was_consumed) {
112 LOG(ERROR) << absl::StreamFormat(
113 "Compiler fuel for \"%s\" was never consumed. This may be a typo in "
114 "the --xla_fuel flag you passed.",
115 pass);
116 }
117 }
118 }
119
120 // Allocates flag_values and flag_objects; this function must not be called more
121 // than once - its call done via call_once.
AllocateFlags()122 static void AllocateFlags() {
123 flag_values = new DebugOptions(DefaultDebugOptionsIgnoringFlags());
124
125 // Returns a lambda that calls "member_setter" on "flag_values" with the
126 // argument passed in to the lambda.
127 auto bool_setter_for = [](void (DebugOptions::*member_setter)(bool)) {
128 return [member_setter](bool value) {
129 (flag_values->*member_setter)(value);
130 return true;
131 };
132 };
133
134 // Returns a lambda that calls "member_setter" on "flag_values" with the
135 // argument passed in to the lambda.
136 auto int32_setter_for = [](void (DebugOptions::*member_setter)(int32_t)) {
137 return [member_setter](int32_t value) {
138 (flag_values->*member_setter)(value);
139 return true;
140 };
141 };
142
143 auto string_setter_for =
144 [](void (DebugOptions::*member_setter)(const string& value)) {
145 return [member_setter](const string& value) {
146 (flag_values->*member_setter)(value);
147 return true;
148 };
149 };
150
151 // Custom "sub-parser" lambda for xla_disable_hlo_passes.
152 auto setter_for_xla_disable_hlo_passes = [](string comma_separated_values) {
153 for (const auto& passname :
154 std::vector<string>(absl::StrSplit(comma_separated_values, ','))) {
155 flag_values->add_xla_disable_hlo_passes(passname);
156 }
157 return true;
158 };
159
160 // Custom "sub-parser" lambda for xla_enable_hlo_passes_only.
161 auto setter_for_xla_enable_hlo_passes_only =
162 [](string comma_separated_values) {
163 for (const auto& passname :
164 std::vector<string>(absl::StrSplit(comma_separated_values, ','))) {
165 flag_values->add_xla_enable_hlo_passes_only(passname);
166 }
167 return true;
168 };
169
170 // Custom "sub-parser" lambda for xla_gpu_ptx_file.
171 auto setter_for_xla_gpu_ptx_file = [](string value) {
172 flag_values->add_xla_gpu_ptx_file(value);
173 return true;
174 };
175
176 // Custom "sub-parser" lambda for xla_gpu_llvm_ir_file.
177 auto setter_for_xla_gpu_llvm_ir_file = [](const string& value) {
178 flag_values->add_xla_gpu_llvm_ir_file(value);
179 return true;
180 };
181
182 // Custom "sub-parser" lambda for xla_backend_extra_options.
183 auto setter_for_xla_backend_extra_options =
184 [](string comma_separated_values) {
185 auto* extra_options_map =
186 flag_values->mutable_xla_backend_extra_options();
187 parse_xla_backend_extra_options(extra_options_map,
188 comma_separated_values);
189 return true;
190 };
191
192 // Custom "sub-parser" for xla_fuel. Note that ConsumeFuel does not do any
193 // locking on the fuel global variables. This means that it's
194 // illegal/undefined behavior to modify this flag value while the compiler is
195 // running.
196 initial_fuel = new absl::flat_hash_map<string, int64>();
197 fuel_ever_consumed = new absl::node_hash_map<string, std::atomic<bool>>();
198 global_fuel = new absl::node_hash_map<string, std::atomic<int64>>();
199 auto setter_for_xla_fuel = [](string xla_fuel_value) {
200 initial_fuel->clear();
201 global_fuel->clear();
202 fuel_ever_consumed->clear();
203
204 for (const auto& kv : absl::StrSplit(xla_fuel_value, ',')) {
205 std::vector<string> pass_and_fuel = absl::StrSplit(kv, '=');
206 if (pass_and_fuel.size() != 2) {
207 LOG(ERROR) << absl::StreamFormat(
208 "Illegal value for --xla_fuel. Saw %s, but expected token %s to "
209 "have format X=INTEGER.",
210 xla_fuel_value, kv);
211 return false;
212 }
213 const auto& pass = pass_and_fuel[0];
214 const auto& fuel_str = pass_and_fuel[1];
215 int64_t fuel;
216 if (!absl::SimpleAtoi(fuel_str, &fuel)) {
217 LOG(ERROR) << absl::StreamFormat(
218 "Illegal value for --xla_fuel. Saw %s, but expected token %s to be "
219 "an integer.",
220 xla_fuel_value, fuel_str);
221 return false;
222 }
223 initial_fuel->emplace(pass, fuel);
224 global_fuel->emplace(pass, fuel);
225 fuel_ever_consumed->emplace(pass, false);
226 }
227
228 // If --xla_fuel was specified, register an atexit handler which logs a
229 // warning if a pass was specified but never consumed any fuel, on the
230 // theory that this is may be a typo.
231 if (!initial_fuel->empty()) {
232 static absl::once_flag register_atexit_once;
233 absl::call_once(
234 register_atexit_once,
235 +[] { std::atexit(WarnIfFuelWasNeverConsumed); });
236 }
237 return true;
238 };
239
240 flag_objects = new std::vector<tensorflow::Flag>();
241 // Don't use an initializer list for initializing the vector; this would
242 // create a temporary copy, and exceeds the stack space when compiling with
243 // certain configurations.
244 flag_objects->push_back(tensorflow::Flag(
245 "xla_cpu_enable_fast_math",
246 bool_setter_for(&DebugOptions::set_xla_cpu_enable_fast_math),
247 flag_values->xla_cpu_enable_fast_math(),
248 "Enable unsafe fast-math optimizations in the CPU compiler; this may "
249 "produce faster code at the expense of some accuracy."));
250 flag_objects->push_back(tensorflow::Flag(
251 "xla_cpu_fast_math_honor_nans",
252 bool_setter_for(&DebugOptions::set_xla_cpu_fast_math_honor_nans),
253 flag_values->xla_cpu_fast_math_honor_nans(),
254 "When xla_cpu_enable_fast_math is true then this controls whether we "
255 "allow operations to produce NaNs. Ignored when "
256 "xla_cpu_enable_fast_math is false."));
257 flag_objects->push_back(tensorflow::Flag(
258 "xla_cpu_fast_math_honor_infs",
259 bool_setter_for(&DebugOptions::set_xla_cpu_fast_math_honor_infs),
260 flag_values->xla_cpu_fast_math_honor_infs(),
261 "When xla_cpu_enable_fast_math is true then this controls whether we "
262 "allow operations to produce infinites. Ignored when "
263 "xla_cpu_enable_fast_math is false."));
264 flag_objects->push_back(tensorflow::Flag(
265 "xla_cpu_fast_math_honor_division",
266 bool_setter_for(&DebugOptions::set_xla_cpu_fast_math_honor_division),
267 flag_values->xla_cpu_fast_math_honor_division(),
268 "When xla_cpu_enable_fast_math is true then this controls whether we "
269 "forbid to use multiplication by the reciprocal instead of division. "
270 "Ignored when xla_cpu_enable_fast_math is false."));
271 flag_objects->push_back(tensorflow::Flag(
272 "xla_cpu_fast_math_honor_functions",
273 bool_setter_for(&DebugOptions::set_xla_cpu_fast_math_honor_functions),
274 flag_values->xla_cpu_fast_math_honor_functions(),
275 "When xla_cpu_enable_fast_math is true then this controls whether we "
276 "forbid to approximate calculations for functions. Ignored when "
277 "xla_cpu_enable_fast_math is false."));
278 flag_objects->push_back(tensorflow::Flag(
279 "xla_cpu_enable_fast_min_max",
280 bool_setter_for(&DebugOptions::set_xla_cpu_enable_fast_min_max),
281 flag_values->xla_cpu_enable_fast_min_max(),
282 "Enable fast floating point min/max lowering that always propagates "
283 "NaNs."));
284 flag_objects->push_back(tensorflow::Flag(
285 "xla_gpu_enable_fast_min_max",
286 bool_setter_for(&DebugOptions::set_xla_gpu_enable_fast_min_max),
287 flag_values->xla_gpu_enable_fast_min_max(),
288 "Enable fast floating point min/max lowering that does not propagate "
289 "NaNs."));
290 flag_objects->push_back(tensorflow::Flag(
291 "xla_llvm_enable_alias_scope_metadata",
292 bool_setter_for(&DebugOptions::set_xla_llvm_enable_alias_scope_metadata),
293 flag_values->xla_llvm_enable_alias_scope_metadata(),
294 "In LLVM-based backends, enable the emission of !alias.scope metadata in "
295 "the generated IR."));
296 flag_objects->push_back(tensorflow::Flag(
297 "xla_llvm_enable_noalias_metadata",
298 bool_setter_for(&DebugOptions::set_xla_llvm_enable_noalias_metadata),
299 flag_values->xla_llvm_enable_noalias_metadata(),
300 "In LLVM-based backends, enable the emission of !noalias metadata in the "
301 "generated IR."));
302 flag_objects->push_back(tensorflow::Flag(
303 "xla_llvm_enable_invariant_load_metadata",
304 bool_setter_for(
305 &DebugOptions::set_xla_llvm_enable_invariant_load_metadata),
306 flag_values->xla_llvm_enable_invariant_load_metadata(),
307 "In LLVM-based backends, enable the emission of !invariant.load metadata "
308 "in the generated IR."));
309 flag_objects->push_back(tensorflow::Flag(
310 "xla_llvm_disable_expensive_passes",
311 bool_setter_for(&DebugOptions::set_xla_llvm_disable_expensive_passes),
312 flag_values->xla_llvm_disable_expensive_passes(),
313 "In LLVM-based backends, disable a custom set of expensive optimization "
314 "passes."));
315 flag_objects->push_back(tensorflow::Flag(
316 "xla_backend_optimization_level",
317 int32_setter_for(&DebugOptions::set_xla_backend_optimization_level),
318 flag_values->xla_backend_optimization_level(),
319 "Numerical optimization level for the XLA compiler backend."));
320 flag_objects->push_back(tensorflow::Flag(
321 "xla_disable_hlo_passes", setter_for_xla_disable_hlo_passes, "",
322 "Comma-separated list of hlo passes to be disabled. These names must "
323 "exactly match the passes' names; no whitespace around commas."));
324 flag_objects->push_back(tensorflow::Flag(
325 "xla_enable_hlo_passes_only", setter_for_xla_enable_hlo_passes_only, "",
326 "Comma-separated list of hlo passes to be enabled. These names must "
327 "exactly match the passes' names; no whitespace around commas. The "
328 "unspecified passes are all disabled."));
329 flag_objects->push_back(tensorflow::Flag(
330 "xla_disable_all_hlo_passes",
331 bool_setter_for(&DebugOptions::set_xla_disable_all_hlo_passes), false,
332 "Disables all HLO passes. Notes that some passes are necessary for "
333 "correctness and the invariants that must be satisfied by 'fully "
334 "optimized' HLO are different for different devices and may change "
335 "over time. The only 'guarantee', such as it is, is that if you compile "
336 "XLA and dump the optimized HLO for some graph, you should be able to "
337 "run it again on the same device with the same build of XLA."));
338 flag_objects->push_back(tensorflow::Flag(
339 "xla_embed_ir_in_executable",
340 bool_setter_for(&DebugOptions::set_xla_embed_ir_in_executable),
341 flag_values->xla_embed_ir_in_executable(),
342 "Embed the compiler IR as a string in the executable."));
343 flag_objects->push_back(tensorflow::Flag(
344 "xla_eliminate_hlo_implicit_broadcast",
345 bool_setter_for(&DebugOptions::set_xla_eliminate_hlo_implicit_broadcast),
346 flag_values->xla_eliminate_hlo_implicit_broadcast(),
347 "Eliminate implicit broadcasts when lowering user computations to HLO "
348 "instructions; use explicit broadcast instead."));
349 flag_objects->push_back(tensorflow::Flag(
350 "xla_cpu_multi_thread_eigen",
351 bool_setter_for(&DebugOptions::set_xla_cpu_multi_thread_eigen),
352 flag_values->xla_cpu_multi_thread_eigen(),
353 "When generating calls to Eigen in the CPU backend, use multi-threaded "
354 "Eigen mode."));
355 flag_objects->push_back(tensorflow::Flag(
356 "xla_gpu_cuda_data_dir", flag_values->mutable_xla_gpu_cuda_data_dir(),
357 "If non-empty, specifies a local directory containing ptxas and nvvm "
358 "libdevice files; otherwise we use those from runfile directories."));
359 flag_objects->push_back(tensorflow::Flag(
360 "xla_gpu_ftz", bool_setter_for(&DebugOptions::set_xla_gpu_ftz),
361 flag_values->xla_gpu_ftz(),
362 "If true, flush-to-zero semantics are enabled in the code generated for "
363 "GPUs."));
364 flag_objects->push_back(tensorflow::Flag(
365 "xla_gpu_disable_multi_streaming",
366 bool_setter_for(&DebugOptions::set_xla_gpu_disable_multi_streaming),
367 flag_values->xla_gpu_disable_multi_streaming(),
368 "If true, multi-streaming in the GPU backend is disabled."));
369 flag_objects->push_back(tensorflow::Flag(
370 "xla_gpu_max_kernel_unroll_factor",
371 int32_setter_for(&DebugOptions::set_xla_gpu_max_kernel_unroll_factor),
372 flag_values->xla_gpu_max_kernel_unroll_factor(),
373 "Specify the maximum kernel unroll factor for the GPU backend."));
374 flag_objects->push_back(tensorflow::Flag(
375 "xla_gpu_ptx_file", setter_for_xla_gpu_ptx_file, "",
376 "If non-empty, specifies a file containing ptx to use. The filename "
377 "prefix must have the same pattern as PTX dumped by XLA. This allows to "
378 "match one specific module. General workflow. Get the generated module "
379 "ptx from XLA, modify it, then pass it back via this option."));
380 flag_objects->push_back(tensorflow::Flag(
381 "xla_gpu_llvm_ir_file", setter_for_xla_gpu_llvm_ir_file, "",
382 "If non-empty, specifies a file containing textual LLVM IR to use. The "
383 "filename prefix must have the same pattern as LLVM dumped by XLA "
384 "(i.e. module_0001.ir-no-opt.ll -> module_0001.MY_NEW_FILE.ll). This "
385 "allows to match one specific module. General workflow. Get the not "
386 "optimized LLVM IR from XLA, modify it, then pass it back via this "
387 "option."));
388 flag_objects->push_back(tensorflow::Flag(
389 "xla_test_all_output_layouts",
390 bool_setter_for(&DebugOptions::set_xla_test_all_output_layouts),
391 flag_values->xla_test_all_output_layouts(),
392 "Let ClientLibraryTestBase::ComputeAndCompare* test all permutations of "
393 "output layouts. For example, with a 3D shape, all permutations of the "
394 "set {0, 1, 2} are tried."));
395 flag_objects->push_back(tensorflow::Flag(
396 "xla_test_all_input_layouts",
397 bool_setter_for(&DebugOptions::set_xla_test_all_input_layouts),
398 flag_values->xla_test_all_input_layouts(),
399 "Let ClientLibraryTestBase::ComputeAndCompare* test all permutations of "
400 "*input* layouts. For example, for 2 input arguments with 2D shape and "
401 "4D shape, the computation will run 2! * 4! times for every possible "
402 "layouts"));
403 flag_objects->push_back(tensorflow::Flag(
404 "xla_hlo_profile", bool_setter_for(&DebugOptions::set_xla_hlo_profile),
405 flag_values->xla_hlo_profile(),
406 "Instrument the computation to collect per-HLO cycle counts"));
407 flag_objects->push_back(tensorflow::Flag(
408 "xla_backend_extra_options", setter_for_xla_backend_extra_options, "",
409 "Extra options to pass to a backend; comma-separated list of 'key=val' "
410 "strings (=val may be omitted); no whitespace around commas."));
411 flag_objects->push_back(tensorflow::Flag(
412 "xla_gpu_use_cudnn_batchnorm",
413 bool_setter_for(&DebugOptions::set_xla_gpu_use_cudnn_batchnorm),
414 flag_values->xla_gpu_use_cudnn_batchnorm(),
415 "Allows the GPU backend to implement batchnorm HLOs using cudnn, rather "
416 "than expanding them to a soup of HLOs."));
417 flag_objects->push_back(
418 tensorflow::Flag("xla_cpu_use_mkl_dnn",
419 bool_setter_for(&DebugOptions::set_xla_cpu_use_mkl_dnn),
420 flag_values->xla_cpu_use_mkl_dnn(),
421 "Generate calls to MKL-DNN in the CPU backend."));
422 flag_objects->push_back(tensorflow::Flag(
423 "xla_gpu_crash_on_verification_failures",
424 bool_setter_for(
425 &DebugOptions::set_xla_gpu_crash_on_verification_failures),
426 flag_values->xla_gpu_crash_on_verification_failures(),
427 "Crashes the program on extra verification failures, e.g. cuDNN cross "
428 "checking failures"));
429 flag_objects->push_back(tensorflow::Flag(
430 "xla_gpu_autotune_level",
431 int32_setter_for(&DebugOptions::set_xla_gpu_autotune_level),
432 flag_values->xla_gpu_autotune_level(),
433 "Set GEMM and Convolution auto-tuning level. 0 = off; 1 = on; 2 = "
434 "on+init; 3 = on+init+reinit; 4 = on+init+reinit+check."));
435 flag_objects->push_back(tensorflow::Flag(
436 "xla_force_host_platform_device_count",
437 int32_setter_for(&DebugOptions::set_xla_force_host_platform_device_count),
438 flag_values->xla_force_host_platform_device_count(),
439 "Force the host platform to pretend that there are these many host "
440 "\"devices\". All of these host devices are backed by the same "
441 "threadpool. Setting this to anything other than 1 can increase overhead "
442 "from context switching but we let the user override this behavior to "
443 "help run tests on the host that run models in parallel across multiple "
444 "devices."));
445 flag_objects->push_back(tensorflow::Flag(
446 "xla_gpu_disable_gpuasm_optimizations",
447 bool_setter_for(&DebugOptions::set_xla_gpu_disable_gpuasm_optimizations),
448 flag_values->xla_gpu_disable_gpuasm_optimizations(),
449 "In XLA:GPU run ptxas in -O0 (default is -O3)."));
450 flag_objects->push_back(tensorflow::Flag(
451 "xla_gpu_asm_extra_flags",
452 string_setter_for(&DebugOptions::set_xla_gpu_asm_extra_flags), "",
453 "Pass extra parameters to the GPU assembler tool (i.e., ptxas for CUDA). "
454 "If multiple parameters, separate them by comma."));
455 flag_objects->push_back(tensorflow::Flag(
456 "xla_fuel", setter_for_xla_fuel, /*default_value_for_display=*/"",
457 "Sets compiler fuel, useful for bisecting bugs in passes. Format "
458 "--xla_fuel=PASS1=NUM1,PASS2=NUM2,..."));
459 flag_objects->push_back(tensorflow::Flag(
460 "xla_dump_to", string_setter_for(&DebugOptions::set_xla_dump_to),
461 flag_values->xla_dump_to(),
462 "Directory into which debugging data is written. If not specified but "
463 "another dumping flag is passed, data will be written to stdout. To "
464 "explicitly write to stdout, set this to \"-\". The values \"sponge\" "
465 "and \"test_undeclared_outputs_dir\" have a special meaning: They cause "
466 "us to dump into the directory specified by the environment variable "
467 "TEST_UNDECLARED_OUTPUTS_DIR."));
468 flag_objects->push_back(tensorflow::Flag(
469 "xla_dump_hlo_as_text",
470 bool_setter_for(&DebugOptions::set_xla_dump_hlo_as_text),
471 flag_values->xla_dump_hlo_as_text(),
472 "Dumps HLO modules as text before and after optimizations. Results are "
473 "written to the --xla_dump_to dir, or, if no dir is specified, to "
474 "stdout."));
475 flag_objects->push_back(tensorflow::Flag(
476 "xla_dump_hlo_as_proto",
477 bool_setter_for(&DebugOptions::set_xla_dump_hlo_as_proto),
478 flag_values->xla_dump_hlo_as_proto(),
479 "Dumps HLO modules as HloProtos to the directory specified by "
480 "--xla_dump_to."));
481 flag_objects->push_back(
482 tensorflow::Flag("xla_dump_hlo_as_dot",
483 bool_setter_for(&DebugOptions::set_xla_dump_hlo_as_dot),
484 flag_values->xla_dump_hlo_as_dot(),
485 "Dumps HLO modules rendered as dot files to the "
486 "directory specified by --xla_dump_to."));
487 flag_objects->push_back(
488 tensorflow::Flag("xla_dump_hlo_as_html",
489 bool_setter_for(&DebugOptions::set_xla_dump_hlo_as_html),
490 flag_values->xla_dump_hlo_as_html(),
491 "Dumps HLO modules rendered as HTML files to the "
492 "directory specified by --xla_dump_to."));
493 flag_objects->push_back(tensorflow::Flag(
494 "xla_dump_hlo_as_url",
495 bool_setter_for(&DebugOptions::set_xla_dump_hlo_as_url),
496 flag_values->xla_dump_hlo_as_url(),
497 "Tries to dump HLO modules rendered as URLs to stdout (and also to the "
498 "directory specified by --xla_dump_to). This is not implemented by "
499 "default; you need to add a plugin which calls "
500 "RegisterGraphToURLRenderer()."));
501 flag_objects->push_back(tensorflow::Flag(
502 "xla_dump_fusion_visualization",
503 bool_setter_for(&DebugOptions::set_xla_dump_fusion_visualization),
504 flag_values->xla_dump_fusion_visualization(),
505 "Tries to generate HLO fusion visualization as an HTML page to the "
506 "directory specified by --xla_dump_to). This is not implemented by "
507 "default; you need to add a plugin which calls "
508 "RegisterGraphToURLRenderer(). Generates a file per computation. "
509 "Currently only implemented for the GPU backend."));
510 flag_objects->push_back(tensorflow::Flag(
511 "xla_dump_hlo_snapshots",
512 bool_setter_for(&DebugOptions::set_xla_dump_hlo_snapshots),
513 flag_values->xla_dump_hlo_snapshots(),
514 "Every time an HLO module is run, dumps an HloSnapshot to the directory "
515 "specified by --xla_dump_to."));
516 flag_objects->push_back(tensorflow::Flag(
517 "xla_dump_hlo_module_re",
518 string_setter_for(&DebugOptions::set_xla_dump_hlo_module_re),
519 flag_values->xla_dump_hlo_module_re(),
520 "Limits dumping only to modules which match this regular expression. "
521 "Default is to dump all modules."));
522 flag_objects->push_back(tensorflow::Flag(
523 "xla_dump_hlo_pass_re",
524 string_setter_for(&DebugOptions::set_xla_dump_hlo_pass_re),
525 flag_values->xla_dump_hlo_pass_re(),
526 "If specified, dumps HLO before and after optimization passes which "
527 "match this regular expression, in addition to dumping at the very "
528 "beginning and end of compilation."));
529 flag_objects->push_back(tensorflow::Flag(
530 "xla_dump_include_timestamp",
531 bool_setter_for(&DebugOptions::set_xla_dump_include_timestamp),
532 flag_values->xla_dump_include_timestamp(),
533 "If specified, includes a timestamp in the dumped filenames."));
534 flag_objects->push_back(tensorflow::Flag(
535 "xla_dump_max_hlo_modules",
536 int32_setter_for(&DebugOptions::set_xla_dump_max_hlo_modules),
537 flag_values->xla_dump_max_hlo_modules(),
538 "Max number of hlo module dumps in a directory. Set to < 0 for "
539 "unbounded."));
540 flag_objects->push_back(tensorflow::Flag(
541 "xla_dump_module_metadata",
542 bool_setter_for(&DebugOptions::set_xla_dump_module_metadata),
543 flag_values->xla_dump_module_metadata(),
544 "Dumps HloModuleMetadata as text protos to the directory specified "
545 "by --xla_dump_to."));
546 flag_objects->push_back(tensorflow::Flag(
547 "xla_dump_compress_protos",
548 bool_setter_for(&DebugOptions::set_xla_dump_compress_protos),
549 flag_values->xla_dump_compress_protos(),
550 "Gzip-compress protos dumped by --xla_dump_hlo_as_proto."));
551 flag_objects->push_back(tensorflow::Flag(
552 "xla_hlo_graph_addresses",
553 bool_setter_for(&DebugOptions::set_xla_hlo_graph_addresses),
554 flag_values->xla_hlo_graph_addresses(),
555 "When rendering graphs (--xla_dump_hlo_as_{dot,html,url}), displays "
556 "the address in memory of each HloInstruction object."));
557 flag_objects->push_back(tensorflow::Flag(
558 "xla_hlo_graph_sharding_color",
559 bool_setter_for(&DebugOptions::set_xla_hlo_graph_sharding_color),
560 flag_values->xla_hlo_graph_sharding_color(),
561 "Assign colors based on sharding assignments when generating the HLO "
562 "graphs."));
563 flag_objects->push_back(tensorflow::Flag(
564 "xla_allow_excess_precision",
565 bool_setter_for(&DebugOptions::set_xla_allow_excess_precision),
566 flag_values->xla_allow_excess_precision(),
567 "Allow xla to increase the output precision of an instruction."));
568 flag_objects->push_back(tensorflow::Flag(
569 "xla_gpu_force_conv_nchw",
570 bool_setter_for(&DebugOptions::set_xla_gpu_force_conv_nchw),
571 flag_values->xla_gpu_force_conv_nchw(),
572 "For cuDNN convolutions, always use NCHW layouts."));
573 flag_objects->push_back(tensorflow::Flag(
574 "xla_gpu_force_conv_nhwc",
575 bool_setter_for(&DebugOptions::set_xla_gpu_force_conv_nhwc),
576 flag_values->xla_gpu_force_conv_nhwc(),
577 "For cuDNN convolutions, always use NHWC layouts."));
578 flag_objects->push_back(tensorflow::Flag(
579 "xla_gpu_algorithm_denylist_path",
580 string_setter_for(&DebugOptions::set_xla_gpu_algorithm_denylist_path),
581 flag_values->xla_gpu_algorithm_denylist_path(),
582 "An AlgorithmDenylist text proto file as a denylist of convolutions to "
583 "avoid to use."));
584 flag_objects->push_back(tensorflow::Flag(
585 "xla_gpu_deterministic_reductions",
586 bool_setter_for(&DebugOptions::set_xla_gpu_deterministic_reductions),
587 flag_values->xla_gpu_deterministic_reductions(),
588 "Always run deterministic reductions on GPU"));
589 flag_objects->push_back(tensorflow::Flag(
590 "xla_tpu_detect_nan",
591 bool_setter_for(&DebugOptions::set_xla_tpu_detect_nan),
592 flag_values->xla_tpu_detect_nan(),
593 "Trigger error on execution on TPU if a NAN value is detected"));
594 flag_objects->push_back(tensorflow::Flag(
595 "xla_tpu_detect_inf",
596 bool_setter_for(&DebugOptions::set_xla_tpu_detect_inf),
597 flag_values->xla_tpu_detect_inf(),
598 "Trigger error on execution on TPU if a INF value is detected"));
599 flag_objects->push_back(tensorflow::Flag(
600 "xla_cpu_enable_xprof_traceme",
601 bool_setter_for(&DebugOptions::set_xla_cpu_enable_xprof_traceme),
602 flag_values->xla_cpu_enable_xprof_traceme(),
603 "If true, XLA CPU generates code to call "
604 "TraceMe::Activity{Start|End} around HLO operations."));
605 flag_objects->push_back(tensorflow::Flag(
606 "xla_gpu_unsafe_fallback_to_driver_on_ptxas_not_found",
607 bool_setter_for(
608 &DebugOptions::
609 set_xla_gpu_unsafe_fallback_to_driver_on_ptxas_not_found),
610 flag_values->xla_gpu_unsafe_fallback_to_driver_on_ptxas_not_found(),
611 "If true, XLA GPU falls back to the driver if ptxas is not found. Note "
612 "that falling back to the driver can have drawbacks like using more "
613 "memory and/or other bugs during compilation, so we recommend setting "
614 "this flag to false."));
615 flag_objects->push_back(tensorflow::Flag(
616 "xla_multiheap_size_constraint_per_heap",
617 int32_setter_for(
618 &DebugOptions::set_xla_multiheap_size_constraint_per_heap),
619 flag_values->xla_multiheap_size_constraint_per_heap(),
620 "Generates multiple heaps (i.e., temp buffers) with a size "
621 "constraint on each heap to avoid Out-of-Memory due to memory "
622 "fragmentation. The constraint is soft, so it works with tensors "
623 "larger than the given constraint size. -1 corresponds to no "
624 "constraints."));
625 flag_objects->push_back(tensorflow::Flag(
626 "xla_gpu_force_compilation_parallelism",
627 int32_setter_for(
628 &DebugOptions::set_xla_gpu_force_compilation_parallelism),
629 flag_values->xla_gpu_force_compilation_parallelism(),
630 "Overrides normal multi-threaded compilation settting to use this many "
631 "threads. Setting to 0 (the default value) means no enforcement."));
632 flag_objects->push_back(tensorflow::Flag(
633 "xla_gpu_deterministic_ops",
634 bool_setter_for(&DebugOptions::set_xla_gpu_deterministic_ops),
635 flag_values->xla_gpu_deterministic_ops(),
636 "Guarantees run-to-run determinism on GPU."));
637 flag_objects->push_back(tensorflow::Flag(
638 "xla_gpu_enable_async_all_reduce",
639 bool_setter_for(&DebugOptions::set_xla_gpu_enable_async_all_reduce),
640 flag_values->xla_gpu_enable_async_all_reduce(),
641 "Converts synchronous all-reduce ops into asynchronous."));
642 flag_objects->push_back(tensorflow::Flag(
643 "xla_dump_disable_metadata",
644 bool_setter_for(&DebugOptions::set_xla_dump_disable_metadata),
645 flag_values->xla_dump_disable_metadata(),
646 "Disable dumping HLO metadata in HLO dumps."));
647 flag_objects->push_back(tensorflow::Flag(
648 "xla_dump_hlo_pipeline_re",
649 string_setter_for(&DebugOptions::set_xla_dump_hlo_pipeline_re),
650 flag_values->xla_dump_hlo_pipeline_re(),
651 "If specified, dumps HLO before and after optimization passes in the "
652 "pass pipelines that match this regular expression."));
653
654 ParseFlagsFromEnvAndDieIfUnknown("XLA_FLAGS", *flag_objects);
655 } // NOLINT(readability/fn_size)
656
AppendDebugOptionsFlags(std::vector<tensorflow::Flag> * flag_list)657 void AppendDebugOptionsFlags(std::vector<tensorflow::Flag>* flag_list) {
658 absl::call_once(flags_init, &AllocateFlags);
659 flag_list->insert(flag_list->end(), flag_objects->begin(),
660 flag_objects->end());
661 }
662
GetDebugOptionsFromFlags()663 xla::DebugOptions GetDebugOptionsFromFlags() {
664 absl::call_once(flags_init, &AllocateFlags);
665 return *flag_values;
666 }
667
ResetThreadLocalFuel()668 void ResetThreadLocalFuel() {
669 absl::call_once(flags_init, &AllocateFlags);
670
671 thread_fuel.reset(new absl::node_hash_map<string, std::atomic<int64>>());
672 CHECK(initial_fuel != nullptr);
673 for (const auto& kv : *initial_fuel) {
674 thread_fuel->emplace(kv.first, kv.second);
675 }
676 }
677
ConsumeFuel(absl::string_view pass,bool * just_ran_out)678 bool ConsumeFuel(absl::string_view pass, bool* just_ran_out) {
679 absl::call_once(flags_init, &AllocateFlags);
680 if (just_ran_out != nullptr) {
681 *just_ran_out = false;
682 }
683 auto* fuel_pool = thread_fuel ? thread_fuel.get() : global_fuel;
684 if (fuel_pool->empty()) {
685 return true;
686 }
687 auto it = fuel_pool->find(pass);
688 if (it == fuel_pool->end()) {
689 return true;
690 }
691 std::atomic<int64>& remaining_fuel = it->second;
692 std::atomic<bool>& fuel_has_been_consumed = fuel_ever_consumed->at(pass);
693 fuel_has_been_consumed = true;
694
695 int64_t remaining = remaining_fuel.fetch_sub(1);
696 if (just_ran_out != nullptr) {
697 *just_ran_out = remaining == 0;
698 }
699 return remaining > 0;
700 }
701
702 } // namespace xla
703