1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/grappler/optimizers/meta_optimizer.h"
17
18 #include "absl/strings/str_cat.h"
19 #include "absl/strings/str_join.h"
20 #include "absl/strings/substitute.h"
21 #include "tensorflow/core/common_runtime/function.h"
22 #include "tensorflow/core/common_runtime/graph_constructor.h"
23 #include "tensorflow/core/common_runtime/metrics.h"
24 #include "tensorflow/core/framework/dataset.h"
25 #include "tensorflow/core/framework/function.pb.h"
26 #include "tensorflow/core/framework/tensor_shape.pb.h"
27 #include "tensorflow/core/framework/tensor_util.h"
28 #include "tensorflow/core/framework/versions.pb.h"
29 #include "tensorflow/core/grappler/clusters/virtual_cluster.h"
30 #include "tensorflow/core/grappler/optimizers/arithmetic_optimizer.h"
31 #include "tensorflow/core/grappler/optimizers/auto_mixed_precision.h"
32 #include "tensorflow/core/grappler/optimizers/auto_parallel.h"
33 #include "tensorflow/core/grappler/optimizers/common_subgraph_elimination.h"
34 #include "tensorflow/core/grappler/optimizers/constant_folding.h"
35 #include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
36 #include "tensorflow/core/grappler/optimizers/debug_stripper.h"
37 #include "tensorflow/core/grappler/optimizers/dependency_optimizer.h"
38 #include "tensorflow/core/grappler/optimizers/function_optimizer.h"
39 #include "tensorflow/core/grappler/optimizers/generic_layout_optimizer.h"
40 #include "tensorflow/core/grappler/optimizers/implementation_selector.h"
41 #include "tensorflow/core/grappler/optimizers/loop_optimizer.h"
42 #include "tensorflow/core/grappler/optimizers/memory_optimizer.h"
43 #include "tensorflow/core/grappler/optimizers/model_pruner.h"
44 #include "tensorflow/core/grappler/optimizers/pin_to_host_optimizer.h"
45 #include "tensorflow/core/grappler/optimizers/remapper.h"
46 #include "tensorflow/core/grappler/optimizers/scoped_allocator_optimizer.h"
47 #include "tensorflow/core/grappler/optimizers/shape_optimizer.h"
48 #include "tensorflow/core/grappler/utils/canonicalizer.h"
49 #include "tensorflow/core/grappler/utils/colocation.h"
50 #include "tensorflow/core/grappler/utils/functions.h"
51 #include "tensorflow/core/grappler/utils/topological_sort.h"
52 #include "tensorflow/core/grappler/utils/tpu.h"
53 #include "tensorflow/core/grappler/verifiers/structure_verifier.h"
54 #include "tensorflow/core/lib/core/status.h"
55 #include "tensorflow/core/lib/gtl/map_util.h"
56 #include "tensorflow/core/util/dump_graph.h"
57 #include "tensorflow/core/util/ptr_util.h"
58 #include "tensorflow/core/util/util.h"
59 #include "tensorflow/core/util/xla_config_registry.h"
60
61 namespace tensorflow {
62 namespace grappler {
63
64 namespace {
65
66 constexpr int kDefaultNumberOfIterations = 2;
67 constexpr int kDefaultMinGraphNodes = 4;
68
NumEdges(const GraphDef & graph)69 int64 NumEdges(const GraphDef& graph) {
70 int64_t num_edges = 0;
71 for (const auto& node : graph.node()) {
72 num_edges += node.input_size();
73 }
74 return num_edges;
75 }
76
PrintSizesBeforeAfter(const GraphDef & before,const GraphDef & after)77 string PrintSizesBeforeAfter(const GraphDef& before, const GraphDef& after) {
78 return strings::StrCat("Graph size after: ", after.node_size(), " nodes (",
79 after.node_size() - before.node_size(), "), ",
80 NumEdges(after), " edges (",
81 NumEdges(after) - NumEdges(before), ")");
82 }
83
NumIterations(const RewriterConfig & cfg)84 int NumIterations(const RewriterConfig& cfg) {
85 return cfg.meta_optimizer_iterations() == RewriterConfig::DEFAULT_NUM_ITERS
86 ? kDefaultNumberOfIterations
87 : cfg.meta_optimizer_iterations();
88 }
89
90 // Check if optimizer is allowed to run only once.
IsRunOnceOptimizer(const string & name)91 bool IsRunOnceOptimizer(const string& name) {
92 return name == "layout" || name == "memory_optimizer" ||
93 name == "loop_optimizer" || name == "auto_mixed_precision" ||
94 name == "auto_mixed_precision_mkl";
95 }
96
97 // Creates a function library stub from a real function library: copy only
98 // signatures and attributes of all the function defined in fdef_lib. This stub
99 // can be swapped with real function library in a graph, before passing it to
100 // optimizer, if optimizer doesn't instantiate functions.
GetFunctionDefLibraryStub(const FunctionDefLibrary & fdef_lib)101 FunctionDefLibrary GetFunctionDefLibraryStub(
102 const FunctionDefLibrary& fdef_lib) {
103 FunctionDefLibrary stub;
104 for (const FunctionDef& fn : fdef_lib.function()) {
105 FunctionDef* fn_stub = stub.mutable_function()->Add();
106 *(fn_stub->mutable_signature()) = fn.signature();
107 *(fn_stub->mutable_attr()) = fn.attr();
108 *(fn_stub->mutable_arg_attr()) = fn.arg_attr();
109 *(fn_stub->mutable_resource_arg_unique_id()) = fn.resource_arg_unique_id();
110 }
111 *stub.mutable_gradient() = fdef_lib.gradient();
112 return stub;
113 }
114
DeadlineMicroSeconds(const RewriterConfig & cfg)115 uint64 DeadlineMicroSeconds(const RewriterConfig& cfg) {
116 if (cfg.meta_optimizer_timeout_ms() <= 0) return 0; // no deadline
117 return Env::Default()->NowMicros() + cfg.meta_optimizer_timeout_ms() * 1000;
118 }
119
120 // A helper function to decide whether to enable the automatic mixed precision
121 // optimizer.
AutoMixedPrecisionEnabled(RewriterConfig::Toggle opt_level)122 bool AutoMixedPrecisionEnabled(RewriterConfig::Toggle opt_level) {
123 if (opt_level == RewriterConfig::ON ||
124 opt_level == RewriterConfig::AGGRESSIVE) {
125 return true;
126 }
127 return false;
128 }
129
IsXlaGlobalJitOn(const OptimizerOptions::GlobalJitLevel & jit_level_in_session_opts)130 bool IsXlaGlobalJitOn(
131 const OptimizerOptions::GlobalJitLevel& jit_level_in_session_opts) {
132 xla_config_registry::XlaGlobalJitLevel xla_global_jit_level =
133 xla_config_registry::GetGlobalJitLevel(jit_level_in_session_opts);
134 // Return true only if XLA JIT is ON for both single-gpu and multi-gpu
135 // graphs. This is a conservative approach that turns off the memory optimizer
136 // when we are sure that all graphs will be processed by XLA JIT.
137 bool is_on = (xla_global_jit_level.single_gpu == OptimizerOptions::ON_1 ||
138 xla_global_jit_level.single_gpu == OptimizerOptions::ON_2) &&
139 (xla_global_jit_level.general == OptimizerOptions::ON_1 ||
140 xla_global_jit_level.general == OptimizerOptions::ON_2);
141 return is_on;
142 }
143
144 // A helper function to decide whether to enable the memory optimizer.
MemoryOptimizerEnabled(RewriterConfig::MemOptType mem_opt_type,bool xla_auto_clustering_on)145 bool MemoryOptimizerEnabled(RewriterConfig::MemOptType mem_opt_type,
146 bool xla_auto_clustering_on) {
147 // Disable the default memory optimizer when XLA JIT is ON as it hurts the
148 // XLA JIT performance. The (current) XLA clustering can result in loss of
149 // concurrency between kernel compute and memory copies. As such, it usually
150 // loses the concurrency needed to hide the latencies of the inserted swap-ins
151 // and swap-outs and incurs great performance overhead. Remove this check when
152 // the XLA JIT can better deal with the concurrency.
153 if (mem_opt_type == RewriterConfig::DEFAULT_MEM_OPT &&
154 xla_auto_clustering_on) {
155 return false;
156 }
157
158 return mem_opt_type != RewriterConfig::NO_MEM_OPT;
159 }
160
GetGraphDevice(const GraphDef & g_def,std::set<std::string> * devices)161 Status GetGraphDevice(const GraphDef& g_def, std::set<std::string>* devices) {
162 for (auto& node : g_def.node()) {
163 DeviceNameUtils::ParsedName parsed_name;
164 if (!DeviceNameUtils::ParseFullName(node.device(), &parsed_name)) {
165 return errors::InvalidArgument("Unable to parse ", node.device(),
166 " as a device name");
167 }
168 devices->insert(parsed_name.type);
169 }
170 return Status::OK();
171 }
172
173 } // namespace
174
175 #define MK_OPT(NAME, CONFIG, VALUE) \
176 if (optimizer == NAME) { \
177 if (plugin_configs.toggle_config[CONFIG] != RewriterConfig::OFF) { \
178 return std::unique_ptr<GraphOptimizer>(VALUE); \
179 } \
180 }
181
LowerControlFlow() const182 bool MetaOptimizer::LowerControlFlow() const {
183 if (config_proto_.experimental().executor_type() ==
184 "SINGLE_THREADED_EXECUTOR")
185 return false;
186
187 if (config_proto_.experimental().use_tfrt()) return false;
188
189 return true;
190 }
191
MakeNewOptimizer(const string & optimizer,const std::set<string> & device_types) const192 std::unique_ptr<GraphOptimizer> MetaOptimizer::MakeNewOptimizer(
193 const string& optimizer, const std::set<string>& device_types) const {
194 ConfigList plugin_configs = PluginGraphOptimizerRegistry::GetPluginConfigs(
195 cfg_.use_plugin_optimizers() != RewriterConfig::OFF, device_types);
196 if (optimizer == "pruning" && !plugin_configs.disable_model_pruning)
197 return std::unique_ptr<GraphOptimizer>(new ModelPruner());
198 MK_OPT("function", "function_optimization",
199 new FunctionOptimizer(cfg_.function_optimization(),
200 /*lower_control_flow=*/LowerControlFlow()));
201 MK_OPT("constfold", "constant_folding",
202 new ConstantFolding(
203 cpu_device_,
204 cfg_.experimental_disable_compressed_tensor_optimization(),
205 !cfg_.experimental_disable_folding_quantization_emulation()));
206 MK_OPT("shape", "shape_optimization", new ShapeOptimizer());
207 MK_OPT("remap", "remapping",
208 new Remapper(cfg_.remapping(), xla_auto_clustering_on_));
209 MK_OPT("layout", "layout_optimizer",
210 new GenericLayoutOptimizer(
211 /*optimization level*/ cfg_.layout_optimizer(),
212 /*CPU layout conversion*/ cfg_.cpu_layout_conversion()));
213 MK_OPT("auto_mixed_precision", "auto_mixed_precision",
214 new AutoMixedPrecision(AutoMixedPrecisionMode::CUDA));
215 #ifdef INTEL_MKL
216 if (IsMKLEnabled()) {
217 MK_OPT("auto_mixed_precision_mkl", "auto_mixed_precision_mkl",
218 new AutoMixedPrecision(AutoMixedPrecisionMode::MKL));
219 }
220 #endif
221 MK_OPT("memory", "memory_optimization",
222 new MemoryOptimizer(RewriterConfig::MANUAL));
223 MK_OPT("common_subgraph_elimination", "common_subgraph_elimination",
224 new CommonSubgraphElimination(cfg_.common_subgraph_elimination()));
225 MK_OPT("arithmetic", "arithmetic_optimization",
226 new ArithmeticOptimizer(cfg_.arithmetic_optimization()));
227 MK_OPT("autoparallel", "auto_parallel",
228 new AutoParallel(cfg_.auto_parallel().num_replicas()));
229 MK_OPT("loop", "loop_optimization",
230 new LoopOptimizer(cfg_.loop_optimization(), cpu_device_));
231 MK_OPT("dependency", "dependency_optimization",
232 new DependencyOptimizer(cfg_.dependency_optimization()));
233 MK_OPT("debug_stripper", "debug_stripper", new DebugStripper());
234 MK_OPT("scoped_allocator", "scoped_allocator_optimization",
235 new ScopedAllocatorOptimizer(cfg_.scoped_allocator_optimization(),
236 cfg_.scoped_allocator_opts()));
237 MK_OPT("pin_to_host", "pin_to_host_optimization",
238 new PinToHostOptimizer(cfg_.pin_to_host_optimization()));
239
240 return std::unique_ptr<GraphOptimizer>();
241 }
242
243 #undef MK_OPT
244
MetaOptimizer(DeviceBase * cpu_device,const ConfigProto & cfg)245 MetaOptimizer::MetaOptimizer(DeviceBase* cpu_device, const ConfigProto& cfg)
246 : cpu_device_(cpu_device),
247 config_proto_(cfg),
248 cfg_(*config_proto_.mutable_graph_options()->mutable_rewrite_options()) {
249 DCHECK(cpu_device_ == nullptr ||
250 cpu_device_->attributes().device_type() == "CPU");
251 auto global_jit_level =
252 cfg.graph_options().optimizer_options().global_jit_level();
253 xla_auto_clustering_on_ = IsXlaGlobalJitOn(global_jit_level);
254 }
255
InitializeOptimizers(const std::set<string> & device_types,std::vector<std::unique_ptr<GraphOptimizer>> * optimizers) const256 Status MetaOptimizer::InitializeOptimizers(
257 const std::set<string>& device_types,
258 std::vector<std::unique_ptr<GraphOptimizer>>* optimizers) const {
259 if (cfg_.disable_meta_optimizer()) {
260 return Status::OK();
261 }
262
263 ConfigList plugin_configs = PluginGraphOptimizerRegistry::GetPluginConfigs(
264 cfg_.use_plugin_optimizers() != RewriterConfig::OFF, device_types);
265 if (!cfg_.disable_model_pruning() && !plugin_configs.disable_model_pruning) {
266 optimizers->push_back(MakeUnique<ModelPruner>());
267 }
268
269 #define USER_IS_ON(CFG) cfg_.CFG() == RewriterConfig::ON
270 #define USER_NOT_OFF(CFG) cfg_.CFG() != RewriterConfig::OFF
271 #define PLUGIN_IS_ON(CFG) \
272 plugin_configs.toggle_config[#CFG] == RewriterConfig::ON
273 #define PLUGIN_NOT_OFF(CFG) \
274 plugin_configs.toggle_config[#CFG] != RewriterConfig::OFF
275 #define BOTH_ARE_ON(CFG) USER_IS_ON(CFG) && PLUGIN_IS_ON(CFG)
276 #define BOTH_NOT_OFF(CFG) USER_NOT_OFF(CFG) && PLUGIN_NOT_OFF(CFG)
277 if (BOTH_NOT_OFF(implementation_selector)) {
278 optimizers->push_back(MakeUnique<ImplementationSelector>());
279 }
280 if (BOTH_NOT_OFF(function_optimization)) {
281 optimizers->push_back(MakeUnique<FunctionOptimizer>(
282 cfg_.function_optimization(),
283 /*lower_control_flow=*/LowerControlFlow()));
284 }
285 if (BOTH_NOT_OFF(common_subgraph_elimination) &&
286 BOTH_NOT_OFF(arithmetic_optimization)) {
287 optimizers->push_back(MakeUnique<CommonSubgraphElimination>(
288 cfg_.common_subgraph_elimination()));
289 }
290 if (BOTH_ARE_ON(debug_stripper)) {
291 optimizers->push_back(MakeUnique<DebugStripper>());
292 }
293 if (BOTH_NOT_OFF(constant_folding)) {
294 optimizers->push_back(MakeUnique<ConstantFolding>(
295 cfg_.constant_folding(), cpu_device_,
296 cfg_.experimental_disable_compressed_tensor_optimization(),
297 !cfg_.experimental_disable_folding_quantization_emulation()));
298 }
299 if (BOTH_NOT_OFF(shape_optimization)) {
300 optimizers->push_back(MakeUnique<ShapeOptimizer>());
301 }
302 if (AutoMixedPrecisionEnabled(cfg_.auto_mixed_precision()) &&
303 AutoMixedPrecisionEnabled(
304 plugin_configs.toggle_config["auto_mixed_precision"])) {
305 optimizers->push_back(
306 MakeUnique<AutoMixedPrecision>(AutoMixedPrecisionMode::CUDA));
307 }
308 #ifdef INTEL_MKL
309 if (AutoMixedPrecisionEnabled(cfg_.auto_mixed_precision_mkl()) &&
310 AutoMixedPrecisionEnabled(
311 plugin_configs.toggle_config["auto_mixed_precision_mkl"]) &&
312 IsMKLEnabled()) {
313 optimizers->push_back(
314 MakeUnique<AutoMixedPrecision>(AutoMixedPrecisionMode::MKL));
315 }
316 #endif
317 if (BOTH_ARE_ON(pin_to_host_optimization)) {
318 optimizers->push_back(MakeUnique<PinToHostOptimizer>());
319 }
320 if (BOTH_NOT_OFF(arithmetic_optimization)) {
321 optimizers->push_back(
322 MakeUnique<ArithmeticOptimizer>(cfg_.arithmetic_optimization()));
323 }
324 if (BOTH_NOT_OFF(layout_optimizer)) {
325 optimizers->push_back(MakeUnique<GenericLayoutOptimizer>(
326 /*optimization level*/ cfg_.layout_optimizer(),
327 /*CPU layout conversion*/ cfg_.cpu_layout_conversion()));
328 }
329 if (BOTH_NOT_OFF(remapping)) {
330 optimizers->push_back(
331 MakeUnique<Remapper>(cfg_.remapping(), xla_auto_clustering_on_));
332 }
333 if (BOTH_NOT_OFF(loop_optimization)) {
334 optimizers->push_back(
335 MakeUnique<LoopOptimizer>(cfg_.loop_optimization(), cpu_device_));
336 }
337 if (BOTH_NOT_OFF(dependency_optimization)) {
338 optimizers->push_back(
339 MakeUnique<DependencyOptimizer>(cfg_.dependency_optimization()));
340 }
341 if (MemoryOptimizerEnabled(cfg_.memory_optimization(),
342 xla_auto_clustering_on_) &&
343 PLUGIN_NOT_OFF(memory_optimization)) {
344 if (cfg_.memory_optimizer_target_node_name_scope().empty()) {
345 optimizers->push_back(
346 // Use the default target node name prefix "gradients/"
347 MakeUnique<MemoryOptimizer>(cfg_.memory_optimization()));
348 } else {
349 optimizers->push_back(MakeUnique<MemoryOptimizer>(
350 cfg_.memory_optimization(),
351 cfg_.memory_optimizer_target_node_name_scope()));
352 }
353 }
354 if (cfg_.auto_parallel().enable() && PLUGIN_IS_ON(auto_parallel)) {
355 optimizers->push_back(
356 MakeUnique<AutoParallel>(cfg_.auto_parallel().num_replicas()));
357 }
358
359 #ifndef ENABLE_MKL
360 if (BOTH_ARE_ON(scoped_allocator_optimization)) {
361 optimizers->push_back(MakeUnique<ScopedAllocatorOptimizer>(
362 cfg_.scoped_allocator_optimization(), cfg_.scoped_allocator_opts()));
363 }
364 #endif
365
366 #undef USER_IS_ON
367 #undef USER_NOT_OFF
368 #undef PLUGIN_IS_ON
369 #undef PLUGIN_NOT_OFF
370 #undef BOTH_ARE_ON
371 #undef BOTH_NOT_OFF
372 return InitializeCustomGraphOptimizers(device_types, std::set<string>(),
373 optimizers);
374 }
375
InitializeOptimizersByName(const std::set<string> & device_types,std::vector<std::unique_ptr<GraphOptimizer>> * optimizers) const376 Status MetaOptimizer::InitializeOptimizersByName(
377 const std::set<string>& device_types,
378 std::vector<std::unique_ptr<GraphOptimizer>>* optimizers) const {
379 std::set<string> initialized_custom_optimizers;
380 for (const string& optimizer_name : cfg_.optimizers()) {
381 auto optimizer = MakeNewOptimizer(optimizer_name, device_types);
382 if (optimizer) {
383 VLOG(2) << "Registered default graph optimizer: " << optimizer_name;
384 optimizers->push_back(std::move(optimizer));
385 continue;
386 }
387
388 auto custom_optimizer =
389 CustomGraphOptimizerRegistry::CreateByNameOrNull(optimizer_name);
390
391 if (custom_optimizer) {
392 VLOG(2) << "Registered custom graph optimizer: " << optimizer_name;
393 TF_RETURN_IF_ERROR(custom_optimizer->InitWithConfig(
394 config_proto_, GetCustomGraphOptimizerConfig(optimizer_name)));
395 optimizers->push_back(std::move(custom_optimizer));
396 initialized_custom_optimizers.insert(optimizer_name);
397 } else {
398 VLOG(2) << "Can't register an optimizer by name: " << optimizer_name;
399 }
400 }
401 return InitializeCustomGraphOptimizers(
402 device_types, initialized_custom_optimizers, optimizers);
403 }
404
InitializeCustomGraphOptimizers(const std::set<string> & device_types,const std::set<string> & pre_initialized_optimizers,std::vector<std::unique_ptr<GraphOptimizer>> * optimizers) const405 Status MetaOptimizer::InitializeCustomGraphOptimizers(
406 const std::set<string>& device_types,
407 const std::set<string>& pre_initialized_optimizers,
408 std::vector<std::unique_ptr<GraphOptimizer>>* optimizers) const {
409 for (const auto& optimizer_config : cfg_.custom_optimizers()) {
410 if (pre_initialized_optimizers.find(optimizer_config.name()) !=
411 pre_initialized_optimizers.end()) {
412 continue;
413 }
414
415 auto custom_optimizer = CustomGraphOptimizerRegistry::CreateByNameOrNull(
416 optimizer_config.name());
417
418 if (custom_optimizer) {
419 VLOG(2) << "Registered custom configurable graph optimizer: "
420 << optimizer_config.name();
421 TF_RETURN_IF_ERROR(
422 custom_optimizer->InitWithConfig(config_proto_, &optimizer_config));
423 optimizers->push_back(std::move(custom_optimizer));
424 } else {
425 // If there are no custom optimizers with given name, try to initialize a
426 // default optimizer. This way, custom configurable optimizers can be
427 // mixed with default optimizers in any order.
428 auto optimizer = MakeNewOptimizer(optimizer_config.name(), device_types);
429 if (optimizer) {
430 VLOG(2) << "Registered default graph optimizer: "
431 << optimizer_config.name();
432 optimizers->push_back(std::move(optimizer));
433 continue;
434 }
435 VLOG(2) << "Can't register an optimizer by name: "
436 << optimizer_config.name();
437 }
438 }
439 return InitializePluginGraphOptimizers(device_types, optimizers);
440 }
441
InitializePluginGraphOptimizers(const std::set<string> & device_types,std::vector<std::unique_ptr<GraphOptimizer>> * optimizers) const442 Status MetaOptimizer::InitializePluginGraphOptimizers(
443 const std::set<string>& device_types,
444 std::vector<std::unique_ptr<GraphOptimizer>>* optimizers) const {
445 if (cfg_.use_plugin_optimizers() == RewriterConfig::OFF) return Status::OK();
446 auto plugin_optimizers =
447 PluginGraphOptimizerRegistry::CreateOptimizers(device_types);
448 for (auto& plugin_optimizer : plugin_optimizers) {
449 optimizers->push_back(std::move(plugin_optimizer));
450 }
451 return Status::OK();
452 }
453
454 const RewriterConfig::CustomGraphOptimizer*
GetCustomGraphOptimizerConfig(const string & name) const455 MetaOptimizer::GetCustomGraphOptimizerConfig(const string& name) const {
456 for (const auto& config : cfg_.custom_optimizers()) {
457 if (config.name() == name) {
458 return &config;
459 }
460 }
461 return nullptr;
462 }
463
InitializeVerifiers(std::vector<std::unique_ptr<GraphVerifier>> * inter_optimizer_verifiers,std::vector<std::unique_ptr<GraphVerifier>> * post_optimization_verifiers) const464 void MetaOptimizer::InitializeVerifiers(
465 std::vector<std::unique_ptr<GraphVerifier>>* inter_optimizer_verifiers,
466 std::vector<std::unique_ptr<GraphVerifier>>* post_optimization_verifiers)
467 const {
468 if (cfg_.inter_optimizer_verifier_config().structure_verifier() ==
469 VerifierConfig::ON) {
470 inter_optimizer_verifiers->push_back(MakeUnique<StructureVerifier>());
471 }
472 if (cfg_.post_optimization_verifier_config().structure_verifier() ==
473 VerifierConfig::ON) {
474 post_optimization_verifiers->push_back(MakeUnique<StructureVerifier>());
475 }
476 }
477
PrintUserAndPluginConfigs(const std::set<string> & device_types) const478 void MetaOptimizer::PrintUserAndPluginConfigs(
479 const std::set<string>& device_types) const {
480 if (cfg_.use_plugin_optimizers() == RewriterConfig::OFF) return;
481 ConfigList plugin_cfg = PluginGraphOptimizerRegistry::GetPluginConfigs(
482 cfg_.use_plugin_optimizers() != RewriterConfig::OFF, device_types);
483 PluginGraphOptimizerRegistry::PrintPluginConfigsIfConflict(device_types);
484
485 ConfigList user_cfg;
486 // Print user's and plugin's configs.
487 if (cfg_.optimizers().empty()) {
488 if (cfg_.disable_meta_optimizer()) {
489 return;
490 }
491 user_cfg.disable_model_pruning = cfg_.disable_model_pruning();
492 #define PRINT_CFG(CFG) user_cfg.toggle_config[#CFG] = cfg_.CFG();
493 PRINT_CFG(implementation_selector)
494 PRINT_CFG(function_optimization)
495 PRINT_CFG(common_subgraph_elimination)
496 PRINT_CFG(arithmetic_optimization)
497 PRINT_CFG(debug_stripper)
498 PRINT_CFG(constant_folding)
499 PRINT_CFG(shape_optimization)
500 PRINT_CFG(pin_to_host_optimization)
501 PRINT_CFG(layout_optimizer)
502 PRINT_CFG(remapping)
503 PRINT_CFG(loop_optimization)
504 PRINT_CFG(dependency_optimization)
505 PRINT_CFG(scoped_allocator_optimization)
506 #undef PRINT_CFG
507 user_cfg.toggle_config["auto_mixed_precision"] =
508 AutoMixedPrecisionEnabled(cfg_.auto_mixed_precision())
509 ? RewriterConfig::ON
510 : RewriterConfig::OFF;
511 user_cfg.toggle_config["auto_mixed_precision_mkl"] =
512 AutoMixedPrecisionEnabled(cfg_.auto_mixed_precision_mkl())
513 ? RewriterConfig::ON
514 : RewriterConfig::OFF;
515 user_cfg.toggle_config["memory_optimization"] =
516 MemoryOptimizerEnabled(cfg_.memory_optimization(),
517 config_proto_.graph_options()
518 .optimizer_options()
519 .global_jit_level())
520 ? RewriterConfig::ON
521 : RewriterConfig::OFF;
522 user_cfg.toggle_config["auto_parallel"] = cfg_.auto_parallel().enable()
523 ? RewriterConfig::ON
524 : RewriterConfig::OFF;
525 } else {
526 for (const string& optimizer_name : cfg_.optimizers()) {
527 if (optimizer_name == "pruning") user_cfg.disable_model_pruning = true;
528
529 #define PRINT_CFG(NAME, CONFIG) \
530 if (optimizer_name == NAME) \
531 user_cfg.toggle_config[CONFIG] = RewriterConfig::ON;
532
533 PRINT_CFG("implementation_selector", "implementation_selector")
534 PRINT_CFG("function", "function_optimization")
535 PRINT_CFG("common_subgraph_elimination", "common_subgraph_elimination")
536 PRINT_CFG("arithmetic", "arithmetic_optimization")
537 PRINT_CFG("debug_stripper", "debug_stripper")
538 PRINT_CFG("constfold", "constant_folding")
539 PRINT_CFG("shape", "shape_optimization")
540 PRINT_CFG("auto_mixed_precision", "auto_mixed_precision")
541 PRINT_CFG("auto_mixed_precision_mkl", "auto_mixed_precision_mkl")
542 PRINT_CFG("pin_to_host", "pin_to_host_optimization")
543 PRINT_CFG("layout", "layout_optimizer")
544 PRINT_CFG("remap", "remapping")
545 PRINT_CFG("loop", "loop_optimization")
546 PRINT_CFG("dependency", "dependency_optimization")
547 PRINT_CFG("memory", "memory_optimization")
548 PRINT_CFG("autoparallel", "auto_parallel")
549 PRINT_CFG("scoped_allocator", "scoped_allocator_optimization")
550 #undef PRINT_CFG
551 }
552 }
553
554 // Print logs only when plugin config has conflict with user config.
555 if (!PluginGraphOptimizerRegistry::IsConfigsConflict(user_cfg, plugin_cfg))
556 return;
557
558 ConfigList final_cfg = user_cfg;
559 // If plugin turns on `disable_model_pruning`, then `disable_model_pruning`
560 // should be true;
561 if (plugin_cfg.disable_model_pruning == true)
562 final_cfg.disable_model_pruning = true;
563 // If plugin turns off a certain optimizer, then the optimizer should be
564 // turned off;
565 for (auto& pair : plugin_cfg.toggle_config) {
566 if (plugin_cfg.toggle_config[pair.first] == RewriterConfig::OFF)
567 final_cfg.toggle_config[pair.first] = RewriterConfig::OFF;
568 }
569
570 string logs =
571 "\nConfig of optimizers\t\tUser's config\tPlugin's config\tFinal "
572 "config(User & Plugin)\n";
573 strings::StrAppend(&logs, "disable_model_pruning\t\t",
574 user_cfg.disable_model_pruning, "\t\t",
575 plugin_cfg.disable_model_pruning, "\t\t",
576 final_cfg.disable_model_pruning, "\n");
577 for (auto& pair : user_cfg.toggle_config) {
578 if (pair.first == "debug_stripper" ||
579 pair.first == "auto_mixed_precision" ||
580 pair.first == "auto_mixed_precision_mkl" ||
581 pair.first == "pin_to_host_optimization" ||
582 pair.first == "scoped_allocator_optimization") {
583 // These optimizers are turned off by default.
584 strings::StrAppend(
585 &logs, pair.first, string(32 - pair.first.size(), ' '),
586 (pair.second == RewriterConfig::ON), "\t\t",
587 (plugin_cfg.toggle_config[pair.first] == RewriterConfig::ON), "\t\t",
588 (final_cfg.toggle_config[pair.first] == RewriterConfig::ON), "\n");
589 } else {
590 // These optimizers are turned on by default.
591 strings::StrAppend(
592 &logs, pair.first, string(32 - pair.first.size(), ' '),
593 (pair.second != RewriterConfig::OFF), "\t\t",
594 (plugin_cfg.toggle_config[pair.first] != RewriterConfig::OFF), "\t\t",
595 (final_cfg.toggle_config[pair.first] != RewriterConfig::OFF), "\n");
596 }
597 }
598 LOG(WARNING) << "User's config has been changed based on plugin's config.";
599 LOG(WARNING) << logs;
600 }
601
OptimizeGraph(Cluster * cluster,GrapplerItem && item,GraphDef * optimized_graph)602 Status MetaOptimizer::OptimizeGraph(Cluster* cluster, GrapplerItem&& item,
603 GraphDef* optimized_graph) {
604 int min_graph_nodes = cfg_.min_graph_nodes() == 0 ? kDefaultMinGraphNodes
605 : cfg_.min_graph_nodes();
606 if (item.graph.node_size() < min_graph_nodes) {
607 VLOG(3) << "Skipping optimization, graph has less than " << min_graph_nodes
608 << " nodes.";
609 *optimized_graph = item.graph;
610 return Status::OK();
611 }
612
613 const uint64 start_us = Env::Default()->NowMicros();
614
615 std::vector<std::unique_ptr<GraphOptimizer>> optimizers;
616 std::set<std::string> device_types;
617 TF_RETURN_IF_ERROR(GetGraphDevice(item.graph, &device_types));
618 if (cfg_.optimizers().empty()) {
619 TF_RETURN_IF_ERROR(InitializeOptimizers(device_types, &optimizers));
620 } else {
621 TF_RETURN_IF_ERROR(InitializeOptimizersByName(device_types, &optimizers));
622 }
623 PrintUserAndPluginConfigs(device_types);
624
625 // Initialize the configured verifiers.
626 std::vector<std::unique_ptr<GraphVerifier>> inter_optimizer_verifiers;
627 std::vector<std::unique_ptr<GraphVerifier>> post_optimization_verifiers;
628 InitializeVerifiers(&inter_optimizer_verifiers, &post_optimization_verifiers);
629 if (inter_optimizer_verifiers.empty()) {
630 VLOG(2) << "No inter optimizer verifiers have been configured";
631 } else {
632 VLOG(2) << inter_optimizer_verifiers.size()
633 << " inter optimizer verifiers have been configured";
634 }
635 if (post_optimization_verifiers.empty()) {
636 VLOG(2) << "No post optimization verifiers have been configured";
637 } else {
638 VLOG(2) << post_optimization_verifiers.size()
639 << " post optimization verifiers have been configured";
640 }
641
642 VLOG(2) << "Optimize GrapplerItem: item.id=" << item.id
643 << " num_optimizers=" << optimizers.size()
644 << ", num nodes = " << item.graph.node_size();
645
646 if (optimizers.empty()) {
647 VLOG(3) << "Skipping graph optimization, no optimizers registered";
648 *optimized_graph = item.graph;
649 return Status::OK();
650 }
651
652 // Invariant: optimized_graph contains the most recently optimized version of
653 // the graph.
654 auto original_producer = item.graph.versions().producer();
655 optimized_graph->Swap(&item.graph);
656
657 GraphOptimizationResult optimization_result(item.id);
658 GraphOptimizer* sa_optimizer = nullptr;
659
660 // Constants in the graph are normally compressed after model_pruner.
661 // Do it here if model pruner is disabled.
662 if (cfg_.disable_model_pruning()) {
663 CompressConstants(optimized_graph);
664 }
665
666 for (int iteration = 0; iteration < NumIterations(cfg_); ++iteration) {
667 // Don't bother optimizing further if the graph is already tiny.
668 if (optimized_graph->node_size() < min_graph_nodes) {
669 VLOG(3) << "Stopping after iteration " << iteration
670 << ", graph is tiny (#nodes = " << optimized_graph->node_size()
671 << " < " << min_graph_nodes << ")";
672 break;
673 }
674
675 VLOG(4) << "Starting optimization iteration " << iteration;
676 if (VLOG_IS_ON(4)) {
677 DumpGraphDefToFile(
678 strings::StrCat("before_MetaOptimizer_iteration_", iteration, "_",
679 reinterpret_cast<uintptr_t>(optimized_graph)),
680 *optimized_graph);
681 }
682
683 for (const auto& optimizer : optimizers) {
684 GRAPPLER_RETURN_IF_DEADLINE_EXCEEDED();
685 // Some optimizers can run only once.
686 if (iteration > 0 && IsRunOnceOptimizer(optimizer->name())) continue;
687 #ifndef ENABLE_MKL
688 // Some must run only on the last iteration.
689 if (optimizer->name() == "scoped_allocator_optimizer") {
690 if (sa_optimizer == nullptr) sa_optimizer = optimizer.get();
691 continue;
692 }
693 #endif
694
695 TF_RETURN_IF_ERROR(RunOptimizer(optimizer.get(), cluster, &item,
696 optimized_graph, &optimization_result));
697
698 if (iteration == 0 && optimizer->name() == "model_pruner") {
699 CompressConstants(optimized_graph);
700 }
701
702 if (VLOG_IS_ON(4)) {
703 DumpGraphDefToFile(
704 strings::StrCat("after_MetaOptimizer_iteration_", iteration, "_",
705 optimizer->name(), "_",
706 reinterpret_cast<uintptr_t>(optimized_graph)),
707 *optimized_graph);
708 }
709 for (const auto& verifier : inter_optimizer_verifiers) {
710 // TODO(ashwinm): Need to enforce verification_deadline.
711 TF_RETURN_IF_ERROR(verifier->Verify(*optimized_graph));
712 }
713 }
714 if (VLOG_IS_ON(4)) {
715 DumpGraphDefToFile(
716 strings::StrCat("after_MetaOptimizer_iteration_", iteration, "_",
717 reinterpret_cast<uintptr_t>(optimized_graph)),
718 *optimized_graph);
719 }
720 // TODO(ashwinm): Need to enforce verification_deadline.
721 for (const auto& verifier : post_optimization_verifiers) {
722 TF_RETURN_IF_ERROR(verifier->Verify(*optimized_graph));
723 }
724 }
725 #ifndef ENABLE_MKL
726 // ScopedAllocatorOptimizer must run last.
727 if (sa_optimizer != nullptr) {
728 TF_RETURN_IF_ERROR(RunOptimizer(sa_optimizer, cluster, &item,
729 optimized_graph, &optimization_result));
730 GRAPPLER_RETURN_IF_DEADLINE_EXCEEDED();
731 }
732 #endif
733
734 bool is_optimized = std::find_if(optimization_result.results.begin(),
735 optimization_result.results.end(),
736 [](const OptimizerResult& result) {
737 return result.status.ok();
738 }) != optimization_result.results.end();
739
740 // Record graph optimization result.
741 optimization_results_.push_back(optimization_result);
742
743 if (is_optimized) {
744 TF_RETURN_IF_ERROR(TopologicalSort(optimized_graph));
745 ReassignColocation(optimized_graph);
746 // Make sure that the optimizers preserved the graph version.
747 DCHECK_EQ(optimized_graph->versions().producer(), original_producer);
748 }
749
750 const uint64 end_us = Env::Default()->NowMicros();
751 metrics::UpdateGrapplerPassTime("OptimizeMainGraph", end_us - start_us);
752
753 return Status::OK();
754 }
755
RunOptimizer(GraphOptimizer * optimizer,Cluster * cluster,GrapplerItem * optimized_item,GraphDef * optimized_graph,GraphOptimizationResult * optimization_result)756 Status MetaOptimizer::RunOptimizer(
757 GraphOptimizer* optimizer, Cluster* cluster, GrapplerItem* optimized_item,
758 GraphDef* optimized_graph, GraphOptimizationResult* optimization_result) {
759 const uint64 start_us = Env::Default()->NowMicros();
760
761 // If optimizer doesn't need a function library, we will replace it with a
762 // stub before running optimization, and will put it back at the end.
763 FunctionDefLibrary optimized_graph_function_library;
764 const bool is_function_library_aware = optimizer->UsesFunctionLibrary();
765
766 // Replace function library in optimized graph with a stub.
767 if (!is_function_library_aware) {
768 VLOG(3) << "Replace function library with a stub for " << optimizer->name();
769 optimized_graph_function_library.Swap(optimized_graph->mutable_library());
770 *optimized_graph->mutable_library() =
771 GetFunctionDefLibraryStub(optimized_graph_function_library);
772 }
773
774 // This swaps the current optimized_graph into optimized item and
775 // resets optimized_graph to an empty graph.
776 optimized_graph->Swap(&optimized_item->graph);
777 *optimized_graph = GraphDef();
778 optimizer->set_deadline_usec(this->deadline_usec());
779 Status status =
780 optimizer->Optimize(cluster, *optimized_item, optimized_graph);
781 const uint64 end_us = Env::Default()->NowMicros();
782 const float duration_ms = (end_us - start_us) / 1000.0f;
783 metrics::UpdateGrapplerPassTime(optimizer->name(), end_us - start_us);
784
785 string message;
786 if (!status.ok()) {
787 optimized_graph->Swap(&optimized_item->graph);
788 if (errors::IsAborted(status)) {
789 // By convention we (ab-)use the Aborted error code to signal that the
790 // optimizer returned without performing any changes to the graph.
791 message = strings::StrCat(optimizer->name(),
792 " did nothing. time = ", duration_ms, "ms.");
793 // Swallow the non-critical error.
794 status = Status::OK();
795 } else if (errors::IsDeadlineExceeded(status)) {
796 message =
797 strings::StrCat(status.ToString(), ", time = ", duration_ms, "ms.");
798 LOG(WARNING) << optimizer->name() << " failed: " << message;
799 } else {
800 message = status.ToString();
801 LOG(ERROR) << optimizer->name() << " failed: " << message;
802 }
803 } else {
804 message = strings::StrCat(
805 PrintSizesBeforeAfter(optimized_item->graph, *optimized_graph),
806 ", time = ", duration_ms, "ms.");
807 VLOG(1) << optimizer->name() << ": " << message;
808 }
809
810 // Swap function library back into the main graph.
811 if (!is_function_library_aware) {
812 optimized_graph->mutable_library()->Swap(&optimized_graph_function_library);
813 }
814
815 OptimizerResult optimizer_result{optimizer->name(), message, status};
816 optimization_result->results.push_back(optimizer_result);
817
818 if (!status.ok() && cfg_.fail_on_optimizer_errors()) return status;
819
820 return Status::OK();
821 }
822
823 // Propagates `_tf_data_function` attributes from functions to their callees.
PropagateTFDataAttrs(const FunctionLibraryDefinition & flib,FunctionDefLibrary & fdef_lib)824 void PropagateTFDataAttrs(const FunctionLibraryDefinition& flib,
825 FunctionDefLibrary& fdef_lib) {
826 // Collect functions that need the attribute in this set.
827 absl::flat_hash_set<std::string> tf_data_functions;
828 std::function<void(const std::string&)> collect_tf_data_functions_dfs =
829 [&](const std::string& func_name) -> void {
830 const FunctionDef* func_def = flib.Find(func_name);
831 // Skip functions that are not reachable from the optimized graph.
832 if (func_def == nullptr) return;
833
834 // Return if we already found and added this function.
835 if (tf_data_functions.contains(func_name)) return;
836
837 // We only get here if the function is (directly or indirectly) called from
838 // a tf.data function, so add it to the set.
839 tf_data_functions.insert(func_name);
840
841 // Proceed with DFS for functions called from current function.
842 for (const NodeDef& node : func_def->node_def()) {
843 if (flib.Contains(node.op())) {
844 // This is a function call node.
845 collect_tf_data_functions_dfs(node.op());
846 }
847 // Check if there are functions in attributes.
848 for (const auto& attr : node.attr()) {
849 const AttrValue& attr_value = attr.second;
850 if (attr_value.has_func()) {
851 collect_tf_data_functions_dfs(attr_value.func().name());
852 }
853 if (attr_value.has_list()) {
854 for (const auto& func : attr_value.list().func()) {
855 collect_tf_data_functions_dfs(func.name());
856 }
857 }
858 }
859 }
860 };
861 // Perform DFS for all tf.data functions in `fdef_lib`.
862 for (const auto& func_def : fdef_lib.function()) {
863 const std::string& func_name = func_def.signature().name();
864 if (data::IsTFDataFunction(func_def))
865 collect_tf_data_functions_dfs(func_name);
866 }
867 // Set attribute for tf.data functions. We cannot do this in the DFS directly
868 // because `FunctionLibraryDefinition` does not seem to provide mutable access
869 // to a `FunctionDef`.
870 for (FunctionDef& func_def : *fdef_lib.mutable_function()) {
871 const std::string& func_name = func_def.signature().name();
872 if (tf_data_functions.contains(func_name) &&
873 !data::IsTFDataFunction(func_def)) {
874 VLOG(2) << "Marking " << func_name << " as tf.data function";
875 (*func_def.mutable_attr())[data::kTFDataFunction].set_b(true);
876 }
877 }
878 }
879
OptimizeConsumeItem(Cluster * cluster,GrapplerItem && item,GraphDef * optimized_graph)880 Status MetaOptimizer::OptimizeConsumeItem(Cluster* cluster, GrapplerItem&& item,
881 GraphDef* optimized_graph) {
882 const uint64 start_us = Env::Default()->NowMicros();
883
884 VLOG(1) << "Starting optimization for grappler item: " << item.id;
885 optimization_results_.clear();
886
887 // Constructs a FunctionLibraryDefinition with functions that are reachable
888 // from the nodes of the graph.
889 const auto minimized_flib =
890 [](const GraphDef& graph) -> FunctionLibraryDefinition {
891 return FunctionLibraryDefinition(OpRegistry::Global(), graph.library())
892 .ReachableDefinitions(graph);
893 };
894
895 // 0. Original graph might contain a huge function library, that is mostly
896 // unused. This library copied over by each individual Grappler optimizer,
897 // which adds a huge overhead. Before starting optimization passes we just
898 // remove all the unreachable functions.
899 // TODO(ezhulenev): Construct reachable function library definition directly
900 // from the proto without constructing temporary FunctionLibraryDefinition.
901 int old_library_size = item.graph.library().function_size();
902 *item.graph.mutable_library() = minimized_flib(item.graph).ToProto();
903 int new_library_size = item.graph.library().function_size();
904
905 VLOG(1) << absl::Substitute(
906 "Deleted $0 unreachable functions from the graph (library size = $1)",
907 old_library_size - new_library_size, new_library_size);
908
909 // Save a few small fields from item before we move it.
910 bool optimize_function_library =
911 item.optimization_options().optimize_function_library;
912 const auto producer = item.graph.versions().producer();
913
914 // 1. Optimize main graph
915 TF_RETURN_IF_ERROR(OptimizeGraph(cluster, std::move(item), optimized_graph));
916 VLOG(1) << "Optimized main graph.";
917 GRAPPLER_RETURN_IF_DEADLINE_EXCEEDED();
918
919 // 2. Optimize functions reachable from the optimized graph.
920 FunctionLibraryDefinition flib = minimized_flib(*optimized_graph);
921 using NodeDefs = protobuf::RepeatedPtrField<NodeDef>;
922
923 // Find functions for which we might need to compute a gradient at runtime.
924 absl::flat_hash_set<string> differentiable_functions;
925
926 const auto find_differentiable_functions =
927 [&](const NodeDefs& nodes) -> void {
928 for (const NodeDef& node : nodes) {
929 if (IsSymbolicGradient(node)) {
930 const auto* f_attr = gtl::FindOrNull(node.attr(), "f");
931 if (f_attr) differentiable_functions.insert(f_attr->func().name());
932 }
933 }
934 };
935
936 // SymbolicGradient nodes inside the main graph.
937 find_differentiable_functions(optimized_graph->node());
938 // SymbolicGradient nodes inside the function library.
939 for (const FunctionDef& function : optimized_graph->library().function()) {
940 find_differentiable_functions(function.node_def());
941 }
942
943 // Find functions that will be compiled by XLA later
944 // We do it by looking for XlaLaunch ops that call functions,
945 // then depth first search down those functions to find transitive functions.
946 // Grappler rewrites can potentially add nodes that are
947 // not supported by XLA, so we choose to skip such functions when we optimize
948 // the function library.
949 absl::flat_hash_set<string> xla_compiled_functions;
950 std::function<void(const string&)> find_all_functions;
951 find_all_functions = [&](const string& func) -> void {
952 // Ignore call cycles in the graph
953 if (xla_compiled_functions.contains(func)) return;
954 // Find func in the flib
955 const FunctionDef* func_def = flib.Find(func);
956 CHECK(func_def) << "not found: " << func;
957 // Mark function to be ignored by grappler
958 xla_compiled_functions.insert(func);
959 // Depth first search through the func for transitively called funcs
960 for (const NodeDef& node : func_def->node_def()) {
961 for (const auto attr : node.attr()) {
962 const AttrValue& attr_value = attr.second;
963 if (attr_value.has_func()) {
964 find_all_functions(attr_value.func().name());
965 }
966 }
967 }
968 };
969
970 auto find_xla_compiled_functions = [&](const NodeDefs& nodes) -> void {
971 NameAttrList function;
972 for (const NodeDef& node : nodes) {
973 // Look only for XlaLaunch nodes that call a function
974 if (!IsXlaLaunch(node)) continue;
975 if (!GetNodeAttr(node, "function", &function).ok()) continue;
976 // Find all transitively called functions
977 find_all_functions(function.name());
978 }
979 };
980
981 // XlaLaunch ops inside the main graph ...
982 find_xla_compiled_functions(optimized_graph->node());
983 // ... and inside the function library.
984 for (const FunctionDef& function : optimized_graph->library().function()) {
985 find_xla_compiled_functions(function.node_def());
986 }
987 // Propagate `_tf_data_function` attributes from functions to their callees.
988 PropagateTFDataAttrs(flib, *optimized_graph->mutable_library());
989
990 // Optimize each function only once.
991 absl::flat_hash_set<string> optimized_funcs;
992 while (optimize_function_library) {
993 optimize_function_library = false;
994
995 int function_idx = 0;
996 for (const FunctionDef& func : optimized_graph->library().function()) {
997 GRAPPLER_RETURN_IF_DEADLINE_EXCEEDED();
998
999 const string& func_name = func.signature().name();
1000
1001 // Skip functions that are not reachable from the optimized graph.
1002 if (!flib.Contains(func_name)) continue;
1003 // Skip already optimized functions.
1004 if (optimized_funcs.contains(func_name)) continue;
1005 // Skip functions that will be compiled by XLA.
1006 if (xla_compiled_functions.contains(func_name)) continue;
1007
1008 // Skip parametrized functions (function type or body is defined only at
1009 // function call time by caller node attributes).
1010 // They should be specialized to their instantiation type parameters by
1011 // the function optimizer, before we can optimize function body.
1012 if (IsParametrized(func)) continue;
1013
1014 // Skip tf.data functions as they are optimized by tf.data meta optimizer
1015 // and in function instantiation.
1016 if (data::IsTFDataFunction(func)) continue;
1017
1018 VLOG(3) << "Optimize function: function=" << func_name << " ["
1019 << function_idx++ << " of "
1020 << optimized_graph->library().function_size() << "]";
1021
1022 // Function optimization might specialize nested function calls, so we
1023 // have to reset the flag and do at least one more pass over the library.
1024 optimize_function_library = true;
1025 optimized_funcs.insert(func_name);
1026
1027 // Make a GrapplerItem from a FunctionDef.
1028 GrapplerFunctionItem func_item;
1029 TF_RETURN_IF_ERROR(
1030 MakeGrapplerFunctionItem(func, flib, producer, &func_item));
1031
1032 // If we need to compute the gradient of optimized function at runtime, we
1033 // can't perform non-differentiable rewrites.
1034 func_item.optimization_options().allow_non_differentiable_rewrites =
1035 !differentiable_functions.contains(func_name);
1036
1037 // Device set available to the function is defined only by the runtime,
1038 // when we instantiate and execute the function. We can't use all devices
1039 // available to the main graph, because after partitioning the function
1040 // call node might execute on a remote worker.
1041 if (!func_item.devices().empty()) {
1042 return errors::Internal("GrapplerFunctionItem devices must be empty.");
1043 }
1044
1045 // We are not allowed to prune certain types of ops from the graph
1046 // instantiated by the function definition, because we must guarantee
1047 // function execution semantics wrt side effects (see
1048 // function_optimizer.cc).
1049 func_item.optimization_options().allow_pruning_stateful_and_dataset_ops =
1050 false;
1051
1052 // Optimize function body graph.
1053 GraphDef optimized_func_graph;
1054 if (IsTPUGraphDef(*optimized_graph)) {
1055 // Skip optimizing functions if this is a TPU graph. Currently, Grappler
1056 // passes do not handle TPU functions correctly in a variety of ways
1057 // (Note that due to the pre-placement TPU graph rewriting passes, the
1058 // TPU-related ops are encapsulated away into functions). For example,
1059 // TPU graphs contain TPUReplicateMetadata node that carries relevant
1060 // TPU metadata and Grappler passes could prune that away. Grappler
1061 // passes could also cause issues around shape inference. Since the
1062 // desired and existing behavior is to not optimize TPU functions with
1063 // Grappler, this check preserves that. The only exception is
1064 // implementation selector what is required to swap in some TPU specific
1065 // lowering code and is verified the work correctly on TPUs.
1066 ImplementationSelector implementation_selector;
1067
1068 // Implementation selector needs to have access to valid function
1069 // signature and attributes, and it doesn't need actual function body.
1070 FunctionDefLibrary func_item_function_library;
1071 func_item_function_library.Swap(func_item.graph.mutable_library());
1072 *func_item.graph.mutable_library() =
1073 GetFunctionDefLibraryStub(func_item_function_library);
1074
1075 TF_RETURN_IF_ERROR(implementation_selector.Optimize(
1076 cluster, func_item, &optimized_func_graph));
1077 } else {
1078 GrapplerFunctionItem func_item_copy = func_item;
1079 TF_RETURN_IF_ERROR(OptimizeGraph(cluster, std::move(func_item_copy),
1080 &optimized_func_graph));
1081 }
1082
1083 // Function body optimization might have created new specialized
1084 // functions for each instantiation context. Add them to the library.
1085 for (const FunctionDef& func_def :
1086 optimized_func_graph.library().function()) {
1087 if (flib.Find(func_def.signature().name()) == nullptr) {
1088 TF_RETURN_IF_ERROR(flib.AddFunctionDef(func_def));
1089 }
1090 }
1091
1092 // Convert optimized graph back to FunctionDef.
1093 FunctionDef optimized_func;
1094 func_item.SwapFunctionBody(std::move(optimized_func_graph));
1095 TF_RETURN_IF_ERROR(MakeFunctionDef(func_item, flib, &optimized_func));
1096
1097 // Replace optimized function with a new FunctionDef.
1098 TF_RETURN_IF_ERROR(flib.ReplaceFunction(func_name, optimized_func));
1099 }
1100
1101 // If optimized at least one function, update the graph library.
1102 if (optimize_function_library) {
1103 *optimized_graph->mutable_library() = flib.ToProto();
1104 }
1105 }
1106
1107 VLOG(1) << "Optimized " << optimized_funcs.size()
1108 << " functions: " << absl::StrJoin(optimized_funcs, ", ");
1109 VLOG(3) << "Optimized graph =\n" << optimized_graph->DebugString();
1110 if (VLOG_IS_ON(1)) {
1111 DumpGraphDefToFile(
1112 strings::StrCat("after_MetaOptimizer_",
1113 reinterpret_cast<uintptr_t>(optimized_graph)),
1114 *optimized_graph);
1115 }
1116
1117 const uint64 end_us = Env::Default()->NowMicros();
1118 metrics::UpdateGrapplerPassTime("*", end_us - start_us);
1119
1120 return Status::OK();
1121 }
1122
GetResultString() const1123 string MetaOptimizer::GetResultString() const {
1124 std::string result_string;
1125 for (const GraphOptimizationResult& graph_result : optimization_results_) {
1126 absl::StrAppend(&result_string,
1127 "Optimization results for grappler item: ", graph_result.id,
1128 "\n");
1129 for (const OptimizerResult& result : graph_result.results) {
1130 absl::StrAppend(&result_string, " ", result.optimizer_name, ": ",
1131 result.message, "\n");
1132 }
1133 }
1134 return result_string;
1135 }
1136
PrintResult()1137 void MetaOptimizer::PrintResult() { LOG(INFO) << GetResultString(); }
1138
MetaOptimizerEnabled(const ConfigProto & cfg)1139 bool MetaOptimizerEnabled(const ConfigProto& cfg) {
1140 const auto& rewrite_cfg = cfg.graph_options().rewrite_options();
1141 if (rewrite_cfg.disable_meta_optimizer()) {
1142 return false;
1143 }
1144 return !rewrite_cfg.disable_model_pruning() ||
1145 rewrite_cfg.layout_optimizer() != RewriterConfig::OFF ||
1146 rewrite_cfg.function_optimization() != RewriterConfig::OFF ||
1147 rewrite_cfg.constant_folding() != RewriterConfig::OFF ||
1148 rewrite_cfg.shape_optimization() != RewriterConfig::OFF ||
1149 rewrite_cfg.remapping() != RewriterConfig::OFF ||
1150 rewrite_cfg.common_subgraph_elimination() != RewriterConfig::OFF ||
1151 rewrite_cfg.arithmetic_optimization() != RewriterConfig::OFF ||
1152 rewrite_cfg.loop_optimization() != RewriterConfig::OFF ||
1153 rewrite_cfg.dependency_optimization() != RewriterConfig::OFF ||
1154 rewrite_cfg.auto_parallel().enable() ||
1155 rewrite_cfg.memory_optimization() != RewriterConfig::NO_MEM_OPT ||
1156 rewrite_cfg.debug_stripper() == RewriterConfig::ON ||
1157 #ifndef ENABLE_MKL
1158 rewrite_cfg.scoped_allocator_optimization() == RewriterConfig::ON ||
1159 #endif
1160 rewrite_cfg.pin_to_host_optimization() == RewriterConfig::ON ||
1161 AutoMixedPrecisionEnabled(rewrite_cfg.auto_mixed_precision()) ||
1162 AutoMixedPrecisionEnabled(rewrite_cfg.auto_mixed_precision_mkl()) ||
1163 !rewrite_cfg.optimizers().empty() ||
1164 !rewrite_cfg.custom_optimizers().empty();
1165 }
1166
RunMetaOptimizer(GrapplerItem && item,const ConfigProto & cfg,DeviceBase * cpu_device,Cluster * cluster,GraphDef * optimized_graph)1167 Status RunMetaOptimizer(GrapplerItem&& item, const ConfigProto& cfg,
1168 DeviceBase* cpu_device, Cluster* cluster,
1169 GraphDef* optimized_graph) {
1170 MetaOptimizer optimizer(cpu_device, cfg);
1171 optimizer.set_deadline_usec(
1172 DeadlineMicroSeconds(cfg.graph_options().rewrite_options()));
1173 return optimizer.OptimizeConsumeItem(cluster, std::move(item),
1174 optimized_graph);
1175 }
1176
OptimizeGraph(std::vector<string> ret_node_names,std::vector<string> keep_node_names,FunctionLibraryDefinition * flib,const DeviceSet & device_set,Device * cpu_device,const ConfigProto & config_proto,const string & grappler_item_id,const GrapplerItem::OptimizationOptions & optimization_options,std::unique_ptr<tensorflow::Graph> * g)1177 Status OptimizeGraph(
1178 std::vector<string> ret_node_names, std::vector<string> keep_node_names,
1179 FunctionLibraryDefinition* flib, const DeviceSet& device_set,
1180 Device* cpu_device, const ConfigProto& config_proto,
1181 const string& grappler_item_id,
1182 const GrapplerItem::OptimizationOptions& optimization_options,
1183 std::unique_ptr<tensorflow::Graph>* g) {
1184 if (!tensorflow::grappler::MetaOptimizerEnabled(config_proto)) {
1185 return Status::OK();
1186 }
1187
1188 tensorflow::grappler::GrapplerItem item;
1189 item.id = grappler_item_id;
1190 item.optimization_options() = optimization_options;
1191
1192 // Add all available devices so that inlined function can be placed.
1193 for (const Device* d : device_set.devices()) {
1194 Status added_device = item.AddDevice(d->name());
1195 if (!added_device.ok()) VLOG(3) << added_device.error_message();
1196 }
1197 VLOG(3) << "Grappler available devices: "
1198 << absl::StrJoin(item.devices(), ", ");
1199
1200 // Add fetches so that the graph can be pruned.
1201 item.fetch.swap(ret_node_names);
1202
1203 // Add noes that can't be removed from the graph.
1204 item.keep_ops = std::move(keep_node_names);
1205
1206 (*g)->ToGraphDef(&item.graph);
1207
1208 if (flib) {
1209 *item.graph.mutable_library() = flib->ToProto();
1210 }
1211
1212 tensorflow::GraphDef out_graph;
1213 tensorflow::grappler::VirtualCluster cluster(&device_set);
1214 // TODO(nareshmodi): Consider adding and using the more generic GraphOptions
1215 // proto (which also contain the OptimizerOptions).
1216 TF_RETURN_IF_ERROR(tensorflow::grappler::RunMetaOptimizer(
1217 std::move(item), config_proto, cpu_device, &cluster, &out_graph));
1218
1219 std::unique_ptr<tensorflow::Graph> optimized_graph(
1220 new tensorflow::Graph(OpRegistry::Global()));
1221
1222 // Copy optimized functions back to the overlay lib.
1223 if (flib) {
1224 for (const FunctionDef& fdef : out_graph.library().function()) {
1225 const string& func_name = fdef.signature().name();
1226 if (flib->Contains(func_name)) {
1227 StackTracesMap stack_traces = flib->GetStackTraces(func_name);
1228 TF_RETURN_IF_ERROR(
1229 flib->ReplaceFunction(func_name, fdef, stack_traces));
1230 } else {
1231 TF_RETURN_IF_ERROR(
1232 flib->AddFunctionDef(fdef, flib->GetStackTraces(func_name)));
1233 }
1234 }
1235 }
1236
1237 TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(
1238 GraphConstructorOptions(), std::move(out_graph), optimized_graph.get()));
1239
1240 // The graph conversion sets the requested device names but not the
1241 // assigned device names. However, since at this point the graph is
1242 // placed TF expects an assigned device name for every node. Therefore
1243 // we copy the requested device into the assigned device field.
1244 for (Node* node : optimized_graph->nodes()) {
1245 if (node->IsOp() && node->assigned_device_name().empty()) {
1246 if (node->requested_device().empty()) {
1247 return errors::Internal(
1248 "Either placer did not place the node or Grappler did not "
1249 "copy the assigned device. Contact Grappler team since latter "
1250 "is more likely. Node=",
1251 node->name(),
1252 " Graph: ", optimized_graph->ToGraphDefDebug().DebugString());
1253 }
1254 node->set_assigned_device_name(node->requested_device());
1255 }
1256 }
1257
1258 *g = std::move(optimized_graph);
1259 return Status::OK();
1260 }
1261
1262 } // namespace grappler
1263 } // namespace tensorflow
1264