• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_JIT_FLAGS_H_
17 #define TENSORFLOW_COMPILER_JIT_FLAGS_H_
18 
19 #include <vector>
20 
21 #include "tensorflow/core/platform/types.h"
22 #include "tensorflow/core/util/command_line_flags.h"
23 
24 namespace tensorflow {
25 
26 // Flags associated with the XLA bridge's mark_for_compilation_pass module.
27 struct MarkForCompilationPassFlags {
28   // Control compilation of operators into XLA computations on CPU and GPU
29   // devices.  0 = use ConfigProto setting; -1 = off; 1 = on for things very
30   // likely to be improved; 2 = on for everything.
31   //
32   // Experimental.
33   int32 tf_xla_auto_jit;
34 
35   // Minimum number of operators in an XLA compilation. Ignored for operators
36   // placed on an XLA device or operators explicitly marked for compilation.
37   int32 tf_xla_min_cluster_size;
38 
39   // Maximum number of operators in an XLA compilation.
40   int32 tf_xla_max_cluster_size;
41 
42   // Dump graphs during XLA compilation.
43   bool tf_xla_clustering_debug;
44 
45   // Enables global JIT compilation for CPU via SessionOptions.
46   bool tf_xla_cpu_global_jit;
47 
48   // "Compiler fuel" for clustering.  Only this many ops will be marked as
49   // eligible for clustering.
50   int64 tf_xla_clustering_fuel;
51 
52   // tf_xla_fusion_only is effective only when global_jit_level is set to ON*
53   // and overrides its behavior. If true, enable fusion of element-wise
54   // operations only using XLA.
55   bool tf_xla_fusion_only;
56 
57   // If tf_xla_disable_deadness_safety_checks_for_debugging is set to true then
58   // we do not do deadness related safety checks.  This is unsound in general,
59   // but can be used as a debugging aid.
60   bool tf_xla_disable_deadness_safety_checks_for_debugging;
61 };
62 
63 // Flags associated with the XLA bridge's xla_device module.
64 struct XlaDeviceFlags {
65   // Switch the CPU device into "on-demand" mode, where instead of
66   // autoclustering ops are compiled one by one just-in-time.
67   // Enabling this mode by a legacy flag is a temporary mechanism. When this
68   // feature is battle-tested, we will switch this to be a session option.
69   bool tf_xla_compile_on_demand;
70 };
71 
72 // Flags common to the _Xla* ops and their kernels.
73 struct XlaOpsCommonFlags {
74   // If true, _XlaCompile always refuses to compile the cluster, which means the
75   // XLA clusters always run in the TF executor.  Defaults to false.
76   bool tf_xla_always_defer_compilation;
77 };
78 
79 // Flags for the build_xla_ops pass.
80 struct BuildXlaOpsPassFlags {
81   // Enables lazy compilation for TF/XLA (only when auto-clustering) if true.
82   // Defaults to true.
83   bool tf_xla_enable_lazy_compilation;
84 };
85 
86 // Return a pointer to the DumpGraphFlags struct;
87 // repeated calls return the same pointer.
88 // This should be called only after Flags::Parse() has returned.
89 
90 // Getters for flags structs defined above.  The first call to any of these
91 // parses TF_XLA_FLAGS for all of them.  Those functions which return a pointer
92 // always return the same pointer.
93 MarkForCompilationPassFlags* GetMarkForCompilationPassFlags();
94 const BuildXlaOpsPassFlags& GetBuildXlaOpsPassFlags();
95 XlaDeviceFlags* GetXlaDeviceFlags();
96 const XlaOpsCommonFlags& GetXlaOpsCommonFlags();
97 
98 // Appends the flag definitions associated with
99 // MarkForCompilationPassFlags/DumpGraphFlags to `flag_list`.
100 //
101 // Has the side-effect of parsing TF_XLA_FLAGS if that hasn't happened yet.
102 void AppendMarkForCompilationPassFlags(
103     std::vector<tensorflow::Flag>* flag_list);
104 }  // namespace tensorflow
105 
106 #endif  // TENSORFLOW_COMPILER_JIT_FLAGS_H_
107