• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_JIT_FLAGS_H_
17 #define TENSORFLOW_COMPILER_JIT_FLAGS_H_
18 
19 #include <string>
20 #include <vector>
21 
22 #include "absl/types/optional.h"
23 #include "tensorflow/core/platform/types.h"
24 #include "tensorflow/core/protobuf/config.pb.h"
25 #include "tensorflow/core/util/command_line_flags.h"
26 
27 namespace tensorflow {
28 
29 struct XlaAutoJitFlag {
30   // Control compilation of operators into XLA computations on CPU and GPU
31   // devices.  0 = use ConfigProto setting; -1 = off; 1 = on for things very
32   // likely to be improved; 2 = on for everything.
33   //
34   // If all non-CPU ops in the graph being optimized are placed on a single GPU
35   // and there is at least one node placed on that GPU then
36   // `optimization_level_single_gpu` applies.  Otherwise
37   // `optimization_level_general` applies.
38   //
39   // Experimental.
40   int32 optimization_level_single_gpu;
41   int32 optimization_level_general;
42 };
43 
44 // Sets the xla_auto_jit_flag based on the given flag string. Supported syntax
45 // is:
46 // <number>: sets general and single_gpu setting to the provided number.
47 // single-gpu(<number>): sets the single_gpu setting to the provided number.
48 bool SetXlaAutoJitFlagFromFlagString(const string& value);
49 
50 // Flags associated with the XLA bridge's mark_for_compilation_pass module.
51 struct MarkForCompilationPassFlags {
52   XlaAutoJitFlag xla_auto_jit_flag;
53 
54   // Minimum number of operators in an XLA compilation. Ignored for operators
55   // placed on an XLA device or operators explicitly marked for compilation.
56   int32 tf_xla_min_cluster_size;
57 
58   // Maximum number of operators in an XLA compilation.
59   int32 tf_xla_max_cluster_size;
60 
61   // If non-empty, limit XLA clustering to the following TF operations.
62   string tf_xla_ops_to_cluster;
63 
64   // If non-empty, remove following operations from XLA clustering excludelist.
65   string tf_xla_cluster_exclude_ops;
66 
67   // Dump graphs during XLA compilation.
68   bool tf_xla_clustering_debug;
69 
70   // Enables global JIT compilation for CPU via SessionOptions.
71   bool tf_xla_cpu_global_jit;
72 
73   // "Compiler fuel" for clustering.  Only this many ops will be marked as
74   // eligible for clustering.
75   int64_t tf_xla_clustering_fuel;
76 
77   // If tf_xla_disable_deadness_safety_checks_for_debugging is set to true then
78   // we do not do deadness related safety checks.  This is unsound in general,
79   // but can be used as a debugging aid.
80   bool tf_xla_disable_deadness_safety_checks_for_debugging;
81 
82   // If tf_xla_disable_resource_variable_safety_checks_for_debugging is set to
83   // true then we do not do safety checks to preserve TensorFlow's resource
84   // variable concurrency semantics.  This is unsound in general, but can be
85   // used as a debugging aid.
86   bool tf_xla_disable_resource_variable_safety_checks_for_debugging;
87 
88   // If true names of clustered operations will be computed deterministically
89   // so that they remain stable from run to run of auto clusteing.
90   bool tf_xla_deterministic_cluster_names;
91 
92   // If non-empty, JIT-compiled executables are saved to and loaded from the
93   // specified file system directory path.
94   std::string tf_xla_persistent_cache_directory;
95 
96   // If true, entries loaded into the XLA compile cache will not have their
97   // signatures checked strictly. This should generally not be disabled except
98   // for debugging. Defaults to false.
99   bool tf_xla_disable_strict_signature_checks;
100 
101   // Specifies the persistance cache prefix. Default is "xla_compile_cache"
102   string tf_xla_persistent_cache_prefix;
103 };
104 
105 // Flags associated with the XLA bridge's xla_device module.
106 struct XlaDeviceFlags {
107   // Switch the CPU device into "on-demand" mode, where instead of
108   // autoclustering ops are compiled one by one just-in-time.
109   // Enabling this mode by a legacy flag is a temporary mechanism. When this
110   // feature is battle-tested, we will switch this to be a session option.
111   bool tf_xla_compile_on_demand;
112 
113   // Enables "XLA" devices if this flag is set.
114   bool tf_xla_enable_xla_devices;
115 };
116 
117 // Flags common to the _Xla* ops and their kernels.
118 struct XlaOpsCommonFlags {
119   // If true, _XlaCompile always refuses to compile the cluster, which means the
120   // XLA clusters always run in the TF executor.  Defaults to false.
121   bool tf_xla_always_defer_compilation;
122   // If true, _XlaCompile compiles the cluster asynchronously with respect to
123   // the main execution. The fallback path is taken while compilation happens.
124   bool tf_xla_async_compilation;
125 };
126 
127 // Flags for the build_xla_ops pass.
128 struct BuildXlaOpsPassFlags {
129   // Enables lazy compilation for TF/XLA (only when auto-clustering) if true.
130   // Defaults to true.
131   bool tf_xla_enable_lazy_compilation;
132 
133   // If true then insert Print nodes to print out values produced by XLA
134   // clusters.  Useful for debugging.
135   bool tf_xla_print_cluster_outputs;
136 
137   // If true, insert CheckNumerics nodes for every floating point typed input to
138   // an XLA cluster.
139   bool tf_xla_check_cluster_input_numerics;
140 
141   // If true, insert CheckNumerics nodes for every floating point typed output
142   // from an XLA cluster.
143   bool tf_xla_check_cluster_output_numerics;
144 
145   // Disables all constant folding. The primary use for this is for testing to
146   // guarantee that tests are run on XLA and not on TF's CPU implementation.
147   bool tf_xla_disable_constant_folding;
148 };
149 
150 // Flags for the IntroduceFloatingPointJitter pass.
151 struct IntroduceFloatingPointJitterPassFlags {
152   // The amount of jitter to introduce.  This amount is added to each element in
153   // the tensors named in `tensor_names.
154   float jitter_amount;
155 
156   // The Tensors to add the jitter to.  The tensors are named in the TensorId
157   // format of <node name>:<output idx>.
158   std::vector<string> tensor_names;
159 };
160 
161 // Flags for common MLIR configurations.
162 struct MlirCommonFlags {
163   ConfigProto::Experimental::MlirBridgeRollout tf_mlir_enable_mlir_bridge;
164 
165   bool tf_mlir_enable_merge_control_flow_pass;
166   bool tf_mlir_enable_convert_control_to_data_outputs_pass;
167 };
168 
169 // Flags for the JitRt pipeline -- see tf_jitrt_pipeline.h for details.
170 struct JitRtFlags {
171   bool always_specialize;
172   bool cost_driven_async_parallel_for;
173 
174   // Enables tracking of the "live" JitRt queries to, on a crash, identify the
175   // "query of death". See TfJitRtQueryOfDeathLogger.
176   bool log_query_of_death;
177 
178   bool vectorize;
179 
180   // Enables crash reproducer for JitRt MLIR pass manager.
181   bool enable_crash_reproducer;
182 };
183 
184 // Return a pointer to the DumpGraphFlags struct;
185 // repeated calls return the same pointer.
186 // This should be called only after Flags::Parse() has returned.
187 
188 // Getters for flags structs defined above.  The first call to any of these
189 // parses TF_XLA_FLAGS for all of them.  Those functions which return a pointer
190 // always return the same pointer.
191 MarkForCompilationPassFlags* GetMarkForCompilationPassFlags();
192 BuildXlaOpsPassFlags* GetBuildXlaOpsPassFlags();
193 XlaDeviceFlags* GetXlaDeviceFlags();
194 const XlaOpsCommonFlags& GetXlaOpsCommonFlags();
195 
196 const IntroduceFloatingPointJitterPassFlags&
197 GetIntroduceFloatingPointJitterPassFlags();
198 
199 MlirCommonFlags* GetMlirCommonFlags();
200 
201 void ResetJitCompilerFlags();
202 
203 const JitRtFlags& GetJitRtFlags();
204 
205 // Returns the effective MLIR bridge rollout state based on the flags and the
206 // optional configuration.
207 ConfigProto::Experimental::MlirBridgeRollout GetMlirBridgeRolloutState(
208     std::optional<const ConfigProto> config_proto);
209 
210 // Appends the flag definitions associated with
211 // MarkForCompilationPassFlags/DumpGraphFlags to `flag_list`.
212 //
213 // Has the side-effect of parsing TF_XLA_FLAGS if that hasn't happened yet.
214 void AppendMarkForCompilationPassFlags(
215     std::vector<tensorflow::Flag>* flag_list);
216 
217 // Disables XLA compilation, forces it to return an error message instead. Can
218 // be used by a server to ensure that JIT compilation is opt-in.
219 void DisableXlaCompilation();
220 
221 // Returns `false` unless `DisableXlaCompilation` was called.
222 bool FailOnXlaCompilation();
223 
224 }  // namespace tensorflow
225 
226 #endif  // TENSORFLOW_COMPILER_JIT_FLAGS_H_
227