1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
16
17 #include <string>
18 #include <unordered_map>
19
20 #include "tensorflow/core/platform/logging.h"
21
22 namespace tensorflow {
23 namespace grappler {
24 namespace {
25
26 typedef std::unordered_map<string, CustomGraphOptimizerRegistry::Creator>
27 RegistrationMap;
28 RegistrationMap* registered_optimizers = nullptr;
GetRegistrationMap()29 RegistrationMap* GetRegistrationMap() {
30 if (registered_optimizers == nullptr)
31 registered_optimizers = new RegistrationMap;
32 return registered_optimizers;
33 }
34
35 // This map is a global map for registered plugin optimizers. It contains the
36 // device_type as its key, and an optimizer creator as the value.
37 typedef std::unordered_map<string, PluginGraphOptimizerRegistry::Creator>
38 PluginRegistrationMap;
GetPluginRegistrationMap()39 PluginRegistrationMap* GetPluginRegistrationMap() {
40 static PluginRegistrationMap* registered_plugin_optimizers =
41 new PluginRegistrationMap;
42 return registered_plugin_optimizers;
43 }
44
45 // This map is a global map for registered plugin configs. It contains the
46 // device_type as its key, and ConfigList as the value.
47 typedef std::unordered_map<string, ConfigList> PluginConfigMap;
GetPluginConfigMap()48 PluginConfigMap* GetPluginConfigMap() {
49 static PluginConfigMap* plugin_config_map = new PluginConfigMap;
50 return plugin_config_map;
51 }
52
53 // Returns plugin's default configuration for each Grappler optimizer (on/off).
54 // See tensorflow/core/protobuf/rewriter_config.proto for more details about
55 // each optimizer.
DefaultPluginConfigs()56 const ConfigList& DefaultPluginConfigs() {
57 static ConfigList* default_plugin_configs =
58 new ConfigList(/*disable_model_pruning=*/false,
59 {{"implementation_selector", RewriterConfig::ON},
60 {"function_optimization", RewriterConfig::ON},
61 {"common_subgraph_elimination", RewriterConfig::ON},
62 {"arithmetic_optimization", RewriterConfig::ON},
63 {"debug_stripper", RewriterConfig::ON},
64 {"constant_folding", RewriterConfig::ON},
65 {"shape_optimization", RewriterConfig::ON},
66 {"auto_mixed_precision", RewriterConfig::ON},
67 {"auto_mixed_precision_mkl", RewriterConfig::ON},
68 {"pin_to_host_optimization", RewriterConfig::ON},
69 {"layout_optimizer", RewriterConfig::ON},
70 {"remapping", RewriterConfig::ON},
71 {"loop_optimization", RewriterConfig::ON},
72 {"dependency_optimization", RewriterConfig::ON},
73 {"auto_parallel", RewriterConfig::ON},
74 {"memory_optimization", RewriterConfig::ON},
75 {"scoped_allocator_optimization", RewriterConfig::ON}});
76 return *default_plugin_configs;
77 }
78
79 } // namespace
80
81 std::unique_ptr<CustomGraphOptimizer>
CreateByNameOrNull(const string & name)82 CustomGraphOptimizerRegistry::CreateByNameOrNull(const string& name) {
83 const auto it = GetRegistrationMap()->find(name);
84 if (it == GetRegistrationMap()->end()) return nullptr;
85 return std::unique_ptr<CustomGraphOptimizer>(it->second());
86 }
87
GetRegisteredOptimizers()88 std::vector<string> CustomGraphOptimizerRegistry::GetRegisteredOptimizers() {
89 std::vector<string> optimizer_names;
90 optimizer_names.reserve(GetRegistrationMap()->size());
91 for (const auto& opt : *GetRegistrationMap())
92 optimizer_names.emplace_back(opt.first);
93 return optimizer_names;
94 }
95
RegisterOptimizerOrDie(const Creator & optimizer_creator,const string & name)96 void CustomGraphOptimizerRegistry::RegisterOptimizerOrDie(
97 const Creator& optimizer_creator, const string& name) {
98 const auto it = GetRegistrationMap()->find(name);
99 if (it != GetRegistrationMap()->end()) {
100 LOG(FATAL) << "CustomGraphOptimizer is registered twice: " << name;
101 }
102 GetRegistrationMap()->insert({name, optimizer_creator});
103 }
104
105 std::vector<std::unique_ptr<CustomGraphOptimizer>>
CreateOptimizers(const std::set<string> & device_types)106 PluginGraphOptimizerRegistry::CreateOptimizers(
107 const std::set<string>& device_types) {
108 std::vector<std::unique_ptr<CustomGraphOptimizer>> optimizer_list;
109 for (auto it = GetPluginRegistrationMap()->begin();
110 it != GetPluginRegistrationMap()->end(); ++it) {
111 if (device_types.find(it->first) == device_types.end()) continue;
112 LOG(INFO) << "Plugin optimizer for device_type " << it->first
113 << " is enabled.";
114 optimizer_list.emplace_back(
115 std::unique_ptr<CustomGraphOptimizer>(it->second()));
116 }
117 return optimizer_list;
118 }
119
RegisterPluginOptimizerOrDie(const Creator & optimizer_creator,const std::string & device_type,ConfigList & configs)120 void PluginGraphOptimizerRegistry::RegisterPluginOptimizerOrDie(
121 const Creator& optimizer_creator, const std::string& device_type,
122 ConfigList& configs) {
123 auto ret = GetPluginConfigMap()->insert({device_type, configs});
124 if (!ret.second) {
125 LOG(FATAL) << "PluginGraphOptimizer with device_type " // Crash OK
126 << device_type << " is registered twice.";
127 }
128 GetPluginRegistrationMap()->insert({device_type, optimizer_creator});
129 }
130
PrintPluginConfigsIfConflict(const std::set<string> & device_types)131 void PluginGraphOptimizerRegistry::PrintPluginConfigsIfConflict(
132 const std::set<string>& device_types) {
133 bool init = false, conflict = false;
134 ConfigList plugin_configs;
135 // Check if plugin's configs have conflict.
136 for (const auto& device_type : device_types) {
137 const auto it = GetPluginConfigMap()->find(device_type);
138 if (it == GetPluginConfigMap()->end()) continue;
139 auto cur_plugin_configs = it->second;
140
141 if (!init) {
142 plugin_configs = cur_plugin_configs;
143 init = true;
144 } else {
145 if (!(plugin_configs == cur_plugin_configs)) {
146 conflict = true;
147 break;
148 }
149 }
150 }
151 if (!conflict) return;
152 LOG(WARNING) << "Plugins have conflicting configs. Potential performance "
153 "regression may happen.";
154 for (const auto& device_type : device_types) {
155 const auto it = GetPluginConfigMap()->find(device_type);
156 if (it == GetPluginConfigMap()->end()) continue;
157 auto cur_plugin_configs = it->second;
158
159 // Print logs in following style:
160 // disable_model_pruning 0
161 // remapping 1
162 // ...
163 string logs = "";
164 strings::StrAppend(&logs, "disable_model_pruning\t\t",
165 cur_plugin_configs.disable_model_pruning, "\n");
166 for (auto const& pair : cur_plugin_configs.toggle_config) {
167 strings::StrAppend(&logs, pair.first, string(32 - pair.first.size(), ' '),
168 (pair.second != RewriterConfig::OFF), "\n");
169 }
170 LOG(WARNING) << "Plugin's configs for device_type " << device_type << ":\n"
171 << logs;
172 }
173 }
174
GetPluginConfigs(bool use_plugin_optimizers,const std::set<string> & device_types)175 ConfigList PluginGraphOptimizerRegistry::GetPluginConfigs(
176 bool use_plugin_optimizers, const std::set<string>& device_types) {
177 if (!use_plugin_optimizers) return DefaultPluginConfigs();
178
179 ConfigList ret_plugin_configs = DefaultPluginConfigs();
180 for (const auto& device_type : device_types) {
181 const auto it = GetPluginConfigMap()->find(device_type);
182 if (it == GetPluginConfigMap()->end()) continue;
183 auto cur_plugin_configs = it->second;
184 // If any of the plugin turns on `disable_model_pruning`,
185 // then `disable_model_pruning` should be true;
186 if (cur_plugin_configs.disable_model_pruning == true)
187 ret_plugin_configs.disable_model_pruning = true;
188
189 // If any of the plugin turns off a certain optimizer,
190 // then the optimizer should be turned off;
191 for (auto& pair : cur_plugin_configs.toggle_config) {
192 if (cur_plugin_configs.toggle_config[pair.first] == RewriterConfig::OFF)
193 ret_plugin_configs.toggle_config[pair.first] = RewriterConfig::OFF;
194 }
195 }
196
197 return ret_plugin_configs;
198 }
199
IsConfigsConflict(ConfigList & user_config,ConfigList & plugin_config)200 bool PluginGraphOptimizerRegistry::IsConfigsConflict(
201 ConfigList& user_config, ConfigList& plugin_config) {
202 if (plugin_config == DefaultPluginConfigs()) return false;
203 if (user_config.disable_model_pruning != plugin_config.disable_model_pruning)
204 return true;
205 // Returns true if user_config is turned on but plugin_config is turned off.
206 for (auto& pair : user_config.toggle_config) {
207 if ((user_config.toggle_config[pair.first] == RewriterConfig::ON) &&
208 (plugin_config.toggle_config[pair.first] == RewriterConfig::OFF))
209 return true;
210 }
211 return false;
212 }
213
214 } // end namespace grappler
215 } // end namespace tensorflow
216