• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/jit/flags.h"
17 
18 #include <mutex>  // NOLINT
19 
20 #include "absl/base/call_once.h"
21 #include "absl/strings/numbers.h"
22 #include "absl/strings/str_split.h"
23 #include "absl/strings/strip.h"
24 #include "tensorflow/compiler/xla/parse_flags_from_env.h"
25 #include "tensorflow/core/platform/macros.h"
26 #include "tensorflow/core/util/command_line_flags.h"
27 
28 namespace tensorflow {
29 namespace {
30 
31 BuildXlaOpsPassFlags* build_ops_flags;
32 MarkForCompilationPassFlags* mark_for_compilation_flags;
33 XlaDeviceFlags* device_flags;
34 XlaOpsCommonFlags* ops_flags;
35 IntroduceFloatingPointJitterPassFlags* jitter_flags;
36 MlirCommonFlags* mlir_flags;
37 
38 std::vector<Flag>* flag_list;
39 absl::once_flag flags_init;
40 
SetterForXlaAutoJitFlag(const string & value)41 bool SetterForXlaAutoJitFlag(const string& value) {
42   int32_t opt_level;
43   // We need to use the mark_for_compilation_flags directly here instead of
44   // going via GetMarkForCompilationPassFlags() to avoid infinite recursion. The
45   // latter will try to setup and parse flags, which would bring us back to this
46   // setter.
47   if (absl::SimpleAtoi(value, &opt_level)) {
48     mark_for_compilation_flags->xla_auto_jit_flag
49         .optimization_level_single_gpu = opt_level;
50     mark_for_compilation_flags->xla_auto_jit_flag.optimization_level_general =
51         opt_level;
52     return true;
53   }
54 
55   if (value == "fusible") {
56     mark_for_compilation_flags->xla_auto_jit_flag
57         .optimization_level_single_gpu = 1;
58     mark_for_compilation_flags->xla_auto_jit_flag.optimization_level_general =
59         1;
60     mark_for_compilation_flags->tf_xla_ops_to_cluster = "FUSIBLE";
61     return true;
62   }
63 
64   absl::string_view value_sv(value);
65   if (!absl::ConsumePrefix(&value_sv, "single-gpu(") ||
66       !absl::ConsumeSuffix(&value_sv, ")") ||
67       !absl::SimpleAtoi(value_sv, &opt_level)) {
68     return false;
69   }
70 
71   mark_for_compilation_flags->xla_auto_jit_flag.optimization_level_single_gpu =
72       opt_level;
73   return true;
74 }
75 
AppendMarkForCompilationPassFlagsInternal(std::vector<Flag> * flag_list)76 void AppendMarkForCompilationPassFlagsInternal(std::vector<Flag>* flag_list) {
77   std::vector<Flag> new_flags = {
78       Flag("tf_xla_auto_jit", SetterForXlaAutoJitFlag, "0",
79            "Control compilation of operators into XLA computations on CPU and "
80            "GPU devices.  0 = use ConfigProto setting; -1 = off; 1 = on for "
81            "things very likely to be improved; 2 = on for everything; "
82            "(experimental) fusible = only for Tensorflow operations that XLA "
83            "knows how to fuse.  "
84            "If set to single-gpu(<N>) then this resolves to <N> for single-GPU "
85            "graphs (graphs that have at least one node placed on a GPU and no "
86            "more than one GPU is in use through the entire graph) and 0 "
87            "otherwise.  Experimental."),
88       Flag("tf_xla_min_cluster_size",
89            &mark_for_compilation_flags->tf_xla_min_cluster_size,
90            "Minimum number of operators in an XLA compilation. Ignored for "
91            "operators placed on an XLA device or operators explicitly marked "
92            "for compilation."),
93       Flag("tf_xla_max_cluster_size",
94            &mark_for_compilation_flags->tf_xla_max_cluster_size,
95            "Maximum number of operators in an XLA compilation."),
96       Flag(
97           "tf_xla_ops_to_cluster",
98           &mark_for_compilation_flags->tf_xla_ops_to_cluster,
99           "(experimental) "
100           "Limit the operations clustered by XLA to these operations. "
101           "If multiple, separate them with commas. Shortcuts: "
102           " PW: All point-wise operations."
103           " RED: All reduction operations."
104           " MISC: Mixed operations."
105           " PWRED: TF operations that get converted to PW+RED operation in XLA."
106           " REDUCEWINDOW: TF operations like MaxPool/AvgPool that get "
107           "converted to ReduceWindow in XLA."
108           " REDUCEWINDOWPW: Operation that get converted to ReduceWindow + PW "
109           "(LRN, LRNGrad)."
110           " BN: TF FusedBatchNorm* operations."
111           " FUSIBLE: All TF operations that XLA can fuse (All the above). "
112           "You can also put any TF operation name, e.g. 'FUSIBLE,MatMul'."),
113       Flag("tf_xla_clustering_debug",
114            &mark_for_compilation_flags->tf_xla_clustering_debug,
115            "Dump graphs during XLA compilation."),
116       Flag("tf_xla_cpu_global_jit",
117            &mark_for_compilation_flags->tf_xla_cpu_global_jit,
118            "Enables global JIT compilation for CPU via SessionOptions."),
119       Flag("tf_xla_clustering_fuel",
120            &mark_for_compilation_flags->tf_xla_clustering_fuel,
121            "Places an artificial limit on the number of ops marked as "
122            "eligible for clustering."),
123       Flag("tf_xla_disable_deadness_safety_checks_for_debugging",
124            &mark_for_compilation_flags
125                 ->tf_xla_disable_deadness_safety_checks_for_debugging,
126            "Disable deadness related safety checks when clustering (this is "
127            "unsound)."),
128       Flag("tf_xla_disable_resource_variable_safety_checks_for_debugging",
129            &mark_for_compilation_flags
130                 ->tf_xla_disable_resource_variable_safety_checks_for_debugging,
131            "Disable resource variables related safety checks when clustering "
132            "(this is unsound).")};
133   flag_list->insert(flag_list->end(), new_flags.begin(), new_flags.end());
134 }
135 
AllocateAndParseFlags()136 void AllocateAndParseFlags() {
137   build_ops_flags = new BuildXlaOpsPassFlags;
138   build_ops_flags->tf_xla_enable_lazy_compilation = true;
139   build_ops_flags->tf_xla_print_cluster_outputs = false;
140   build_ops_flags->tf_xla_check_cluster_input_numerics = false;
141   build_ops_flags->tf_xla_check_cluster_output_numerics = false;
142   build_ops_flags->tf_xla_disable_constant_folding = false;
143 
144   mark_for_compilation_flags = new MarkForCompilationPassFlags;
145   mark_for_compilation_flags->xla_auto_jit_flag.optimization_level_single_gpu =
146       0;
147   mark_for_compilation_flags->xla_auto_jit_flag.optimization_level_general = 0;
148   mark_for_compilation_flags->tf_xla_min_cluster_size = 4;
149   mark_for_compilation_flags->tf_xla_max_cluster_size =
150       std::numeric_limits<int32>::max();
151   mark_for_compilation_flags->tf_xla_clustering_debug = false;
152   mark_for_compilation_flags->tf_xla_cpu_global_jit = false;
153   mark_for_compilation_flags->tf_xla_clustering_fuel =
154       std::numeric_limits<int64>::max();
155   mark_for_compilation_flags
156       ->tf_xla_disable_deadness_safety_checks_for_debugging = false;
157   mark_for_compilation_flags
158       ->tf_xla_disable_resource_variable_safety_checks_for_debugging = false;
159 
160   device_flags = new XlaDeviceFlags;
161   device_flags->tf_xla_compile_on_demand = false;
162   device_flags->tf_xla_enable_xla_devices = false;
163 
164   ops_flags = new XlaOpsCommonFlags;
165   ops_flags->tf_xla_always_defer_compilation = false;
166   ops_flags->tf_xla_async_compilation = false;
167 
168   jitter_flags = new IntroduceFloatingPointJitterPassFlags;
169   jitter_flags->jitter_amount = 1e-5;
170 
171   // The `enable_mlir_bridge` flag allows the user to explicitly request that
172   // their program is (or isn't) compiled using the MLIR-based TF-to-XLA bridge.
173   //
174   // The `enable_mlir_bridge_is_explicit` variable tracks whether or not the
175   // user has made an explicit request. That is, if this variable is set to
176   // true, the program honors the user's request as per `enable_mlir_bridge`; if
177   // it's set to false, the default behavior is used (which may run either
178   // bridge, on a per-graph basis).
179   bool enable_mlir_bridge = false;
180   bool enable_mlir_bridge_is_explicit = false;
181   bool mlir_bridge_safe_mode = false;
182   bool enable_mlir_merge_control_flow_pass = false;
183 
184   auto setter_for_jitter_tensor_names = [](string sequence) {
185     jitter_flags->tensor_names = absl::StrSplit(sequence, ',');
186     return true;
187   };
188 
189   flag_list = new std::vector<Flag>(
190       {Flag("tf_xla_enable_lazy_compilation",
191             &build_ops_flags->tf_xla_enable_lazy_compilation, ""),
192        Flag("tf_xla_print_cluster_outputs",
193             &build_ops_flags->tf_xla_print_cluster_outputs,
194             "If true then insert Print nodes to print out values produced by "
195             "XLA clusters."),
196        Flag("tf_xla_check_cluster_input_numerics",
197             &build_ops_flags->tf_xla_check_cluster_input_numerics,
198             "If true then insert CheckNumerics nodes to check all cluster "
199             "inputs."),
200        Flag("tf_xla_check_cluster_output_numerics",
201             &build_ops_flags->tf_xla_check_cluster_output_numerics,
202             "If true then insert CheckNumerics nodes to check all cluster "
203             "outputs."),
204        Flag("tf_xla_disable_constant_folding",
205             &build_ops_flags->tf_xla_disable_constant_folding,
206             "If true then disables constant folding on TF graph before XLA "
207             "compilation."),
208 
209        Flag("tf_xla_compile_on_demand", &device_flags->tf_xla_compile_on_demand,
210             "Switch a device into 'on-demand' mode, where instead of "
211             "autoclustering ops are compiled one by one just-in-time."),
212 
213        Flag("tf_xla_enable_xla_devices",
214             &device_flags->tf_xla_enable_xla_devices,
215             "Generate XLA_* devices, where placing a computation on such a "
216             "device"
217             "forces compilation by XLA. Deprecated."),
218 
219        Flag("tf_xla_always_defer_compilation",
220             &ops_flags->tf_xla_always_defer_compilation, ""),
221        Flag("tf_xla_async_compilation", &ops_flags->tf_xla_async_compilation,
222             "When lazy compilation is enabled, asynchronous compilation starts "
223             "the cluster compilation in the background, and the fallback path "
224             "is executed until the compilation has finished."),
225 
226        Flag("tf_introduce_floating_point_jitter_to_tensors",
227             setter_for_jitter_tensor_names, "",
228             "The Tensors to add the jitter to.  The tensors are named in the "
229             "TensorId format of <node name>:<output idx>."),
230        Flag("tf_introduce_floating_point_jitter_amount",
231             &jitter_flags->jitter_amount,
232             "The amount of jitter to introduce.  This amount is added to each "
233             "element in the tensors named in `tensor_names."),
234 
235        Flag("tf_mlir_enable_mlir_bridge", &enable_mlir_bridge,
236             "Enables experimental MLIR-Based TensorFlow Compiler Bridge.",
237             &enable_mlir_bridge_is_explicit),
238        Flag("tf_mlir_enable_merge_control_flow_pass",
239             &enable_mlir_merge_control_flow_pass,
240             "Enables MergeControlFlow pass for MLIR-Based TensorFlow Compiler "
241             "Bridge."),
242        Flag(
243            "tf_mlir_bridge_safe_mode", &mlir_bridge_safe_mode,
244            "When tf_mlir_enable_mlir_bridge is true, this field can enable "
245            "the MLIR bridge's safe mode. When the MLIR bridge is in safe mode, "
246            "it only runs for graphs that use features MLIR bridge currently "
247            "supports.")});
248 
249   AppendMarkForCompilationPassFlagsInternal(flag_list);
250   xla::ParseFlagsFromEnvAndDieIfUnknown("TF_XLA_FLAGS", *flag_list);
251 
252   mlir_flags = new MlirCommonFlags;
253   if (!enable_mlir_bridge_is_explicit) {
254     mlir_flags->tf_mlir_enable_mlir_bridge =
255         (mlir_bridge_safe_mode)
256             ? ConfigProto::Experimental::
257                   MLIR_BRIDGE_ROLLOUT_SAFE_MODE_FALLBACK_ENABLED
258             : ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_UNSPECIFIED;
259   } else if (enable_mlir_bridge) {
260     mlir_flags->tf_mlir_enable_mlir_bridge =
261         (mlir_bridge_safe_mode)
262             ? ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_SAFE_MODE_ENABLED
263             : ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED;
264   } else {
265     mlir_flags->tf_mlir_enable_mlir_bridge =
266         ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_DISABLED;
267   }
268   mlir_flags->tf_mlir_enable_merge_control_flow_pass =
269       enable_mlir_merge_control_flow_pass;
270 }
271 
272 }  // namespace
273 
SetXlaAutoJitFlagFromFlagString(const string & value)274 bool SetXlaAutoJitFlagFromFlagString(const string& value) {
275   absl::call_once(flags_init, &AllocateAndParseFlags);
276   return SetterForXlaAutoJitFlag(value);
277 }
278 
GetBuildXlaOpsPassFlags()279 BuildXlaOpsPassFlags* GetBuildXlaOpsPassFlags() {
280   absl::call_once(flags_init, &AllocateAndParseFlags);
281   return build_ops_flags;
282 }
283 
GetMarkForCompilationPassFlags()284 MarkForCompilationPassFlags* GetMarkForCompilationPassFlags() {
285   absl::call_once(flags_init, &AllocateAndParseFlags);
286   return mark_for_compilation_flags;
287 }
288 
GetXlaDeviceFlags()289 XlaDeviceFlags* GetXlaDeviceFlags() {
290   absl::call_once(flags_init, &AllocateAndParseFlags);
291   return device_flags;
292 }
293 
GetXlaOpsCommonFlags()294 const XlaOpsCommonFlags& GetXlaOpsCommonFlags() {
295   absl::call_once(flags_init, &AllocateAndParseFlags);
296   return *ops_flags;
297 }
298 
299 const IntroduceFloatingPointJitterPassFlags&
GetIntroduceFloatingPointJitterPassFlags()300 GetIntroduceFloatingPointJitterPassFlags() {
301   absl::call_once(flags_init, &AllocateAndParseFlags);
302   return *jitter_flags;
303 }
304 
GetMlirCommonFlags()305 MlirCommonFlags* GetMlirCommonFlags() {
306   absl::call_once(flags_init, &AllocateAndParseFlags);
307   return mlir_flags;
308 }
309 
GetMlirBridgeRolloutState(absl::optional<const ConfigProto> config_proto)310 ConfigProto::Experimental::MlirBridgeRollout GetMlirBridgeRolloutState(
311     absl::optional<const ConfigProto> config_proto) {
312   // TF1 graphs that do not override Sessions's ConfigProto and TF2 graphs
313   // can enable/disable the graph via tf_mlir_enable_mlir_bridge.
314   auto tf_mlir_enable_mlir_bridge =
315       GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge;
316   if (tf_mlir_enable_mlir_bridge !=
317       ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_UNSPECIFIED) {
318     return tf_mlir_enable_mlir_bridge;
319   }
320 
321   // If a ConfigProto was not passed in, we can assume the caller is
322   // checking if TF2 graph should have the bridge enabled / disabled. In that
323   // case, we have already checked tf_mlir_enable_mlir_bridge so it is safe to
324   // return UNSPECIFIED here.
325   if (!config_proto.has_value()) {
326     return ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_UNSPECIFIED;
327   }
328 
329   // TF1 graphs that do override Session's ConfigProto and set
330   // ConfigProto's enable_mlir_bridge or mlir_bridge_rollout fields will not
331   // update tf_mlir_enable_mlir_bridge so check their values.
332 
333   // ConfigProto's enable_mlir_bridge defaults to false so only respect it
334   // when it is true.
335   if (config_proto.value().experimental().enable_mlir_bridge()) {
336     return ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED;
337   }
338   return config_proto.value().experimental().mlir_bridge_rollout();
339 }
340 
AppendMarkForCompilationPassFlags(std::vector<Flag> * flag_list)341 void AppendMarkForCompilationPassFlags(std::vector<Flag>* flag_list) {
342   absl::call_once(flags_init, &AllocateAndParseFlags);
343   AppendMarkForCompilationPassFlagsInternal(flag_list);
344 }
345 
346 static std::atomic<bool> xla_compilation_disabled(false);
347 
DisableXlaCompilation()348 void DisableXlaCompilation() { xla_compilation_disabled = true; }
349 
FailOnXlaCompilation()350 bool FailOnXlaCompilation() { return xla_compilation_disabled; }
351 
352 }  // namespace tensorflow
353