1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/debug_options_flags.h"
17
18 #include <mutex> // NOLINT(build/c++11): only using std::call_once, not mutex.
19 #include <vector>
20 #include "absl/strings/str_split.h"
21 #include "tensorflow/compiler/xla/debug_options_parsers.h"
22 #include "tensorflow/compiler/xla/parse_flags_from_env.h"
23
24 namespace xla {
25
DefaultDebugOptionsIgnoringFlags()26 DebugOptions DefaultDebugOptionsIgnoringFlags() {
27 DebugOptions opts;
28 opts.set_xla_llvm_enable_alias_scope_metadata(true);
29 opts.set_xla_llvm_enable_noalias_metadata(true);
30 opts.set_xla_llvm_enable_invariant_load_metadata(true);
31 opts.set_xla_llvm_disable_expensive_passes(false);
32 opts.set_xla_backend_optimization_level(3);
33 opts.set_xla_cpu_multi_thread_eigen(true);
34 opts.set_xla_gpu_cuda_data_dir("./cuda_sdk_lib");
35 opts.set_xla_eliminate_hlo_implicit_broadcast(true);
36 opts.set_xla_dump_hlo_as_html(false);
37 #ifdef INTEL_MKL
38 opts.set_xla_cpu_use_mkl_dnn(true);
39 #endif // INTEL_MKL
40 opts.set_xla_gpu_max_kernel_unroll_factor(4);
41 // Set cudnn batchnorm off by default; it does not provide a performance win
42 // on average.
43 opts.set_xla_gpu_use_cudnn_batchnorm(false);
44
45 // Run all GPU work on one stream by default. Using multiple streams
46 // increases memory usage and we lack strong motivating benchmarks for tuning
47 // the heuristics needed to decide when to run on multiple streams. See
48 // b/77879207.
49 opts.set_xla_gpu_disable_multi_streaming(true);
50
51 // TODO(jlebar): Disable fastmath once doing so is not a performance
52 // regression.
53 opts.set_xla_cpu_enable_fast_math(true);
54 opts.set_xla_gpu_enable_fast_min_max(true);
55
56 opts.set_xla_force_host_platform_device_count(1);
57 return opts;
58 }
59
60 static DebugOptions* flag_values;
61 static std::vector<tensorflow::Flag>* flag_objects;
62 static std::once_flag flags_init;
63
64 // Allocates flag_values and flag_objects; this function must not be called more
65 // than once - its call done via call_once.
AllocateFlags()66 static void AllocateFlags() {
67 flag_values = new DebugOptions(DefaultDebugOptionsIgnoringFlags());
68
69 // Returns a lambda that calls "member_setter" on "flag_values" with the
70 // argument passed in to the lambda.
71 auto bool_setter_for = [](void (DebugOptions::*member_setter)(bool)) {
72 return [member_setter](bool value) {
73 (flag_values->*member_setter)(value);
74 return true;
75 };
76 };
77
78 // Returns a lambda that calls "member_setter" on "flag_values" with the
79 // argument passed in to the lambda.
80 auto int32_setter_for = [](void (DebugOptions::*member_setter)(int32)) {
81 return [member_setter](int32 value) {
82 (flag_values->*member_setter)(value);
83 return true;
84 };
85 };
86
87 auto string_setter_for =
88 [](void (DebugOptions::*member_setter)(const string& value)) {
89 return [member_setter](const string& value) {
90 (flag_values->*member_setter)(value);
91 return true;
92 };
93 };
94
95 // Custom "sub-parser" lambda for xla_disable_hlo_passes.
96 auto setter_for_xla_disable_hlo_passes = [](string comma_separated_values) {
97 std::vector<string> disabled_passes =
98 absl::StrSplit(comma_separated_values, ',');
99 for (const auto& passname : disabled_passes) {
100 flag_values->add_xla_disable_hlo_passes(passname);
101 }
102 return true;
103 };
104
105 // Custom "sub-parser" lambda for xla_backend_extra_options.
106 auto setter_for_xla_backend_extra_options =
107 [](string comma_separated_values) {
108 auto* extra_options_map =
109 flag_values->mutable_xla_backend_extra_options();
110 parse_xla_backend_extra_options(extra_options_map,
111 comma_separated_values);
112 return true;
113 };
114
115 // Custom "sub-parser" lambda for xla_reduce_precision.
116 auto setter_for_xla_reduce_precision =
117 [](string reduce_precision_option_value) {
118 HloReducePrecisionOptions* option_proto =
119 flag_values->add_hlo_reduce_precision_options();
120 return parse_xla_reduce_precision_option(option_proto,
121 reduce_precision_option_value);
122 };
123
124 flag_objects = new std::vector<tensorflow::Flag>({
125 tensorflow::Flag(
126 "xla_cpu_enable_fast_math",
127 bool_setter_for(&DebugOptions::set_xla_cpu_enable_fast_math),
128 flag_values->xla_cpu_enable_fast_math(),
129 "Enable unsafe fast-math optimizations in the CPU compiler; "
130 "this may produce faster code at the expense of some accuracy."),
131 tensorflow::Flag(
132 "xla_cpu_fast_math_honor_nans",
133 bool_setter_for(&DebugOptions::set_xla_cpu_fast_math_honor_nans),
134 flag_values->xla_cpu_fast_math_honor_nans(),
135 "When xla_cpu_enable_fast_math is true then this controls whether we "
136 "allow operations to produce NaNs. Ignored when "
137 "xla_cpu_enable_fast_math is false."),
138 tensorflow::Flag(
139 "xla_cpu_fast_math_honor_infs",
140 bool_setter_for(&DebugOptions::set_xla_cpu_fast_math_honor_infs),
141 flag_values->xla_cpu_fast_math_honor_infs(),
142 "When xla_cpu_enable_fast_math is true then this controls whether we "
143 "allow operations to produce infinites. Ignored when "
144 "xla_cpu_enable_fast_math is false."),
145 tensorflow::Flag(
146 "xla_gpu_enable_fast_min_max",
147 bool_setter_for(&DebugOptions::set_xla_gpu_enable_fast_min_max),
148 flag_values->xla_gpu_enable_fast_min_max(),
149 "Enable fast floating point min/max lowering that does not propagate "
150 "NaNs."),
151 tensorflow::Flag(
152 "xla_llvm_enable_alias_scope_metadata",
153 bool_setter_for(
154 &DebugOptions::set_xla_llvm_enable_alias_scope_metadata),
155 flag_values->xla_llvm_enable_alias_scope_metadata(),
156 "In LLVM-based backends, enable the emission of "
157 "!alias.scope metadata in the generated IR."),
158 tensorflow::Flag(
159 "xla_llvm_enable_noalias_metadata",
160 bool_setter_for(&DebugOptions::set_xla_llvm_enable_noalias_metadata),
161 flag_values->xla_llvm_enable_noalias_metadata(),
162 "In LLVM-based backends, enable the emission of "
163 "!noalias metadata in the generated IR."),
164 tensorflow::Flag(
165 "xla_llvm_enable_invariant_load_metadata",
166 bool_setter_for(
167 &DebugOptions::set_xla_llvm_enable_invariant_load_metadata),
168 flag_values->xla_llvm_enable_invariant_load_metadata(),
169 "In LLVM-based backends, enable the emission of "
170 "!invariant.load metadata in "
171 "the generated IR."),
172 tensorflow::Flag(
173 "xla_llvm_disable_expensive_passes",
174 bool_setter_for(&DebugOptions::set_xla_llvm_disable_expensive_passes),
175 flag_values->xla_llvm_disable_expensive_passes(),
176 "In LLVM-based backends, disable a custom set of "
177 "expensive optimization passes."),
178 tensorflow::Flag(
179 "xla_backend_optimization_level",
180 int32_setter_for(&DebugOptions::set_xla_backend_optimization_level),
181 flag_values->xla_backend_optimization_level(),
182 "Numerical optimization level for the XLA compiler backend."),
183 tensorflow::Flag(
184 "xla_disable_hlo_passes", setter_for_xla_disable_hlo_passes, "",
185 "Comma-separated list of hlo passes to be disabled. These names "
186 "must exactly match the passes' names; no whitespace around "
187 "commas."),
188 tensorflow::Flag(
189 "xla_disable_all_hlo_passes",
190 bool_setter_for(&DebugOptions::set_xla_disable_all_hlo_passes), false,
191 "Disables all HLO passes. Notes that some passes are necessary for "
192 "correctness and the invariants that must be satisfied by 'fully "
193 "optimized' HLO are different for different devices and may change "
194 "over time. The only 'guarantee', such as it is, is that if you "
195 "compile XLA and dump the optimized HLO for some graph, you should "
196 "be able to run it again on the same device with the same build of "
197 "XLA."),
198 tensorflow::Flag(
199 "xla_embed_ir_in_executable",
200 bool_setter_for(&DebugOptions::set_xla_embed_ir_in_executable),
201 flag_values->xla_embed_ir_in_executable(),
202 "Embed the compiler IR as a string in the executable."),
203 tensorflow::Flag(
204 "xla_eliminate_hlo_implicit_broadcast",
205 bool_setter_for(
206 &DebugOptions::set_xla_eliminate_hlo_implicit_broadcast),
207 flag_values->xla_eliminate_hlo_implicit_broadcast(),
208 "Eliminate implicit broadcasts when lowering user "
209 "computations to HLO instructions; use explicit "
210 "broadcast instead."),
211 tensorflow::Flag(
212 "xla_cpu_multi_thread_eigen",
213 bool_setter_for(&DebugOptions::set_xla_cpu_multi_thread_eigen),
214 flag_values->xla_cpu_multi_thread_eigen(),
215 "When generating calls to Eigen in the CPU backend, "
216 "use multi-threaded Eigen mode."),
217 tensorflow::Flag("xla_gpu_cuda_data_dir",
218 flag_values->mutable_xla_gpu_cuda_data_dir(),
219 "If non-empty, speficies a local directory containing "
220 "ptxas and nvvm libdevice files; otherwise we use "
221 "those from runfile directories."),
222 tensorflow::Flag("xla_gpu_ftz",
223 bool_setter_for(&DebugOptions::set_xla_gpu_ftz),
224 flag_values->xla_gpu_ftz(),
225 "If true, flush-to-zero semantics are enabled in the "
226 "code generated for GPUs."),
227 tensorflow::Flag(
228 "xla_gpu_disable_multi_streaming",
229 bool_setter_for(&DebugOptions::set_xla_gpu_disable_multi_streaming),
230 flag_values->xla_gpu_disable_multi_streaming(),
231 "If true, multi-streaming in the GPU backend is disabled."),
232 tensorflow::Flag(
233 "xla_gpu_max_kernel_unroll_factor",
234 int32_setter_for(&DebugOptions::set_xla_gpu_max_kernel_unroll_factor),
235 flag_values->xla_gpu_max_kernel_unroll_factor(),
236 "Specify the maximum kernel unroll factor for the GPU backend."),
237 tensorflow::Flag(
238 "xla_test_all_output_layouts",
239 bool_setter_for(&DebugOptions::set_xla_test_all_output_layouts),
240 flag_values->xla_test_all_output_layouts(),
241 "Let ClientLibraryTestBase::ComputeAndCompare* test "
242 "all permutations of output layouts. For example, with "
243 "a 3D shape, all permutations of the set {0, 1, 2} are "
244 "tried."),
245 tensorflow::Flag(
246 "xla_test_all_input_layouts",
247 bool_setter_for(&DebugOptions::set_xla_test_all_input_layouts),
248 flag_values->xla_test_all_input_layouts(),
249 "Let ClientLibraryTestBase::ComputeAndCompare* test "
250 "all permutations of *input* layouts. For example, for "
251 "2 input arguments with 2D shape and 4D shape, the "
252 "computation will run 2! * 4! times for every possible "
253 "layouts"),
254 tensorflow::Flag(
255 "xla_hlo_profile",
256 bool_setter_for(&DebugOptions::set_xla_hlo_profile),
257 flag_values->xla_hlo_profile(),
258 "Instrument the computation to collect per-HLO cycle counts"),
259 tensorflow::Flag("xla_backend_extra_options",
260 setter_for_xla_backend_extra_options, "",
261 "Extra options to pass to a backend; "
262 "comma-separated list of 'key=val' strings (=val "
263 "may be omitted); no whitespace around commas."),
264 tensorflow::Flag("xla_reduce_precision", setter_for_xla_reduce_precision,
265 "",
266 "Directions for adding reduce-precision operations. "
267 "Format is 'LOCATION=E,M:OPS;NAMES' where LOCATION is "
268 "the class of locations in which to insert the "
269 "operations (e.g., 'OP_OUTPUTS'), E and M are the "
270 "exponent and matissa bit counts respectively, and "
271 "OPS and NAMES are comma-separated (no spaces) lists "
272 "of the operation types and names to which to attach "
273 "the reduce-precision operations. The NAMES string "
274 "and its preceding ';' may be omitted. This option "
275 "may be repeated to define multiple sets of added "
276 "reduce-precision operations."),
277 tensorflow::Flag(
278 "xla_gpu_use_cudnn_batchnorm",
279 bool_setter_for(&DebugOptions::set_xla_gpu_use_cudnn_batchnorm),
280 flag_values->xla_gpu_use_cudnn_batchnorm(),
281 "Allows the GPU backend to implement batchnorm HLOs using cudnn, "
282 "rather than expanding them to a soup of HLOs."),
283 tensorflow::Flag("xla_cpu_use_mkl_dnn",
284 bool_setter_for(&DebugOptions::set_xla_cpu_use_mkl_dnn),
285 flag_values->xla_cpu_use_mkl_dnn(),
286 "Generate calls to MKL-DNN in the CPU backend."),
287 tensorflow::Flag(
288 "xla_gpu_crash_on_verification_failures",
289 bool_setter_for(
290 &DebugOptions::set_xla_gpu_crash_on_verification_failures),
291 flag_values->xla_gpu_crash_on_verification_failures(),
292 "Crashes the program on extra verification failures, e.g. cuDNN "
293 "cross checking failures"),
294 tensorflow::Flag(
295 "xla_force_host_platform_device_count",
296 int32_setter_for(
297 &DebugOptions::set_xla_force_host_platform_device_count),
298 flag_values->xla_force_host_platform_device_count(),
299 "Force the host platform to pretend that there are these many "
300 "host \"devices\". All of these host devices are backed by the same"
301 "threadpool. Setting this to anything other than 1 can increase "
302 "overhead from context switching but we let the user override this "
303 "behavior to help run tests on the host that run models in parallel "
304 "across multiple devices."),
305 tensorflow::Flag(
306 "xla_gpu_disable_ptxas_optimizations",
307 bool_setter_for(
308 &DebugOptions::set_xla_gpu_disable_ptxas_optimizations),
309 flag_values->xla_gpu_disable_ptxas_optimizations(),
310 "In XLA:GPU run ptxas in -O0 (default is -O3)."),
311
312 tensorflow::Flag(
313 "xla_dump_to", string_setter_for(&DebugOptions::set_xla_dump_to),
314 flag_values->xla_dump_to(),
315 "Directory into which debugging data is written. If not specified "
316 "but another dumping flag is passed, data will be written to stdout. "
317 " To explicitly write to stdout, set this to \"-\". The values "
318 "\"sponge\" and \"test_undeclared_outputs_dir\" have a special "
319 "meaning: They cause us to dump into the directory specified by the "
320 "environment variable TEST_UNDECLARED_OUTPUTS_DIR."),
321 tensorflow::Flag(
322 "xla_dump_hlo_as_text",
323 bool_setter_for(&DebugOptions::set_xla_dump_hlo_as_text),
324 flag_values->xla_dump_hlo_as_text(),
325 "Dumps HLO modules as text before and after optimizations. Results "
326 "are written to the --xla_dump_to dir, or, if no dir is specified, "
327 "to stdout."),
328 tensorflow::Flag(
329 "xla_dump_hlo_as_proto",
330 bool_setter_for(&DebugOptions::set_xla_dump_hlo_as_proto),
331 flag_values->xla_dump_hlo_as_proto(),
332 "Dumps HLO modules as HloProtos to the directory specified by "
333 "--xla_dump_to."),
334 tensorflow::Flag(
335 "xla_dump_hlo_as_dot",
336 bool_setter_for(&DebugOptions::set_xla_dump_hlo_as_dot),
337 flag_values->xla_dump_hlo_as_dot(),
338 "Dumps HLO modules rendered as dot files to the directory "
339 "specified by --xla_dump_to."),
340 tensorflow::Flag("xla_dump_hlo_as_html",
341 bool_setter_for(&DebugOptions::set_xla_dump_hlo_as_html),
342 flag_values->xla_dump_hlo_as_html(),
343 "Dumps HLO modules rendered as HTML files to the "
344 "directory specified by --xla_dump_to."),
345 tensorflow::Flag(
346 "xla_dump_hlo_as_url",
347 bool_setter_for(&DebugOptions::set_xla_dump_hlo_as_url),
348 flag_values->xla_dump_hlo_as_url(),
349 "Tries to dump HLO modules rendered as URLs to stdout (and also to "
350 "the directory specified by --xla_dump_to). This is not implemented "
351 "by default; you need to add a plugin which calls "
352 "RegisterGraphToURLRenderer()."),
353 tensorflow::Flag(
354 "xla_dump_hlo_snapshots",
355 bool_setter_for(&DebugOptions::set_xla_dump_hlo_snapshots),
356 flag_values->xla_dump_hlo_snapshots(),
357 "Every time an HLO module is run, dumps an HloSnapshot to the "
358 "directory specified by --xla_dump_to."),
359 tensorflow::Flag(
360 "xla_dump_hlo_module_re",
361 string_setter_for(&DebugOptions::set_xla_dump_hlo_module_re),
362 flag_values->xla_dump_hlo_module_re(),
363 "Limits dumping only to modules which match this regular expression. "
364 " Default is to dump all modules."),
365 tensorflow::Flag(
366 "xla_dump_hlo_pass_re",
367 string_setter_for(&DebugOptions::set_xla_dump_hlo_pass_re),
368 flag_values->xla_dump_hlo_pass_re(),
369 "If specified, dumps HLO before and after optimization passes which "
370 "match this regular expression, in addition to dumping at the very "
371 "beginning and end of compilation."),
372 tensorflow::Flag(
373 "xla_hlo_graph_addresses",
374 bool_setter_for(&DebugOptions::set_xla_hlo_graph_addresses),
375 flag_values->xla_hlo_graph_addresses(),
376 "When rendering graphs (--xla_dump_hlo_as_{dot,html,url}), displays "
377 "the address in memory of each HloInstruction object."),
378 tensorflow::Flag(
379 "xla_hlo_graph_sharding_color",
380 bool_setter_for(&DebugOptions::set_xla_hlo_graph_sharding_color),
381 flag_values->xla_hlo_graph_sharding_color(),
382 "Assign colors based on sharding assignments when generating the "
383 "HLO graphs."),
384 });
385 ParseFlagsFromEnvAndDieIfUnknown("XLA_FLAGS", *flag_objects);
386 }
387
AppendDebugOptionsFlags(std::vector<tensorflow::Flag> * flag_list)388 void AppendDebugOptionsFlags(std::vector<tensorflow::Flag>* flag_list) {
389 std::call_once(flags_init, &AllocateFlags);
390 flag_list->insert(flag_list->end(), flag_objects->begin(),
391 flag_objects->end());
392 }
393
GetDebugOptionsFromFlags()394 xla::DebugOptions GetDebugOptionsFromFlags() {
395 std::call_once(flags_init, &AllocateFlags);
396 return *flag_values;
397 }
398
399 } // namespace xla
400