1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/jit/flags.h"
17
18 #include <mutex> // NOLINT
19 #include <vector>
20
21 #include "absl/base/call_once.h"
22 #include "absl/strings/numbers.h"
23 #include "absl/strings/str_split.h"
24 #include "absl/strings/strip.h"
25 #include "tensorflow/compiler/mlir/tensorflow/utils/dump_graph.h"
26 #include "tensorflow/compiler/xla/parse_flags_from_env.h"
27 #include "tensorflow/core/platform/macros.h"
28 #include "tensorflow/core/util/command_line_flags.h"
29
30 namespace tensorflow {
31 namespace {
32
33 BuildXlaOpsPassFlags* build_ops_flags;
34 MarkForCompilationPassFlags* mark_for_compilation_flags;
35 XlaDeviceFlags* device_flags;
36 XlaOpsCommonFlags* ops_flags;
37 IntroduceFloatingPointJitterPassFlags* jitter_flags;
38 MlirCommonFlags* mlir_flags;
39 JitRtFlags* jitrt_flags;
40 std::vector<Flag>* jitrt_flag_list;
41
42 std::vector<Flag>* flag_list;
43 absl::once_flag flags_init;
44
SetterForXlaAutoJitFlag(const string & value)45 bool SetterForXlaAutoJitFlag(const string& value) {
46 int32_t opt_level;
47 // We need to use the mark_for_compilation_flags directly here instead of
48 // going via GetMarkForCompilationPassFlags() to avoid infinite recursion. The
49 // latter will try to setup and parse flags, which would bring us back to this
50 // setter.
51 if (absl::SimpleAtoi(value, &opt_level)) {
52 mark_for_compilation_flags->xla_auto_jit_flag
53 .optimization_level_single_gpu = opt_level;
54 mark_for_compilation_flags->xla_auto_jit_flag.optimization_level_general =
55 opt_level;
56 return true;
57 }
58
59 if (value == "fusible") {
60 mark_for_compilation_flags->xla_auto_jit_flag
61 .optimization_level_single_gpu = 1;
62 mark_for_compilation_flags->xla_auto_jit_flag.optimization_level_general =
63 1;
64 mark_for_compilation_flags->tf_xla_ops_to_cluster = "FUSIBLE";
65 return true;
66 }
67
68 absl::string_view value_sv(value);
69 if (!absl::ConsumePrefix(&value_sv, "single-gpu(") ||
70 !absl::ConsumeSuffix(&value_sv, ")") ||
71 !absl::SimpleAtoi(value_sv, &opt_level)) {
72 return false;
73 }
74
75 mark_for_compilation_flags->xla_auto_jit_flag.optimization_level_single_gpu =
76 opt_level;
77 return true;
78 }
79
AppendMarkForCompilationPassFlagsInternal(std::vector<Flag> * flag_list)80 void AppendMarkForCompilationPassFlagsInternal(std::vector<Flag>* flag_list) {
81 std::vector<Flag> new_flags = {
82 Flag("tf_xla_auto_jit", SetterForXlaAutoJitFlag, "0",
83 "Control compilation of operators into XLA computations on CPU and "
84 "GPU devices. 0 = use ConfigProto setting; -1 = off; 1 = on for "
85 "things very likely to be improved; 2 = on for everything; "
86 "(experimental) fusible = only for Tensorflow operations that XLA "
87 "knows how to fuse. "
88 "If set to single-gpu(<N>) then this resolves to <N> for single-GPU "
89 "graphs (graphs that have at least one node placed on a GPU and no "
90 "more than one GPU is in use through the entire graph) and 0 "
91 "otherwise. Experimental."),
92 Flag("tf_xla_min_cluster_size",
93 &mark_for_compilation_flags->tf_xla_min_cluster_size,
94 "Minimum number of operators in an XLA compilation. Ignored for "
95 "operators placed on an XLA device or operators explicitly marked "
96 "for compilation."),
97 Flag("tf_xla_max_cluster_size",
98 &mark_for_compilation_flags->tf_xla_max_cluster_size,
99 "Maximum number of operators in an XLA compilation."),
100 Flag(
101 "tf_xla_ops_to_cluster",
102 &mark_for_compilation_flags->tf_xla_ops_to_cluster,
103 "(experimental) "
104 "Limit the operations clustered by XLA to these operations. "
105 "If multiple, separate them with commas. Shortcuts: "
106 " PW: All point-wise operations."
107 " RED: All reduction operations."
108 " MISC: Mixed operations."
109 " PWRED: TF operations that get converted to PW+RED operation in XLA."
110 " REDUCEWINDOW: TF operations like MaxPool/AvgPool that get "
111 "converted to ReduceWindow in XLA."
112 " REDUCEWINDOWPW: Operation that get converted to ReduceWindow + PW "
113 "(LRN, LRNGrad)."
114 " BN: TF FusedBatchNorm* operations."
115 " FUSIBLE: All TF operations that XLA can fuse (All the above). "
116 "You can also put any TF operation name, e.g. 'FUSIBLE,MatMul'."),
117 Flag("tf_xla_cluster_exclude_ops",
118 &mark_for_compilation_flags->tf_xla_cluster_exclude_ops,
119 "(experimental) "
120 "Exclude the operations from auto-clustering. "
121 "If multiple, separate them with commas."
122 " Where, Some_other_ops"),
123 Flag("tf_xla_clustering_debug",
124 &mark_for_compilation_flags->tf_xla_clustering_debug,
125 "Dump graphs during XLA compilation."),
126 Flag("tf_xla_cpu_global_jit",
127 &mark_for_compilation_flags->tf_xla_cpu_global_jit,
128 "Enables global JIT compilation for CPU via SessionOptions."),
129 Flag("tf_xla_clustering_fuel",
130 &mark_for_compilation_flags->tf_xla_clustering_fuel,
131 "Places an artificial limit on the number of ops marked as "
132 "eligible for clustering."),
133 Flag("tf_xla_disable_deadness_safety_checks_for_debugging",
134 &mark_for_compilation_flags
135 ->tf_xla_disable_deadness_safety_checks_for_debugging,
136 "Disable deadness related safety checks when clustering (this is "
137 "unsound)."),
138 Flag("tf_xla_disable_resource_variable_safety_checks_for_debugging",
139 &mark_for_compilation_flags
140 ->tf_xla_disable_resource_variable_safety_checks_for_debugging,
141 "Disable resource variables related safety checks when clustering "
142 "(this is unsound)."),
143 Flag("tf_xla_deterministic_cluster_names",
144 &mark_for_compilation_flags->tf_xla_deterministic_cluster_names,
145 "Causes the function names assigned by auto clustering to be "
146 "deterministic from run to run."),
147 Flag("tf_xla_persistent_cache_directory",
148 &mark_for_compilation_flags->tf_xla_persistent_cache_directory,
149 "If non-empty, JIT-compiled executables are saved to and loaded "
150 "from the specified file system directory path. Empty by default."),
151 Flag("tf_xla_disable_strict_signature_checks",
152 &mark_for_compilation_flags->tf_xla_disable_strict_signature_checks,
153 "If true, entires loaded into the XLA compile cache will not have "
154 "their signatures checked strictly. Defaults to false."),
155 Flag("tf_xla_persistent_cache_prefix",
156 &mark_for_compilation_flags->tf_xla_persistent_cache_prefix,
157 "Specifies the persistance cache prefix. Default is "
158 "\"xla_compile_cache\"")};
159 flag_list->insert(flag_list->end(), new_flags.begin(), new_flags.end());
160 }
161
AllocateAndParseJitRtFlags()162 void AllocateAndParseJitRtFlags() {
163 jitrt_flags = new JitRtFlags;
164 jitrt_flags->always_specialize = false;
165 jitrt_flags->cost_driven_async_parallel_for = false;
166 jitrt_flags->log_query_of_death = false;
167 jitrt_flags->vectorize = false;
168 jitrt_flags->enable_crash_reproducer = false;
169 jitrt_flag_list = new std::vector<Flag>({
170 Flag("always_specialize", &jitrt_flags->always_specialize, ""),
171 Flag("cost_driven_async_parallel_for",
172 &jitrt_flags->cost_driven_async_parallel_for, ""),
173 Flag("log_query_of_death", &jitrt_flags->log_query_of_death, ""),
174 Flag("vectorize", &jitrt_flags->vectorize, ""),
175 Flag("enable_crash_reproducer", &jitrt_flags->enable_crash_reproducer,
176 ""),
177 });
178 xla::ParseFlagsFromEnvAndDieIfUnknown("TF_JITRT_FLAGS", *jitrt_flag_list);
179 }
180
AllocateAndParseFlags()181 void AllocateAndParseFlags() {
182 build_ops_flags = new BuildXlaOpsPassFlags;
183 build_ops_flags->tf_xla_enable_lazy_compilation = true;
184 build_ops_flags->tf_xla_print_cluster_outputs = false;
185 build_ops_flags->tf_xla_check_cluster_input_numerics = false;
186 build_ops_flags->tf_xla_check_cluster_output_numerics = false;
187 build_ops_flags->tf_xla_disable_constant_folding = false;
188
189 mark_for_compilation_flags = new MarkForCompilationPassFlags;
190 mark_for_compilation_flags->xla_auto_jit_flag.optimization_level_single_gpu =
191 0;
192 mark_for_compilation_flags->xla_auto_jit_flag.optimization_level_general = 0;
193 mark_for_compilation_flags->tf_xla_min_cluster_size = 4;
194 mark_for_compilation_flags->tf_xla_max_cluster_size =
195 std::numeric_limits<int32>::max();
196 mark_for_compilation_flags->tf_xla_clustering_debug = false;
197 mark_for_compilation_flags->tf_xla_cpu_global_jit = false;
198 mark_for_compilation_flags->tf_xla_clustering_fuel =
199 std::numeric_limits<int64_t>::max();
200 mark_for_compilation_flags
201 ->tf_xla_disable_deadness_safety_checks_for_debugging = false;
202 mark_for_compilation_flags
203 ->tf_xla_disable_resource_variable_safety_checks_for_debugging = false;
204 mark_for_compilation_flags->tf_xla_deterministic_cluster_names = false;
205 mark_for_compilation_flags->tf_xla_persistent_cache_directory = "";
206 mark_for_compilation_flags->tf_xla_disable_strict_signature_checks = false;
207 mark_for_compilation_flags->tf_xla_persistent_cache_prefix =
208 "xla_compile_cache";
209
210 device_flags = new XlaDeviceFlags;
211 device_flags->tf_xla_compile_on_demand = false;
212 device_flags->tf_xla_enable_xla_devices = false;
213
214 ops_flags = new XlaOpsCommonFlags;
215 ops_flags->tf_xla_always_defer_compilation = false;
216 ops_flags->tf_xla_async_compilation = false;
217
218 jitter_flags = new IntroduceFloatingPointJitterPassFlags;
219 jitter_flags->jitter_amount = 1e-5;
220
221 // The `enable_mlir_bridge` flag allows the user to explicitly request that
222 // their program is (or isn't) compiled using the MLIR-based TF-to-XLA bridge.
223 //
224 // The `enable_mlir_bridge_is_explicit` variable tracks whether or not the
225 // user has made an explicit request. That is, if this variable is set to
226 // true, the program honors the user's request as per `enable_mlir_bridge`; if
227 // it's set to false, the default behavior is used (which may run either
228 // bridge, on a per-graph basis).
229 bool enable_mlir_bridge = false;
230 bool enable_mlir_bridge_is_explicit = false;
231 bool mlir_bridge_safe_mode = false;
232 bool enable_mlir_merge_control_flow_pass = true;
233 bool enable_mlir_convert_control_to_data_outputs_pass = false;
234 auto setter_for_jitter_tensor_names = [](string sequence) {
235 jitter_flags->tensor_names = absl::StrSplit(sequence, ',');
236 return true;
237 };
238 // Dump graphs in TFG dialect.
239 bool use_tfg_graph_dumper = false;
240
241 flag_list = new std::vector<Flag>(
242 {Flag("tf_xla_enable_lazy_compilation",
243 &build_ops_flags->tf_xla_enable_lazy_compilation, ""),
244 Flag("tf_xla_print_cluster_outputs",
245 &build_ops_flags->tf_xla_print_cluster_outputs,
246 "If true then insert Print nodes to print out values produced by "
247 "XLA clusters."),
248 Flag("tf_xla_check_cluster_input_numerics",
249 &build_ops_flags->tf_xla_check_cluster_input_numerics,
250 "If true then insert CheckNumerics nodes to check all cluster "
251 "inputs."),
252 Flag("tf_xla_check_cluster_output_numerics",
253 &build_ops_flags->tf_xla_check_cluster_output_numerics,
254 "If true then insert CheckNumerics nodes to check all cluster "
255 "outputs."),
256 Flag("tf_xla_disable_constant_folding",
257 &build_ops_flags->tf_xla_disable_constant_folding,
258 "If true then disables constant folding on TF graph before XLA "
259 "compilation."),
260
261 Flag("tf_xla_compile_on_demand", &device_flags->tf_xla_compile_on_demand,
262 "Switch a device into 'on-demand' mode, where instead of "
263 "autoclustering ops are compiled one by one just-in-time."),
264
265 Flag("tf_xla_enable_xla_devices",
266 &device_flags->tf_xla_enable_xla_devices,
267 "Generate XLA_* devices, where placing a computation on such a "
268 "device"
269 "forces compilation by XLA. Deprecated."),
270
271 Flag("tf_xla_always_defer_compilation",
272 &ops_flags->tf_xla_always_defer_compilation, ""),
273 Flag("tf_xla_async_compilation", &ops_flags->tf_xla_async_compilation,
274 "When lazy compilation is enabled, asynchronous compilation starts "
275 "the cluster compilation in the background, and the fallback path "
276 "is executed until the compilation has finished."),
277
278 Flag("tf_introduce_floating_point_jitter_to_tensors",
279 setter_for_jitter_tensor_names, "",
280 "The Tensors to add the jitter to. The tensors are named in the "
281 "TensorId format of <node name>:<output idx>."),
282 Flag("tf_introduce_floating_point_jitter_amount",
283 &jitter_flags->jitter_amount,
284 "The amount of jitter to introduce. This amount is added to each "
285 "element in the tensors named in `tensor_names."),
286
287 Flag("tf_mlir_enable_mlir_bridge", &enable_mlir_bridge,
288 "Enables experimental MLIR-Based TensorFlow Compiler Bridge.",
289 &enable_mlir_bridge_is_explicit),
290 Flag("tf_mlir_enable_merge_control_flow_pass",
291 &enable_mlir_merge_control_flow_pass,
292 "Enables MergeControlFlow pass for MLIR-Based TensorFlow Compiler "
293 "Bridge."),
294 Flag("tf_mlir_enable_convert_control_to_data_outputs_pass",
295 &enable_mlir_convert_control_to_data_outputs_pass,
296 "Enables `tf-executor-convert-control-to-data-outputs` pass for "
297 "MLIR-Based TensorFlow Compiler Bridge."),
298 Flag(
299 "tf_mlir_bridge_safe_mode", &mlir_bridge_safe_mode,
300 "When tf_mlir_enable_mlir_bridge is true, this field can enable "
301 "the MLIR bridge's safe mode. When the MLIR bridge is in safe mode, "
302 "it only runs for graphs that use features MLIR bridge currently "
303 "supports."),
304 Flag("tf_dump_graphs_in_tfg", &use_tfg_graph_dumper,
305 "When tf_dump_graphs_in_tfg is true, graphs after transformations "
306 "are dumped in MLIR TFG dialect and not in GraphDef")});
307
308 AppendMarkForCompilationPassFlagsInternal(flag_list);
309 xla::ParseFlagsFromEnvAndDieIfUnknown("TF_XLA_FLAGS", *flag_list);
310
311 mlir_flags = new MlirCommonFlags;
312 if (!enable_mlir_bridge_is_explicit) {
313 mlir_flags->tf_mlir_enable_mlir_bridge =
314 (mlir_bridge_safe_mode)
315 ? ConfigProto::Experimental::
316 MLIR_BRIDGE_ROLLOUT_SAFE_MODE_FALLBACK_ENABLED
317 : ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_UNSPECIFIED;
318 } else if (enable_mlir_bridge) {
319 mlir_flags->tf_mlir_enable_mlir_bridge =
320 (mlir_bridge_safe_mode)
321 ? ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_SAFE_MODE_ENABLED
322 : ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED;
323 } else {
324 mlir_flags->tf_mlir_enable_mlir_bridge =
325 ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_DISABLED;
326 }
327 mlir_flags->tf_mlir_enable_merge_control_flow_pass =
328 enable_mlir_merge_control_flow_pass;
329 mlir_flags->tf_mlir_enable_convert_control_to_data_outputs_pass =
330 enable_mlir_convert_control_to_data_outputs_pass;
331
332 if (use_tfg_graph_dumper) {
333 UseMlirForGraphDump(MlirDumpConfig{}.elide_large_attributes().emit_dialect(
334 MlirDumpConfig::Dialect::kTFG));
335 }
336
337 AllocateAndParseJitRtFlags();
338 }
339
ResetFlags()340 void ResetFlags() {
341 delete build_ops_flags;
342 delete mark_for_compilation_flags;
343 delete device_flags;
344 delete ops_flags;
345 delete jitter_flags;
346 delete mlir_flags;
347 delete flag_list;
348 delete jitrt_flags;
349 delete jitrt_flag_list;
350 AllocateAndParseFlags();
351 }
352
353 } // namespace
354
SetXlaAutoJitFlagFromFlagString(const string & value)355 bool SetXlaAutoJitFlagFromFlagString(const string& value) {
356 absl::call_once(flags_init, &AllocateAndParseFlags);
357 return SetterForXlaAutoJitFlag(value);
358 }
359
GetBuildXlaOpsPassFlags()360 BuildXlaOpsPassFlags* GetBuildXlaOpsPassFlags() {
361 absl::call_once(flags_init, &AllocateAndParseFlags);
362 return build_ops_flags;
363 }
364
GetMarkForCompilationPassFlags()365 MarkForCompilationPassFlags* GetMarkForCompilationPassFlags() {
366 absl::call_once(flags_init, &AllocateAndParseFlags);
367 return mark_for_compilation_flags;
368 }
369
GetXlaDeviceFlags()370 XlaDeviceFlags* GetXlaDeviceFlags() {
371 absl::call_once(flags_init, &AllocateAndParseFlags);
372 return device_flags;
373 }
374
GetXlaOpsCommonFlags()375 const XlaOpsCommonFlags& GetXlaOpsCommonFlags() {
376 absl::call_once(flags_init, &AllocateAndParseFlags);
377 return *ops_flags;
378 }
379
380 const IntroduceFloatingPointJitterPassFlags&
GetIntroduceFloatingPointJitterPassFlags()381 GetIntroduceFloatingPointJitterPassFlags() {
382 absl::call_once(flags_init, &AllocateAndParseFlags);
383 return *jitter_flags;
384 }
385
GetMlirCommonFlags()386 MlirCommonFlags* GetMlirCommonFlags() {
387 absl::call_once(flags_init, &AllocateAndParseFlags);
388 return mlir_flags;
389 }
390
ResetJitCompilerFlags()391 void ResetJitCompilerFlags() { ResetFlags(); }
392
GetJitRtFlags()393 const JitRtFlags& GetJitRtFlags() {
394 absl::call_once(flags_init, &AllocateAndParseFlags);
395 return *jitrt_flags;
396 }
397
GetMlirBridgeRolloutState(std::optional<const ConfigProto> config_proto)398 ConfigProto::Experimental::MlirBridgeRollout GetMlirBridgeRolloutState(
399 std::optional<const ConfigProto> config_proto) {
400 // TF1 graphs that do not override Sessions's ConfigProto and TF2 graphs
401 // can enable/disable the graph via tf_mlir_enable_mlir_bridge.
402 auto tf_mlir_enable_mlir_bridge =
403 GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge;
404 if (tf_mlir_enable_mlir_bridge !=
405 ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_UNSPECIFIED) {
406 return tf_mlir_enable_mlir_bridge;
407 }
408
409 // If a ConfigProto was not passed in, we can assume the caller is
410 // checking if TF2 graph should have the bridge enabled / disabled. In that
411 // case, we have already checked tf_mlir_enable_mlir_bridge so it is safe to
412 // return UNSPECIFIED here.
413 if (!config_proto.has_value()) {
414 return ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_UNSPECIFIED;
415 }
416
417 // TF1 graphs that do override Session's ConfigProto and set
418 // ConfigProto's enable_mlir_bridge or mlir_bridge_rollout fields will not
419 // update tf_mlir_enable_mlir_bridge so check their values.
420
421 // ConfigProto's enable_mlir_bridge defaults to false so only respect it
422 // when it is true.
423 if (config_proto.value().experimental().enable_mlir_bridge()) {
424 return ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED;
425 }
426 return config_proto.value().experimental().mlir_bridge_rollout();
427 }
428
AppendMarkForCompilationPassFlags(std::vector<Flag> * flag_list)429 void AppendMarkForCompilationPassFlags(std::vector<Flag>* flag_list) {
430 absl::call_once(flags_init, &AllocateAndParseFlags);
431 AppendMarkForCompilationPassFlagsInternal(flag_list);
432 }
433
434 static std::atomic<bool> xla_compilation_disabled(false);
435
DisableXlaCompilation()436 void DisableXlaCompilation() { xla_compilation_disabled = true; }
437
FailOnXlaCompilation()438 bool FailOnXlaCompilation() { return xla_compilation_disabled; }
439
440 } // namespace tensorflow
441