1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <mutex> // NOLINT
17
18 #include "tensorflow/compiler/jit/flags.h"
19 #include "tensorflow/compiler/xla/parse_flags_from_env.h"
20 #include "tensorflow/core/util/command_line_flags.h"
21
22 namespace tensorflow {
23 namespace {
24
25 BuildXlaOpsPassFlags* build_ops_flags;
26 MarkForCompilationPassFlags* mark_for_compilation_flags;
27 XlaDeviceFlags* device_flags;
28 XlaOpsCommonFlags* ops_flags;
29
30 std::vector<Flag>* flag_list;
31 std::once_flag flags_init;
32
AppendMarkForCompilationPassFlagsInternal(std::vector<Flag> * flag_list)33 void AppendMarkForCompilationPassFlagsInternal(std::vector<Flag>* flag_list) {
34 std::vector<Flag> new_flags = {
35 Flag("tf_xla_auto_jit", &mark_for_compilation_flags->tf_xla_auto_jit,
36 "Control compilation of operators into XLA computations on CPU and "
37 "GPU devices. 0 = use ConfigProto setting; -1 = off; 1 = on for "
38 "things very likely to be improved; 2 = on for everything. "
39 "Experimental."),
40 Flag("tf_xla_min_cluster_size",
41 &mark_for_compilation_flags->tf_xla_min_cluster_size,
42 "Minimum number of operators in an XLA compilation. Ignored for "
43 "operators placed on an XLA device or operators explicitly marked "
44 "for compilation."),
45 Flag("tf_xla_max_cluster_size",
46 &mark_for_compilation_flags->tf_xla_max_cluster_size,
47 "Maximum number of operators in an XLA compilation."),
48 Flag("tf_xla_clustering_debug",
49 &mark_for_compilation_flags->tf_xla_clustering_debug,
50 "Dump graphs during XLA compilation."),
51 Flag("tf_xla_cpu_global_jit",
52 &mark_for_compilation_flags->tf_xla_cpu_global_jit,
53 "Enables global JIT compilation for CPU via SessionOptions."),
54 Flag("tf_xla_clustering_fuel",
55 &mark_for_compilation_flags->tf_xla_clustering_fuel,
56 "Places an artificial limit on the number of ops marked as "
57 "eligible for clustering."),
58 Flag("tf_xla_fusion_only",
59 &mark_for_compilation_flags->tf_xla_fusion_only,
60 "enable fusion of element-wise operations only using XLA when "
61 "global_jit_level is ON*."),
62 Flag("tf_xla_disable_deadness_safety_checks_for_debugging",
63 &mark_for_compilation_flags
64 ->tf_xla_disable_deadness_safety_checks_for_debugging,
65 "Disable deadness related safety checks when clustering (this is "
66 "unsound).")};
67 flag_list->insert(flag_list->end(), new_flags.begin(), new_flags.end());
68 }
69
AllocateAndParseFlags()70 void AllocateAndParseFlags() {
71 build_ops_flags = new BuildXlaOpsPassFlags;
72 build_ops_flags->tf_xla_enable_lazy_compilation = true;
73
74 mark_for_compilation_flags = new MarkForCompilationPassFlags;
75 mark_for_compilation_flags->tf_xla_auto_jit = 0;
76 mark_for_compilation_flags->tf_xla_min_cluster_size = 4;
77 mark_for_compilation_flags->tf_xla_max_cluster_size =
78 std::numeric_limits<int32>::max();
79 mark_for_compilation_flags->tf_xla_clustering_debug = false;
80 mark_for_compilation_flags->tf_xla_cpu_global_jit = false;
81 mark_for_compilation_flags->tf_xla_clustering_fuel =
82 std::numeric_limits<int64>::max();
83 mark_for_compilation_flags->tf_xla_fusion_only = false;
84 mark_for_compilation_flags
85 ->tf_xla_disable_deadness_safety_checks_for_debugging = false;
86
87 device_flags = new XlaDeviceFlags;
88 device_flags->tf_xla_compile_on_demand = false;
89
90 ops_flags = new XlaOpsCommonFlags;
91 ops_flags->tf_xla_always_defer_compilation = false;
92
93 flag_list = new std::vector<Flag>({
94 Flag("tf_xla_enable_lazy_compilation",
95 &build_ops_flags->tf_xla_enable_lazy_compilation, ""),
96
97 Flag("tf_xla_compile_on_demand", &device_flags->tf_xla_compile_on_demand,
98 "Switch a device into 'on-demand' mode, where instead of "
99 "autoclustering ops are compiled one by one just-in-time."),
100
101 Flag("tf_xla_always_defer_compilation",
102 &ops_flags->tf_xla_always_defer_compilation, ""),
103 });
104 AppendMarkForCompilationPassFlagsInternal(flag_list);
105 xla::ParseFlagsFromEnvAndDieIfUnknown("TF_XLA_FLAGS", *flag_list);
106 }
107
108 } // namespace
109
GetBuildXlaOpsPassFlags()110 const BuildXlaOpsPassFlags& GetBuildXlaOpsPassFlags() {
111 std::call_once(flags_init, &AllocateAndParseFlags);
112 return *build_ops_flags;
113 }
114
GetMarkForCompilationPassFlags()115 MarkForCompilationPassFlags* GetMarkForCompilationPassFlags() {
116 std::call_once(flags_init, &AllocateAndParseFlags);
117 return mark_for_compilation_flags;
118 }
119
GetXlaDeviceFlags()120 XlaDeviceFlags* GetXlaDeviceFlags() {
121 std::call_once(flags_init, &AllocateAndParseFlags);
122 return device_flags;
123 }
124
GetXlaOpsCommonFlags()125 const XlaOpsCommonFlags& GetXlaOpsCommonFlags() {
126 std::call_once(flags_init, &AllocateAndParseFlags);
127 return *ops_flags;
128 }
129
AppendMarkForCompilationPassFlags(std::vector<Flag> * flag_list)130 void AppendMarkForCompilationPassFlags(std::vector<Flag>* flag_list) {
131 std::call_once(flags_init, &AllocateAndParseFlags);
132 AppendMarkForCompilationPassFlagsInternal(flag_list);
133 }
134
135 } // namespace tensorflow
136