• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/jit/flags.h"
17 
18 #include <mutex>  // NOLINT
19 #include <vector>
20 
21 #include "absl/base/call_once.h"
22 #include "absl/strings/numbers.h"
23 #include "absl/strings/str_split.h"
24 #include "absl/strings/strip.h"
25 #include "tensorflow/compiler/mlir/tensorflow/utils/dump_graph.h"
26 #include "tensorflow/compiler/xla/parse_flags_from_env.h"
27 #include "tensorflow/core/platform/macros.h"
28 #include "tensorflow/core/util/command_line_flags.h"
29 
30 namespace tensorflow {
31 namespace {
32 
33 BuildXlaOpsPassFlags* build_ops_flags;
34 MarkForCompilationPassFlags* mark_for_compilation_flags;
35 XlaDeviceFlags* device_flags;
36 XlaOpsCommonFlags* ops_flags;
37 IntroduceFloatingPointJitterPassFlags* jitter_flags;
38 MlirCommonFlags* mlir_flags;
39 JitRtFlags* jitrt_flags;
40 std::vector<Flag>* jitrt_flag_list;
41 
42 std::vector<Flag>* flag_list;
43 absl::once_flag flags_init;
44 
SetterForXlaAutoJitFlag(const string & value)45 bool SetterForXlaAutoJitFlag(const string& value) {
46   int32_t opt_level;
47   // We need to use the mark_for_compilation_flags directly here instead of
48   // going via GetMarkForCompilationPassFlags() to avoid infinite recursion. The
49   // latter will try to setup and parse flags, which would bring us back to this
50   // setter.
51   if (absl::SimpleAtoi(value, &opt_level)) {
52     mark_for_compilation_flags->xla_auto_jit_flag
53         .optimization_level_single_gpu = opt_level;
54     mark_for_compilation_flags->xla_auto_jit_flag.optimization_level_general =
55         opt_level;
56     return true;
57   }
58 
59   if (value == "fusible") {
60     mark_for_compilation_flags->xla_auto_jit_flag
61         .optimization_level_single_gpu = 1;
62     mark_for_compilation_flags->xla_auto_jit_flag.optimization_level_general =
63         1;
64     mark_for_compilation_flags->tf_xla_ops_to_cluster = "FUSIBLE";
65     return true;
66   }
67 
68   absl::string_view value_sv(value);
69   if (!absl::ConsumePrefix(&value_sv, "single-gpu(") ||
70       !absl::ConsumeSuffix(&value_sv, ")") ||
71       !absl::SimpleAtoi(value_sv, &opt_level)) {
72     return false;
73   }
74 
75   mark_for_compilation_flags->xla_auto_jit_flag.optimization_level_single_gpu =
76       opt_level;
77   return true;
78 }
79 
AppendMarkForCompilationPassFlagsInternal(std::vector<Flag> * flag_list)80 void AppendMarkForCompilationPassFlagsInternal(std::vector<Flag>* flag_list) {
81   std::vector<Flag> new_flags = {
82       Flag("tf_xla_auto_jit", SetterForXlaAutoJitFlag, "0",
83            "Control compilation of operators into XLA computations on CPU and "
84            "GPU devices.  0 = use ConfigProto setting; -1 = off; 1 = on for "
85            "things very likely to be improved; 2 = on for everything; "
86            "(experimental) fusible = only for Tensorflow operations that XLA "
87            "knows how to fuse.  "
88            "If set to single-gpu(<N>) then this resolves to <N> for single-GPU "
89            "graphs (graphs that have at least one node placed on a GPU and no "
90            "more than one GPU is in use through the entire graph) and 0 "
91            "otherwise.  Experimental."),
92       Flag("tf_xla_min_cluster_size",
93            &mark_for_compilation_flags->tf_xla_min_cluster_size,
94            "Minimum number of operators in an XLA compilation. Ignored for "
95            "operators placed on an XLA device or operators explicitly marked "
96            "for compilation."),
97       Flag("tf_xla_max_cluster_size",
98            &mark_for_compilation_flags->tf_xla_max_cluster_size,
99            "Maximum number of operators in an XLA compilation."),
100       Flag(
101           "tf_xla_ops_to_cluster",
102           &mark_for_compilation_flags->tf_xla_ops_to_cluster,
103           "(experimental) "
104           "Limit the operations clustered by XLA to these operations. "
105           "If multiple, separate them with commas. Shortcuts: "
106           " PW: All point-wise operations."
107           " RED: All reduction operations."
108           " MISC: Mixed operations."
109           " PWRED: TF operations that get converted to PW+RED operation in XLA."
110           " REDUCEWINDOW: TF operations like MaxPool/AvgPool that get "
111           "converted to ReduceWindow in XLA."
112           " REDUCEWINDOWPW: Operation that get converted to ReduceWindow + PW "
113           "(LRN, LRNGrad)."
114           " BN: TF FusedBatchNorm* operations."
115           " FUSIBLE: All TF operations that XLA can fuse (All the above). "
116           "You can also put any TF operation name, e.g. 'FUSIBLE,MatMul'."),
117       Flag("tf_xla_cluster_exclude_ops",
118            &mark_for_compilation_flags->tf_xla_cluster_exclude_ops,
119            "(experimental) "
120            "Exclude the operations from auto-clustering. "
121            "If multiple, separate them with commas."
122            " Where, Some_other_ops"),
123       Flag("tf_xla_clustering_debug",
124            &mark_for_compilation_flags->tf_xla_clustering_debug,
125            "Dump graphs during XLA compilation."),
126       Flag("tf_xla_cpu_global_jit",
127            &mark_for_compilation_flags->tf_xla_cpu_global_jit,
128            "Enables global JIT compilation for CPU via SessionOptions."),
129       Flag("tf_xla_clustering_fuel",
130            &mark_for_compilation_flags->tf_xla_clustering_fuel,
131            "Places an artificial limit on the number of ops marked as "
132            "eligible for clustering."),
133       Flag("tf_xla_disable_deadness_safety_checks_for_debugging",
134            &mark_for_compilation_flags
135                 ->tf_xla_disable_deadness_safety_checks_for_debugging,
136            "Disable deadness related safety checks when clustering (this is "
137            "unsound)."),
138       Flag("tf_xla_disable_resource_variable_safety_checks_for_debugging",
139            &mark_for_compilation_flags
140                 ->tf_xla_disable_resource_variable_safety_checks_for_debugging,
141            "Disable resource variables related safety checks when clustering "
142            "(this is unsound)."),
143       Flag("tf_xla_deterministic_cluster_names",
144            &mark_for_compilation_flags->tf_xla_deterministic_cluster_names,
145            "Causes the function names assigned by auto clustering to be "
146            "deterministic from run to run."),
147       Flag("tf_xla_persistent_cache_directory",
148            &mark_for_compilation_flags->tf_xla_persistent_cache_directory,
149            "If non-empty, JIT-compiled executables are saved to and loaded "
150            "from the specified file system directory path. Empty by default."),
151       Flag("tf_xla_disable_strict_signature_checks",
152            &mark_for_compilation_flags->tf_xla_disable_strict_signature_checks,
153            "If true, entires loaded into the XLA compile cache will not have "
154            "their signatures checked strictly. Defaults to false."),
155       Flag("tf_xla_persistent_cache_prefix",
156            &mark_for_compilation_flags->tf_xla_persistent_cache_prefix,
157            "Specifies the persistance cache prefix. Default is "
158            "\"xla_compile_cache\"")};
159   flag_list->insert(flag_list->end(), new_flags.begin(), new_flags.end());
160 }
161 
AllocateAndParseJitRtFlags()162 void AllocateAndParseJitRtFlags() {
163   jitrt_flags = new JitRtFlags;
164   jitrt_flags->always_specialize = false;
165   jitrt_flags->cost_driven_async_parallel_for = false;
166   jitrt_flags->log_query_of_death = false;
167   jitrt_flags->vectorize = false;
168   jitrt_flags->enable_crash_reproducer = false;
169   jitrt_flag_list = new std::vector<Flag>({
170       Flag("always_specialize", &jitrt_flags->always_specialize, ""),
171       Flag("cost_driven_async_parallel_for",
172            &jitrt_flags->cost_driven_async_parallel_for, ""),
173       Flag("log_query_of_death", &jitrt_flags->log_query_of_death, ""),
174       Flag("vectorize", &jitrt_flags->vectorize, ""),
175       Flag("enable_crash_reproducer", &jitrt_flags->enable_crash_reproducer,
176            ""),
177   });
178   xla::ParseFlagsFromEnvAndDieIfUnknown("TF_JITRT_FLAGS", *jitrt_flag_list);
179 }
180 
AllocateAndParseFlags()181 void AllocateAndParseFlags() {
182   build_ops_flags = new BuildXlaOpsPassFlags;
183   build_ops_flags->tf_xla_enable_lazy_compilation = true;
184   build_ops_flags->tf_xla_print_cluster_outputs = false;
185   build_ops_flags->tf_xla_check_cluster_input_numerics = false;
186   build_ops_flags->tf_xla_check_cluster_output_numerics = false;
187   build_ops_flags->tf_xla_disable_constant_folding = false;
188 
189   mark_for_compilation_flags = new MarkForCompilationPassFlags;
190   mark_for_compilation_flags->xla_auto_jit_flag.optimization_level_single_gpu =
191       0;
192   mark_for_compilation_flags->xla_auto_jit_flag.optimization_level_general = 0;
193   mark_for_compilation_flags->tf_xla_min_cluster_size = 4;
194   mark_for_compilation_flags->tf_xla_max_cluster_size =
195       std::numeric_limits<int32>::max();
196   mark_for_compilation_flags->tf_xla_clustering_debug = false;
197   mark_for_compilation_flags->tf_xla_cpu_global_jit = false;
198   mark_for_compilation_flags->tf_xla_clustering_fuel =
199       std::numeric_limits<int64_t>::max();
200   mark_for_compilation_flags
201       ->tf_xla_disable_deadness_safety_checks_for_debugging = false;
202   mark_for_compilation_flags
203       ->tf_xla_disable_resource_variable_safety_checks_for_debugging = false;
204   mark_for_compilation_flags->tf_xla_deterministic_cluster_names = false;
205   mark_for_compilation_flags->tf_xla_persistent_cache_directory = "";
206   mark_for_compilation_flags->tf_xla_disable_strict_signature_checks = false;
207   mark_for_compilation_flags->tf_xla_persistent_cache_prefix =
208       "xla_compile_cache";
209 
210   device_flags = new XlaDeviceFlags;
211   device_flags->tf_xla_compile_on_demand = false;
212   device_flags->tf_xla_enable_xla_devices = false;
213 
214   ops_flags = new XlaOpsCommonFlags;
215   ops_flags->tf_xla_always_defer_compilation = false;
216   ops_flags->tf_xla_async_compilation = false;
217 
218   jitter_flags = new IntroduceFloatingPointJitterPassFlags;
219   jitter_flags->jitter_amount = 1e-5;
220 
221   // The `enable_mlir_bridge` flag allows the user to explicitly request that
222   // their program is (or isn't) compiled using the MLIR-based TF-to-XLA bridge.
223   //
224   // The `enable_mlir_bridge_is_explicit` variable tracks whether or not the
225   // user has made an explicit request. That is, if this variable is set to
226   // true, the program honors the user's request as per `enable_mlir_bridge`; if
227   // it's set to false, the default behavior is used (which may run either
228   // bridge, on a per-graph basis).
229   bool enable_mlir_bridge = false;
230   bool enable_mlir_bridge_is_explicit = false;
231   bool mlir_bridge_safe_mode = false;
232   bool enable_mlir_merge_control_flow_pass = true;
233   bool enable_mlir_convert_control_to_data_outputs_pass = false;
234   auto setter_for_jitter_tensor_names = [](string sequence) {
235     jitter_flags->tensor_names = absl::StrSplit(sequence, ',');
236     return true;
237   };
238   // Dump graphs in TFG dialect.
239   bool use_tfg_graph_dumper = false;
240 
241   flag_list = new std::vector<Flag>(
242       {Flag("tf_xla_enable_lazy_compilation",
243             &build_ops_flags->tf_xla_enable_lazy_compilation, ""),
244        Flag("tf_xla_print_cluster_outputs",
245             &build_ops_flags->tf_xla_print_cluster_outputs,
246             "If true then insert Print nodes to print out values produced by "
247             "XLA clusters."),
248        Flag("tf_xla_check_cluster_input_numerics",
249             &build_ops_flags->tf_xla_check_cluster_input_numerics,
250             "If true then insert CheckNumerics nodes to check all cluster "
251             "inputs."),
252        Flag("tf_xla_check_cluster_output_numerics",
253             &build_ops_flags->tf_xla_check_cluster_output_numerics,
254             "If true then insert CheckNumerics nodes to check all cluster "
255             "outputs."),
256        Flag("tf_xla_disable_constant_folding",
257             &build_ops_flags->tf_xla_disable_constant_folding,
258             "If true then disables constant folding on TF graph before XLA "
259             "compilation."),
260 
261        Flag("tf_xla_compile_on_demand", &device_flags->tf_xla_compile_on_demand,
262             "Switch a device into 'on-demand' mode, where instead of "
263             "autoclustering ops are compiled one by one just-in-time."),
264 
265        Flag("tf_xla_enable_xla_devices",
266             &device_flags->tf_xla_enable_xla_devices,
267             "Generate XLA_* devices, where placing a computation on such a "
268             "device"
269             "forces compilation by XLA. Deprecated."),
270 
271        Flag("tf_xla_always_defer_compilation",
272             &ops_flags->tf_xla_always_defer_compilation, ""),
273        Flag("tf_xla_async_compilation", &ops_flags->tf_xla_async_compilation,
274             "When lazy compilation is enabled, asynchronous compilation starts "
275             "the cluster compilation in the background, and the fallback path "
276             "is executed until the compilation has finished."),
277 
278        Flag("tf_introduce_floating_point_jitter_to_tensors",
279             setter_for_jitter_tensor_names, "",
280             "The Tensors to add the jitter to.  The tensors are named in the "
281             "TensorId format of <node name>:<output idx>."),
282        Flag("tf_introduce_floating_point_jitter_amount",
283             &jitter_flags->jitter_amount,
284             "The amount of jitter to introduce.  This amount is added to each "
285             "element in the tensors named in `tensor_names."),
286 
287        Flag("tf_mlir_enable_mlir_bridge", &enable_mlir_bridge,
288             "Enables experimental MLIR-Based TensorFlow Compiler Bridge.",
289             &enable_mlir_bridge_is_explicit),
290        Flag("tf_mlir_enable_merge_control_flow_pass",
291             &enable_mlir_merge_control_flow_pass,
292             "Enables MergeControlFlow pass for MLIR-Based TensorFlow Compiler "
293             "Bridge."),
294        Flag("tf_mlir_enable_convert_control_to_data_outputs_pass",
295             &enable_mlir_convert_control_to_data_outputs_pass,
296             "Enables `tf-executor-convert-control-to-data-outputs` pass for "
297             "MLIR-Based TensorFlow Compiler Bridge."),
298        Flag(
299            "tf_mlir_bridge_safe_mode", &mlir_bridge_safe_mode,
300            "When tf_mlir_enable_mlir_bridge is true, this field can enable "
301            "the MLIR bridge's safe mode. When the MLIR bridge is in safe mode, "
302            "it only runs for graphs that use features MLIR bridge currently "
303            "supports."),
304        Flag("tf_dump_graphs_in_tfg", &use_tfg_graph_dumper,
305             "When tf_dump_graphs_in_tfg is true, graphs after transformations "
306             "are dumped in MLIR TFG dialect and not in GraphDef")});
307 
308   AppendMarkForCompilationPassFlagsInternal(flag_list);
309   xla::ParseFlagsFromEnvAndDieIfUnknown("TF_XLA_FLAGS", *flag_list);
310 
311   mlir_flags = new MlirCommonFlags;
312   if (!enable_mlir_bridge_is_explicit) {
313     mlir_flags->tf_mlir_enable_mlir_bridge =
314         (mlir_bridge_safe_mode)
315             ? ConfigProto::Experimental::
316                   MLIR_BRIDGE_ROLLOUT_SAFE_MODE_FALLBACK_ENABLED
317             : ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_UNSPECIFIED;
318   } else if (enable_mlir_bridge) {
319     mlir_flags->tf_mlir_enable_mlir_bridge =
320         (mlir_bridge_safe_mode)
321             ? ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_SAFE_MODE_ENABLED
322             : ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED;
323   } else {
324     mlir_flags->tf_mlir_enable_mlir_bridge =
325         ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_DISABLED;
326   }
327   mlir_flags->tf_mlir_enable_merge_control_flow_pass =
328       enable_mlir_merge_control_flow_pass;
329   mlir_flags->tf_mlir_enable_convert_control_to_data_outputs_pass =
330       enable_mlir_convert_control_to_data_outputs_pass;
331 
332   if (use_tfg_graph_dumper) {
333     UseMlirForGraphDump(MlirDumpConfig{}.elide_large_attributes().emit_dialect(
334         MlirDumpConfig::Dialect::kTFG));
335   }
336 
337   AllocateAndParseJitRtFlags();
338 }
339 
ResetFlags()340 void ResetFlags() {
341   delete build_ops_flags;
342   delete mark_for_compilation_flags;
343   delete device_flags;
344   delete ops_flags;
345   delete jitter_flags;
346   delete mlir_flags;
347   delete flag_list;
348   delete jitrt_flags;
349   delete jitrt_flag_list;
350   AllocateAndParseFlags();
351 }
352 
353 }  // namespace
354 
SetXlaAutoJitFlagFromFlagString(const string & value)355 bool SetXlaAutoJitFlagFromFlagString(const string& value) {
356   absl::call_once(flags_init, &AllocateAndParseFlags);
357   return SetterForXlaAutoJitFlag(value);
358 }
359 
GetBuildXlaOpsPassFlags()360 BuildXlaOpsPassFlags* GetBuildXlaOpsPassFlags() {
361   absl::call_once(flags_init, &AllocateAndParseFlags);
362   return build_ops_flags;
363 }
364 
GetMarkForCompilationPassFlags()365 MarkForCompilationPassFlags* GetMarkForCompilationPassFlags() {
366   absl::call_once(flags_init, &AllocateAndParseFlags);
367   return mark_for_compilation_flags;
368 }
369 
GetXlaDeviceFlags()370 XlaDeviceFlags* GetXlaDeviceFlags() {
371   absl::call_once(flags_init, &AllocateAndParseFlags);
372   return device_flags;
373 }
374 
GetXlaOpsCommonFlags()375 const XlaOpsCommonFlags& GetXlaOpsCommonFlags() {
376   absl::call_once(flags_init, &AllocateAndParseFlags);
377   return *ops_flags;
378 }
379 
380 const IntroduceFloatingPointJitterPassFlags&
GetIntroduceFloatingPointJitterPassFlags()381 GetIntroduceFloatingPointJitterPassFlags() {
382   absl::call_once(flags_init, &AllocateAndParseFlags);
383   return *jitter_flags;
384 }
385 
GetMlirCommonFlags()386 MlirCommonFlags* GetMlirCommonFlags() {
387   absl::call_once(flags_init, &AllocateAndParseFlags);
388   return mlir_flags;
389 }
390 
ResetJitCompilerFlags()391 void ResetJitCompilerFlags() { ResetFlags(); }
392 
GetJitRtFlags()393 const JitRtFlags& GetJitRtFlags() {
394   absl::call_once(flags_init, &AllocateAndParseFlags);
395   return *jitrt_flags;
396 }
397 
GetMlirBridgeRolloutState(std::optional<const ConfigProto> config_proto)398 ConfigProto::Experimental::MlirBridgeRollout GetMlirBridgeRolloutState(
399     std::optional<const ConfigProto> config_proto) {
400   // TF1 graphs that do not override Sessions's ConfigProto and TF2 graphs
401   // can enable/disable the graph via tf_mlir_enable_mlir_bridge.
402   auto tf_mlir_enable_mlir_bridge =
403       GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge;
404   if (tf_mlir_enable_mlir_bridge !=
405       ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_UNSPECIFIED) {
406     return tf_mlir_enable_mlir_bridge;
407   }
408 
409   // If a ConfigProto was not passed in, we can assume the caller is
410   // checking if TF2 graph should have the bridge enabled / disabled. In that
411   // case, we have already checked tf_mlir_enable_mlir_bridge so it is safe to
412   // return UNSPECIFIED here.
413   if (!config_proto.has_value()) {
414     return ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_UNSPECIFIED;
415   }
416 
417   // TF1 graphs that do override Session's ConfigProto and set
418   // ConfigProto's enable_mlir_bridge or mlir_bridge_rollout fields will not
419   // update tf_mlir_enable_mlir_bridge so check their values.
420 
421   // ConfigProto's enable_mlir_bridge defaults to false so only respect it
422   // when it is true.
423   if (config_proto.value().experimental().enable_mlir_bridge()) {
424     return ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED;
425   }
426   return config_proto.value().experimental().mlir_bridge_rollout();
427 }
428 
AppendMarkForCompilationPassFlags(std::vector<Flag> * flag_list)429 void AppendMarkForCompilationPassFlags(std::vector<Flag>* flag_list) {
430   absl::call_once(flags_init, &AllocateAndParseFlags);
431   AppendMarkForCompilationPassFlagsInternal(flag_list);
432 }
433 
434 static std::atomic<bool> xla_compilation_disabled(false);
435 
DisableXlaCompilation()436 void DisableXlaCompilation() { xla_compilation_disabled = true; }
437 
FailOnXlaCompilation()438 bool FailOnXlaCompilation() { return xla_compilation_disabled; }
439 
440 }  // namespace tensorflow
441