• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/optimizers/meta_optimizer.h"
17 #include "absl/strings/substitute.h"
18 #include "tensorflow/core/common_runtime/function.h"
19 #include "tensorflow/core/framework/function.pb.h"
20 #include "tensorflow/core/framework/tensor_util.h"
21 #include "tensorflow/core/framework/versions.pb.h"
22 #include "tensorflow/core/graph/graph_constructor.h"
23 #include "tensorflow/core/grappler/clusters/virtual_cluster.h"
24 #include "tensorflow/core/grappler/optimizers/arithmetic_optimizer.h"
25 #include "tensorflow/core/grappler/optimizers/auto_parallel.h"
26 #include "tensorflow/core/grappler/optimizers/constant_folding.h"
27 #include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
28 #include "tensorflow/core/grappler/optimizers/debug_stripper.h"
29 #include "tensorflow/core/grappler/optimizers/dependency_optimizer.h"
30 #include "tensorflow/core/grappler/optimizers/function_optimizer.h"
31 #include "tensorflow/core/grappler/optimizers/implementation_selector.h"
32 #include "tensorflow/core/grappler/optimizers/layout_optimizer.h"
33 #include "tensorflow/core/grappler/optimizers/loop_optimizer.h"
34 #include "tensorflow/core/grappler/optimizers/memory_optimizer.h"
35 #include "tensorflow/core/grappler/optimizers/model_pruner.h"
36 #include "tensorflow/core/grappler/optimizers/pin_to_host_optimizer.h"
37 #include "tensorflow/core/grappler/optimizers/remapper.h"
38 #include "tensorflow/core/grappler/optimizers/scoped_allocator_optimizer.h"
39 #include "tensorflow/core/grappler/optimizers/shape_optimizer.h"
40 #include "tensorflow/core/grappler/utils/colocation.h"
41 #include "tensorflow/core/grappler/utils/functions.h"
42 #include "tensorflow/core/grappler/utils/topological_sort.h"
43 #include "tensorflow/core/grappler/verifiers/structure_verifier.h"
44 #include "tensorflow/core/lib/core/status.h"
45 #include "tensorflow/core/lib/gtl/map_util.h"
46 #include "tensorflow/core/util/dump_graph.h"
47 #include "tensorflow/core/util/ptr_util.h"
48 
49 namespace tensorflow {
50 namespace grappler {
51 
52 namespace {
53 
54 constexpr int kDefaultNumberOfIterations = 2;
55 constexpr int kDefaultMinGraphNodes = 4;
56 
NumEdges(const GraphDef & graph)57 int64 NumEdges(const GraphDef& graph) {
58   int64 num_edges = 0;
59   for (const auto& node : graph.node()) {
60     num_edges += node.input_size();
61   }
62   return num_edges;
63 }
64 
PrintSizesBeforeAfter(const GraphDef & before,const GraphDef & after)65 string PrintSizesBeforeAfter(const GraphDef& before, const GraphDef& after) {
66   return strings::StrCat("Graph size after: ", after.node_size(), " nodes (",
67                          after.node_size() - before.node_size(), "), ",
68                          NumEdges(after), " edges (",
69                          NumEdges(after) - NumEdges(before), ")");
70 }
71 
NumIterations(const RewriterConfig & cfg)72 int NumIterations(const RewriterConfig& cfg) {
73   return cfg.meta_optimizer_iterations() == RewriterConfig::DEFAULT_NUM_ITERS
74              ? kDefaultNumberOfIterations
75              : cfg.meta_optimizer_iterations();
76 }
77 
78 // Check if optimizer is allowed to run only once.
IsRunOnceOptimizer(const string & name)79 bool IsRunOnceOptimizer(const string& name) {
80   return name == "layout" || name == "memory_optimizer" ||
81          name == "loop_optimizer";
82 }
83 
84 // Check if the graphdef contains nodes that indicate TPU execution.
IsTPUGraphDef(const GraphDef & def)85 bool IsTPUGraphDef(const GraphDef& def) {
86   for (auto node : def.node()) {
87     if (node.op() == "TPUCompile" || node.op() == "TPUPartitionedCall") {
88       return true;
89     }
90   }
91   return false;
92 }
93 
DeadlineMicroSeconds(const RewriterConfig & cfg)94 uint64 DeadlineMicroSeconds(const RewriterConfig& cfg) {
95   const uint64 kFiveMinutesInUsec = 5 * 60 * 1000 * 1000;
96   if (cfg.meta_optimizer_timeout_ms() < 0) {
97     return 0;
98   } else {
99     return cfg.meta_optimizer_timeout_ms() == 0
100                ? Env::Default()->NowMicros() + kFiveMinutesInUsec
101                : Env::Default()->NowMicros() +
102                      cfg.meta_optimizer_timeout_ms() * 1000;
103   }
104 }
105 
CompressConstants(GraphDef * graph)106 Status CompressConstants(GraphDef* graph) {
107   for (int i = 0; i < graph->node_size(); ++i) {
108     NodeDef* node = graph->mutable_node(i);
109     if ((IsConstant(*node) || IsHostConstant(*node)) &&
110         HasNodeAttr(*node, "value")) {
111       AttrValue& attr_val = (*node->mutable_attr())["value"];
112       tensor::CompressTensorProtoInPlace(attr_val.mutable_tensor());
113     }
114   }
115   return Status::OK();
116 }
117 
118 }  // namespace
119 
120 #define MK_OPT(NAME, VALUE) \
121   if (optimizer == NAME) return std::unique_ptr<GraphOptimizer>(VALUE)
122 
MakeNewOptimizer(const string & optimizer) const123 std::unique_ptr<GraphOptimizer> MetaOptimizer::MakeNewOptimizer(
124     const string& optimizer) const {
125   MK_OPT("pruning", new ModelPruner());
126   MK_OPT("function", new FunctionOptimizer(cfg_.function_optimization()));
127   MK_OPT("constfold", new ConstantFolding(cpu_device_));
128   MK_OPT("shape", new ShapeOptimizer());
129   MK_OPT("remap", new Remapper(cfg_.remapping()));
130   MK_OPT("layout", new LayoutOptimizer());
131   MK_OPT("memory", new MemoryOptimizer(RewriterConfig::MANUAL));
132   MK_OPT("arithmetic", new ArithmeticOptimizer(cfg_.arithmetic_optimization()));
133   MK_OPT("autoparallel", new AutoParallel(cfg_.auto_parallel().num_replicas()));
134   MK_OPT("loop", new LoopOptimizer(cfg_.loop_optimization(), cpu_device_));
135   MK_OPT("dependency", new DependencyOptimizer(cfg_.dependency_optimization()));
136   MK_OPT("debug_stripper", new DebugStripper());
137   MK_OPT("scoped_allocator",
138          new ScopedAllocatorOptimizer(cfg_.scoped_allocator_optimization(),
139                                       cfg_.scoped_allocator_opts()));
140   MK_OPT("pin_to_host",
141          new PinToHostOptimizer(cfg_.pin_to_host_optimization()));
142 
143   return std::unique_ptr<GraphOptimizer>();
144 }
145 
146 #undef MK_OPT
147 
MetaOptimizer(DeviceBase * cpu_device,const ConfigProto & cfg)148 MetaOptimizer::MetaOptimizer(DeviceBase* cpu_device, const ConfigProto& cfg)
149     : cpu_device_(cpu_device),
150       config_proto_(cfg),
151       cfg_(*config_proto_.mutable_graph_options()->mutable_rewrite_options()) {
152   DCHECK(cpu_device_ == nullptr ||
153          cpu_device_->attributes().device_type() == "CPU");
154 }
155 
InitializeOptimizers(std::vector<std::unique_ptr<GraphOptimizer>> * optimizers) const156 Status MetaOptimizer::InitializeOptimizers(
157     std::vector<std::unique_ptr<GraphOptimizer>>* optimizers) const {
158   if (cfg_.disable_meta_optimizer()) {
159     return Status::OK();
160   }
161   if (!cfg_.disable_model_pruning()) {
162     optimizers->push_back(MakeUnique<ModelPruner>());
163   }
164   if (cfg_.implementation_selector() != RewriterConfig::OFF) {
165     optimizers->push_back(MakeUnique<ImplementationSelector>());
166   }
167   if (cfg_.function_optimization() != RewriterConfig::OFF) {
168     optimizers->push_back(
169         MakeUnique<FunctionOptimizer>(cfg_.function_optimization()));
170   }
171   if (cfg_.debug_stripper() == RewriterConfig::ON) {
172     optimizers->push_back(MakeUnique<DebugStripper>());
173   }
174   if (cfg_.constant_folding() != RewriterConfig::OFF) {
175     optimizers->push_back(
176         MakeUnique<ConstantFolding>(cfg_.constant_folding(), cpu_device_));
177   }
178   if (cfg_.shape_optimization() != RewriterConfig::OFF) {
179     optimizers->push_back(MakeUnique<ShapeOptimizer>());
180   }
181   if (cfg_.remapping() != RewriterConfig::OFF) {
182     optimizers->push_back(MakeUnique<Remapper>(cfg_.remapping()));
183   }
184   if (cfg_.pin_to_host_optimization() == RewriterConfig::ON) {
185     optimizers->push_back(MakeUnique<PinToHostOptimizer>());
186   }
187   if (cfg_.arithmetic_optimization() != RewriterConfig::OFF) {
188     optimizers->push_back(
189         MakeUnique<ArithmeticOptimizer>(cfg_.arithmetic_optimization()));
190   }
191   if (cfg_.loop_optimization() != RewriterConfig::OFF) {
192     optimizers->push_back(
193         MakeUnique<LoopOptimizer>(cfg_.loop_optimization(), cpu_device_));
194   }
195   if (cfg_.dependency_optimization() != RewriterConfig::OFF) {
196     optimizers->push_back(
197         MakeUnique<DependencyOptimizer>(cfg_.dependency_optimization()));
198   }
199   if (cfg_.layout_optimizer() != RewriterConfig::OFF) {
200     optimizers->push_back(MakeUnique<LayoutOptimizer>());
201   }
202   if (cfg_.memory_optimization() != RewriterConfig::NO_MEM_OPT) {
203     if (cfg_.memory_optimizer_target_node_name_scope().empty()) {
204       optimizers->push_back(
205           // Use the default target node name prefix "gradients/"
206           MakeUnique<MemoryOptimizer>(cfg_.memory_optimization()));
207     } else {
208       optimizers->push_back(MakeUnique<MemoryOptimizer>(
209           cfg_.memory_optimization(),
210           cfg_.memory_optimizer_target_node_name_scope()));
211     }
212   }
213   if (cfg_.auto_parallel().enable()) {
214     optimizers->push_back(
215         MakeUnique<AutoParallel>(cfg_.auto_parallel().num_replicas()));
216   }
217   if (cfg_.scoped_allocator_optimization()) {
218     optimizers->push_back(MakeUnique<ScopedAllocatorOptimizer>(
219         cfg_.scoped_allocator_optimization(), cfg_.scoped_allocator_opts()));
220   }
221   return InitializeCustomGraphOptimizers(std::set<string>(), optimizers);
222 }
223 
InitializeOptimizersByName(std::vector<std::unique_ptr<GraphOptimizer>> * optimizers) const224 Status MetaOptimizer::InitializeOptimizersByName(
225     std::vector<std::unique_ptr<GraphOptimizer>>* optimizers) const {
226   std::set<string> initialized_custom_optimizers;
227   for (const string& optimizer_name : cfg_.optimizers()) {
228     auto optimizer = MakeNewOptimizer(optimizer_name);
229     if (optimizer) {
230       VLOG(2) << "Registered default graph optimizer: " << optimizer_name;
231       optimizers->push_back(std::move(optimizer));
232       continue;
233     }
234 
235     auto custom_optimizer =
236         CustomGraphOptimizerRegistry::CreateByNameOrNull(optimizer_name);
237 
238     if (custom_optimizer) {
239       VLOG(2) << "Registered custom graph optimizer: " << optimizer_name;
240       TF_RETURN_IF_ERROR(custom_optimizer->Init(
241           GetCustomGraphOptimizerConfig(optimizer_name)));
242       optimizers->push_back(std::move(custom_optimizer));
243       initialized_custom_optimizers.insert(optimizer_name);
244     } else {
245       VLOG(2) << "Can't register an optimizer by name: " << optimizer_name;
246     }
247   }
248   return InitializeCustomGraphOptimizers(initialized_custom_optimizers,
249                                          optimizers);
250 }
251 
InitializeCustomGraphOptimizers(const std::set<string> & pre_initialized_optimizers,std::vector<std::unique_ptr<GraphOptimizer>> * optimizers) const252 Status MetaOptimizer::InitializeCustomGraphOptimizers(
253     const std::set<string>& pre_initialized_optimizers,
254     std::vector<std::unique_ptr<GraphOptimizer>>* optimizers) const {
255   for (const auto& optimizer_config : cfg_.custom_optimizers()) {
256     if (pre_initialized_optimizers.find(optimizer_config.name()) !=
257         pre_initialized_optimizers.end()) {
258       continue;
259     }
260 
261     auto custom_optimizer = CustomGraphOptimizerRegistry::CreateByNameOrNull(
262         optimizer_config.name());
263 
264     if (custom_optimizer) {
265       VLOG(2) << "Registered custom configurable graph optimizer: "
266               << optimizer_config.name();
267       TF_RETURN_IF_ERROR(custom_optimizer->Init(&optimizer_config));
268       optimizers->push_back(std::move(custom_optimizer));
269     } else {
270       // If there are no custom optimizers with given name, try to initalize a
271       // default optimizer. This way, custom configurable optimizers can be
272       // mixed with default optimizers in any order.
273       auto optimizer = MakeNewOptimizer(optimizer_config.name());
274       if (optimizer) {
275         VLOG(2) << "Registered default graph optimizer: "
276                 << optimizer_config.name();
277         optimizers->push_back(std::move(optimizer));
278         continue;
279       }
280       VLOG(2) << "Can't register an optimizer by name: "
281               << optimizer_config.name();
282     }
283   }
284   return Status::OK();
285 }
286 
287 const RewriterConfig::CustomGraphOptimizer*
GetCustomGraphOptimizerConfig(const string & name) const288 MetaOptimizer::GetCustomGraphOptimizerConfig(const string& name) const {
289   for (const auto& config : cfg_.custom_optimizers()) {
290     if (config.name() == name) {
291       return &config;
292     }
293   }
294   return nullptr;
295 }
296 
InitializeVerifiers(std::vector<std::unique_ptr<GraphVerifier>> * inter_optimizer_verifiers,std::vector<std::unique_ptr<GraphVerifier>> * post_optimization_verifiers) const297 void MetaOptimizer::InitializeVerifiers(
298     std::vector<std::unique_ptr<GraphVerifier>>* inter_optimizer_verifiers,
299     std::vector<std::unique_ptr<GraphVerifier>>* post_optimization_verifiers)
300     const {
301   if (cfg_.inter_optimizer_verifier_config().structure_verifier() ==
302       VerifierConfig::ON) {
303     inter_optimizer_verifiers->push_back(MakeUnique<StructureVerifier>());
304   }
305   if (cfg_.post_optimization_verifier_config().structure_verifier() ==
306       VerifierConfig::ON) {
307     post_optimization_verifiers->push_back(MakeUnique<StructureVerifier>());
308   }
309 }
310 
311 #define RUN_OPTIMIZER_OR_RETURN_IF_ERROR(optimizer)                            \
312   {                                                                            \
313     const Status status = RunOptimizer(optimizer, cluster, &optimized_item,    \
314                                        optimized_graph, &optimization_result); \
315     if (status.ok()) {                                                         \
316       is_optimized = true;                                                     \
317     } else if (cfg_.fail_on_optimizer_errors()) {                              \
318       VLOG(2) << "Optimizer '" << optimizer->name() << "' failed: " << status; \
319       TF_RETURN_IF_ERROR(status);                                              \
320     }                                                                          \
321   }
322 
OptimizeGraph(Cluster * cluster,const GrapplerItem & item,GraphDef * optimized_graph)323 Status MetaOptimizer::OptimizeGraph(Cluster* cluster, const GrapplerItem& item,
324                                     GraphDef* optimized_graph) {
325   int min_graph_nodes = cfg_.min_graph_nodes() == 0 ? kDefaultMinGraphNodes
326                                                     : cfg_.min_graph_nodes();
327   if (item.graph.node_size() < min_graph_nodes) {
328     VLOG(3) << "Skipping optimization, graph has less than " << min_graph_nodes
329             << " nodes.";
330     *optimized_graph = item.graph;
331     return Status::OK();
332   }
333 
334   std::vector<std::unique_ptr<GraphOptimizer>> optimizers;
335   if (cfg_.optimizers().empty()) {
336     TF_RETURN_IF_ERROR(InitializeOptimizers(&optimizers));
337   } else {
338     TF_RETURN_IF_ERROR(InitializeOptimizersByName(&optimizers));
339   }
340 
341   // Initialize the configured verifiers.
342   std::vector<std::unique_ptr<GraphVerifier>> inter_optimizer_verifiers;
343   std::vector<std::unique_ptr<GraphVerifier>> post_optimization_verifiers;
344   InitializeVerifiers(&inter_optimizer_verifiers, &post_optimization_verifiers);
345   if (inter_optimizer_verifiers.empty()) {
346     VLOG(2) << "No inter optimizer verifiers have been configured";
347   } else {
348     VLOG(2) << inter_optimizer_verifiers.size()
349             << " inter optimizer verifiers have been configured";
350   }
351   if (post_optimization_verifiers.empty()) {
352     VLOG(2) << "No post optimization verifiers have been configured";
353   } else {
354     VLOG(2) << post_optimization_verifiers.size()
355             << " post optimization verifiers have been configured";
356   }
357 
358   VLOG(2) << "Optimize GrapplerItem: item.id=" << item.id
359           << " num_optimizers=" << optimizers.size()
360           << ", num nodes = " << item.graph.node_size();
361 
362   if (optimizers.empty()) {
363     VLOG(3) << "Skipping graph optimization, no optimizers registered";
364     *optimized_graph = item.graph;
365     return Status::OK();
366   }
367 
368   // Invariant: optimized_graph contains the most recently optimized version of
369   // the graph.
370   GrapplerItem optimized_item = item;
371   optimized_graph->Swap(&optimized_item.graph);
372 
373   bool is_optimized = false;
374   GraphOptimizationResult optimization_result(item.id);
375   GraphOptimizer* fusion_optimizer = nullptr;
376   GraphOptimizer* sa_optimizer = nullptr;
377 
378   for (int iteration = 0; iteration < NumIterations(cfg_); ++iteration) {
379     // Don't bother optimizing further if the graph is already tiny.
380     if (optimized_graph->node_size() < min_graph_nodes) {
381       VLOG(3) << "Stopping after iteration " << iteration
382               << ", graph is tiny (#nodes = " << optimized_graph->node_size()
383               << "  < " << min_graph_nodes << ")";
384       break;
385     }
386 
387     VLOG(4) << "Starting optimization iteration " << iteration;
388     if (VLOG_IS_ON(4)) {
389       DumpGraphDefToFile(
390           strings::StrCat("before_MetaOptimizer_iteration_", iteration, "_",
391                           reinterpret_cast<uintptr_t>(optimized_graph)),
392           *optimized_graph);
393     }
394     for (const auto& optimizer : optimizers) {
395       GRAPPLER_RETURN_IF_DEADLINE_EXCEEDED();
396       // Some optimizers can run only once.
397       if (iteration > 0 && IsRunOnceOptimizer(optimizer->name())) continue;
398       // Some must run only on the last iteration.
399       if (optimizer->name() == "scoped_allocator_optimizer") {
400         if (sa_optimizer == nullptr) sa_optimizer = optimizer.get();
401         continue;
402       }
403       if (optimizer->name() == "xla-fusion") {
404         if (fusion_optimizer == nullptr) fusion_optimizer = optimizer.get();
405         continue;
406       }
407       RUN_OPTIMIZER_OR_RETURN_IF_ERROR(optimizer.get());
408 
409       if (VLOG_IS_ON(4)) {
410         DumpGraphDefToFile(
411             strings::StrCat("after_MetaOptimizer_iteration_", iteration, "_",
412                             optimizer->name(), "_",
413                             reinterpret_cast<uintptr_t>(optimized_graph)),
414             *optimized_graph);
415       }
416       for (const auto& verifier : inter_optimizer_verifiers) {
417         // TODO(ashwinm): Need to enforce verification_deadline.
418         TF_RETURN_IF_ERROR(verifier->Verify(*optimized_graph));
419       }
420     }
421     if (VLOG_IS_ON(4)) {
422       DumpGraphDefToFile(
423           strings::StrCat("after_MetaOptimizer_iteration_", iteration, "_",
424                           reinterpret_cast<uintptr_t>(optimized_graph)),
425           *optimized_graph);
426     }
427     // TODO(ashwinm): Need to enforce verification_deadline.
428     for (const auto& verifier : post_optimization_verifiers) {
429       TF_RETURN_IF_ERROR(verifier->Verify(*optimized_graph));
430     }
431   }
432 
433   // Run fusion optimizer if requested after all other optimizers since: 1) it
434   // doesn't need to be called more than once. 2) we don't want subsequent
435   // optimization passes to break the fusion clusters. We could potentially
436   // encapsulate the fusion clusters right away, but that will prevent a lot of
437   // optimizations from taking place since we don't have shape inference for
438   // functions, and we can't optimize across function boundaries.
439   if (fusion_optimizer != nullptr) {
440     RUN_OPTIMIZER_OR_RETURN_IF_ERROR(fusion_optimizer);
441   }
442 
443   // ScopedAllocatorOptimizer must run last.
444   if (sa_optimizer != nullptr) {
445     RUN_OPTIMIZER_OR_RETURN_IF_ERROR(sa_optimizer);
446   }
447 
448   // Compress the constants in the final graph.
449   TF_RETURN_IF_ERROR(CompressConstants(optimized_graph));
450 
451   // Record graph optimization result.
452   optimization_results_.push_back(optimization_result);
453 
454   if (is_optimized) {
455     TF_RETURN_IF_ERROR(TopologicalSort(optimized_graph));
456     ReassignColocation(optimized_graph);
457     // Make sure that the optimizers preserved the graph version.
458     DCHECK_EQ(optimized_graph->versions().producer(),
459               item.graph.versions().producer());
460   }
461 
462   return Status::OK();
463 }
464 
465 #undef RUN_OPTIMIZER_OR_RETURN_IF_ERROR
466 
RunOptimizer(GraphOptimizer * optimizer,Cluster * cluster,GrapplerItem * optimized_item,GraphDef * optimized_graph,GraphOptimizationResult * optimization_result)467 Status MetaOptimizer::RunOptimizer(
468     GraphOptimizer* optimizer, Cluster* cluster, GrapplerItem* optimized_item,
469     GraphDef* optimized_graph, GraphOptimizationResult* optimization_result) {
470   uint64 start_us = Env::Default()->NowMicros();
471   // This swaps the current optimized_graph into optimized item and
472   // resets optimized_graph to an empty graph.
473   optimized_graph->Swap(&optimized_item->graph);
474   *optimized_graph = GraphDef();
475   optimizer->set_deadline_usec(this->deadline_usec());
476   Status status =
477       optimizer->Optimize(cluster, *optimized_item, optimized_graph);
478   uint64 end_us = Env::Default()->NowMicros();
479 
480   string result;
481   if (!status.ok()) {
482     optimized_graph->Swap(&optimized_item->graph);
483     result = status.ToString();
484   } else {
485     float duration_ms = (end_us - start_us) / 1000.0f;
486     result = strings::StrCat(
487         PrintSizesBeforeAfter(optimized_item->graph, *optimized_graph),
488         ", time = ", duration_ms, "ms.");
489   }
490   VLOG(1) << optimizer->name() << ": " << result;
491 
492   OptimizerResult optimizer_result{optimizer->name(), result};
493   optimization_result->results.push_back(optimizer_result);
494   return status;
495 }
496 
Optimize(Cluster * cluster,const GrapplerItem & item,GraphDef * optimized_graph)497 Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
498                                GraphDef* optimized_graph) {
499   VLOG(1) << "Starting optimization for grappler item: " << item.id;
500   optimization_results_.clear();
501 
502   // Constructs a FunctionLibraryDefinition with functions that are reachable
503   // from the nodes of the graph.
504   const auto minimized_flib =
505       [](const GraphDef& graph) -> FunctionLibraryDefinition {
506     return FunctionLibraryDefinition(OpRegistry::Global(), graph.library())
507         .ReachableDefinitions(graph);
508   };
509 
510   // 0. Original graph might contain a huge function library, that is mostly
511   // unused. This library copied over by each individual Grappler optimizer,
512   // which adds a huge overhead. Before starting optimization passes we just
513   // remove all the unreachable functions.
514   // TODO(ezhulenev): Construct reachable function library definition directly
515   // from the proto without constructing temporary FunctionLibraryDefinition.
516   GraphDef trimmed_graph;  // do not copy graph with a potentially huge library
517   *trimmed_graph.mutable_node() = item.graph.node();
518   *trimmed_graph.mutable_versions() = item.graph.versions();
519   *trimmed_graph.mutable_library() = minimized_flib(item.graph).ToProto();
520 
521   GrapplerItem trimmed_item = item.WithGraph(std::move(trimmed_graph));
522 
523   VLOG(1) << absl::Substitute(
524       "Deleted $0 unreachable functions from the graph (library size = $1)",
525       item.graph.library().function_size() -
526           trimmed_item.graph.library().function_size(),
527       trimmed_item.graph.library().function_size());
528 
529   // 1. Optimize main graph
530   TF_RETURN_IF_ERROR(OptimizeGraph(cluster, trimmed_item, optimized_graph));
531   VLOG(1) << "Optimized main graph.";
532   GRAPPLER_RETURN_IF_DEADLINE_EXCEEDED();
533 
534   // Skip optimizing functions if this is a TPU graph. Currently, Grappler
535   // passes do not handle TPU functions correctly in a variety of ways (Note
536   // that due to the pre-placement TPU graph rewriting passes, the TPU-related
537   // ops are encapsulated away into functions). For example, TPU graphs contain
538   // TPUReplicateMetadata node that carries relevant TPU metadata and Grappler
539   // passes could prune that away. Grappler passes could also cause issues
540   // around shape inference. Since the desired and existing behavior is to not
541   // optimize TPU functions with Grappler, this check preserves that.
542   if (IsTPUGraphDef(*optimized_graph)) {
543     VLOG(2) << "Skipping optimizing funcs for TPU graphs";
544     if (VLOG_IS_ON(1)) {
545       DumpGraphDefToFile(
546           strings::StrCat("after_MetaOptimizer_",
547                           reinterpret_cast<uintptr_t>(optimized_graph)),
548           *optimized_graph);
549     }
550     return Status::OK();
551   }
552 
553   // 2. Optimize functions reachable from the optimized graph.
554   FunctionLibraryDefinition flib = minimized_flib(*optimized_graph);
555 
556   // Find functions for which we might need to compute a gradient at runtime.
557   absl::flat_hash_set<string> differentiable_functions;
558   for (const NodeDef& node : optimized_graph->node()) {
559     if (IsSymbolicGradient(node)) {
560       const auto* f_attr = gtl::FindOrNull(node.attr(), "f");
561       if (f_attr) differentiable_functions.insert(f_attr->func().name());
562     }
563   }
564 
565   // Optimize each function only once.
566   absl::flat_hash_set<string> optimized_funcs;
567   bool optimize_function_library =
568       item.optimization_options().optimize_function_library;
569 
570   while (optimize_function_library) {
571     optimize_function_library = false;
572 
573     for (const FunctionDef& func : optimized_graph->library().function()) {
574       GRAPPLER_RETURN_IF_DEADLINE_EXCEEDED();
575 
576       const string& func_name = func.signature().name();
577 
578       // Skip functions that are not reachable from the optimized graph.
579       if (!flib.Contains(func_name)) continue;
580 
581       // Skip already optimized functions.
582       if (optimized_funcs.find(func_name) != optimized_funcs.end()) continue;
583 
584       // Skip parametrized functions (function type or body is defined only at
585       // function call time by caller node attributes).
586       // They should be specialized to their instantiation type parameters by
587       // the function optimizer, before we can optimize function body.
588       if (IsParametrized(func)) continue;
589 
590       VLOG(3) << "Optimize function: function=" << func_name;
591 
592       // Function optimization might specialize nested function calls, so we
593       // have to reset the flag and do at least one more pass over the library.
594       optimize_function_library = true;
595       optimized_funcs.insert(func_name);
596 
597       // Make a GrapplerItem from a FunctionDef.
598       GrapplerFunctionItem func_item;
599       TF_RETURN_IF_ERROR(MakeGrapplerFunctionItem(
600           func, flib, trimmed_item.graph.versions().producer(), &func_item));
601 
602       // If we need to compute the gradient of optimized function at runtime, we
603       // can't perform non-differentiable rewrites.
604       if (differentiable_functions.find(func_name) !=
605           differentiable_functions.end()) {
606         func_item.optimization_options().allow_non_differentiable_rewrites =
607             false;
608       }
609 
610       // Function item is allowed to use all devices from the main graph.
611       Status added_devices = func_item.AddDevices(item);
612       if (!added_devices.ok()) {
613         VLOG(3) << added_devices.error_message();
614       }
615 
616       // We are not allowed to prune certain types of ops from the graph
617       // instantiated by the function definition, because we must guarantee
618       // function execution semantics wrt side effects (see
619       // function_optimizer.cc).
620       func_item.optimization_options().allow_pruning_stateful_and_dataset_ops =
621           false;
622 
623       // Optimize function body graph.
624       GraphDef optimized_func_graph;
625       TF_RETURN_IF_ERROR(
626           OptimizeGraph(cluster, func_item, &optimized_func_graph));
627 
628       // Function body optimization might have created new specialized
629       // functions for each instantiation context. Add them to the library.
630       for (const FunctionDef& func_def :
631            optimized_func_graph.library().function()) {
632         if (flib.Find(func_def.signature().name()) == nullptr) {
633           TF_RETURN_IF_ERROR(flib.AddFunctionDef(func_def));
634         }
635       }
636 
637       // Convert optimized graph back to FunctionDef.
638       FunctionDef optimized_func;
639       func_item.SwapFunctionBody(std::move(optimized_func_graph));
640       TF_RETURN_IF_ERROR(MakeFunctionDef(func_item, flib, &optimized_func));
641 
642       // Replace optimized function with a new FunctionDef.
643       TF_RETURN_IF_ERROR(flib.ReplaceFunction(func_name, optimized_func));
644     }
645 
646     // If optimized at least one function, update the graph library.
647     if (optimize_function_library) {
648       *optimized_graph->mutable_library() = flib.ToProto();
649     }
650   }
651 
652   VLOG(1) << "Optimized " << optimized_funcs.size()
653           << " functions: " << str_util::Join(optimized_funcs, ", ");
654 
655   if (VLOG_IS_ON(1)) {
656     DumpGraphDefToFile(
657         strings::StrCat("after_MetaOptimizer_",
658                         reinterpret_cast<uintptr_t>(optimized_graph)),
659         *optimized_graph);
660   }
661   return Status::OK();
662 }
663 
PrintResult()664 void MetaOptimizer::PrintResult() {
665   for (const GraphOptimizationResult& graph_result : optimization_results_) {
666     LOG(INFO) << "Optimization results for grappler item: " << graph_result.id;
667     for (const OptimizerResult& result : graph_result.results) {
668       LOG(INFO) << "  " << result.optimizer_name << ": " << result.result;
669     }
670   }
671 }
672 
Feedback(Cluster * cluster,const GrapplerItem & item,const GraphDef & pruned_graph,double result)673 void MetaOptimizer::Feedback(Cluster* cluster, const GrapplerItem& item,
674                              const GraphDef& pruned_graph, double result) {
675   // Nothing to do for MetaOptimizer.
676 }
677 
MetaOptimizerEnabled(const ConfigProto & cfg)678 bool MetaOptimizerEnabled(const ConfigProto& cfg) {
679   const auto& rewrite_cfg = cfg.graph_options().rewrite_options();
680   if (rewrite_cfg.disable_meta_optimizer()) {
681     return false;
682   }
683   return !rewrite_cfg.disable_model_pruning() ||
684          rewrite_cfg.layout_optimizer() != RewriterConfig::OFF ||
685          rewrite_cfg.function_optimization() != RewriterConfig::OFF ||
686          rewrite_cfg.constant_folding() != RewriterConfig::OFF ||
687          rewrite_cfg.shape_optimization() != RewriterConfig::OFF ||
688          rewrite_cfg.remapping() != RewriterConfig::OFF ||
689          rewrite_cfg.arithmetic_optimization() != RewriterConfig::OFF ||
690          rewrite_cfg.loop_optimization() != RewriterConfig::OFF ||
691          rewrite_cfg.dependency_optimization() != RewriterConfig::OFF ||
692          rewrite_cfg.auto_parallel().enable() ||
693          rewrite_cfg.memory_optimization() != RewriterConfig::NO_MEM_OPT ||
694          rewrite_cfg.debug_stripper() == RewriterConfig::ON ||
695          rewrite_cfg.scoped_allocator_optimization() == RewriterConfig::ON ||
696          rewrite_cfg.pin_to_host_optimization() == RewriterConfig::ON ||
697          !rewrite_cfg.optimizers().empty() ||
698          !rewrite_cfg.custom_optimizers().empty();
699 }
700 
RunMetaOptimizer(const GrapplerItem & item,const ConfigProto & cfg,DeviceBase * cpu_device,Cluster * cluster,GraphDef * optimized_graph)701 Status RunMetaOptimizer(const GrapplerItem& item, const ConfigProto& cfg,
702                         DeviceBase* cpu_device, Cluster* cluster,
703                         GraphDef* optimized_graph) {
704   MetaOptimizer optimizer(cpu_device, cfg);
705   optimizer.set_deadline_usec(
706       DeadlineMicroSeconds(cfg.graph_options().rewrite_options()));
707   Status status = optimizer.Optimize(cluster, item, optimized_graph);
708   if (!status.ok()) {
709     *optimized_graph = item.graph;
710   }
711   return status;
712 }
713 
OptimizeGraph(std::vector<string> ret_node_names,std::vector<string> keep_node_names,FunctionLibraryDefinition * flib,const DeviceSet & device_set,Device * cpu_device,const ConfigProto & config_proto,const string & grappler_item_id,const GrapplerItem::OptimizationOptions & optimization_options,std::unique_ptr<tensorflow::Graph> * g)714 Status OptimizeGraph(
715     std::vector<string> ret_node_names, std::vector<string> keep_node_names,
716     FunctionLibraryDefinition* flib, const DeviceSet& device_set,
717     Device* cpu_device, const ConfigProto& config_proto,
718     const string& grappler_item_id,
719     const GrapplerItem::OptimizationOptions& optimization_options,
720     std::unique_ptr<tensorflow::Graph>* g) {
721   if (!tensorflow::grappler::MetaOptimizerEnabled(config_proto)) {
722     return Status::OK();
723   }
724 
725   tensorflow::grappler::GrapplerItem item;
726   item.id = grappler_item_id;
727   item.optimization_options() = optimization_options;
728 
729   // Add all available devices so that inlined function can be placed.
730   for (const Device* d : device_set.devices()) {
731     Status added_device = item.AddDevice(d->name());
732     if (!added_device.ok()) VLOG(3) << added_device.error_message();
733   }
734 
735   // Add fetches so that the graph can be pruned.
736   item.fetch.swap(ret_node_names);
737 
738   // Add noes that can't be removed from the graph.
739   item.keep_ops = std::move(keep_node_names);
740 
741   (*g)->ToGraphDef(&item.graph);
742 
743   if (flib) {
744     *item.graph.mutable_library() = flib->ToProto();
745   }
746 
747   tensorflow::GraphDef out_graph;
748 
749   tensorflow::grappler::VirtualCluster cluster(&device_set);
750 
751   // TODO(nareshmodi): Consider adding and using the more generic GraphOptions
752   // proto (which also contain the OptimizerOptions).
753   TF_RETURN_IF_ERROR(tensorflow::grappler::RunMetaOptimizer(
754       item, config_proto, cpu_device, &cluster, &out_graph));
755 
756   std::unique_ptr<tensorflow::Graph> optimized_graph(
757       new tensorflow::Graph(OpRegistry::Global()));
758   TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(GraphConstructorOptions(),
759                                             out_graph, optimized_graph.get()));
760 
761   // Copy optimized functions back to the overlay lib.
762   if (flib) {
763     for (const FunctionDef& fdef : out_graph.library().function()) {
764       const string& func_name = fdef.signature().name();
765       if (flib->Contains(func_name)) {
766         TF_RETURN_IF_ERROR(flib->ReplaceFunction(func_name, fdef));
767       } else {
768         TF_RETURN_IF_ERROR(flib->AddFunctionDef(fdef));
769       }
770     }
771   }
772 
773   *g = std::move(optimized_graph);
774 
775   // The graph conversion sets the requested device names but not the
776   // assigned device names. However, since at this point the graph is
777   // placed TF expects an assigned device name for every node. Therefore
778   // we copy the requested device into the assigned device field.
779   for (Node* node : (*g)->nodes()) {
780     if (node->IsOp() && node->assigned_device_name().empty()) {
781       if (node->requested_device().empty()) {
782         return errors::Internal(
783             "Either placer did not place the node or Grappler did not "
784             "copy the assigned device. Contact Grappler team since latter "
785             "is more likely. Node=",
786             node->name(), " Graph: ", (*g)->ToGraphDefDebug().DebugString());
787       }
788       node->set_assigned_device_name(node->requested_device());
789     }
790   }
791 
792   return Status::OK();
793 }
794 
795 }  // namespace grappler
796 }  // namespace tensorflow
797