• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/optimizers/meta_optimizer.h"
17 
18 #include "absl/strings/str_cat.h"
19 #include "absl/strings/str_join.h"
20 #include "absl/strings/substitute.h"
21 #include "tensorflow/core/common_runtime/function.h"
22 #include "tensorflow/core/common_runtime/graph_constructor.h"
23 #include "tensorflow/core/common_runtime/metrics.h"
24 #include "tensorflow/core/framework/dataset.h"
25 #include "tensorflow/core/framework/function.pb.h"
26 #include "tensorflow/core/framework/tensor_shape.pb.h"
27 #include "tensorflow/core/framework/tensor_util.h"
28 #include "tensorflow/core/framework/versions.pb.h"
29 #include "tensorflow/core/grappler/clusters/virtual_cluster.h"
30 #include "tensorflow/core/grappler/optimizers/arithmetic_optimizer.h"
31 #include "tensorflow/core/grappler/optimizers/auto_mixed_precision.h"
32 #include "tensorflow/core/grappler/optimizers/auto_parallel.h"
33 #include "tensorflow/core/grappler/optimizers/common_subgraph_elimination.h"
34 #include "tensorflow/core/grappler/optimizers/constant_folding.h"
35 #include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
36 #include "tensorflow/core/grappler/optimizers/debug_stripper.h"
37 #include "tensorflow/core/grappler/optimizers/dependency_optimizer.h"
38 #include "tensorflow/core/grappler/optimizers/function_optimizer.h"
39 #include "tensorflow/core/grappler/optimizers/generic_layout_optimizer.h"
40 #include "tensorflow/core/grappler/optimizers/implementation_selector.h"
41 #include "tensorflow/core/grappler/optimizers/loop_optimizer.h"
42 #include "tensorflow/core/grappler/optimizers/memory_optimizer.h"
43 #include "tensorflow/core/grappler/optimizers/model_pruner.h"
44 #include "tensorflow/core/grappler/optimizers/pin_to_host_optimizer.h"
45 #include "tensorflow/core/grappler/optimizers/remapper.h"
46 #include "tensorflow/core/grappler/optimizers/scoped_allocator_optimizer.h"
47 #include "tensorflow/core/grappler/optimizers/shape_optimizer.h"
48 #include "tensorflow/core/grappler/utils/canonicalizer.h"
49 #include "tensorflow/core/grappler/utils/colocation.h"
50 #include "tensorflow/core/grappler/utils/functions.h"
51 #include "tensorflow/core/grappler/utils/topological_sort.h"
52 #include "tensorflow/core/grappler/utils/tpu.h"
53 #include "tensorflow/core/grappler/verifiers/structure_verifier.h"
54 #include "tensorflow/core/lib/core/status.h"
55 #include "tensorflow/core/lib/gtl/map_util.h"
56 #include "tensorflow/core/util/dump_graph.h"
57 #include "tensorflow/core/util/ptr_util.h"
58 #include "tensorflow/core/util/util.h"
59 #include "tensorflow/core/util/xla_config_registry.h"
60 
61 namespace tensorflow {
62 namespace grappler {
63 
64 namespace {
65 
66 constexpr int kDefaultNumberOfIterations = 2;
67 constexpr int kDefaultMinGraphNodes = 4;
68 
NumEdges(const GraphDef & graph)69 int64 NumEdges(const GraphDef& graph) {
70   int64_t num_edges = 0;
71   for (const auto& node : graph.node()) {
72     num_edges += node.input_size();
73   }
74   return num_edges;
75 }
76 
PrintSizesBeforeAfter(const GraphDef & before,const GraphDef & after)77 string PrintSizesBeforeAfter(const GraphDef& before, const GraphDef& after) {
78   return strings::StrCat("Graph size after: ", after.node_size(), " nodes (",
79                          after.node_size() - before.node_size(), "), ",
80                          NumEdges(after), " edges (",
81                          NumEdges(after) - NumEdges(before), ")");
82 }
83 
NumIterations(const RewriterConfig & cfg)84 int NumIterations(const RewriterConfig& cfg) {
85   return cfg.meta_optimizer_iterations() == RewriterConfig::DEFAULT_NUM_ITERS
86              ? kDefaultNumberOfIterations
87              : cfg.meta_optimizer_iterations();
88 }
89 
90 // Check if optimizer is allowed to run only once.
IsRunOnceOptimizer(const string & name)91 bool IsRunOnceOptimizer(const string& name) {
92   return name == "layout" || name == "memory_optimizer" ||
93          name == "loop_optimizer" || name == "auto_mixed_precision" ||
94          name == "auto_mixed_precision_mkl";
95 }
96 
97 // Creates a function library stub from a real function library: copy only
98 // signatures and attributes of all the function defined in fdef_lib. This stub
99 // can be swapped with real function library in a graph, before passing it to
100 // optimizer, if optimizer doesn't instantiate functions.
GetFunctionDefLibraryStub(const FunctionDefLibrary & fdef_lib)101 FunctionDefLibrary GetFunctionDefLibraryStub(
102     const FunctionDefLibrary& fdef_lib) {
103   FunctionDefLibrary stub;
104   for (const FunctionDef& fn : fdef_lib.function()) {
105     FunctionDef* fn_stub = stub.mutable_function()->Add();
106     *(fn_stub->mutable_signature()) = fn.signature();
107     *(fn_stub->mutable_attr()) = fn.attr();
108     *(fn_stub->mutable_arg_attr()) = fn.arg_attr();
109     *(fn_stub->mutable_resource_arg_unique_id()) = fn.resource_arg_unique_id();
110   }
111   *stub.mutable_gradient() = fdef_lib.gradient();
112   return stub;
113 }
114 
DeadlineMicroSeconds(const RewriterConfig & cfg)115 uint64 DeadlineMicroSeconds(const RewriterConfig& cfg) {
116   if (cfg.meta_optimizer_timeout_ms() <= 0) return 0;  // no deadline
117   return Env::Default()->NowMicros() + cfg.meta_optimizer_timeout_ms() * 1000;
118 }
119 
120 // A helper function to decide whether to enable the automatic mixed precision
121 // optimizer.
AutoMixedPrecisionEnabled(RewriterConfig::Toggle opt_level)122 bool AutoMixedPrecisionEnabled(RewriterConfig::Toggle opt_level) {
123   if (opt_level == RewriterConfig::ON ||
124       opt_level == RewriterConfig::AGGRESSIVE) {
125     return true;
126   }
127   return false;
128 }
129 
IsXlaGlobalJitOn(const OptimizerOptions::GlobalJitLevel & jit_level_in_session_opts)130 bool IsXlaGlobalJitOn(
131     const OptimizerOptions::GlobalJitLevel& jit_level_in_session_opts) {
132   xla_config_registry::XlaGlobalJitLevel xla_global_jit_level =
133       xla_config_registry::GetGlobalJitLevel(jit_level_in_session_opts);
134   // Return true only if XLA JIT is ON for both single-gpu and multi-gpu
135   // graphs. This is a conservative approach that turns off the memory optimizer
136   // when we are sure that all graphs will be processed by XLA JIT.
137   bool is_on = (xla_global_jit_level.single_gpu == OptimizerOptions::ON_1 ||
138                 xla_global_jit_level.single_gpu == OptimizerOptions::ON_2) &&
139                (xla_global_jit_level.general == OptimizerOptions::ON_1 ||
140                 xla_global_jit_level.general == OptimizerOptions::ON_2);
141   return is_on;
142 }
143 
144 // A helper function to decide whether to enable the memory optimizer.
MemoryOptimizerEnabled(RewriterConfig::MemOptType mem_opt_type,bool xla_auto_clustering_on)145 bool MemoryOptimizerEnabled(RewriterConfig::MemOptType mem_opt_type,
146                             bool xla_auto_clustering_on) {
147   // Disable the default memory optimizer when XLA JIT is ON as it hurts the
148   // XLA JIT performance. The (current) XLA clustering can result in loss of
149   // concurrency between kernel compute and memory copies. As such, it usually
150   // loses the concurrency needed to hide the latencies of the inserted swap-ins
151   // and swap-outs and incurs great performance overhead. Remove this check when
152   // the XLA JIT can better deal with the concurrency.
153   if (mem_opt_type == RewriterConfig::DEFAULT_MEM_OPT &&
154       xla_auto_clustering_on) {
155     return false;
156   }
157 
158   return mem_opt_type != RewriterConfig::NO_MEM_OPT;
159 }
160 
GetGraphDevice(const GraphDef & g_def,std::set<std::string> * devices)161 Status GetGraphDevice(const GraphDef& g_def, std::set<std::string>* devices) {
162   for (auto& node : g_def.node()) {
163     DeviceNameUtils::ParsedName parsed_name;
164     if (!DeviceNameUtils::ParseFullName(node.device(), &parsed_name)) {
165       return errors::InvalidArgument("Unable to parse ", node.device(),
166                                      " as a device name");
167     }
168     devices->insert(parsed_name.type);
169   }
170   return Status::OK();
171 }
172 
173 }  // namespace
174 
175 #define MK_OPT(NAME, CONFIG, VALUE)                                    \
176   if (optimizer == NAME) {                                             \
177     if (plugin_configs.toggle_config[CONFIG] != RewriterConfig::OFF) { \
178       return std::unique_ptr<GraphOptimizer>(VALUE);                   \
179     }                                                                  \
180   }
181 
LowerControlFlow() const182 bool MetaOptimizer::LowerControlFlow() const {
183   if (config_proto_.experimental().executor_type() ==
184       "SINGLE_THREADED_EXECUTOR")
185     return false;
186 
187   if (config_proto_.experimental().use_tfrt()) return false;
188 
189   return true;
190 }
191 
MakeNewOptimizer(const string & optimizer,const std::set<string> & device_types) const192 std::unique_ptr<GraphOptimizer> MetaOptimizer::MakeNewOptimizer(
193     const string& optimizer, const std::set<string>& device_types) const {
194   ConfigList plugin_configs = PluginGraphOptimizerRegistry::GetPluginConfigs(
195       cfg_.use_plugin_optimizers() != RewriterConfig::OFF, device_types);
196   if (optimizer == "pruning" && !plugin_configs.disable_model_pruning)
197     return std::unique_ptr<GraphOptimizer>(new ModelPruner());
198   MK_OPT("function", "function_optimization",
199          new FunctionOptimizer(cfg_.function_optimization(),
200                                /*lower_control_flow=*/LowerControlFlow()));
201   MK_OPT("constfold", "constant_folding",
202          new ConstantFolding(
203              cpu_device_,
204              cfg_.experimental_disable_compressed_tensor_optimization(),
205              !cfg_.experimental_disable_folding_quantization_emulation()));
206   MK_OPT("shape", "shape_optimization", new ShapeOptimizer());
207   MK_OPT("remap", "remapping",
208          new Remapper(cfg_.remapping(), xla_auto_clustering_on_));
209   MK_OPT("layout", "layout_optimizer",
210          new GenericLayoutOptimizer(
211              /*optimization level*/ cfg_.layout_optimizer(),
212              /*CPU layout conversion*/ cfg_.cpu_layout_conversion()));
213   MK_OPT("auto_mixed_precision", "auto_mixed_precision",
214          new AutoMixedPrecision(AutoMixedPrecisionMode::CUDA));
215 #ifdef INTEL_MKL
216   if (IsMKLEnabled()) {
217     MK_OPT("auto_mixed_precision_mkl", "auto_mixed_precision_mkl",
218            new AutoMixedPrecision(AutoMixedPrecisionMode::MKL));
219   }
220 #endif
221   MK_OPT("memory", "memory_optimization",
222          new MemoryOptimizer(RewriterConfig::MANUAL));
223   MK_OPT("common_subgraph_elimination", "common_subgraph_elimination",
224          new CommonSubgraphElimination(cfg_.common_subgraph_elimination()));
225   MK_OPT("arithmetic", "arithmetic_optimization",
226          new ArithmeticOptimizer(cfg_.arithmetic_optimization()));
227   MK_OPT("autoparallel", "auto_parallel",
228          new AutoParallel(cfg_.auto_parallel().num_replicas()));
229   MK_OPT("loop", "loop_optimization",
230          new LoopOptimizer(cfg_.loop_optimization(), cpu_device_));
231   MK_OPT("dependency", "dependency_optimization",
232          new DependencyOptimizer(cfg_.dependency_optimization()));
233   MK_OPT("debug_stripper", "debug_stripper", new DebugStripper());
234   MK_OPT("scoped_allocator", "scoped_allocator_optimization",
235          new ScopedAllocatorOptimizer(cfg_.scoped_allocator_optimization(),
236                                       cfg_.scoped_allocator_opts()));
237   MK_OPT("pin_to_host", "pin_to_host_optimization",
238          new PinToHostOptimizer(cfg_.pin_to_host_optimization()));
239 
240   return std::unique_ptr<GraphOptimizer>();
241 }
242 
243 #undef MK_OPT
244 
MetaOptimizer(DeviceBase * cpu_device,const ConfigProto & cfg)245 MetaOptimizer::MetaOptimizer(DeviceBase* cpu_device, const ConfigProto& cfg)
246     : cpu_device_(cpu_device),
247       config_proto_(cfg),
248       cfg_(*config_proto_.mutable_graph_options()->mutable_rewrite_options()) {
249   DCHECK(cpu_device_ == nullptr ||
250          cpu_device_->attributes().device_type() == "CPU");
251   auto global_jit_level =
252       cfg.graph_options().optimizer_options().global_jit_level();
253   xla_auto_clustering_on_ = IsXlaGlobalJitOn(global_jit_level);
254 }
255 
InitializeOptimizers(const std::set<string> & device_types,std::vector<std::unique_ptr<GraphOptimizer>> * optimizers) const256 Status MetaOptimizer::InitializeOptimizers(
257     const std::set<string>& device_types,
258     std::vector<std::unique_ptr<GraphOptimizer>>* optimizers) const {
259   if (cfg_.disable_meta_optimizer()) {
260     return Status::OK();
261   }
262 
263   ConfigList plugin_configs = PluginGraphOptimizerRegistry::GetPluginConfigs(
264       cfg_.use_plugin_optimizers() != RewriterConfig::OFF, device_types);
265   if (!cfg_.disable_model_pruning() && !plugin_configs.disable_model_pruning) {
266     optimizers->push_back(MakeUnique<ModelPruner>());
267   }
268 
269 #define USER_IS_ON(CFG) cfg_.CFG() == RewriterConfig::ON
270 #define USER_NOT_OFF(CFG) cfg_.CFG() != RewriterConfig::OFF
271 #define PLUGIN_IS_ON(CFG) \
272   plugin_configs.toggle_config[#CFG] == RewriterConfig::ON
273 #define PLUGIN_NOT_OFF(CFG) \
274   plugin_configs.toggle_config[#CFG] != RewriterConfig::OFF
275 #define BOTH_ARE_ON(CFG) USER_IS_ON(CFG) && PLUGIN_IS_ON(CFG)
276 #define BOTH_NOT_OFF(CFG) USER_NOT_OFF(CFG) && PLUGIN_NOT_OFF(CFG)
277   if (BOTH_NOT_OFF(implementation_selector)) {
278     optimizers->push_back(MakeUnique<ImplementationSelector>());
279   }
280   if (BOTH_NOT_OFF(function_optimization)) {
281     optimizers->push_back(MakeUnique<FunctionOptimizer>(
282         cfg_.function_optimization(),
283         /*lower_control_flow=*/LowerControlFlow()));
284   }
285   if (BOTH_NOT_OFF(common_subgraph_elimination) &&
286       BOTH_NOT_OFF(arithmetic_optimization)) {
287     optimizers->push_back(MakeUnique<CommonSubgraphElimination>(
288         cfg_.common_subgraph_elimination()));
289   }
290   if (BOTH_ARE_ON(debug_stripper)) {
291     optimizers->push_back(MakeUnique<DebugStripper>());
292   }
293   if (BOTH_NOT_OFF(constant_folding)) {
294     optimizers->push_back(MakeUnique<ConstantFolding>(
295         cfg_.constant_folding(), cpu_device_,
296         cfg_.experimental_disable_compressed_tensor_optimization(),
297         !cfg_.experimental_disable_folding_quantization_emulation()));
298   }
299   if (BOTH_NOT_OFF(shape_optimization)) {
300     optimizers->push_back(MakeUnique<ShapeOptimizer>());
301   }
302   if (AutoMixedPrecisionEnabled(cfg_.auto_mixed_precision()) &&
303       AutoMixedPrecisionEnabled(
304           plugin_configs.toggle_config["auto_mixed_precision"])) {
305     optimizers->push_back(
306         MakeUnique<AutoMixedPrecision>(AutoMixedPrecisionMode::CUDA));
307   }
308 #ifdef INTEL_MKL
309   if (AutoMixedPrecisionEnabled(cfg_.auto_mixed_precision_mkl()) &&
310       AutoMixedPrecisionEnabled(
311           plugin_configs.toggle_config["auto_mixed_precision_mkl"]) &&
312       IsMKLEnabled()) {
313     optimizers->push_back(
314         MakeUnique<AutoMixedPrecision>(AutoMixedPrecisionMode::MKL));
315   }
316 #endif
317   if (BOTH_ARE_ON(pin_to_host_optimization)) {
318     optimizers->push_back(MakeUnique<PinToHostOptimizer>());
319   }
320   if (BOTH_NOT_OFF(arithmetic_optimization)) {
321     optimizers->push_back(
322         MakeUnique<ArithmeticOptimizer>(cfg_.arithmetic_optimization()));
323   }
324   if (BOTH_NOT_OFF(layout_optimizer)) {
325     optimizers->push_back(MakeUnique<GenericLayoutOptimizer>(
326         /*optimization level*/ cfg_.layout_optimizer(),
327         /*CPU layout conversion*/ cfg_.cpu_layout_conversion()));
328   }
329   if (BOTH_NOT_OFF(remapping)) {
330     optimizers->push_back(
331         MakeUnique<Remapper>(cfg_.remapping(), xla_auto_clustering_on_));
332   }
333   if (BOTH_NOT_OFF(loop_optimization)) {
334     optimizers->push_back(
335         MakeUnique<LoopOptimizer>(cfg_.loop_optimization(), cpu_device_));
336   }
337   if (BOTH_NOT_OFF(dependency_optimization)) {
338     optimizers->push_back(
339         MakeUnique<DependencyOptimizer>(cfg_.dependency_optimization()));
340   }
341   if (MemoryOptimizerEnabled(cfg_.memory_optimization(),
342                              xla_auto_clustering_on_) &&
343       PLUGIN_NOT_OFF(memory_optimization)) {
344     if (cfg_.memory_optimizer_target_node_name_scope().empty()) {
345       optimizers->push_back(
346           // Use the default target node name prefix "gradients/"
347           MakeUnique<MemoryOptimizer>(cfg_.memory_optimization()));
348     } else {
349       optimizers->push_back(MakeUnique<MemoryOptimizer>(
350           cfg_.memory_optimization(),
351           cfg_.memory_optimizer_target_node_name_scope()));
352     }
353   }
354   if (cfg_.auto_parallel().enable() && PLUGIN_IS_ON(auto_parallel)) {
355     optimizers->push_back(
356         MakeUnique<AutoParallel>(cfg_.auto_parallel().num_replicas()));
357   }
358 
359 #ifndef ENABLE_MKL
360   if (BOTH_ARE_ON(scoped_allocator_optimization)) {
361     optimizers->push_back(MakeUnique<ScopedAllocatorOptimizer>(
362         cfg_.scoped_allocator_optimization(), cfg_.scoped_allocator_opts()));
363   }
364 #endif
365 
366 #undef USER_IS_ON
367 #undef USER_NOT_OFF
368 #undef PLUGIN_IS_ON
369 #undef PLUGIN_NOT_OFF
370 #undef BOTH_ARE_ON
371 #undef BOTH_NOT_OFF
372   return InitializeCustomGraphOptimizers(device_types, std::set<string>(),
373                                          optimizers);
374 }
375 
InitializeOptimizersByName(const std::set<string> & device_types,std::vector<std::unique_ptr<GraphOptimizer>> * optimizers) const376 Status MetaOptimizer::InitializeOptimizersByName(
377     const std::set<string>& device_types,
378     std::vector<std::unique_ptr<GraphOptimizer>>* optimizers) const {
379   std::set<string> initialized_custom_optimizers;
380   for (const string& optimizer_name : cfg_.optimizers()) {
381     auto optimizer = MakeNewOptimizer(optimizer_name, device_types);
382     if (optimizer) {
383       VLOG(2) << "Registered default graph optimizer: " << optimizer_name;
384       optimizers->push_back(std::move(optimizer));
385       continue;
386     }
387 
388     auto custom_optimizer =
389         CustomGraphOptimizerRegistry::CreateByNameOrNull(optimizer_name);
390 
391     if (custom_optimizer) {
392       VLOG(2) << "Registered custom graph optimizer: " << optimizer_name;
393       TF_RETURN_IF_ERROR(custom_optimizer->InitWithConfig(
394           config_proto_, GetCustomGraphOptimizerConfig(optimizer_name)));
395       optimizers->push_back(std::move(custom_optimizer));
396       initialized_custom_optimizers.insert(optimizer_name);
397     } else {
398       VLOG(2) << "Can't register an optimizer by name: " << optimizer_name;
399     }
400   }
401   return InitializeCustomGraphOptimizers(
402       device_types, initialized_custom_optimizers, optimizers);
403 }
404 
InitializeCustomGraphOptimizers(const std::set<string> & device_types,const std::set<string> & pre_initialized_optimizers,std::vector<std::unique_ptr<GraphOptimizer>> * optimizers) const405 Status MetaOptimizer::InitializeCustomGraphOptimizers(
406     const std::set<string>& device_types,
407     const std::set<string>& pre_initialized_optimizers,
408     std::vector<std::unique_ptr<GraphOptimizer>>* optimizers) const {
409   for (const auto& optimizer_config : cfg_.custom_optimizers()) {
410     if (pre_initialized_optimizers.find(optimizer_config.name()) !=
411         pre_initialized_optimizers.end()) {
412       continue;
413     }
414 
415     auto custom_optimizer = CustomGraphOptimizerRegistry::CreateByNameOrNull(
416         optimizer_config.name());
417 
418     if (custom_optimizer) {
419       VLOG(2) << "Registered custom configurable graph optimizer: "
420               << optimizer_config.name();
421       TF_RETURN_IF_ERROR(
422           custom_optimizer->InitWithConfig(config_proto_, &optimizer_config));
423       optimizers->push_back(std::move(custom_optimizer));
424     } else {
425       // If there are no custom optimizers with given name, try to initialize a
426       // default optimizer. This way, custom configurable optimizers can be
427       // mixed with default optimizers in any order.
428       auto optimizer = MakeNewOptimizer(optimizer_config.name(), device_types);
429       if (optimizer) {
430         VLOG(2) << "Registered default graph optimizer: "
431                 << optimizer_config.name();
432         optimizers->push_back(std::move(optimizer));
433         continue;
434       }
435       VLOG(2) << "Can't register an optimizer by name: "
436               << optimizer_config.name();
437     }
438   }
439   return InitializePluginGraphOptimizers(device_types, optimizers);
440 }
441 
InitializePluginGraphOptimizers(const std::set<string> & device_types,std::vector<std::unique_ptr<GraphOptimizer>> * optimizers) const442 Status MetaOptimizer::InitializePluginGraphOptimizers(
443     const std::set<string>& device_types,
444     std::vector<std::unique_ptr<GraphOptimizer>>* optimizers) const {
445   if (cfg_.use_plugin_optimizers() == RewriterConfig::OFF) return Status::OK();
446   auto plugin_optimizers =
447       PluginGraphOptimizerRegistry::CreateOptimizers(device_types);
448   for (auto& plugin_optimizer : plugin_optimizers) {
449     optimizers->push_back(std::move(plugin_optimizer));
450   }
451   return Status::OK();
452 }
453 
454 const RewriterConfig::CustomGraphOptimizer*
GetCustomGraphOptimizerConfig(const string & name) const455 MetaOptimizer::GetCustomGraphOptimizerConfig(const string& name) const {
456   for (const auto& config : cfg_.custom_optimizers()) {
457     if (config.name() == name) {
458       return &config;
459     }
460   }
461   return nullptr;
462 }
463 
InitializeVerifiers(std::vector<std::unique_ptr<GraphVerifier>> * inter_optimizer_verifiers,std::vector<std::unique_ptr<GraphVerifier>> * post_optimization_verifiers) const464 void MetaOptimizer::InitializeVerifiers(
465     std::vector<std::unique_ptr<GraphVerifier>>* inter_optimizer_verifiers,
466     std::vector<std::unique_ptr<GraphVerifier>>* post_optimization_verifiers)
467     const {
468   if (cfg_.inter_optimizer_verifier_config().structure_verifier() ==
469       VerifierConfig::ON) {
470     inter_optimizer_verifiers->push_back(MakeUnique<StructureVerifier>());
471   }
472   if (cfg_.post_optimization_verifier_config().structure_verifier() ==
473       VerifierConfig::ON) {
474     post_optimization_verifiers->push_back(MakeUnique<StructureVerifier>());
475   }
476 }
477 
PrintUserAndPluginConfigs(const std::set<string> & device_types) const478 void MetaOptimizer::PrintUserAndPluginConfigs(
479     const std::set<string>& device_types) const {
480   if (cfg_.use_plugin_optimizers() == RewriterConfig::OFF) return;
481   ConfigList plugin_cfg = PluginGraphOptimizerRegistry::GetPluginConfigs(
482       cfg_.use_plugin_optimizers() != RewriterConfig::OFF, device_types);
483   PluginGraphOptimizerRegistry::PrintPluginConfigsIfConflict(device_types);
484 
485   ConfigList user_cfg;
486   // Print user's and plugin's configs.
487   if (cfg_.optimizers().empty()) {
488     if (cfg_.disable_meta_optimizer()) {
489       return;
490     }
491     user_cfg.disable_model_pruning = cfg_.disable_model_pruning();
492 #define PRINT_CFG(CFG) user_cfg.toggle_config[#CFG] = cfg_.CFG();
493     PRINT_CFG(implementation_selector)
494     PRINT_CFG(function_optimization)
495     PRINT_CFG(common_subgraph_elimination)
496     PRINT_CFG(arithmetic_optimization)
497     PRINT_CFG(debug_stripper)
498     PRINT_CFG(constant_folding)
499     PRINT_CFG(shape_optimization)
500     PRINT_CFG(pin_to_host_optimization)
501     PRINT_CFG(layout_optimizer)
502     PRINT_CFG(remapping)
503     PRINT_CFG(loop_optimization)
504     PRINT_CFG(dependency_optimization)
505     PRINT_CFG(scoped_allocator_optimization)
506 #undef PRINT_CFG
507     user_cfg.toggle_config["auto_mixed_precision"] =
508         AutoMixedPrecisionEnabled(cfg_.auto_mixed_precision())
509             ? RewriterConfig::ON
510             : RewriterConfig::OFF;
511     user_cfg.toggle_config["auto_mixed_precision_mkl"] =
512         AutoMixedPrecisionEnabled(cfg_.auto_mixed_precision_mkl())
513             ? RewriterConfig::ON
514             : RewriterConfig::OFF;
515     user_cfg.toggle_config["memory_optimization"] =
516         MemoryOptimizerEnabled(cfg_.memory_optimization(),
517                                config_proto_.graph_options()
518                                    .optimizer_options()
519                                    .global_jit_level())
520             ? RewriterConfig::ON
521             : RewriterConfig::OFF;
522     user_cfg.toggle_config["auto_parallel"] = cfg_.auto_parallel().enable()
523                                                   ? RewriterConfig::ON
524                                                   : RewriterConfig::OFF;
525   } else {
526     for (const string& optimizer_name : cfg_.optimizers()) {
527       if (optimizer_name == "pruning") user_cfg.disable_model_pruning = true;
528 
529 #define PRINT_CFG(NAME, CONFIG) \
530   if (optimizer_name == NAME)   \
531     user_cfg.toggle_config[CONFIG] = RewriterConfig::ON;
532 
533       PRINT_CFG("implementation_selector", "implementation_selector")
534       PRINT_CFG("function", "function_optimization")
535       PRINT_CFG("common_subgraph_elimination", "common_subgraph_elimination")
536       PRINT_CFG("arithmetic", "arithmetic_optimization")
537       PRINT_CFG("debug_stripper", "debug_stripper")
538       PRINT_CFG("constfold", "constant_folding")
539       PRINT_CFG("shape", "shape_optimization")
540       PRINT_CFG("auto_mixed_precision", "auto_mixed_precision")
541       PRINT_CFG("auto_mixed_precision_mkl", "auto_mixed_precision_mkl")
542       PRINT_CFG("pin_to_host", "pin_to_host_optimization")
543       PRINT_CFG("layout", "layout_optimizer")
544       PRINT_CFG("remap", "remapping")
545       PRINT_CFG("loop", "loop_optimization")
546       PRINT_CFG("dependency", "dependency_optimization")
547       PRINT_CFG("memory", "memory_optimization")
548       PRINT_CFG("autoparallel", "auto_parallel")
549       PRINT_CFG("scoped_allocator", "scoped_allocator_optimization")
550 #undef PRINT_CFG
551     }
552   }
553 
554   // Print logs only when plugin config has conflict with user config.
555   if (!PluginGraphOptimizerRegistry::IsConfigsConflict(user_cfg, plugin_cfg))
556     return;
557 
558   ConfigList final_cfg = user_cfg;
559   // If plugin turns on `disable_model_pruning`, then `disable_model_pruning`
560   // should be true;
561   if (plugin_cfg.disable_model_pruning == true)
562     final_cfg.disable_model_pruning = true;
563   // If plugin turns off a certain optimizer, then the optimizer should be
564   // turned off;
565   for (auto& pair : plugin_cfg.toggle_config) {
566     if (plugin_cfg.toggle_config[pair.first] == RewriterConfig::OFF)
567       final_cfg.toggle_config[pair.first] = RewriterConfig::OFF;
568   }
569 
570   string logs =
571       "\nConfig of optimizers\t\tUser's config\tPlugin's config\tFinal "
572       "config(User & Plugin)\n";
573   strings::StrAppend(&logs, "disable_model_pruning\t\t",
574                      user_cfg.disable_model_pruning, "\t\t",
575                      plugin_cfg.disable_model_pruning, "\t\t",
576                      final_cfg.disable_model_pruning, "\n");
577   for (auto& pair : user_cfg.toggle_config) {
578     if (pair.first == "debug_stripper" ||
579         pair.first == "auto_mixed_precision" ||
580         pair.first == "auto_mixed_precision_mkl" ||
581         pair.first == "pin_to_host_optimization" ||
582         pair.first == "scoped_allocator_optimization") {
583       // These optimizers are turned off by default.
584       strings::StrAppend(
585           &logs, pair.first, string(32 - pair.first.size(), ' '),
586           (pair.second == RewriterConfig::ON), "\t\t",
587           (plugin_cfg.toggle_config[pair.first] == RewriterConfig::ON), "\t\t",
588           (final_cfg.toggle_config[pair.first] == RewriterConfig::ON), "\n");
589     } else {
590       // These optimizers are turned on by default.
591       strings::StrAppend(
592           &logs, pair.first, string(32 - pair.first.size(), ' '),
593           (pair.second != RewriterConfig::OFF), "\t\t",
594           (plugin_cfg.toggle_config[pair.first] != RewriterConfig::OFF), "\t\t",
595           (final_cfg.toggle_config[pair.first] != RewriterConfig::OFF), "\n");
596     }
597   }
598   LOG(WARNING) << "User's config has been changed based on plugin's config.";
599   LOG(WARNING) << logs;
600 }
601 
OptimizeGraph(Cluster * cluster,GrapplerItem && item,GraphDef * optimized_graph)602 Status MetaOptimizer::OptimizeGraph(Cluster* cluster, GrapplerItem&& item,
603                                     GraphDef* optimized_graph) {
604   int min_graph_nodes = cfg_.min_graph_nodes() == 0 ? kDefaultMinGraphNodes
605                                                     : cfg_.min_graph_nodes();
606   if (item.graph.node_size() < min_graph_nodes) {
607     VLOG(3) << "Skipping optimization, graph has less than " << min_graph_nodes
608             << " nodes.";
609     *optimized_graph = item.graph;
610     return Status::OK();
611   }
612 
613   const uint64 start_us = Env::Default()->NowMicros();
614 
615   std::vector<std::unique_ptr<GraphOptimizer>> optimizers;
616   std::set<std::string> device_types;
617   TF_RETURN_IF_ERROR(GetGraphDevice(item.graph, &device_types));
618   if (cfg_.optimizers().empty()) {
619     TF_RETURN_IF_ERROR(InitializeOptimizers(device_types, &optimizers));
620   } else {
621     TF_RETURN_IF_ERROR(InitializeOptimizersByName(device_types, &optimizers));
622   }
623   PrintUserAndPluginConfigs(device_types);
624 
625   // Initialize the configured verifiers.
626   std::vector<std::unique_ptr<GraphVerifier>> inter_optimizer_verifiers;
627   std::vector<std::unique_ptr<GraphVerifier>> post_optimization_verifiers;
628   InitializeVerifiers(&inter_optimizer_verifiers, &post_optimization_verifiers);
629   if (inter_optimizer_verifiers.empty()) {
630     VLOG(2) << "No inter optimizer verifiers have been configured";
631   } else {
632     VLOG(2) << inter_optimizer_verifiers.size()
633             << " inter optimizer verifiers have been configured";
634   }
635   if (post_optimization_verifiers.empty()) {
636     VLOG(2) << "No post optimization verifiers have been configured";
637   } else {
638     VLOG(2) << post_optimization_verifiers.size()
639             << " post optimization verifiers have been configured";
640   }
641 
642   VLOG(2) << "Optimize GrapplerItem: item.id=" << item.id
643           << " num_optimizers=" << optimizers.size()
644           << ", num nodes = " << item.graph.node_size();
645 
646   if (optimizers.empty()) {
647     VLOG(3) << "Skipping graph optimization, no optimizers registered";
648     *optimized_graph = item.graph;
649     return Status::OK();
650   }
651 
652   // Invariant: optimized_graph contains the most recently optimized version of
653   // the graph.
654   auto original_producer = item.graph.versions().producer();
655   optimized_graph->Swap(&item.graph);
656 
657   GraphOptimizationResult optimization_result(item.id);
658   GraphOptimizer* sa_optimizer = nullptr;
659 
660   // Constants in the graph are normally compressed after model_pruner.
661   // Do it here if model pruner is disabled.
662   if (cfg_.disable_model_pruning()) {
663     CompressConstants(optimized_graph);
664   }
665 
666   for (int iteration = 0; iteration < NumIterations(cfg_); ++iteration) {
667     // Don't bother optimizing further if the graph is already tiny.
668     if (optimized_graph->node_size() < min_graph_nodes) {
669       VLOG(3) << "Stopping after iteration " << iteration
670               << ", graph is tiny (#nodes = " << optimized_graph->node_size()
671               << "  < " << min_graph_nodes << ")";
672       break;
673     }
674 
675     VLOG(4) << "Starting optimization iteration " << iteration;
676     if (VLOG_IS_ON(4)) {
677       DumpGraphDefToFile(
678           strings::StrCat("before_MetaOptimizer_iteration_", iteration, "_",
679                           reinterpret_cast<uintptr_t>(optimized_graph)),
680           *optimized_graph);
681     }
682 
683     for (const auto& optimizer : optimizers) {
684       GRAPPLER_RETURN_IF_DEADLINE_EXCEEDED();
685       // Some optimizers can run only once.
686       if (iteration > 0 && IsRunOnceOptimizer(optimizer->name())) continue;
687 #ifndef ENABLE_MKL
688       // Some must run only on the last iteration.
689       if (optimizer->name() == "scoped_allocator_optimizer") {
690         if (sa_optimizer == nullptr) sa_optimizer = optimizer.get();
691         continue;
692       }
693 #endif
694 
695       TF_RETURN_IF_ERROR(RunOptimizer(optimizer.get(), cluster, &item,
696                                       optimized_graph, &optimization_result));
697 
698       if (iteration == 0 && optimizer->name() == "model_pruner") {
699         CompressConstants(optimized_graph);
700       }
701 
702       if (VLOG_IS_ON(4)) {
703         DumpGraphDefToFile(
704             strings::StrCat("after_MetaOptimizer_iteration_", iteration, "_",
705                             optimizer->name(), "_",
706                             reinterpret_cast<uintptr_t>(optimized_graph)),
707             *optimized_graph);
708       }
709       for (const auto& verifier : inter_optimizer_verifiers) {
710         // TODO(ashwinm): Need to enforce verification_deadline.
711         TF_RETURN_IF_ERROR(verifier->Verify(*optimized_graph));
712       }
713     }
714     if (VLOG_IS_ON(4)) {
715       DumpGraphDefToFile(
716           strings::StrCat("after_MetaOptimizer_iteration_", iteration, "_",
717                           reinterpret_cast<uintptr_t>(optimized_graph)),
718           *optimized_graph);
719     }
720     // TODO(ashwinm): Need to enforce verification_deadline.
721     for (const auto& verifier : post_optimization_verifiers) {
722       TF_RETURN_IF_ERROR(verifier->Verify(*optimized_graph));
723     }
724   }
725 #ifndef ENABLE_MKL
726   // ScopedAllocatorOptimizer must run last.
727   if (sa_optimizer != nullptr) {
728     TF_RETURN_IF_ERROR(RunOptimizer(sa_optimizer, cluster, &item,
729                                     optimized_graph, &optimization_result));
730     GRAPPLER_RETURN_IF_DEADLINE_EXCEEDED();
731   }
732 #endif
733 
734   bool is_optimized = std::find_if(optimization_result.results.begin(),
735                                    optimization_result.results.end(),
736                                    [](const OptimizerResult& result) {
737                                      return result.status.ok();
738                                    }) != optimization_result.results.end();
739 
740   // Record graph optimization result.
741   optimization_results_.push_back(optimization_result);
742 
743   if (is_optimized) {
744     TF_RETURN_IF_ERROR(TopologicalSort(optimized_graph));
745     ReassignColocation(optimized_graph);
746     // Make sure that the optimizers preserved the graph version.
747     DCHECK_EQ(optimized_graph->versions().producer(), original_producer);
748   }
749 
750   const uint64 end_us = Env::Default()->NowMicros();
751   metrics::UpdateGrapplerPassTime("OptimizeMainGraph", end_us - start_us);
752 
753   return Status::OK();
754 }
755 
RunOptimizer(GraphOptimizer * optimizer,Cluster * cluster,GrapplerItem * optimized_item,GraphDef * optimized_graph,GraphOptimizationResult * optimization_result)756 Status MetaOptimizer::RunOptimizer(
757     GraphOptimizer* optimizer, Cluster* cluster, GrapplerItem* optimized_item,
758     GraphDef* optimized_graph, GraphOptimizationResult* optimization_result) {
759   const uint64 start_us = Env::Default()->NowMicros();
760 
761   // If optimizer doesn't need a function library, we will replace it with a
762   // stub before running optimization, and will put it back at the end.
763   FunctionDefLibrary optimized_graph_function_library;
764   const bool is_function_library_aware = optimizer->UsesFunctionLibrary();
765 
766   // Replace function library in optimized graph with a stub.
767   if (!is_function_library_aware) {
768     VLOG(3) << "Replace function library with a stub for " << optimizer->name();
769     optimized_graph_function_library.Swap(optimized_graph->mutable_library());
770     *optimized_graph->mutable_library() =
771         GetFunctionDefLibraryStub(optimized_graph_function_library);
772   }
773 
774   // This swaps the current optimized_graph into optimized item and
775   // resets optimized_graph to an empty graph.
776   optimized_graph->Swap(&optimized_item->graph);
777   *optimized_graph = GraphDef();
778   optimizer->set_deadline_usec(this->deadline_usec());
779   Status status =
780       optimizer->Optimize(cluster, *optimized_item, optimized_graph);
781   const uint64 end_us = Env::Default()->NowMicros();
782   const float duration_ms = (end_us - start_us) / 1000.0f;
783   metrics::UpdateGrapplerPassTime(optimizer->name(), end_us - start_us);
784 
785   string message;
786   if (!status.ok()) {
787     optimized_graph->Swap(&optimized_item->graph);
788     if (errors::IsAborted(status)) {
789       // By convention we (ab-)use the Aborted error code to signal that the
790       // optimizer returned without performing any changes to the graph.
791       message = strings::StrCat(optimizer->name(),
792                                 " did nothing. time = ", duration_ms, "ms.");
793       // Swallow the non-critical error.
794       status = Status::OK();
795     } else if (errors::IsDeadlineExceeded(status)) {
796       message =
797           strings::StrCat(status.ToString(), ", time = ", duration_ms, "ms.");
798       LOG(WARNING) << optimizer->name() << " failed: " << message;
799     } else {
800       message = status.ToString();
801       LOG(ERROR) << optimizer->name() << " failed: " << message;
802     }
803   } else {
804     message = strings::StrCat(
805         PrintSizesBeforeAfter(optimized_item->graph, *optimized_graph),
806         ", time = ", duration_ms, "ms.");
807     VLOG(1) << optimizer->name() << ": " << message;
808   }
809 
810   // Swap function library back into the main graph.
811   if (!is_function_library_aware) {
812     optimized_graph->mutable_library()->Swap(&optimized_graph_function_library);
813   }
814 
815   OptimizerResult optimizer_result{optimizer->name(), message, status};
816   optimization_result->results.push_back(optimizer_result);
817 
818   if (!status.ok() && cfg_.fail_on_optimizer_errors()) return status;
819 
820   return Status::OK();
821 }
822 
823 // Propagates `_tf_data_function` attributes from functions to their callees.
PropagateTFDataAttrs(const FunctionLibraryDefinition & flib,FunctionDefLibrary & fdef_lib)824 void PropagateTFDataAttrs(const FunctionLibraryDefinition& flib,
825                           FunctionDefLibrary& fdef_lib) {
826   // Collect functions that need the attribute in this set.
827   absl::flat_hash_set<std::string> tf_data_functions;
828   std::function<void(const std::string&)> collect_tf_data_functions_dfs =
829       [&](const std::string& func_name) -> void {
830     const FunctionDef* func_def = flib.Find(func_name);
831     // Skip functions that are not reachable from the optimized graph.
832     if (func_def == nullptr) return;
833 
834     // Return if we already found and added this function.
835     if (tf_data_functions.contains(func_name)) return;
836 
837     // We only get here if the function is (directly or indirectly) called from
838     // a tf.data function, so add it to the set.
839     tf_data_functions.insert(func_name);
840 
841     // Proceed with DFS for functions called from current function.
842     for (const NodeDef& node : func_def->node_def()) {
843       if (flib.Contains(node.op())) {
844         // This is a function call node.
845         collect_tf_data_functions_dfs(node.op());
846       }
847       // Check if there are functions in attributes.
848       for (const auto& attr : node.attr()) {
849         const AttrValue& attr_value = attr.second;
850         if (attr_value.has_func()) {
851           collect_tf_data_functions_dfs(attr_value.func().name());
852         }
853         if (attr_value.has_list()) {
854           for (const auto& func : attr_value.list().func()) {
855             collect_tf_data_functions_dfs(func.name());
856           }
857         }
858       }
859     }
860   };
861   // Perform DFS for all tf.data functions in `fdef_lib`.
862   for (const auto& func_def : fdef_lib.function()) {
863     const std::string& func_name = func_def.signature().name();
864     if (data::IsTFDataFunction(func_def))
865       collect_tf_data_functions_dfs(func_name);
866   }
867   // Set attribute for tf.data functions. We cannot do this in the DFS directly
868   // because `FunctionLibraryDefinition` does not seem to provide mutable access
869   // to a `FunctionDef`.
870   for (FunctionDef& func_def : *fdef_lib.mutable_function()) {
871     const std::string& func_name = func_def.signature().name();
872     if (tf_data_functions.contains(func_name) &&
873         !data::IsTFDataFunction(func_def)) {
874       VLOG(2) << "Marking " << func_name << " as tf.data function";
875       (*func_def.mutable_attr())[data::kTFDataFunction].set_b(true);
876     }
877   }
878 }
879 
OptimizeConsumeItem(Cluster * cluster,GrapplerItem && item,GraphDef * optimized_graph)880 Status MetaOptimizer::OptimizeConsumeItem(Cluster* cluster, GrapplerItem&& item,
881                                           GraphDef* optimized_graph) {
882   const uint64 start_us = Env::Default()->NowMicros();
883 
884   VLOG(1) << "Starting optimization for grappler item: " << item.id;
885   optimization_results_.clear();
886 
887   // Constructs a FunctionLibraryDefinition with functions that are reachable
888   // from the nodes of the graph.
889   const auto minimized_flib =
890       [](const GraphDef& graph) -> FunctionLibraryDefinition {
891     return FunctionLibraryDefinition(OpRegistry::Global(), graph.library())
892         .ReachableDefinitions(graph);
893   };
894 
895   // 0. Original graph might contain a huge function library, that is mostly
896   // unused. This library copied over by each individual Grappler optimizer,
897   // which adds a huge overhead. Before starting optimization passes we just
898   // remove all the unreachable functions.
899   // TODO(ezhulenev): Construct reachable function library definition directly
900   // from the proto without constructing temporary FunctionLibraryDefinition.
901   int old_library_size = item.graph.library().function_size();
902   *item.graph.mutable_library() = minimized_flib(item.graph).ToProto();
903   int new_library_size = item.graph.library().function_size();
904 
905   VLOG(1) << absl::Substitute(
906       "Deleted $0 unreachable functions from the graph (library size = $1)",
907       old_library_size - new_library_size, new_library_size);
908 
909   // Save a few small fields from item before we move it.
910   bool optimize_function_library =
911       item.optimization_options().optimize_function_library;
912   const auto producer = item.graph.versions().producer();
913 
914   // 1. Optimize main graph
915   TF_RETURN_IF_ERROR(OptimizeGraph(cluster, std::move(item), optimized_graph));
916   VLOG(1) << "Optimized main graph.";
917   GRAPPLER_RETURN_IF_DEADLINE_EXCEEDED();
918 
919   // 2. Optimize functions reachable from the optimized graph.
920   FunctionLibraryDefinition flib = minimized_flib(*optimized_graph);
921   using NodeDefs = protobuf::RepeatedPtrField<NodeDef>;
922 
923   // Find functions for which we might need to compute a gradient at runtime.
924   absl::flat_hash_set<string> differentiable_functions;
925 
926   const auto find_differentiable_functions =
927       [&](const NodeDefs& nodes) -> void {
928     for (const NodeDef& node : nodes) {
929       if (IsSymbolicGradient(node)) {
930         const auto* f_attr = gtl::FindOrNull(node.attr(), "f");
931         if (f_attr) differentiable_functions.insert(f_attr->func().name());
932       }
933     }
934   };
935 
936   // SymbolicGradient nodes inside the main graph.
937   find_differentiable_functions(optimized_graph->node());
938   // SymbolicGradient nodes inside the function library.
939   for (const FunctionDef& function : optimized_graph->library().function()) {
940     find_differentiable_functions(function.node_def());
941   }
942 
943   // Find functions that will be compiled by XLA later
944   // We do it by looking for XlaLaunch ops that call functions,
945   // then depth first search down those functions to find transitive functions.
946   // Grappler rewrites can potentially add nodes that are
947   // not supported by XLA, so we choose to skip such functions when we optimize
948   // the function library.
949   absl::flat_hash_set<string> xla_compiled_functions;
950   std::function<void(const string&)> find_all_functions;
951   find_all_functions = [&](const string& func) -> void {
952     // Ignore call cycles in the graph
953     if (xla_compiled_functions.contains(func)) return;
954     // Find func in the flib
955     const FunctionDef* func_def = flib.Find(func);
956     CHECK(func_def) << "not found: " << func;
957     // Mark function to be ignored by grappler
958     xla_compiled_functions.insert(func);
959     // Depth first search through the func for transitively called funcs
960     for (const NodeDef& node : func_def->node_def()) {
961       for (const auto attr : node.attr()) {
962         const AttrValue& attr_value = attr.second;
963         if (attr_value.has_func()) {
964           find_all_functions(attr_value.func().name());
965         }
966       }
967     }
968   };
969 
970   auto find_xla_compiled_functions = [&](const NodeDefs& nodes) -> void {
971     NameAttrList function;
972     for (const NodeDef& node : nodes) {
973       // Look only for XlaLaunch nodes that call a function
974       if (!IsXlaLaunch(node)) continue;
975       if (!GetNodeAttr(node, "function", &function).ok()) continue;
976       // Find all transitively called functions
977       find_all_functions(function.name());
978     }
979   };
980 
981   // XlaLaunch ops inside the main graph ...
982   find_xla_compiled_functions(optimized_graph->node());
983   // ... and inside the function library.
984   for (const FunctionDef& function : optimized_graph->library().function()) {
985     find_xla_compiled_functions(function.node_def());
986   }
987   // Propagate `_tf_data_function` attributes from functions to their callees.
988   PropagateTFDataAttrs(flib, *optimized_graph->mutable_library());
989 
990   // Optimize each function only once.
991   absl::flat_hash_set<string> optimized_funcs;
992   while (optimize_function_library) {
993     optimize_function_library = false;
994 
995     int function_idx = 0;
996     for (const FunctionDef& func : optimized_graph->library().function()) {
997       GRAPPLER_RETURN_IF_DEADLINE_EXCEEDED();
998 
999       const string& func_name = func.signature().name();
1000 
1001       // Skip functions that are not reachable from the optimized graph.
1002       if (!flib.Contains(func_name)) continue;
1003       // Skip already optimized functions.
1004       if (optimized_funcs.contains(func_name)) continue;
1005       // Skip functions that will be compiled by XLA.
1006       if (xla_compiled_functions.contains(func_name)) continue;
1007 
1008       // Skip parametrized functions (function type or body is defined only at
1009       // function call time by caller node attributes).
1010       // They should be specialized to their instantiation type parameters by
1011       // the function optimizer, before we can optimize function body.
1012       if (IsParametrized(func)) continue;
1013 
1014       // Skip tf.data functions as they are optimized by tf.data meta optimizer
1015       // and in function instantiation.
1016       if (data::IsTFDataFunction(func)) continue;
1017 
1018       VLOG(3) << "Optimize function: function=" << func_name << " ["
1019               << function_idx++ << " of "
1020               << optimized_graph->library().function_size() << "]";
1021 
1022       // Function optimization might specialize nested function calls, so we
1023       // have to reset the flag and do at least one more pass over the library.
1024       optimize_function_library = true;
1025       optimized_funcs.insert(func_name);
1026 
1027       // Make a GrapplerItem from a FunctionDef.
1028       GrapplerFunctionItem func_item;
1029       TF_RETURN_IF_ERROR(
1030           MakeGrapplerFunctionItem(func, flib, producer, &func_item));
1031 
1032       // If we need to compute the gradient of optimized function at runtime, we
1033       // can't perform non-differentiable rewrites.
1034       func_item.optimization_options().allow_non_differentiable_rewrites =
1035           !differentiable_functions.contains(func_name);
1036 
1037       // Device set available to the function is defined only by the runtime,
1038       // when we instantiate and execute the function. We can't use all devices
1039       // available to the main graph, because after partitioning the function
1040       // call node might execute on a remote worker.
1041       if (!func_item.devices().empty()) {
1042         return errors::Internal("GrapplerFunctionItem devices must be empty.");
1043       }
1044 
1045       // We are not allowed to prune certain types of ops from the graph
1046       // instantiated by the function definition, because we must guarantee
1047       // function execution semantics wrt side effects (see
1048       // function_optimizer.cc).
1049       func_item.optimization_options().allow_pruning_stateful_and_dataset_ops =
1050           false;
1051 
1052       // Optimize function body graph.
1053       GraphDef optimized_func_graph;
1054       if (IsTPUGraphDef(*optimized_graph)) {
1055         // Skip optimizing functions if this is a TPU graph. Currently, Grappler
1056         // passes do not handle TPU functions correctly in a variety of ways
1057         // (Note that due to the pre-placement TPU graph rewriting passes, the
1058         // TPU-related ops are encapsulated away into functions). For example,
1059         // TPU graphs contain TPUReplicateMetadata node that carries relevant
1060         // TPU metadata and Grappler passes could prune that away. Grappler
1061         // passes could also cause issues around shape inference. Since the
1062         // desired and existing behavior is to not optimize TPU functions with
1063         // Grappler, this check preserves that. The only exception is
1064         // implementation selector what is required to swap in some TPU specific
1065         // lowering code and is verified the work correctly on TPUs.
1066         ImplementationSelector implementation_selector;
1067 
1068         // Implementation selector needs to have access to valid function
1069         // signature and attributes, and it doesn't need actual function body.
1070         FunctionDefLibrary func_item_function_library;
1071         func_item_function_library.Swap(func_item.graph.mutable_library());
1072         *func_item.graph.mutable_library() =
1073             GetFunctionDefLibraryStub(func_item_function_library);
1074 
1075         TF_RETURN_IF_ERROR(implementation_selector.Optimize(
1076             cluster, func_item, &optimized_func_graph));
1077       } else {
1078         GrapplerFunctionItem func_item_copy = func_item;
1079         TF_RETURN_IF_ERROR(OptimizeGraph(cluster, std::move(func_item_copy),
1080                                          &optimized_func_graph));
1081       }
1082 
1083       // Function body optimization might have created new specialized
1084       // functions for each instantiation context. Add them to the library.
1085       for (const FunctionDef& func_def :
1086            optimized_func_graph.library().function()) {
1087         if (flib.Find(func_def.signature().name()) == nullptr) {
1088           TF_RETURN_IF_ERROR(flib.AddFunctionDef(func_def));
1089         }
1090       }
1091 
1092       // Convert optimized graph back to FunctionDef.
1093       FunctionDef optimized_func;
1094       func_item.SwapFunctionBody(std::move(optimized_func_graph));
1095       TF_RETURN_IF_ERROR(MakeFunctionDef(func_item, flib, &optimized_func));
1096 
1097       // Replace optimized function with a new FunctionDef.
1098       TF_RETURN_IF_ERROR(flib.ReplaceFunction(func_name, optimized_func));
1099     }
1100 
1101     // If optimized at least one function, update the graph library.
1102     if (optimize_function_library) {
1103       *optimized_graph->mutable_library() = flib.ToProto();
1104     }
1105   }
1106 
1107   VLOG(1) << "Optimized " << optimized_funcs.size()
1108           << " functions: " << absl::StrJoin(optimized_funcs, ", ");
1109   VLOG(3) << "Optimized graph =\n" << optimized_graph->DebugString();
1110   if (VLOG_IS_ON(1)) {
1111     DumpGraphDefToFile(
1112         strings::StrCat("after_MetaOptimizer_",
1113                         reinterpret_cast<uintptr_t>(optimized_graph)),
1114         *optimized_graph);
1115   }
1116 
1117   const uint64 end_us = Env::Default()->NowMicros();
1118   metrics::UpdateGrapplerPassTime("*", end_us - start_us);
1119 
1120   return Status::OK();
1121 }
1122 
GetResultString() const1123 string MetaOptimizer::GetResultString() const {
1124   std::string result_string;
1125   for (const GraphOptimizationResult& graph_result : optimization_results_) {
1126     absl::StrAppend(&result_string,
1127                     "Optimization results for grappler item: ", graph_result.id,
1128                     "\n");
1129     for (const OptimizerResult& result : graph_result.results) {
1130       absl::StrAppend(&result_string, "  ", result.optimizer_name, ": ",
1131                       result.message, "\n");
1132     }
1133   }
1134   return result_string;
1135 }
1136 
PrintResult()1137 void MetaOptimizer::PrintResult() { LOG(INFO) << GetResultString(); }
1138 
MetaOptimizerEnabled(const ConfigProto & cfg)1139 bool MetaOptimizerEnabled(const ConfigProto& cfg) {
1140   const auto& rewrite_cfg = cfg.graph_options().rewrite_options();
1141   if (rewrite_cfg.disable_meta_optimizer()) {
1142     return false;
1143   }
1144   return !rewrite_cfg.disable_model_pruning() ||
1145          rewrite_cfg.layout_optimizer() != RewriterConfig::OFF ||
1146          rewrite_cfg.function_optimization() != RewriterConfig::OFF ||
1147          rewrite_cfg.constant_folding() != RewriterConfig::OFF ||
1148          rewrite_cfg.shape_optimization() != RewriterConfig::OFF ||
1149          rewrite_cfg.remapping() != RewriterConfig::OFF ||
1150          rewrite_cfg.common_subgraph_elimination() != RewriterConfig::OFF ||
1151          rewrite_cfg.arithmetic_optimization() != RewriterConfig::OFF ||
1152          rewrite_cfg.loop_optimization() != RewriterConfig::OFF ||
1153          rewrite_cfg.dependency_optimization() != RewriterConfig::OFF ||
1154          rewrite_cfg.auto_parallel().enable() ||
1155          rewrite_cfg.memory_optimization() != RewriterConfig::NO_MEM_OPT ||
1156          rewrite_cfg.debug_stripper() == RewriterConfig::ON ||
1157 #ifndef ENABLE_MKL
1158          rewrite_cfg.scoped_allocator_optimization() == RewriterConfig::ON ||
1159 #endif
1160          rewrite_cfg.pin_to_host_optimization() == RewriterConfig::ON ||
1161          AutoMixedPrecisionEnabled(rewrite_cfg.auto_mixed_precision()) ||
1162          AutoMixedPrecisionEnabled(rewrite_cfg.auto_mixed_precision_mkl()) ||
1163          !rewrite_cfg.optimizers().empty() ||
1164          !rewrite_cfg.custom_optimizers().empty();
1165 }
1166 
RunMetaOptimizer(GrapplerItem && item,const ConfigProto & cfg,DeviceBase * cpu_device,Cluster * cluster,GraphDef * optimized_graph)1167 Status RunMetaOptimizer(GrapplerItem&& item, const ConfigProto& cfg,
1168                         DeviceBase* cpu_device, Cluster* cluster,
1169                         GraphDef* optimized_graph) {
1170   MetaOptimizer optimizer(cpu_device, cfg);
1171   optimizer.set_deadline_usec(
1172       DeadlineMicroSeconds(cfg.graph_options().rewrite_options()));
1173   return optimizer.OptimizeConsumeItem(cluster, std::move(item),
1174                                        optimized_graph);
1175 }
1176 
OptimizeGraph(std::vector<string> ret_node_names,std::vector<string> keep_node_names,FunctionLibraryDefinition * flib,const DeviceSet & device_set,Device * cpu_device,const ConfigProto & config_proto,const string & grappler_item_id,const GrapplerItem::OptimizationOptions & optimization_options,std::unique_ptr<tensorflow::Graph> * g)1177 Status OptimizeGraph(
1178     std::vector<string> ret_node_names, std::vector<string> keep_node_names,
1179     FunctionLibraryDefinition* flib, const DeviceSet& device_set,
1180     Device* cpu_device, const ConfigProto& config_proto,
1181     const string& grappler_item_id,
1182     const GrapplerItem::OptimizationOptions& optimization_options,
1183     std::unique_ptr<tensorflow::Graph>* g) {
1184   if (!tensorflow::grappler::MetaOptimizerEnabled(config_proto)) {
1185     return Status::OK();
1186   }
1187 
1188   tensorflow::grappler::GrapplerItem item;
1189   item.id = grappler_item_id;
1190   item.optimization_options() = optimization_options;
1191 
1192   // Add all available devices so that inlined function can be placed.
1193   for (const Device* d : device_set.devices()) {
1194     Status added_device = item.AddDevice(d->name());
1195     if (!added_device.ok()) VLOG(3) << added_device.error_message();
1196   }
1197   VLOG(3) << "Grappler available devices: "
1198           << absl::StrJoin(item.devices(), ", ");
1199 
1200   // Add fetches so that the graph can be pruned.
1201   item.fetch.swap(ret_node_names);
1202 
1203   // Add noes that can't be removed from the graph.
1204   item.keep_ops = std::move(keep_node_names);
1205 
1206   (*g)->ToGraphDef(&item.graph);
1207 
1208   if (flib) {
1209     *item.graph.mutable_library() = flib->ToProto();
1210   }
1211 
1212   tensorflow::GraphDef out_graph;
1213   tensorflow::grappler::VirtualCluster cluster(&device_set);
1214   // TODO(nareshmodi): Consider adding and using the more generic GraphOptions
1215   // proto (which also contain the OptimizerOptions).
1216   TF_RETURN_IF_ERROR(tensorflow::grappler::RunMetaOptimizer(
1217       std::move(item), config_proto, cpu_device, &cluster, &out_graph));
1218 
1219   std::unique_ptr<tensorflow::Graph> optimized_graph(
1220       new tensorflow::Graph(OpRegistry::Global()));
1221 
1222   // Copy optimized functions back to the overlay lib.
1223   if (flib) {
1224     for (const FunctionDef& fdef : out_graph.library().function()) {
1225       const string& func_name = fdef.signature().name();
1226       if (flib->Contains(func_name)) {
1227         StackTracesMap stack_traces = flib->GetStackTraces(func_name);
1228         TF_RETURN_IF_ERROR(
1229             flib->ReplaceFunction(func_name, fdef, stack_traces));
1230       } else {
1231         TF_RETURN_IF_ERROR(
1232             flib->AddFunctionDef(fdef, flib->GetStackTraces(func_name)));
1233       }
1234     }
1235   }
1236 
1237   TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(
1238       GraphConstructorOptions(), std::move(out_graph), optimized_graph.get()));
1239 
1240   // The graph conversion sets the requested device names but not the
1241   // assigned device names. However, since at this point the graph is
1242   // placed TF expects an assigned device name for every node. Therefore
1243   // we copy the requested device into the assigned device field.
1244   for (Node* node : optimized_graph->nodes()) {
1245     if (node->IsOp() && node->assigned_device_name().empty()) {
1246       if (node->requested_device().empty()) {
1247         return errors::Internal(
1248             "Either placer did not place the node or Grappler did not "
1249             "copy the assigned device. Contact Grappler team since latter "
1250             "is more likely. Node=",
1251             node->name(),
1252             " Graph: ", optimized_graph->ToGraphDefDebug().DebugString());
1253       }
1254       node->set_assigned_device_name(node->requested_device());
1255     }
1256   }
1257 
1258   *g = std::move(optimized_graph);
1259   return Status::OK();
1260 }
1261 
1262 }  // namespace grappler
1263 }  // namespace tensorflow
1264