1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/jit/flags.h"
17
18 #include <mutex> // NOLINT
19
20 #include "absl/base/call_once.h"
21 #include "absl/strings/numbers.h"
22 #include "absl/strings/str_split.h"
23 #include "absl/strings/strip.h"
24 #include "tensorflow/compiler/xla/parse_flags_from_env.h"
25 #include "tensorflow/core/platform/macros.h"
26 #include "tensorflow/core/util/command_line_flags.h"
27
28 namespace tensorflow {
29 namespace {
30
31 BuildXlaOpsPassFlags* build_ops_flags;
32 MarkForCompilationPassFlags* mark_for_compilation_flags;
33 XlaDeviceFlags* device_flags;
34 XlaOpsCommonFlags* ops_flags;
35 IntroduceFloatingPointJitterPassFlags* jitter_flags;
36 MlirCommonFlags* mlir_flags;
37
38 std::vector<Flag>* flag_list;
39 absl::once_flag flags_init;
40
SetterForXlaAutoJitFlag(const string & value)41 bool SetterForXlaAutoJitFlag(const string& value) {
42 int32_t opt_level;
43 // We need to use the mark_for_compilation_flags directly here instead of
44 // going via GetMarkForCompilationPassFlags() to avoid infinite recursion. The
45 // latter will try to setup and parse flags, which would bring us back to this
46 // setter.
47 if (absl::SimpleAtoi(value, &opt_level)) {
48 mark_for_compilation_flags->xla_auto_jit_flag
49 .optimization_level_single_gpu = opt_level;
50 mark_for_compilation_flags->xla_auto_jit_flag.optimization_level_general =
51 opt_level;
52 return true;
53 }
54
55 if (value == "fusible") {
56 mark_for_compilation_flags->xla_auto_jit_flag
57 .optimization_level_single_gpu = 1;
58 mark_for_compilation_flags->xla_auto_jit_flag.optimization_level_general =
59 1;
60 mark_for_compilation_flags->tf_xla_ops_to_cluster = "FUSIBLE";
61 return true;
62 }
63
64 absl::string_view value_sv(value);
65 if (!absl::ConsumePrefix(&value_sv, "single-gpu(") ||
66 !absl::ConsumeSuffix(&value_sv, ")") ||
67 !absl::SimpleAtoi(value_sv, &opt_level)) {
68 return false;
69 }
70
71 mark_for_compilation_flags->xla_auto_jit_flag.optimization_level_single_gpu =
72 opt_level;
73 return true;
74 }
75
AppendMarkForCompilationPassFlagsInternal(std::vector<Flag> * flag_list)76 void AppendMarkForCompilationPassFlagsInternal(std::vector<Flag>* flag_list) {
77 std::vector<Flag> new_flags = {
78 Flag("tf_xla_auto_jit", SetterForXlaAutoJitFlag, "0",
79 "Control compilation of operators into XLA computations on CPU and "
80 "GPU devices. 0 = use ConfigProto setting; -1 = off; 1 = on for "
81 "things very likely to be improved; 2 = on for everything; "
82 "(experimental) fusible = only for Tensorflow operations that XLA "
83 "knows how to fuse. "
84 "If set to single-gpu(<N>) then this resolves to <N> for single-GPU "
85 "graphs (graphs that have at least one node placed on a GPU and no "
86 "more than one GPU is in use through the entire graph) and 0 "
87 "otherwise. Experimental."),
88 Flag("tf_xla_min_cluster_size",
89 &mark_for_compilation_flags->tf_xla_min_cluster_size,
90 "Minimum number of operators in an XLA compilation. Ignored for "
91 "operators placed on an XLA device or operators explicitly marked "
92 "for compilation."),
93 Flag("tf_xla_max_cluster_size",
94 &mark_for_compilation_flags->tf_xla_max_cluster_size,
95 "Maximum number of operators in an XLA compilation."),
96 Flag(
97 "tf_xla_ops_to_cluster",
98 &mark_for_compilation_flags->tf_xla_ops_to_cluster,
99 "(experimental) "
100 "Limit the operations clustered by XLA to these operations. "
101 "If multiple, separate them with commas. Shortcuts: "
102 " PW: All point-wise operations."
103 " RED: All reduction operations."
104 " MISC: Mixed operations."
105 " PWRED: TF operations that get converted to PW+RED operation in XLA."
106 " REDUCEWINDOW: TF operations like MaxPool/AvgPool that get "
107 "converted to ReduceWindow in XLA."
108 " REDUCEWINDOWPW: Operation that get converted to ReduceWindow + PW "
109 "(LRN, LRNGrad)."
110 " BN: TF FusedBatchNorm* operations."
111 " FUSIBLE: All TF operations that XLA can fuse (All the above). "
112 "You can also put any TF operation name, e.g. 'FUSIBLE,MatMul'."),
113 Flag("tf_xla_clustering_debug",
114 &mark_for_compilation_flags->tf_xla_clustering_debug,
115 "Dump graphs during XLA compilation."),
116 Flag("tf_xla_cpu_global_jit",
117 &mark_for_compilation_flags->tf_xla_cpu_global_jit,
118 "Enables global JIT compilation for CPU via SessionOptions."),
119 Flag("tf_xla_clustering_fuel",
120 &mark_for_compilation_flags->tf_xla_clustering_fuel,
121 "Places an artificial limit on the number of ops marked as "
122 "eligible for clustering."),
123 Flag("tf_xla_disable_deadness_safety_checks_for_debugging",
124 &mark_for_compilation_flags
125 ->tf_xla_disable_deadness_safety_checks_for_debugging,
126 "Disable deadness related safety checks when clustering (this is "
127 "unsound)."),
128 Flag("tf_xla_disable_resource_variable_safety_checks_for_debugging",
129 &mark_for_compilation_flags
130 ->tf_xla_disable_resource_variable_safety_checks_for_debugging,
131 "Disable resource variables related safety checks when clustering "
132 "(this is unsound).")};
133 flag_list->insert(flag_list->end(), new_flags.begin(), new_flags.end());
134 }
135
AllocateAndParseFlags()136 void AllocateAndParseFlags() {
137 build_ops_flags = new BuildXlaOpsPassFlags;
138 build_ops_flags->tf_xla_enable_lazy_compilation = true;
139 build_ops_flags->tf_xla_print_cluster_outputs = false;
140 build_ops_flags->tf_xla_check_cluster_input_numerics = false;
141 build_ops_flags->tf_xla_check_cluster_output_numerics = false;
142 build_ops_flags->tf_xla_disable_constant_folding = false;
143
144 mark_for_compilation_flags = new MarkForCompilationPassFlags;
145 mark_for_compilation_flags->xla_auto_jit_flag.optimization_level_single_gpu =
146 0;
147 mark_for_compilation_flags->xla_auto_jit_flag.optimization_level_general = 0;
148 mark_for_compilation_flags->tf_xla_min_cluster_size = 4;
149 mark_for_compilation_flags->tf_xla_max_cluster_size =
150 std::numeric_limits<int32>::max();
151 mark_for_compilation_flags->tf_xla_clustering_debug = false;
152 mark_for_compilation_flags->tf_xla_cpu_global_jit = false;
153 mark_for_compilation_flags->tf_xla_clustering_fuel =
154 std::numeric_limits<int64>::max();
155 mark_for_compilation_flags
156 ->tf_xla_disable_deadness_safety_checks_for_debugging = false;
157 mark_for_compilation_flags
158 ->tf_xla_disable_resource_variable_safety_checks_for_debugging = false;
159
160 device_flags = new XlaDeviceFlags;
161 device_flags->tf_xla_compile_on_demand = false;
162 device_flags->tf_xla_enable_xla_devices = false;
163
164 ops_flags = new XlaOpsCommonFlags;
165 ops_flags->tf_xla_always_defer_compilation = false;
166 ops_flags->tf_xla_async_compilation = false;
167
168 jitter_flags = new IntroduceFloatingPointJitterPassFlags;
169 jitter_flags->jitter_amount = 1e-5;
170
171 // The `enable_mlir_bridge` flag allows the user to explicitly request that
172 // their program is (or isn't) compiled using the MLIR-based TF-to-XLA bridge.
173 //
174 // The `enable_mlir_bridge_is_explicit` variable tracks whether or not the
175 // user has made an explicit request. That is, if this variable is set to
176 // true, the program honors the user's request as per `enable_mlir_bridge`; if
177 // it's set to false, the default behavior is used (which may run either
178 // bridge, on a per-graph basis).
179 bool enable_mlir_bridge = false;
180 bool enable_mlir_bridge_is_explicit = false;
181 bool mlir_bridge_safe_mode = false;
182 bool enable_mlir_merge_control_flow_pass = false;
183
184 auto setter_for_jitter_tensor_names = [](string sequence) {
185 jitter_flags->tensor_names = absl::StrSplit(sequence, ',');
186 return true;
187 };
188
189 flag_list = new std::vector<Flag>(
190 {Flag("tf_xla_enable_lazy_compilation",
191 &build_ops_flags->tf_xla_enable_lazy_compilation, ""),
192 Flag("tf_xla_print_cluster_outputs",
193 &build_ops_flags->tf_xla_print_cluster_outputs,
194 "If true then insert Print nodes to print out values produced by "
195 "XLA clusters."),
196 Flag("tf_xla_check_cluster_input_numerics",
197 &build_ops_flags->tf_xla_check_cluster_input_numerics,
198 "If true then insert CheckNumerics nodes to check all cluster "
199 "inputs."),
200 Flag("tf_xla_check_cluster_output_numerics",
201 &build_ops_flags->tf_xla_check_cluster_output_numerics,
202 "If true then insert CheckNumerics nodes to check all cluster "
203 "outputs."),
204 Flag("tf_xla_disable_constant_folding",
205 &build_ops_flags->tf_xla_disable_constant_folding,
206 "If true then disables constant folding on TF graph before XLA "
207 "compilation."),
208
209 Flag("tf_xla_compile_on_demand", &device_flags->tf_xla_compile_on_demand,
210 "Switch a device into 'on-demand' mode, where instead of "
211 "autoclustering ops are compiled one by one just-in-time."),
212
213 Flag("tf_xla_enable_xla_devices",
214 &device_flags->tf_xla_enable_xla_devices,
215 "Generate XLA_* devices, where placing a computation on such a "
216 "device"
217 "forces compilation by XLA. Deprecated."),
218
219 Flag("tf_xla_always_defer_compilation",
220 &ops_flags->tf_xla_always_defer_compilation, ""),
221 Flag("tf_xla_async_compilation", &ops_flags->tf_xla_async_compilation,
222 "When lazy compilation is enabled, asynchronous compilation starts "
223 "the cluster compilation in the background, and the fallback path "
224 "is executed until the compilation has finished."),
225
226 Flag("tf_introduce_floating_point_jitter_to_tensors",
227 setter_for_jitter_tensor_names, "",
228 "The Tensors to add the jitter to. The tensors are named in the "
229 "TensorId format of <node name>:<output idx>."),
230 Flag("tf_introduce_floating_point_jitter_amount",
231 &jitter_flags->jitter_amount,
232 "The amount of jitter to introduce. This amount is added to each "
233 "element in the tensors named in `tensor_names."),
234
235 Flag("tf_mlir_enable_mlir_bridge", &enable_mlir_bridge,
236 "Enables experimental MLIR-Based TensorFlow Compiler Bridge.",
237 &enable_mlir_bridge_is_explicit),
238 Flag("tf_mlir_enable_merge_control_flow_pass",
239 &enable_mlir_merge_control_flow_pass,
240 "Enables MergeControlFlow pass for MLIR-Based TensorFlow Compiler "
241 "Bridge."),
242 Flag(
243 "tf_mlir_bridge_safe_mode", &mlir_bridge_safe_mode,
244 "When tf_mlir_enable_mlir_bridge is true, this field can enable "
245 "the MLIR bridge's safe mode. When the MLIR bridge is in safe mode, "
246 "it only runs for graphs that use features MLIR bridge currently "
247 "supports.")});
248
249 AppendMarkForCompilationPassFlagsInternal(flag_list);
250 xla::ParseFlagsFromEnvAndDieIfUnknown("TF_XLA_FLAGS", *flag_list);
251
252 mlir_flags = new MlirCommonFlags;
253 if (!enable_mlir_bridge_is_explicit) {
254 mlir_flags->tf_mlir_enable_mlir_bridge =
255 (mlir_bridge_safe_mode)
256 ? ConfigProto::Experimental::
257 MLIR_BRIDGE_ROLLOUT_SAFE_MODE_FALLBACK_ENABLED
258 : ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_UNSPECIFIED;
259 } else if (enable_mlir_bridge) {
260 mlir_flags->tf_mlir_enable_mlir_bridge =
261 (mlir_bridge_safe_mode)
262 ? ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_SAFE_MODE_ENABLED
263 : ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED;
264 } else {
265 mlir_flags->tf_mlir_enable_mlir_bridge =
266 ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_DISABLED;
267 }
268 mlir_flags->tf_mlir_enable_merge_control_flow_pass =
269 enable_mlir_merge_control_flow_pass;
270 }
271
272 } // namespace
273
SetXlaAutoJitFlagFromFlagString(const string & value)274 bool SetXlaAutoJitFlagFromFlagString(const string& value) {
275 absl::call_once(flags_init, &AllocateAndParseFlags);
276 return SetterForXlaAutoJitFlag(value);
277 }
278
GetBuildXlaOpsPassFlags()279 BuildXlaOpsPassFlags* GetBuildXlaOpsPassFlags() {
280 absl::call_once(flags_init, &AllocateAndParseFlags);
281 return build_ops_flags;
282 }
283
GetMarkForCompilationPassFlags()284 MarkForCompilationPassFlags* GetMarkForCompilationPassFlags() {
285 absl::call_once(flags_init, &AllocateAndParseFlags);
286 return mark_for_compilation_flags;
287 }
288
GetXlaDeviceFlags()289 XlaDeviceFlags* GetXlaDeviceFlags() {
290 absl::call_once(flags_init, &AllocateAndParseFlags);
291 return device_flags;
292 }
293
GetXlaOpsCommonFlags()294 const XlaOpsCommonFlags& GetXlaOpsCommonFlags() {
295 absl::call_once(flags_init, &AllocateAndParseFlags);
296 return *ops_flags;
297 }
298
299 const IntroduceFloatingPointJitterPassFlags&
GetIntroduceFloatingPointJitterPassFlags()300 GetIntroduceFloatingPointJitterPassFlags() {
301 absl::call_once(flags_init, &AllocateAndParseFlags);
302 return *jitter_flags;
303 }
304
GetMlirCommonFlags()305 MlirCommonFlags* GetMlirCommonFlags() {
306 absl::call_once(flags_init, &AllocateAndParseFlags);
307 return mlir_flags;
308 }
309
GetMlirBridgeRolloutState(absl::optional<const ConfigProto> config_proto)310 ConfigProto::Experimental::MlirBridgeRollout GetMlirBridgeRolloutState(
311 absl::optional<const ConfigProto> config_proto) {
312 // TF1 graphs that do not override Sessions's ConfigProto and TF2 graphs
313 // can enable/disable the graph via tf_mlir_enable_mlir_bridge.
314 auto tf_mlir_enable_mlir_bridge =
315 GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge;
316 if (tf_mlir_enable_mlir_bridge !=
317 ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_UNSPECIFIED) {
318 return tf_mlir_enable_mlir_bridge;
319 }
320
321 // If a ConfigProto was not passed in, we can assume the caller is
322 // checking if TF2 graph should have the bridge enabled / disabled. In that
323 // case, we have already checked tf_mlir_enable_mlir_bridge so it is safe to
324 // return UNSPECIFIED here.
325 if (!config_proto.has_value()) {
326 return ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_UNSPECIFIED;
327 }
328
329 // TF1 graphs that do override Session's ConfigProto and set
330 // ConfigProto's enable_mlir_bridge or mlir_bridge_rollout fields will not
331 // update tf_mlir_enable_mlir_bridge so check their values.
332
333 // ConfigProto's enable_mlir_bridge defaults to false so only respect it
334 // when it is true.
335 if (config_proto.value().experimental().enable_mlir_bridge()) {
336 return ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED;
337 }
338 return config_proto.value().experimental().mlir_bridge_rollout();
339 }
340
AppendMarkForCompilationPassFlags(std::vector<Flag> * flag_list)341 void AppendMarkForCompilationPassFlags(std::vector<Flag>* flag_list) {
342 absl::call_once(flags_init, &AllocateAndParseFlags);
343 AppendMarkForCompilationPassFlagsInternal(flag_list);
344 }
345
346 static std::atomic<bool> xla_compilation_disabled(false);
347
DisableXlaCompilation()348 void DisableXlaCompilation() { xla_compilation_disabled = true; }
349
FailOnXlaCompilation()350 bool FailOnXlaCompilation() { return xla_compilation_disabled; }
351
352 } // namespace tensorflow
353