• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
16 
17 #include <string>
18 #include <unordered_map>
19 
20 #include "tensorflow/core/platform/logging.h"
21 
22 namespace tensorflow {
23 namespace grappler {
24 namespace {
25 
26 typedef std::unordered_map<string, CustomGraphOptimizerRegistry::Creator>
27     RegistrationMap;
28 RegistrationMap* registered_optimizers = nullptr;
GetRegistrationMap()29 RegistrationMap* GetRegistrationMap() {
30   if (registered_optimizers == nullptr)
31     registered_optimizers = new RegistrationMap;
32   return registered_optimizers;
33 }
34 
35 // This map is a global map for registered plugin optimizers. It contains the
36 // device_type as its key, and an optimizer creator as the value.
37 typedef std::unordered_map<string, PluginGraphOptimizerRegistry::Creator>
38     PluginRegistrationMap;
GetPluginRegistrationMap()39 PluginRegistrationMap* GetPluginRegistrationMap() {
40   static PluginRegistrationMap* registered_plugin_optimizers =
41       new PluginRegistrationMap;
42   return registered_plugin_optimizers;
43 }
44 
45 // This map is a global map for registered plugin configs. It contains the
46 // device_type as its key, and ConfigList as the value.
47 typedef std::unordered_map<string, ConfigList> PluginConfigMap;
GetPluginConfigMap()48 PluginConfigMap* GetPluginConfigMap() {
49   static PluginConfigMap* plugin_config_map = new PluginConfigMap;
50   return plugin_config_map;
51 }
52 
53 // Returns plugin's default configuration for each Grappler optimizer (on/off).
54 // See tensorflow/core/protobuf/rewriter_config.proto for more details about
55 // each optimizer.
DefaultPluginConfigs()56 const ConfigList& DefaultPluginConfigs() {
57   static ConfigList* default_plugin_configs =
58       new ConfigList(/*disable_model_pruning=*/false,
59                      {{"implementation_selector", RewriterConfig::ON},
60                       {"function_optimization", RewriterConfig::ON},
61                       {"common_subgraph_elimination", RewriterConfig::ON},
62                       {"arithmetic_optimization", RewriterConfig::ON},
63                       {"debug_stripper", RewriterConfig::ON},
64                       {"constant_folding", RewriterConfig::ON},
65                       {"shape_optimization", RewriterConfig::ON},
66                       {"auto_mixed_precision", RewriterConfig::ON},
67                       {"auto_mixed_precision_mkl", RewriterConfig::ON},
68                       {"pin_to_host_optimization", RewriterConfig::ON},
69                       {"layout_optimizer", RewriterConfig::ON},
70                       {"remapping", RewriterConfig::ON},
71                       {"loop_optimization", RewriterConfig::ON},
72                       {"dependency_optimization", RewriterConfig::ON},
73                       {"auto_parallel", RewriterConfig::ON},
74                       {"memory_optimization", RewriterConfig::ON},
75                       {"scoped_allocator_optimization", RewriterConfig::ON}});
76   return *default_plugin_configs;
77 }
78 
79 }  // namespace
80 
81 std::unique_ptr<CustomGraphOptimizer>
CreateByNameOrNull(const string & name)82 CustomGraphOptimizerRegistry::CreateByNameOrNull(const string& name) {
83   const auto it = GetRegistrationMap()->find(name);
84   if (it == GetRegistrationMap()->end()) return nullptr;
85   return std::unique_ptr<CustomGraphOptimizer>(it->second());
86 }
87 
GetRegisteredOptimizers()88 std::vector<string> CustomGraphOptimizerRegistry::GetRegisteredOptimizers() {
89   std::vector<string> optimizer_names;
90   optimizer_names.reserve(GetRegistrationMap()->size());
91   for (const auto& opt : *GetRegistrationMap())
92     optimizer_names.emplace_back(opt.first);
93   return optimizer_names;
94 }
95 
RegisterOptimizerOrDie(const Creator & optimizer_creator,const string & name)96 void CustomGraphOptimizerRegistry::RegisterOptimizerOrDie(
97     const Creator& optimizer_creator, const string& name) {
98   const auto it = GetRegistrationMap()->find(name);
99   if (it != GetRegistrationMap()->end()) {
100     LOG(FATAL) << "CustomGraphOptimizer is registered twice: " << name;
101   }
102   GetRegistrationMap()->insert({name, optimizer_creator});
103 }
104 
105 std::vector<std::unique_ptr<CustomGraphOptimizer>>
CreateOptimizers(const std::set<string> & device_types)106 PluginGraphOptimizerRegistry::CreateOptimizers(
107     const std::set<string>& device_types) {
108   std::vector<std::unique_ptr<CustomGraphOptimizer>> optimizer_list;
109   for (auto it = GetPluginRegistrationMap()->begin();
110        it != GetPluginRegistrationMap()->end(); ++it) {
111     if (device_types.find(it->first) == device_types.end()) continue;
112     LOG(INFO) << "Plugin optimizer for device_type " << it->first
113               << " is enabled.";
114     optimizer_list.emplace_back(
115         std::unique_ptr<CustomGraphOptimizer>(it->second()));
116   }
117   return optimizer_list;
118 }
119 
RegisterPluginOptimizerOrDie(const Creator & optimizer_creator,const std::string & device_type,ConfigList & configs)120 void PluginGraphOptimizerRegistry::RegisterPluginOptimizerOrDie(
121     const Creator& optimizer_creator, const std::string& device_type,
122     ConfigList& configs) {
123   auto ret = GetPluginConfigMap()->insert({device_type, configs});
124   if (!ret.second) {
125     LOG(FATAL) << "PluginGraphOptimizer with device_type "  // Crash OK
126                << device_type << " is registered twice.";
127   }
128   GetPluginRegistrationMap()->insert({device_type, optimizer_creator});
129 }
130 
PrintPluginConfigsIfConflict(const std::set<string> & device_types)131 void PluginGraphOptimizerRegistry::PrintPluginConfigsIfConflict(
132     const std::set<string>& device_types) {
133   bool init = false, conflict = false;
134   ConfigList plugin_configs;
135   // Check if plugin's configs have conflict.
136   for (const auto& device_type : device_types) {
137     const auto it = GetPluginConfigMap()->find(device_type);
138     if (it == GetPluginConfigMap()->end()) continue;
139     auto cur_plugin_configs = it->second;
140 
141     if (!init) {
142       plugin_configs = cur_plugin_configs;
143       init = true;
144     } else {
145       if (!(plugin_configs == cur_plugin_configs)) {
146         conflict = true;
147         break;
148       }
149     }
150   }
151   if (!conflict) return;
152   LOG(WARNING) << "Plugins have conflicting configs. Potential performance "
153                   "regression may happen.";
154   for (const auto& device_type : device_types) {
155     const auto it = GetPluginConfigMap()->find(device_type);
156     if (it == GetPluginConfigMap()->end()) continue;
157     auto cur_plugin_configs = it->second;
158 
159     // Print logs in following style:
160     // disable_model_pruning    0
161     // remapping                1
162     // ...
163     string logs = "";
164     strings::StrAppend(&logs, "disable_model_pruning\t\t",
165                        cur_plugin_configs.disable_model_pruning, "\n");
166     for (auto const& pair : cur_plugin_configs.toggle_config) {
167       strings::StrAppend(&logs, pair.first, string(32 - pair.first.size(), ' '),
168                          (pair.second != RewriterConfig::OFF), "\n");
169     }
170     LOG(WARNING) << "Plugin's configs for device_type " << device_type << ":\n"
171                  << logs;
172   }
173 }
174 
GetPluginConfigs(bool use_plugin_optimizers,const std::set<string> & device_types)175 ConfigList PluginGraphOptimizerRegistry::GetPluginConfigs(
176     bool use_plugin_optimizers, const std::set<string>& device_types) {
177   if (!use_plugin_optimizers) return DefaultPluginConfigs();
178 
179   ConfigList ret_plugin_configs = DefaultPluginConfigs();
180   for (const auto& device_type : device_types) {
181     const auto it = GetPluginConfigMap()->find(device_type);
182     if (it == GetPluginConfigMap()->end()) continue;
183     auto cur_plugin_configs = it->second;
184     // If any of the plugin turns on `disable_model_pruning`,
185     // then `disable_model_pruning` should be true;
186     if (cur_plugin_configs.disable_model_pruning == true)
187       ret_plugin_configs.disable_model_pruning = true;
188 
189     // If any of the plugin turns off a certain optimizer,
190     // then the optimizer should be turned off;
191     for (auto& pair : cur_plugin_configs.toggle_config) {
192       if (cur_plugin_configs.toggle_config[pair.first] == RewriterConfig::OFF)
193         ret_plugin_configs.toggle_config[pair.first] = RewriterConfig::OFF;
194     }
195   }
196 
197   return ret_plugin_configs;
198 }
199 
IsConfigsConflict(ConfigList & user_config,ConfigList & plugin_config)200 bool PluginGraphOptimizerRegistry::IsConfigsConflict(
201     ConfigList& user_config, ConfigList& plugin_config) {
202   if (plugin_config == DefaultPluginConfigs()) return false;
203   if (user_config.disable_model_pruning != plugin_config.disable_model_pruning)
204     return true;
205   // Returns true if user_config is turned on but plugin_config is turned off.
206   for (auto& pair : user_config.toggle_config) {
207     if ((user_config.toggle_config[pair.first] == RewriterConfig::ON) &&
208         (plugin_config.toggle_config[pair.first] == RewriterConfig::OFF))
209       return true;
210   }
211   return false;
212 }
213 
214 }  // end namespace grappler
215 }  // end namespace tensorflow
216