• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_JIT_FLAGS_H_
17 #define TENSORFLOW_COMPILER_JIT_FLAGS_H_
18 
19 #include <vector>
20 
21 #include "absl/types/optional.h"
22 #include "tensorflow/core/platform/types.h"
23 #include "tensorflow/core/protobuf/config.pb.h"
24 #include "tensorflow/core/util/command_line_flags.h"
25 
26 namespace tensorflow {
27 
28 struct XlaAutoJitFlag {
29   // Control compilation of operators into XLA computations on CPU and GPU
30   // devices.  0 = use ConfigProto setting; -1 = off; 1 = on for things very
31   // likely to be improved; 2 = on for everything.
32   //
33   // If all non-CPU ops in the graph being optimized are placed on a single GPU
34   // and there is at least one node placed on that GPU then
35   // `optimization_level_single_gpu` applies.  Otherwise
36   // `optimization_level_general` applies.
37   //
38   // Experimental.
39   int32 optimization_level_single_gpu;
40   int32 optimization_level_general;
41 };
42 
43 // Sets the xla_auto_jit_flag based on the given flag string. Supported syntax
44 // is:
45 // <number>: sets general and single_gpu setting to the provided number.
46 // single-gpu(<number>): sets the single_gpu setting to the provided number.
47 bool SetXlaAutoJitFlagFromFlagString(const string& value);
48 
49 // Flags associated with the XLA bridge's mark_for_compilation_pass module.
50 struct MarkForCompilationPassFlags {
51   XlaAutoJitFlag xla_auto_jit_flag;
52 
53   // Minimum number of operators in an XLA compilation. Ignored for operators
54   // placed on an XLA device or operators explicitly marked for compilation.
55   int32 tf_xla_min_cluster_size;
56 
57   // Maximum number of operators in an XLA compilation.
58   int32 tf_xla_max_cluster_size;
59 
60   // If non-empty, limit XLA clustering to the following TF operations.
61   string tf_xla_ops_to_cluster;
62 
63   // Dump graphs during XLA compilation.
64   bool tf_xla_clustering_debug;
65 
66   // Enables global JIT compilation for CPU via SessionOptions.
67   bool tf_xla_cpu_global_jit;
68 
69   // "Compiler fuel" for clustering.  Only this many ops will be marked as
70   // eligible for clustering.
71   int64 tf_xla_clustering_fuel;
72 
73   // If tf_xla_disable_deadness_safety_checks_for_debugging is set to true then
74   // we do not do deadness related safety checks.  This is unsound in general,
75   // but can be used as a debugging aid.
76   bool tf_xla_disable_deadness_safety_checks_for_debugging;
77 
78   // If tf_xla_disable_resource_variable_safety_checks_for_debugging is set to
79   // true then we do not do safety checks to preserve TensorFlow's resource
80   // variable concurrency semantics.  This is unsound in general, but can be
81   // used as a debugging aid.
82   bool tf_xla_disable_resource_variable_safety_checks_for_debugging;
83 };
84 
85 // Flags associated with the XLA bridge's xla_device module.
86 struct XlaDeviceFlags {
87   // Switch the CPU device into "on-demand" mode, where instead of
88   // autoclustering ops are compiled one by one just-in-time.
89   // Enabling this mode by a legacy flag is a temporary mechanism. When this
90   // feature is battle-tested, we will switch this to be a session option.
91   bool tf_xla_compile_on_demand;
92 
93   // Enables "XLA" devices if this flag is set.
94   bool tf_xla_enable_xla_devices;
95 };
96 
97 // Flags common to the _Xla* ops and their kernels.
98 struct XlaOpsCommonFlags {
99   // If true, _XlaCompile always refuses to compile the cluster, which means the
100   // XLA clusters always run in the TF executor.  Defaults to false.
101   bool tf_xla_always_defer_compilation;
102   // If true, _XlaCompile compiles the cluster asynchronously with respect to
103   // the main execution. The fallback path is taken while compilation happens.
104   bool tf_xla_async_compilation;
105 };
106 
107 // Flags for the build_xla_ops pass.
108 struct BuildXlaOpsPassFlags {
109   // Enables lazy compilation for TF/XLA (only when auto-clustering) if true.
110   // Defaults to true.
111   bool tf_xla_enable_lazy_compilation;
112 
113   // If true then insert Print nodes to print out values produced by XLA
114   // clusters.  Useful for debugging.
115   bool tf_xla_print_cluster_outputs;
116 
117   // If true, insert CheckNumerics nodes for every floating point typed input to
118   // an XLA cluster.
119   bool tf_xla_check_cluster_input_numerics;
120 
121   // If true, insert CheckNumerics nodes for every floating point typed output
122   // from an XLA cluster.
123   bool tf_xla_check_cluster_output_numerics;
124 
125   // Disables all constant folding. The primary use for this is for testing to
126   // guarantee that tests are run on XLA and not on TF's CPU implementation.
127   bool tf_xla_disable_constant_folding;
128 };
129 
130 // Flags for the IntroduceFloatingPointJitter pass.
131 struct IntroduceFloatingPointJitterPassFlags {
132   // The amount of jitter to introduce.  This amount is added to each element in
133   // the tensors named in `tensor_names.
134   float jitter_amount;
135 
136   // The Tensors to add the jitter to.  The tensors are named in the TensorId
137   // format of <node name>:<output idx>.
138   std::vector<string> tensor_names;
139 };
140 
141 // Flags for common MLIR configurations.
142 struct MlirCommonFlags {
143   ConfigProto::Experimental::MlirBridgeRollout tf_mlir_enable_mlir_bridge;
144 
145   bool tf_mlir_enable_merge_control_flow_pass;
146 };
147 
148 // Return a pointer to the DumpGraphFlags struct;
149 // repeated calls return the same pointer.
150 // This should be called only after Flags::Parse() has returned.
151 
152 // Getters for flags structs defined above.  The first call to any of these
153 // parses TF_XLA_FLAGS for all of them.  Those functions which return a pointer
154 // always return the same pointer.
155 MarkForCompilationPassFlags* GetMarkForCompilationPassFlags();
156 BuildXlaOpsPassFlags* GetBuildXlaOpsPassFlags();
157 XlaDeviceFlags* GetXlaDeviceFlags();
158 const XlaOpsCommonFlags& GetXlaOpsCommonFlags();
159 
160 const IntroduceFloatingPointJitterPassFlags&
161 GetIntroduceFloatingPointJitterPassFlags();
162 
163 MlirCommonFlags* GetMlirCommonFlags();
164 
165 // Returns the effective MLIR bridge rollout state based on the flags and the
166 // optional configuration.
167 ConfigProto::Experimental::MlirBridgeRollout GetMlirBridgeRolloutState(
168     absl::optional<const ConfigProto> config_proto);
169 
170 // Appends the flag definitions associated with
171 // MarkForCompilationPassFlags/DumpGraphFlags to `flag_list`.
172 //
173 // Has the side-effect of parsing TF_XLA_FLAGS if that hasn't happened yet.
174 void AppendMarkForCompilationPassFlags(
175     std::vector<tensorflow::Flag>* flag_list);
176 
177 // Disables XLA compilation, forces it to return an error message instead. Can
178 // be used by a server to ensure that JIT compilation is opt-in.
179 void DisableXlaCompilation();
180 
181 // Returns `false` unless `DisableXlaCompilation` was called.
182 bool FailOnXlaCompilation();
183 
184 }  // namespace tensorflow
185 
186 #endif  // TENSORFLOW_COMPILER_JIT_FLAGS_H_
187