• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <mutex>  // NOLINT
17 
18 #include "tensorflow/compiler/jit/flags.h"
19 #include "tensorflow/compiler/xla/parse_flags_from_env.h"
20 #include "tensorflow/core/util/command_line_flags.h"
21 
22 namespace tensorflow {
23 namespace {
24 
25 BuildXlaOpsPassFlags* build_ops_flags;
26 MarkForCompilationPassFlags* mark_for_compilation_flags;
27 XlaDeviceFlags* device_flags;
28 XlaOpsCommonFlags* ops_flags;
29 
30 std::vector<Flag>* flag_list;
31 std::once_flag flags_init;
32 
AppendMarkForCompilationPassFlagsInternal(std::vector<Flag> * flag_list)33 void AppendMarkForCompilationPassFlagsInternal(std::vector<Flag>* flag_list) {
34   std::vector<Flag> new_flags = {
35       Flag("tf_xla_auto_jit", &mark_for_compilation_flags->tf_xla_auto_jit,
36            "Control compilation of operators into XLA computations on CPU and "
37            "GPU devices.  0 = use ConfigProto setting; -1 = off; 1 = on for "
38            "things very likely to be improved; 2 = on for everything.  "
39            "Experimental."),
40       Flag("tf_xla_min_cluster_size",
41            &mark_for_compilation_flags->tf_xla_min_cluster_size,
42            "Minimum number of operators in an XLA compilation. Ignored for "
43            "operators placed on an XLA device or operators explicitly marked "
44            "for compilation."),
45       Flag("tf_xla_max_cluster_size",
46            &mark_for_compilation_flags->tf_xla_max_cluster_size,
47            "Maximum number of operators in an XLA compilation."),
48       Flag("tf_xla_clustering_debug",
49            &mark_for_compilation_flags->tf_xla_clustering_debug,
50            "Dump graphs during XLA compilation."),
51       Flag("tf_xla_cpu_global_jit",
52            &mark_for_compilation_flags->tf_xla_cpu_global_jit,
53            "Enables global JIT compilation for CPU via SessionOptions."),
54       Flag("tf_xla_clustering_fuel",
55            &mark_for_compilation_flags->tf_xla_clustering_fuel,
56            "Places an artificial limit on the number of ops marked as "
57            "eligible for clustering."),
58       Flag("tf_xla_fusion_only",
59            &mark_for_compilation_flags->tf_xla_fusion_only,
60            "enable fusion of element-wise operations only using XLA when "
61            "global_jit_level is ON*."),
62       Flag("tf_xla_disable_deadness_safety_checks_for_debugging",
63            &mark_for_compilation_flags
64                 ->tf_xla_disable_deadness_safety_checks_for_debugging,
65            "Disable deadness related safety checks when clustering (this is "
66            "unsound).")};
67   flag_list->insert(flag_list->end(), new_flags.begin(), new_flags.end());
68 }
69 
AllocateAndParseFlags()70 void AllocateAndParseFlags() {
71   build_ops_flags = new BuildXlaOpsPassFlags;
72   build_ops_flags->tf_xla_enable_lazy_compilation = true;
73 
74   mark_for_compilation_flags = new MarkForCompilationPassFlags;
75   mark_for_compilation_flags->tf_xla_auto_jit = 0;
76   mark_for_compilation_flags->tf_xla_min_cluster_size = 4;
77   mark_for_compilation_flags->tf_xla_max_cluster_size =
78       std::numeric_limits<int32>::max();
79   mark_for_compilation_flags->tf_xla_clustering_debug = false;
80   mark_for_compilation_flags->tf_xla_cpu_global_jit = false;
81   mark_for_compilation_flags->tf_xla_clustering_fuel =
82       std::numeric_limits<int64>::max();
83   mark_for_compilation_flags->tf_xla_fusion_only = false;
84   mark_for_compilation_flags
85       ->tf_xla_disable_deadness_safety_checks_for_debugging = false;
86 
87   device_flags = new XlaDeviceFlags;
88   device_flags->tf_xla_compile_on_demand = false;
89 
90   ops_flags = new XlaOpsCommonFlags;
91   ops_flags->tf_xla_always_defer_compilation = false;
92 
93   flag_list = new std::vector<Flag>({
94       Flag("tf_xla_enable_lazy_compilation",
95            &build_ops_flags->tf_xla_enable_lazy_compilation, ""),
96 
97       Flag("tf_xla_compile_on_demand", &device_flags->tf_xla_compile_on_demand,
98            "Switch a device into 'on-demand' mode, where instead of "
99            "autoclustering ops are compiled one by one just-in-time."),
100 
101       Flag("tf_xla_always_defer_compilation",
102            &ops_flags->tf_xla_always_defer_compilation, ""),
103   });
104   AppendMarkForCompilationPassFlagsInternal(flag_list);
105   xla::ParseFlagsFromEnvAndDieIfUnknown("TF_XLA_FLAGS", *flag_list);
106 }
107 
108 }  // namespace
109 
GetBuildXlaOpsPassFlags()110 const BuildXlaOpsPassFlags& GetBuildXlaOpsPassFlags() {
111   std::call_once(flags_init, &AllocateAndParseFlags);
112   return *build_ops_flags;
113 }
114 
GetMarkForCompilationPassFlags()115 MarkForCompilationPassFlags* GetMarkForCompilationPassFlags() {
116   std::call_once(flags_init, &AllocateAndParseFlags);
117   return mark_for_compilation_flags;
118 }
119 
GetXlaDeviceFlags()120 XlaDeviceFlags* GetXlaDeviceFlags() {
121   std::call_once(flags_init, &AllocateAndParseFlags);
122   return device_flags;
123 }
124 
GetXlaOpsCommonFlags()125 const XlaOpsCommonFlags& GetXlaOpsCommonFlags() {
126   std::call_once(flags_init, &AllocateAndParseFlags);
127   return *ops_flags;
128 }
129 
AppendMarkForCompilationPassFlags(std::vector<Flag> * flag_list)130 void AppendMarkForCompilationPassFlags(std::vector<Flag>* flag_list) {
131   std::call_once(flags_init, &AllocateAndParseFlags);
132   AppendMarkForCompilationPassFlagsInternal(flag_list);
133 }
134 
135 }  // namespace tensorflow
136