1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/grappler/optimizers/meta_optimizer.h"
17 #include "absl/strings/substitute.h"
18 #include "tensorflow/core/common_runtime/function.h"
19 #include "tensorflow/core/framework/function.pb.h"
20 #include "tensorflow/core/framework/tensor_util.h"
21 #include "tensorflow/core/framework/versions.pb.h"
22 #include "tensorflow/core/graph/graph_constructor.h"
23 #include "tensorflow/core/grappler/clusters/virtual_cluster.h"
24 #include "tensorflow/core/grappler/optimizers/arithmetic_optimizer.h"
25 #include "tensorflow/core/grappler/optimizers/auto_parallel.h"
26 #include "tensorflow/core/grappler/optimizers/constant_folding.h"
27 #include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
28 #include "tensorflow/core/grappler/optimizers/debug_stripper.h"
29 #include "tensorflow/core/grappler/optimizers/dependency_optimizer.h"
30 #include "tensorflow/core/grappler/optimizers/function_optimizer.h"
31 #include "tensorflow/core/grappler/optimizers/implementation_selector.h"
32 #include "tensorflow/core/grappler/optimizers/layout_optimizer.h"
33 #include "tensorflow/core/grappler/optimizers/loop_optimizer.h"
34 #include "tensorflow/core/grappler/optimizers/memory_optimizer.h"
35 #include "tensorflow/core/grappler/optimizers/model_pruner.h"
36 #include "tensorflow/core/grappler/optimizers/pin_to_host_optimizer.h"
37 #include "tensorflow/core/grappler/optimizers/remapper.h"
38 #include "tensorflow/core/grappler/optimizers/scoped_allocator_optimizer.h"
39 #include "tensorflow/core/grappler/optimizers/shape_optimizer.h"
40 #include "tensorflow/core/grappler/utils/colocation.h"
41 #include "tensorflow/core/grappler/utils/functions.h"
42 #include "tensorflow/core/grappler/utils/topological_sort.h"
43 #include "tensorflow/core/grappler/verifiers/structure_verifier.h"
44 #include "tensorflow/core/lib/core/status.h"
45 #include "tensorflow/core/lib/gtl/map_util.h"
46 #include "tensorflow/core/util/dump_graph.h"
47 #include "tensorflow/core/util/ptr_util.h"
48
49 namespace tensorflow {
50 namespace grappler {
51
52 namespace {
53
54 constexpr int kDefaultNumberOfIterations = 2;
55 constexpr int kDefaultMinGraphNodes = 4;
56
NumEdges(const GraphDef & graph)57 int64 NumEdges(const GraphDef& graph) {
58 int64 num_edges = 0;
59 for (const auto& node : graph.node()) {
60 num_edges += node.input_size();
61 }
62 return num_edges;
63 }
64
PrintSizesBeforeAfter(const GraphDef & before,const GraphDef & after)65 string PrintSizesBeforeAfter(const GraphDef& before, const GraphDef& after) {
66 return strings::StrCat("Graph size after: ", after.node_size(), " nodes (",
67 after.node_size() - before.node_size(), "), ",
68 NumEdges(after), " edges (",
69 NumEdges(after) - NumEdges(before), ")");
70 }
71
NumIterations(const RewriterConfig & cfg)72 int NumIterations(const RewriterConfig& cfg) {
73 return cfg.meta_optimizer_iterations() == RewriterConfig::DEFAULT_NUM_ITERS
74 ? kDefaultNumberOfIterations
75 : cfg.meta_optimizer_iterations();
76 }
77
78 // Check if optimizer is allowed to run only once.
IsRunOnceOptimizer(const string & name)79 bool IsRunOnceOptimizer(const string& name) {
80 return name == "layout" || name == "memory_optimizer" ||
81 name == "loop_optimizer";
82 }
83
84 // Check if the graphdef contains nodes that indicate TPU execution.
IsTPUGraphDef(const GraphDef & def)85 bool IsTPUGraphDef(const GraphDef& def) {
86 for (auto node : def.node()) {
87 if (node.op() == "TPUCompile" || node.op() == "TPUPartitionedCall") {
88 return true;
89 }
90 }
91 return false;
92 }
93
DeadlineMicroSeconds(const RewriterConfig & cfg)94 uint64 DeadlineMicroSeconds(const RewriterConfig& cfg) {
95 const uint64 kFiveMinutesInUsec = 5 * 60 * 1000 * 1000;
96 if (cfg.meta_optimizer_timeout_ms() < 0) {
97 return 0;
98 } else {
99 return cfg.meta_optimizer_timeout_ms() == 0
100 ? Env::Default()->NowMicros() + kFiveMinutesInUsec
101 : Env::Default()->NowMicros() +
102 cfg.meta_optimizer_timeout_ms() * 1000;
103 }
104 }
105
CompressConstants(GraphDef * graph)106 Status CompressConstants(GraphDef* graph) {
107 for (int i = 0; i < graph->node_size(); ++i) {
108 NodeDef* node = graph->mutable_node(i);
109 if ((IsConstant(*node) || IsHostConstant(*node)) &&
110 HasNodeAttr(*node, "value")) {
111 AttrValue& attr_val = (*node->mutable_attr())["value"];
112 tensor::CompressTensorProtoInPlace(attr_val.mutable_tensor());
113 }
114 }
115 return Status::OK();
116 }
117
118 } // namespace
119
120 #define MK_OPT(NAME, VALUE) \
121 if (optimizer == NAME) return std::unique_ptr<GraphOptimizer>(VALUE)
122
MakeNewOptimizer(const string & optimizer) const123 std::unique_ptr<GraphOptimizer> MetaOptimizer::MakeNewOptimizer(
124 const string& optimizer) const {
125 MK_OPT("pruning", new ModelPruner());
126 MK_OPT("function", new FunctionOptimizer(cfg_.function_optimization()));
127 MK_OPT("constfold", new ConstantFolding(cpu_device_));
128 MK_OPT("shape", new ShapeOptimizer());
129 MK_OPT("remap", new Remapper(cfg_.remapping()));
130 MK_OPT("layout", new LayoutOptimizer());
131 MK_OPT("memory", new MemoryOptimizer(RewriterConfig::MANUAL));
132 MK_OPT("arithmetic", new ArithmeticOptimizer(cfg_.arithmetic_optimization()));
133 MK_OPT("autoparallel", new AutoParallel(cfg_.auto_parallel().num_replicas()));
134 MK_OPT("loop", new LoopOptimizer(cfg_.loop_optimization(), cpu_device_));
135 MK_OPT("dependency", new DependencyOptimizer(cfg_.dependency_optimization()));
136 MK_OPT("debug_stripper", new DebugStripper());
137 MK_OPT("scoped_allocator",
138 new ScopedAllocatorOptimizer(cfg_.scoped_allocator_optimization(),
139 cfg_.scoped_allocator_opts()));
140 MK_OPT("pin_to_host",
141 new PinToHostOptimizer(cfg_.pin_to_host_optimization()));
142
143 return std::unique_ptr<GraphOptimizer>();
144 }
145
146 #undef MK_OPT
147
MetaOptimizer(DeviceBase * cpu_device,const ConfigProto & cfg)148 MetaOptimizer::MetaOptimizer(DeviceBase* cpu_device, const ConfigProto& cfg)
149 : cpu_device_(cpu_device),
150 config_proto_(cfg),
151 cfg_(*config_proto_.mutable_graph_options()->mutable_rewrite_options()) {
152 DCHECK(cpu_device_ == nullptr ||
153 cpu_device_->attributes().device_type() == "CPU");
154 }
155
InitializeOptimizers(std::vector<std::unique_ptr<GraphOptimizer>> * optimizers) const156 Status MetaOptimizer::InitializeOptimizers(
157 std::vector<std::unique_ptr<GraphOptimizer>>* optimizers) const {
158 if (cfg_.disable_meta_optimizer()) {
159 return Status::OK();
160 }
161 if (!cfg_.disable_model_pruning()) {
162 optimizers->push_back(MakeUnique<ModelPruner>());
163 }
164 if (cfg_.implementation_selector() != RewriterConfig::OFF) {
165 optimizers->push_back(MakeUnique<ImplementationSelector>());
166 }
167 if (cfg_.function_optimization() != RewriterConfig::OFF) {
168 optimizers->push_back(
169 MakeUnique<FunctionOptimizer>(cfg_.function_optimization()));
170 }
171 if (cfg_.debug_stripper() == RewriterConfig::ON) {
172 optimizers->push_back(MakeUnique<DebugStripper>());
173 }
174 if (cfg_.constant_folding() != RewriterConfig::OFF) {
175 optimizers->push_back(
176 MakeUnique<ConstantFolding>(cfg_.constant_folding(), cpu_device_));
177 }
178 if (cfg_.shape_optimization() != RewriterConfig::OFF) {
179 optimizers->push_back(MakeUnique<ShapeOptimizer>());
180 }
181 if (cfg_.remapping() != RewriterConfig::OFF) {
182 optimizers->push_back(MakeUnique<Remapper>(cfg_.remapping()));
183 }
184 if (cfg_.pin_to_host_optimization() == RewriterConfig::ON) {
185 optimizers->push_back(MakeUnique<PinToHostOptimizer>());
186 }
187 if (cfg_.arithmetic_optimization() != RewriterConfig::OFF) {
188 optimizers->push_back(
189 MakeUnique<ArithmeticOptimizer>(cfg_.arithmetic_optimization()));
190 }
191 if (cfg_.loop_optimization() != RewriterConfig::OFF) {
192 optimizers->push_back(
193 MakeUnique<LoopOptimizer>(cfg_.loop_optimization(), cpu_device_));
194 }
195 if (cfg_.dependency_optimization() != RewriterConfig::OFF) {
196 optimizers->push_back(
197 MakeUnique<DependencyOptimizer>(cfg_.dependency_optimization()));
198 }
199 if (cfg_.layout_optimizer() != RewriterConfig::OFF) {
200 optimizers->push_back(MakeUnique<LayoutOptimizer>());
201 }
202 if (cfg_.memory_optimization() != RewriterConfig::NO_MEM_OPT) {
203 if (cfg_.memory_optimizer_target_node_name_scope().empty()) {
204 optimizers->push_back(
205 // Use the default target node name prefix "gradients/"
206 MakeUnique<MemoryOptimizer>(cfg_.memory_optimization()));
207 } else {
208 optimizers->push_back(MakeUnique<MemoryOptimizer>(
209 cfg_.memory_optimization(),
210 cfg_.memory_optimizer_target_node_name_scope()));
211 }
212 }
213 if (cfg_.auto_parallel().enable()) {
214 optimizers->push_back(
215 MakeUnique<AutoParallel>(cfg_.auto_parallel().num_replicas()));
216 }
217 if (cfg_.scoped_allocator_optimization()) {
218 optimizers->push_back(MakeUnique<ScopedAllocatorOptimizer>(
219 cfg_.scoped_allocator_optimization(), cfg_.scoped_allocator_opts()));
220 }
221 return InitializeCustomGraphOptimizers(std::set<string>(), optimizers);
222 }
223
InitializeOptimizersByName(std::vector<std::unique_ptr<GraphOptimizer>> * optimizers) const224 Status MetaOptimizer::InitializeOptimizersByName(
225 std::vector<std::unique_ptr<GraphOptimizer>>* optimizers) const {
226 std::set<string> initialized_custom_optimizers;
227 for (const string& optimizer_name : cfg_.optimizers()) {
228 auto optimizer = MakeNewOptimizer(optimizer_name);
229 if (optimizer) {
230 VLOG(2) << "Registered default graph optimizer: " << optimizer_name;
231 optimizers->push_back(std::move(optimizer));
232 continue;
233 }
234
235 auto custom_optimizer =
236 CustomGraphOptimizerRegistry::CreateByNameOrNull(optimizer_name);
237
238 if (custom_optimizer) {
239 VLOG(2) << "Registered custom graph optimizer: " << optimizer_name;
240 TF_RETURN_IF_ERROR(custom_optimizer->Init(
241 GetCustomGraphOptimizerConfig(optimizer_name)));
242 optimizers->push_back(std::move(custom_optimizer));
243 initialized_custom_optimizers.insert(optimizer_name);
244 } else {
245 VLOG(2) << "Can't register an optimizer by name: " << optimizer_name;
246 }
247 }
248 return InitializeCustomGraphOptimizers(initialized_custom_optimizers,
249 optimizers);
250 }
251
InitializeCustomGraphOptimizers(const std::set<string> & pre_initialized_optimizers,std::vector<std::unique_ptr<GraphOptimizer>> * optimizers) const252 Status MetaOptimizer::InitializeCustomGraphOptimizers(
253 const std::set<string>& pre_initialized_optimizers,
254 std::vector<std::unique_ptr<GraphOptimizer>>* optimizers) const {
255 for (const auto& optimizer_config : cfg_.custom_optimizers()) {
256 if (pre_initialized_optimizers.find(optimizer_config.name()) !=
257 pre_initialized_optimizers.end()) {
258 continue;
259 }
260
261 auto custom_optimizer = CustomGraphOptimizerRegistry::CreateByNameOrNull(
262 optimizer_config.name());
263
264 if (custom_optimizer) {
265 VLOG(2) << "Registered custom configurable graph optimizer: "
266 << optimizer_config.name();
267 TF_RETURN_IF_ERROR(custom_optimizer->Init(&optimizer_config));
268 optimizers->push_back(std::move(custom_optimizer));
269 } else {
270 // If there are no custom optimizers with given name, try to initalize a
271 // default optimizer. This way, custom configurable optimizers can be
272 // mixed with default optimizers in any order.
273 auto optimizer = MakeNewOptimizer(optimizer_config.name());
274 if (optimizer) {
275 VLOG(2) << "Registered default graph optimizer: "
276 << optimizer_config.name();
277 optimizers->push_back(std::move(optimizer));
278 continue;
279 }
280 VLOG(2) << "Can't register an optimizer by name: "
281 << optimizer_config.name();
282 }
283 }
284 return Status::OK();
285 }
286
287 const RewriterConfig::CustomGraphOptimizer*
GetCustomGraphOptimizerConfig(const string & name) const288 MetaOptimizer::GetCustomGraphOptimizerConfig(const string& name) const {
289 for (const auto& config : cfg_.custom_optimizers()) {
290 if (config.name() == name) {
291 return &config;
292 }
293 }
294 return nullptr;
295 }
296
InitializeVerifiers(std::vector<std::unique_ptr<GraphVerifier>> * inter_optimizer_verifiers,std::vector<std::unique_ptr<GraphVerifier>> * post_optimization_verifiers) const297 void MetaOptimizer::InitializeVerifiers(
298 std::vector<std::unique_ptr<GraphVerifier>>* inter_optimizer_verifiers,
299 std::vector<std::unique_ptr<GraphVerifier>>* post_optimization_verifiers)
300 const {
301 if (cfg_.inter_optimizer_verifier_config().structure_verifier() ==
302 VerifierConfig::ON) {
303 inter_optimizer_verifiers->push_back(MakeUnique<StructureVerifier>());
304 }
305 if (cfg_.post_optimization_verifier_config().structure_verifier() ==
306 VerifierConfig::ON) {
307 post_optimization_verifiers->push_back(MakeUnique<StructureVerifier>());
308 }
309 }
310
311 #define RUN_OPTIMIZER_OR_RETURN_IF_ERROR(optimizer) \
312 { \
313 const Status status = RunOptimizer(optimizer, cluster, &optimized_item, \
314 optimized_graph, &optimization_result); \
315 if (status.ok()) { \
316 is_optimized = true; \
317 } else if (cfg_.fail_on_optimizer_errors()) { \
318 VLOG(2) << "Optimizer '" << optimizer->name() << "' failed: " << status; \
319 TF_RETURN_IF_ERROR(status); \
320 } \
321 }
322
OptimizeGraph(Cluster * cluster,const GrapplerItem & item,GraphDef * optimized_graph)323 Status MetaOptimizer::OptimizeGraph(Cluster* cluster, const GrapplerItem& item,
324 GraphDef* optimized_graph) {
325 int min_graph_nodes = cfg_.min_graph_nodes() == 0 ? kDefaultMinGraphNodes
326 : cfg_.min_graph_nodes();
327 if (item.graph.node_size() < min_graph_nodes) {
328 VLOG(3) << "Skipping optimization, graph has less than " << min_graph_nodes
329 << " nodes.";
330 *optimized_graph = item.graph;
331 return Status::OK();
332 }
333
334 std::vector<std::unique_ptr<GraphOptimizer>> optimizers;
335 if (cfg_.optimizers().empty()) {
336 TF_RETURN_IF_ERROR(InitializeOptimizers(&optimizers));
337 } else {
338 TF_RETURN_IF_ERROR(InitializeOptimizersByName(&optimizers));
339 }
340
341 // Initialize the configured verifiers.
342 std::vector<std::unique_ptr<GraphVerifier>> inter_optimizer_verifiers;
343 std::vector<std::unique_ptr<GraphVerifier>> post_optimization_verifiers;
344 InitializeVerifiers(&inter_optimizer_verifiers, &post_optimization_verifiers);
345 if (inter_optimizer_verifiers.empty()) {
346 VLOG(2) << "No inter optimizer verifiers have been configured";
347 } else {
348 VLOG(2) << inter_optimizer_verifiers.size()
349 << " inter optimizer verifiers have been configured";
350 }
351 if (post_optimization_verifiers.empty()) {
352 VLOG(2) << "No post optimization verifiers have been configured";
353 } else {
354 VLOG(2) << post_optimization_verifiers.size()
355 << " post optimization verifiers have been configured";
356 }
357
358 VLOG(2) << "Optimize GrapplerItem: item.id=" << item.id
359 << " num_optimizers=" << optimizers.size()
360 << ", num nodes = " << item.graph.node_size();
361
362 if (optimizers.empty()) {
363 VLOG(3) << "Skipping graph optimization, no optimizers registered";
364 *optimized_graph = item.graph;
365 return Status::OK();
366 }
367
368 // Invariant: optimized_graph contains the most recently optimized version of
369 // the graph.
370 GrapplerItem optimized_item = item;
371 optimized_graph->Swap(&optimized_item.graph);
372
373 bool is_optimized = false;
374 GraphOptimizationResult optimization_result(item.id);
375 GraphOptimizer* fusion_optimizer = nullptr;
376 GraphOptimizer* sa_optimizer = nullptr;
377
378 for (int iteration = 0; iteration < NumIterations(cfg_); ++iteration) {
379 // Don't bother optimizing further if the graph is already tiny.
380 if (optimized_graph->node_size() < min_graph_nodes) {
381 VLOG(3) << "Stopping after iteration " << iteration
382 << ", graph is tiny (#nodes = " << optimized_graph->node_size()
383 << " < " << min_graph_nodes << ")";
384 break;
385 }
386
387 VLOG(4) << "Starting optimization iteration " << iteration;
388 if (VLOG_IS_ON(4)) {
389 DumpGraphDefToFile(
390 strings::StrCat("before_MetaOptimizer_iteration_", iteration, "_",
391 reinterpret_cast<uintptr_t>(optimized_graph)),
392 *optimized_graph);
393 }
394 for (const auto& optimizer : optimizers) {
395 GRAPPLER_RETURN_IF_DEADLINE_EXCEEDED();
396 // Some optimizers can run only once.
397 if (iteration > 0 && IsRunOnceOptimizer(optimizer->name())) continue;
398 // Some must run only on the last iteration.
399 if (optimizer->name() == "scoped_allocator_optimizer") {
400 if (sa_optimizer == nullptr) sa_optimizer = optimizer.get();
401 continue;
402 }
403 if (optimizer->name() == "xla-fusion") {
404 if (fusion_optimizer == nullptr) fusion_optimizer = optimizer.get();
405 continue;
406 }
407 RUN_OPTIMIZER_OR_RETURN_IF_ERROR(optimizer.get());
408
409 if (VLOG_IS_ON(4)) {
410 DumpGraphDefToFile(
411 strings::StrCat("after_MetaOptimizer_iteration_", iteration, "_",
412 optimizer->name(), "_",
413 reinterpret_cast<uintptr_t>(optimized_graph)),
414 *optimized_graph);
415 }
416 for (const auto& verifier : inter_optimizer_verifiers) {
417 // TODO(ashwinm): Need to enforce verification_deadline.
418 TF_RETURN_IF_ERROR(verifier->Verify(*optimized_graph));
419 }
420 }
421 if (VLOG_IS_ON(4)) {
422 DumpGraphDefToFile(
423 strings::StrCat("after_MetaOptimizer_iteration_", iteration, "_",
424 reinterpret_cast<uintptr_t>(optimized_graph)),
425 *optimized_graph);
426 }
427 // TODO(ashwinm): Need to enforce verification_deadline.
428 for (const auto& verifier : post_optimization_verifiers) {
429 TF_RETURN_IF_ERROR(verifier->Verify(*optimized_graph));
430 }
431 }
432
433 // Run fusion optimizer if requested after all other optimizers since: 1) it
434 // doesn't need to be called more than once. 2) we don't want subsequent
435 // optimization passes to break the fusion clusters. We could potentially
436 // encapsulate the fusion clusters right away, but that will prevent a lot of
437 // optimizations from taking place since we don't have shape inference for
438 // functions, and we can't optimize across function boundaries.
439 if (fusion_optimizer != nullptr) {
440 RUN_OPTIMIZER_OR_RETURN_IF_ERROR(fusion_optimizer);
441 }
442
443 // ScopedAllocatorOptimizer must run last.
444 if (sa_optimizer != nullptr) {
445 RUN_OPTIMIZER_OR_RETURN_IF_ERROR(sa_optimizer);
446 }
447
448 // Compress the constants in the final graph.
449 TF_RETURN_IF_ERROR(CompressConstants(optimized_graph));
450
451 // Record graph optimization result.
452 optimization_results_.push_back(optimization_result);
453
454 if (is_optimized) {
455 TF_RETURN_IF_ERROR(TopologicalSort(optimized_graph));
456 ReassignColocation(optimized_graph);
457 // Make sure that the optimizers preserved the graph version.
458 DCHECK_EQ(optimized_graph->versions().producer(),
459 item.graph.versions().producer());
460 }
461
462 return Status::OK();
463 }
464
465 #undef RUN_OPTIMIZER_OR_RETURN_IF_ERROR
466
RunOptimizer(GraphOptimizer * optimizer,Cluster * cluster,GrapplerItem * optimized_item,GraphDef * optimized_graph,GraphOptimizationResult * optimization_result)467 Status MetaOptimizer::RunOptimizer(
468 GraphOptimizer* optimizer, Cluster* cluster, GrapplerItem* optimized_item,
469 GraphDef* optimized_graph, GraphOptimizationResult* optimization_result) {
470 uint64 start_us = Env::Default()->NowMicros();
471 // This swaps the current optimized_graph into optimized item and
472 // resets optimized_graph to an empty graph.
473 optimized_graph->Swap(&optimized_item->graph);
474 *optimized_graph = GraphDef();
475 optimizer->set_deadline_usec(this->deadline_usec());
476 Status status =
477 optimizer->Optimize(cluster, *optimized_item, optimized_graph);
478 uint64 end_us = Env::Default()->NowMicros();
479
480 string result;
481 if (!status.ok()) {
482 optimized_graph->Swap(&optimized_item->graph);
483 result = status.ToString();
484 } else {
485 float duration_ms = (end_us - start_us) / 1000.0f;
486 result = strings::StrCat(
487 PrintSizesBeforeAfter(optimized_item->graph, *optimized_graph),
488 ", time = ", duration_ms, "ms.");
489 }
490 VLOG(1) << optimizer->name() << ": " << result;
491
492 OptimizerResult optimizer_result{optimizer->name(), result};
493 optimization_result->results.push_back(optimizer_result);
494 return status;
495 }
496
Optimize(Cluster * cluster,const GrapplerItem & item,GraphDef * optimized_graph)497 Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
498 GraphDef* optimized_graph) {
499 VLOG(1) << "Starting optimization for grappler item: " << item.id;
500 optimization_results_.clear();
501
502 // Constructs a FunctionLibraryDefinition with functions that are reachable
503 // from the nodes of the graph.
504 const auto minimized_flib =
505 [](const GraphDef& graph) -> FunctionLibraryDefinition {
506 return FunctionLibraryDefinition(OpRegistry::Global(), graph.library())
507 .ReachableDefinitions(graph);
508 };
509
510 // 0. Original graph might contain a huge function library, that is mostly
511 // unused. This library copied over by each individual Grappler optimizer,
512 // which adds a huge overhead. Before starting optimization passes we just
513 // remove all the unreachable functions.
514 // TODO(ezhulenev): Construct reachable function library definition directly
515 // from the proto without constructing temporary FunctionLibraryDefinition.
516 GraphDef trimmed_graph; // do not copy graph with a potentially huge library
517 *trimmed_graph.mutable_node() = item.graph.node();
518 *trimmed_graph.mutable_versions() = item.graph.versions();
519 *trimmed_graph.mutable_library() = minimized_flib(item.graph).ToProto();
520
521 GrapplerItem trimmed_item = item.WithGraph(std::move(trimmed_graph));
522
523 VLOG(1) << absl::Substitute(
524 "Deleted $0 unreachable functions from the graph (library size = $1)",
525 item.graph.library().function_size() -
526 trimmed_item.graph.library().function_size(),
527 trimmed_item.graph.library().function_size());
528
529 // 1. Optimize main graph
530 TF_RETURN_IF_ERROR(OptimizeGraph(cluster, trimmed_item, optimized_graph));
531 VLOG(1) << "Optimized main graph.";
532 GRAPPLER_RETURN_IF_DEADLINE_EXCEEDED();
533
534 // Skip optimizing functions if this is a TPU graph. Currently, Grappler
535 // passes do not handle TPU functions correctly in a variety of ways (Note
536 // that due to the pre-placement TPU graph rewriting passes, the TPU-related
537 // ops are encapsulated away into functions). For example, TPU graphs contain
538 // TPUReplicateMetadata node that carries relevant TPU metadata and Grappler
539 // passes could prune that away. Grappler passes could also cause issues
540 // around shape inference. Since the desired and existing behavior is to not
541 // optimize TPU functions with Grappler, this check preserves that.
542 if (IsTPUGraphDef(*optimized_graph)) {
543 VLOG(2) << "Skipping optimizing funcs for TPU graphs";
544 if (VLOG_IS_ON(1)) {
545 DumpGraphDefToFile(
546 strings::StrCat("after_MetaOptimizer_",
547 reinterpret_cast<uintptr_t>(optimized_graph)),
548 *optimized_graph);
549 }
550 return Status::OK();
551 }
552
553 // 2. Optimize functions reachable from the optimized graph.
554 FunctionLibraryDefinition flib = minimized_flib(*optimized_graph);
555
556 // Find functions for which we might need to compute a gradient at runtime.
557 absl::flat_hash_set<string> differentiable_functions;
558 for (const NodeDef& node : optimized_graph->node()) {
559 if (IsSymbolicGradient(node)) {
560 const auto* f_attr = gtl::FindOrNull(node.attr(), "f");
561 if (f_attr) differentiable_functions.insert(f_attr->func().name());
562 }
563 }
564
565 // Optimize each function only once.
566 absl::flat_hash_set<string> optimized_funcs;
567 bool optimize_function_library =
568 item.optimization_options().optimize_function_library;
569
570 while (optimize_function_library) {
571 optimize_function_library = false;
572
573 for (const FunctionDef& func : optimized_graph->library().function()) {
574 GRAPPLER_RETURN_IF_DEADLINE_EXCEEDED();
575
576 const string& func_name = func.signature().name();
577
578 // Skip functions that are not reachable from the optimized graph.
579 if (!flib.Contains(func_name)) continue;
580
581 // Skip already optimized functions.
582 if (optimized_funcs.find(func_name) != optimized_funcs.end()) continue;
583
584 // Skip parametrized functions (function type or body is defined only at
585 // function call time by caller node attributes).
586 // They should be specialized to their instantiation type parameters by
587 // the function optimizer, before we can optimize function body.
588 if (IsParametrized(func)) continue;
589
590 VLOG(3) << "Optimize function: function=" << func_name;
591
592 // Function optimization might specialize nested function calls, so we
593 // have to reset the flag and do at least one more pass over the library.
594 optimize_function_library = true;
595 optimized_funcs.insert(func_name);
596
597 // Make a GrapplerItem from a FunctionDef.
598 GrapplerFunctionItem func_item;
599 TF_RETURN_IF_ERROR(MakeGrapplerFunctionItem(
600 func, flib, trimmed_item.graph.versions().producer(), &func_item));
601
602 // If we need to compute the gradient of optimized function at runtime, we
603 // can't perform non-differentiable rewrites.
604 if (differentiable_functions.find(func_name) !=
605 differentiable_functions.end()) {
606 func_item.optimization_options().allow_non_differentiable_rewrites =
607 false;
608 }
609
610 // Function item is allowed to use all devices from the main graph.
611 Status added_devices = func_item.AddDevices(item);
612 if (!added_devices.ok()) {
613 VLOG(3) << added_devices.error_message();
614 }
615
616 // We are not allowed to prune certain types of ops from the graph
617 // instantiated by the function definition, because we must guarantee
618 // function execution semantics wrt side effects (see
619 // function_optimizer.cc).
620 func_item.optimization_options().allow_pruning_stateful_and_dataset_ops =
621 false;
622
623 // Optimize function body graph.
624 GraphDef optimized_func_graph;
625 TF_RETURN_IF_ERROR(
626 OptimizeGraph(cluster, func_item, &optimized_func_graph));
627
628 // Function body optimization might have created new specialized
629 // functions for each instantiation context. Add them to the library.
630 for (const FunctionDef& func_def :
631 optimized_func_graph.library().function()) {
632 if (flib.Find(func_def.signature().name()) == nullptr) {
633 TF_RETURN_IF_ERROR(flib.AddFunctionDef(func_def));
634 }
635 }
636
637 // Convert optimized graph back to FunctionDef.
638 FunctionDef optimized_func;
639 func_item.SwapFunctionBody(std::move(optimized_func_graph));
640 TF_RETURN_IF_ERROR(MakeFunctionDef(func_item, flib, &optimized_func));
641
642 // Replace optimized function with a new FunctionDef.
643 TF_RETURN_IF_ERROR(flib.ReplaceFunction(func_name, optimized_func));
644 }
645
646 // If optimized at least one function, update the graph library.
647 if (optimize_function_library) {
648 *optimized_graph->mutable_library() = flib.ToProto();
649 }
650 }
651
652 VLOG(1) << "Optimized " << optimized_funcs.size()
653 << " functions: " << str_util::Join(optimized_funcs, ", ");
654
655 if (VLOG_IS_ON(1)) {
656 DumpGraphDefToFile(
657 strings::StrCat("after_MetaOptimizer_",
658 reinterpret_cast<uintptr_t>(optimized_graph)),
659 *optimized_graph);
660 }
661 return Status::OK();
662 }
663
PrintResult()664 void MetaOptimizer::PrintResult() {
665 for (const GraphOptimizationResult& graph_result : optimization_results_) {
666 LOG(INFO) << "Optimization results for grappler item: " << graph_result.id;
667 for (const OptimizerResult& result : graph_result.results) {
668 LOG(INFO) << " " << result.optimizer_name << ": " << result.result;
669 }
670 }
671 }
672
Feedback(Cluster * cluster,const GrapplerItem & item,const GraphDef & pruned_graph,double result)673 void MetaOptimizer::Feedback(Cluster* cluster, const GrapplerItem& item,
674 const GraphDef& pruned_graph, double result) {
675 // Nothing to do for MetaOptimizer.
676 }
677
MetaOptimizerEnabled(const ConfigProto & cfg)678 bool MetaOptimizerEnabled(const ConfigProto& cfg) {
679 const auto& rewrite_cfg = cfg.graph_options().rewrite_options();
680 if (rewrite_cfg.disable_meta_optimizer()) {
681 return false;
682 }
683 return !rewrite_cfg.disable_model_pruning() ||
684 rewrite_cfg.layout_optimizer() != RewriterConfig::OFF ||
685 rewrite_cfg.function_optimization() != RewriterConfig::OFF ||
686 rewrite_cfg.constant_folding() != RewriterConfig::OFF ||
687 rewrite_cfg.shape_optimization() != RewriterConfig::OFF ||
688 rewrite_cfg.remapping() != RewriterConfig::OFF ||
689 rewrite_cfg.arithmetic_optimization() != RewriterConfig::OFF ||
690 rewrite_cfg.loop_optimization() != RewriterConfig::OFF ||
691 rewrite_cfg.dependency_optimization() != RewriterConfig::OFF ||
692 rewrite_cfg.auto_parallel().enable() ||
693 rewrite_cfg.memory_optimization() != RewriterConfig::NO_MEM_OPT ||
694 rewrite_cfg.debug_stripper() == RewriterConfig::ON ||
695 rewrite_cfg.scoped_allocator_optimization() == RewriterConfig::ON ||
696 rewrite_cfg.pin_to_host_optimization() == RewriterConfig::ON ||
697 !rewrite_cfg.optimizers().empty() ||
698 !rewrite_cfg.custom_optimizers().empty();
699 }
700
RunMetaOptimizer(const GrapplerItem & item,const ConfigProto & cfg,DeviceBase * cpu_device,Cluster * cluster,GraphDef * optimized_graph)701 Status RunMetaOptimizer(const GrapplerItem& item, const ConfigProto& cfg,
702 DeviceBase* cpu_device, Cluster* cluster,
703 GraphDef* optimized_graph) {
704 MetaOptimizer optimizer(cpu_device, cfg);
705 optimizer.set_deadline_usec(
706 DeadlineMicroSeconds(cfg.graph_options().rewrite_options()));
707 Status status = optimizer.Optimize(cluster, item, optimized_graph);
708 if (!status.ok()) {
709 *optimized_graph = item.graph;
710 }
711 return status;
712 }
713
OptimizeGraph(std::vector<string> ret_node_names,std::vector<string> keep_node_names,FunctionLibraryDefinition * flib,const DeviceSet & device_set,Device * cpu_device,const ConfigProto & config_proto,const string & grappler_item_id,const GrapplerItem::OptimizationOptions & optimization_options,std::unique_ptr<tensorflow::Graph> * g)714 Status OptimizeGraph(
715 std::vector<string> ret_node_names, std::vector<string> keep_node_names,
716 FunctionLibraryDefinition* flib, const DeviceSet& device_set,
717 Device* cpu_device, const ConfigProto& config_proto,
718 const string& grappler_item_id,
719 const GrapplerItem::OptimizationOptions& optimization_options,
720 std::unique_ptr<tensorflow::Graph>* g) {
721 if (!tensorflow::grappler::MetaOptimizerEnabled(config_proto)) {
722 return Status::OK();
723 }
724
725 tensorflow::grappler::GrapplerItem item;
726 item.id = grappler_item_id;
727 item.optimization_options() = optimization_options;
728
729 // Add all available devices so that inlined function can be placed.
730 for (const Device* d : device_set.devices()) {
731 Status added_device = item.AddDevice(d->name());
732 if (!added_device.ok()) VLOG(3) << added_device.error_message();
733 }
734
735 // Add fetches so that the graph can be pruned.
736 item.fetch.swap(ret_node_names);
737
738 // Add noes that can't be removed from the graph.
739 item.keep_ops = std::move(keep_node_names);
740
741 (*g)->ToGraphDef(&item.graph);
742
743 if (flib) {
744 *item.graph.mutable_library() = flib->ToProto();
745 }
746
747 tensorflow::GraphDef out_graph;
748
749 tensorflow::grappler::VirtualCluster cluster(&device_set);
750
751 // TODO(nareshmodi): Consider adding and using the more generic GraphOptions
752 // proto (which also contain the OptimizerOptions).
753 TF_RETURN_IF_ERROR(tensorflow::grappler::RunMetaOptimizer(
754 item, config_proto, cpu_device, &cluster, &out_graph));
755
756 std::unique_ptr<tensorflow::Graph> optimized_graph(
757 new tensorflow::Graph(OpRegistry::Global()));
758 TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(GraphConstructorOptions(),
759 out_graph, optimized_graph.get()));
760
761 // Copy optimized functions back to the overlay lib.
762 if (flib) {
763 for (const FunctionDef& fdef : out_graph.library().function()) {
764 const string& func_name = fdef.signature().name();
765 if (flib->Contains(func_name)) {
766 TF_RETURN_IF_ERROR(flib->ReplaceFunction(func_name, fdef));
767 } else {
768 TF_RETURN_IF_ERROR(flib->AddFunctionDef(fdef));
769 }
770 }
771 }
772
773 *g = std::move(optimized_graph);
774
775 // The graph conversion sets the requested device names but not the
776 // assigned device names. However, since at this point the graph is
777 // placed TF expects an assigned device name for every node. Therefore
778 // we copy the requested device into the assigned device field.
779 for (Node* node : (*g)->nodes()) {
780 if (node->IsOp() && node->assigned_device_name().empty()) {
781 if (node->requested_device().empty()) {
782 return errors::Internal(
783 "Either placer did not place the node or Grappler did not "
784 "copy the assigned device. Contact Grappler team since latter "
785 "is more likely. Node=",
786 node->name(), " Graph: ", (*g)->ToGraphDefDebug().DebugString());
787 }
788 node->set_assigned_device_name(node->requested_device());
789 }
790 }
791
792 return Status::OK();
793 }
794
795 } // namespace grappler
796 } // namespace tensorflow
797