1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/grappler/optimizers/dependency_optimizer.h"
17
18 #include <unordered_set>
19
20 #include "absl/container/flat_hash_map.h"
21 #include "tensorflow/core/framework/node_def.pb.h"
22 #include "tensorflow/core/framework/node_def_util.h"
23 #include "tensorflow/core/framework/op.h"
24 #include "tensorflow/core/grappler/costs/graph_properties.h"
25 #include "tensorflow/core/grappler/grappler_item.h"
26 #include "tensorflow/core/grappler/op_types.h"
27 #include "tensorflow/core/grappler/optimizers/constant_folding.h"
28 #include "tensorflow/core/grappler/utils.h"
29 #include "tensorflow/core/grappler/utils/topological_sort.h"
30 #include "tensorflow/core/lib/core/errors.h"
31 #include "tensorflow/core/lib/core/stringpiece.h"
32 #include "tensorflow/core/lib/gtl/inlined_vector.h"
33 #include "tensorflow/core/lib/strings/str_util.h"
34 #include "tensorflow/core/lib/strings/strcat.h"
35 #include "tensorflow/core/util/device_name_utils.h"
36
37 namespace tensorflow {
38 namespace grappler {
39
40 namespace {
41
RemoveControlInput(NodeDef * node,const string & control_input_to_remove,NodeMap * node_map)42 bool RemoveControlInput(NodeDef* node, const string& control_input_to_remove,
43 NodeMap* node_map) {
44 for (int pos = node->input_size() - 1; pos >= 0; --pos) {
45 const string& input = node->input(pos);
46 if (input[0] != '^') break;
47 if (input == control_input_to_remove) {
48 node->mutable_input()->SwapElements(pos, node->input_size() - 1);
49 node->mutable_input()->RemoveLast();
50 node_map->RemoveOutput(NodeName(input), node->name());
51 return true;
52 }
53 }
54 return false;
55 }
56
57 } // namespace
58
SafeToRemoveIdentity(const NodeDef & node) const59 bool DependencyOptimizer::SafeToRemoveIdentity(const NodeDef& node) const {
60 if (!IsIdentity(node) && !IsIdentityN(node)) {
61 return true;
62 }
63
64 if (nodes_to_preserve_.find(node.name()) != nodes_to_preserve_.end()) {
65 return false;
66 }
67 if (!fetch_nodes_known_) {
68 // The output values of this node may be needed.
69 return false;
70 }
71
72 if (node.input_size() < 1) {
73 // Node lacks input, is invalid
74 return false;
75 }
76
77 const NodeDef* input = node_map_->GetNode(NodeName(node.input(0)));
78 CHECK(input != nullptr) << "node = " << node.name()
79 << " input = " << node.input(0);
80 // Don't remove Identity nodes corresponding to Variable reads or following
81 // Recv.
82 if (IsVariable(*input) || IsRecv(*input)) {
83 return false;
84 }
85 for (const auto& consumer : node_map_->GetOutputs(node.name())) {
86 if (node.input_size() > 1 && (IsRetval(*consumer) || IsMerge(*consumer))) {
87 return false;
88 }
89 if (IsSwitch(*input)) {
90 for (const string& consumer_input : consumer->input()) {
91 if (consumer_input == AsControlDependency(node.name())) {
92 return false;
93 }
94 }
95 }
96 }
97 return true;
98 }
99
SafeToConvertToNoOp(const NodeDef & node) const100 bool DependencyOptimizer::SafeToConvertToNoOp(const NodeDef& node) const {
101 if (HasRegularOutputs(node, *node_map_)) {
102 // The output values of this node may be needed.
103 VLOG(3) << "Not safe to convert '" << node.name()
104 << " to NoOp. Node has outputs.";
105 return false;
106 }
107 if (!fetch_nodes_known_) {
108 VLOG(3) << "Not safe to convert '" << node.name()
109 << " to NoOp. Fetches unknown.";
110 return false;
111 }
112 if (nodes_to_preserve_.find(node.name()) != nodes_to_preserve_.end()) {
113 VLOG(3) << "Not safe to convert to NoOp: " << node.name()
114 << " is in preserve set.";
115 return false;
116 }
117 if (IsMerge(node) || IsSwitch(node) || ModifiesFrameInfo(node)) {
118 VLOG(3) << "Not safe to convert '" << node.name()
119 << " to NoOp. Node modifies frame info.";
120 return false;
121 }
122 // Ops reading variables are marked as stateful, but are safe to remove if
123 // redundant.
124 static const absl::flat_hash_set<string>* gather_ops =
125 new absl::flat_hash_set<string>{"Gather", "GatherV2", "GatherNd",
126 "ResourceGather", "ResourceGatherNd"};
127 const bool is_variable_read =
128 IsReadVariableOp(node) || IsReadVariablesOp(node) ||
129 gather_ops->find(node.op()) != gather_ops->end();
130 if (!is_variable_read && !IsFreeOfSideEffect(node)) {
131 VLOG(3) << "Not safe to convert '" << node.name()
132 << " to NoOp. Node has side effect.";
133 return false;
134 }
135 if (node.op().rfind("Submodel", 0) == 0) {
136 return false;
137 }
138 const OpDef* op_def = nullptr;
139 Status status = OpRegistry::Global()->LookUpOpDef(node.op(), &op_def);
140 if (!status.ok() || op_def->output_arg_size() == 0) {
141 return false;
142 }
143 const std::unordered_set<string> do_not_rewrite_ops{
144 "Assert", "CheckNumerics", "_Retval",
145 "_Arg", "_ParallelConcatUpdate", "TPUExecute",
146 "TPUCompile", "ControlTrigger"};
147 if (do_not_rewrite_ops.find(node.op()) != do_not_rewrite_ops.end()) {
148 return false;
149 }
150 if (!SafeToRemoveIdentity(node)) {
151 return false;
152 }
153 return true;
154 }
155
NumEdgesIfBypassed(const NodeDef & node,const std::vector<NodeDef * > & output_nodes) const156 int DependencyOptimizer::NumEdgesIfBypassed(
157 const NodeDef& node, const std::vector<NodeDef*>& output_nodes) const {
158 const bool is_multi_input_identity_n =
159 IsIdentityN(node) && !IsIdentityNSingleInput(node);
160 const int num_outputs = output_nodes.size();
161 const int num_inputs = node.input_size();
162
163 if (is_multi_input_identity_n) {
164 // multi-input identity_n with input/output control dependencies will likely
165 // increase number of edges after optimization.
166 int num_edges_if_bypassed(0);
167 for (const string& input_node_name : node.input()) {
168 if (IsControlInput(input_node_name)) {
169 num_edges_if_bypassed += num_outputs;
170 } else {
171 ++num_edges_if_bypassed;
172 }
173 }
174
175 for (auto consumer : output_nodes) {
176 for (int j = 0; j < consumer->input_size(); ++j) {
177 const TensorId consumer_input = ParseTensorName(consumer->input(j));
178 if (consumer_input.node() == node.name()) {
179 if (IsControlInput(consumer_input)) {
180 num_edges_if_bypassed += num_inputs;
181 } else {
182 ++num_edges_if_bypassed;
183 }
184 }
185 }
186 }
187 return num_edges_if_bypassed;
188 } else {
189 return num_inputs * num_outputs;
190 }
191 }
192
BypassingNodeIsBeneficial(const NodeDef & node,const std::vector<NodeDef * > & input_nodes,const std::vector<NodeDef * > & output_nodes) const193 bool DependencyOptimizer::BypassingNodeIsBeneficial(
194 const NodeDef& node, const std::vector<NodeDef*>& input_nodes,
195 const std::vector<NodeDef*>& output_nodes) const {
196 const bool is_identity = IsIdentity(node) || IsIdentityNSingleInput(node);
197 const bool is_multi_input_identity_n =
198 IsIdentityN(node) && !IsIdentityNSingleInput(node);
199 const int num_outputs = output_nodes.size();
200 const int num_inputs = node.input_size();
201
202 if (NumEdgesIfBypassed(node, output_nodes) > num_inputs + num_outputs) {
203 return false;
204 }
205
206 // Make sure that we don't increase the number of edges that cross
207 // device boundaries.
208 if ((num_inputs == 1 && num_outputs > 1 &&
209 input_nodes[0]->device() != node.device()) ||
210 (num_inputs > 1 && num_outputs == 1 &&
211 output_nodes[0]->device() != node.device())) {
212 return false;
213 }
214
215 // TODO(rmlarsen): Not all device crossings are equally expensive.
216 // Assign a cost to each based on device affinity and compute a
217 // cost before and after.
218 const string& node_dev = node.device();
219 int num_cross_in = 0;
220 for (NodeDef* input_node : input_nodes) {
221 num_cross_in += static_cast<int>(input_node->device() != node_dev);
222 }
223 int num_cross_out = 0;
224 for (NodeDef* output_node : output_nodes) {
225 num_cross_out += static_cast<int>(output_node->device() != node_dev);
226 }
227
228 // Make sure we do not increase the number of device crossings.
229 const int num_cross_before = num_cross_in + num_cross_out;
230 int num_cross_after = 0;
231 for (NodeDef* input_node : input_nodes) {
232 for (NodeDef* output_node : output_nodes) {
233 num_cross_after +=
234 static_cast<int>(input_node->device() != output_node->device());
235 }
236 }
237 if (num_cross_after > num_cross_before) {
238 return false;
239 }
240
241 if ((is_identity || is_multi_input_identity_n) && num_cross_in > 0 &&
242 num_cross_out > 0 && num_cross_after > 0) {
243 // This identity node follows a device crossing, so it might be
244 // following a _Recv node after partitioning. Do not remove such nodes,
245 // unless they only have consumers on the same device as themselves.
246 return false;
247 }
248
249 return true;
250 }
251
OptimizeNode(int node_idx,SetVector<int> * nodes_to_simplify,std::set<int> * nodes_to_delete)252 void DependencyOptimizer::OptimizeNode(int node_idx,
253 SetVector<int>* nodes_to_simplify,
254 std::set<int>* nodes_to_delete) {
255 NodeDef* node = optimized_graph_->mutable_node(node_idx);
256 const bool is_noop = IsNoOp(*node);
257 const bool is_identity = IsIdentity(*node) || IsIdentityNSingleInput(*node);
258 const bool is_multi_input_identity =
259 IsIdentityN(*node) && !IsIdentityNSingleInput(*node);
260 const string node_name = node->name();
261 // Constant nodes with no input control dependency are always executed early,
262 // so we can prune all their output control dependencies.
263 if (IsConstant(*node) && node->input_size() == 0) {
264 const auto output_nodes = node_map_->GetOutputs(node_name);
265 for (NodeDef* fanout : output_nodes) {
266 bool optimize_fanout = false;
267 bool data_connection = false;
268 for (int i = fanout->input_size() - 1; i >= 0; --i) {
269 const TensorId input_tensor = ParseTensorName(fanout->input(i));
270 if (input_tensor.node() == node_name) {
271 if (input_tensor.index() < 0) {
272 fanout->mutable_input()->SwapElements(i, fanout->input_size() - 1);
273 fanout->mutable_input()->RemoveLast();
274 optimize_fanout = true;
275 } else {
276 data_connection = true;
277 }
278 }
279 }
280 if (optimize_fanout) {
281 nodes_to_simplify->PushBack(node_to_idx_[fanout]);
282 if (!data_connection) {
283 node_map_->RemoveOutput(node_name, fanout->name());
284 }
285 }
286 }
287 if (node_map_->GetOutputs(node_name).empty() && fetch_nodes_known_ &&
288 nodes_to_preserve_.find(node_name) == nodes_to_preserve_.end()) {
289 // Mark the node for deletion.
290 nodes_to_delete->insert(node_to_idx_[node]);
291 }
292 return;
293 }
294
295 // Change ops that only have control dependencies as outputs to NoOps.
296 if (!is_noop && SafeToConvertToNoOp(*node)) {
297 VLOG(2) << "***** Replacing " << node_name << " (" << node->op()
298 << ") with NoOp.";
299 // The outputs of this node are not consumed. Replace its inputs with
300 // control dependencies and replace the op itself with the NoOp op.
301 std::unordered_set<string> ctrl_inputs;
302 int pos = 0;
303 while (pos < node->input_size()) {
304 const string old_input = node->input(pos);
305 if (IsControlInput(old_input)) {
306 if (!ctrl_inputs.insert(old_input).second) {
307 // We found a duplicate control input. Remove it.
308 node->mutable_input()->SwapElements(pos, node->input_size() - 1);
309 node->mutable_input()->RemoveLast();
310 } else {
311 ++pos;
312 }
313 continue;
314 }
315 // Replace a normal input with a control input.
316 const string ctrl_input = ConstantFolding::AddControlDependency(
317 old_input, optimized_graph_, node_map_.get());
318 ctrl_inputs.insert(ctrl_input);
319 node->set_input(pos, ctrl_input);
320 node_map_->UpdateInput(node_name, old_input, ctrl_input);
321 const NodeDef* old_input_node = node_map_->GetNode(old_input);
322 nodes_to_simplify->PushBack(node_to_idx_[old_input_node]);
323 ++pos;
324 }
325 node->set_op("NoOp");
326 EraseRegularNodeAttributes(node);
327 DedupControlInputs(node);
328 nodes_to_simplify->PushBack(node_to_idx_[node]);
329 return;
330 }
331
332 // Remove NoOp nodes if the product of their fan-in and fan-out is less than
333 // or equal to the sum of the fan-in and fan-out. The non-trivial rewrites
334 // take the following form:
335 //
336 // Case a)
337 // x --^> +------+ x --^> +---+
338 // y --^> | NoOp | --^> a ==> y --^> | a |
339 // ... | | ... | |
340 // z --^> +------+ z --^> +---+
341 //
342 // Case b)
343 // +------+ --^> a +---+ --^> a
344 // x --^> | NoOp | --^> b ==> | x | --^> b
345 // | | ... | | ...
346 // +------+ --^> c +---+ --^> c
347 // Case c)
348 // +------+ x ---^> a
349 // x --^> | NoOp | --^> a ==> \/
350 // y --^> | | --^> b /\
351 // +------+ y ---^> b
352 //
353 // We only apply this optimization if we don't increase the number of control
354 // edges across device boundaries, e.g. in cases a) and b) if NoOp and
355 // a and x, respectively, are on the same device. Control edges across device
356 // boundaries require inter-device communication (Send/Recv pairs to be
357 // inserted in the graph), which is very costly.
358 //
359 // We also remove identity nodes, subject to the same constraints on number of
360 // resulting control edges and device boundary crossings:
361 //
362 // Case a)
363 // +----------+ ---> a +---+ ---> a
364 // x --> | Identity | --^> b ==> | x | --^> b
365 // | | ... | | ...
366 // +----------+ --^> c +---+ --^> c
367 //
368 // Case b)
369 // x ---> +----------+ ---> a x ---> +---+
370 // y --^> | Identity | ==> y --^> | a |
371 // ... | | ... | |
372 // z --^> +----------+ z --^> +---+
373 //
374 // Case c)
375 // +----------+ x ---> +---+
376 // x ---> | Identity | ---> a ==> \--^> | a |
377 // y --^> | | --^> b /\ +---+
378 // +----------+ y --^> b
379
380 if (is_noop || ((is_identity || is_multi_input_identity) &&
381 SafeToRemoveIdentity(*node))) {
382 const int num_inputs = node->input_size();
383 std::vector<NodeDef*> input_nodes;
384 for (int i = 0; i < num_inputs; ++i) {
385 NodeDef* input_node = node_map_->GetNode(node->input(i));
386 if (input_node == nullptr) {
387 LOG(ERROR) << "Invalid input " << node->input(i);
388 return;
389 }
390 input_nodes.push_back(input_node);
391 }
392 const auto& output_node_set = node_map_->GetOutputs(node_name);
393 const std::vector<NodeDef*> output_nodes(output_node_set.begin(),
394 output_node_set.end());
395
396 if (!BypassingNodeIsBeneficial(*node, input_nodes, output_nodes)) {
397 return;
398 }
399
400 VLOG(2) << "***** Rerouting input around\n" << node->DebugString();
401 // Now remove the node and re-wire its inputs to its outputs.
402 for (auto consumer : output_nodes) {
403 bool updated_consumer = false;
404 VLOG(2) << "consumer before:\n" << consumer->DebugString();
405 // Remove dependency on node from consumer.
406 for (int i = 0; i < num_inputs; ++i) {
407 const NodeDef* input = input_nodes[i];
408 // Forward dependency from input to consumer if it doesn't already
409 // depend on it.
410 if ((is_identity && i == 0) ||
411 (is_multi_input_identity && !IsControlInput(node->input(i)))) {
412 // Replace regular input from Identity node.
413 string new_input;
414 const string& input_to_forward = node->input(i);
415 CHECK(!IsControlInput(input_to_forward));
416 for (int j = 0; j < consumer->input_size(); ++j) {
417 const TensorId old_input = ParseTensorName(consumer->input(j));
418 if (old_input.node() == node_name) {
419 if (old_input.index() == i) {
420 // Regular input
421 new_input = input_to_forward;
422 node_map_->UpdateInput(consumer->name(),
423 string(old_input.node()), new_input);
424 consumer->set_input(j, new_input);
425 } else if (old_input.index() == -1) {
426 // Control dependency
427 new_input = AsControlDependency(NodeName(input_to_forward));
428 node_map_->UpdateInput(consumer->name(),
429 string(old_input.node()), new_input);
430 consumer->set_input(j, new_input);
431 }
432 }
433 }
434 updated_consumer = true;
435 } else {
436 // Forward dependency from input to consumer if it doesn't already
437 // depend on it.
438 if (node_map_->GetOutputs(input->name()).count(consumer) == 0) {
439 consumer->add_input(AsControlDependency(input->name()));
440 node_map_->AddOutput(input->name(), consumer->name());
441 nodes_to_simplify->PushBack(node_to_idx_[input]);
442 updated_consumer = true;
443 }
444 }
445 }
446 updated_consumer |= RemoveControlInput(
447 consumer, AsControlDependency(node_name), node_map_.get());
448 if (updated_consumer) {
449 nodes_to_simplify->PushBack(node_to_idx_[consumer]);
450 }
451 VLOG(2) << "consumer after:\n" << consumer->DebugString();
452 }
453 node_map_->RemoveOutputs(node_name);
454 if (fetch_nodes_known_ &&
455 nodes_to_preserve_.find(node_name) == nodes_to_preserve_.end()) {
456 // Mark the node for deletion.
457 nodes_to_delete->insert(node_idx);
458
459 // Disconnect the node from its inputs to enable further optimizations.
460 node_map_->RemoveInputs(node_name);
461 node->clear_input();
462 }
463 }
464 }
465
CleanControlInputs()466 void DependencyOptimizer::CleanControlInputs() {
467 for (int i = 0; i < optimized_graph_->node_size(); ++i) {
468 DedupControlInputs(optimized_graph_->mutable_node(i));
469 }
470 }
471
OptimizeDependencies()472 Status DependencyOptimizer::OptimizeDependencies() {
473 SetVector<int> nodes_to_simplify;
474 std::set<int> nodes_to_delete;
475 for (int i = 0; i < optimized_graph_->node_size(); ++i) {
476 const NodeDef& node = optimized_graph_->node(i);
477 if (IsNoOp(node) || IsIdentity(node) || IsIdentityN(node) ||
478 IsConstant(node) || SafeToConvertToNoOp(node)) {
479 nodes_to_simplify.PushBack(i);
480 }
481 }
482 while (!nodes_to_simplify.Empty()) {
483 int node_to_simplify = nodes_to_simplify.PopBack();
484 // Discard nodes that were marked for deletion already.
485 while (nodes_to_delete.find(node_to_simplify) != nodes_to_delete.end()) {
486 node_to_simplify = nodes_to_simplify.PopBack();
487 }
488 OptimizeNode(node_to_simplify, &nodes_to_simplify, &nodes_to_delete);
489 }
490
491 if (fetch_nodes_known_) {
492 VLOG(1) << "Deleted " << nodes_to_delete.size() << " out of "
493 << optimized_graph_->node_size() << " nodes.";
494 EraseNodesFromGraph(nodes_to_delete, optimized_graph_);
495 node_map_.reset(new NodeMap(optimized_graph_));
496 BuildNodeToIdx();
497 }
498 return Status::OK();
499 }
500
501 namespace {
502
503 enum DistanceFromSource : uint8 { ZERO = 0, ONE = 1, TWO_OR_GREATER = 2 };
504
LongestPathsLowerBounds(int source,const std::pair<int,int> & target_range,const std::vector<std::vector<int>> & outputs,std::vector<DistanceFromSource> * longest_distance)505 void LongestPathsLowerBounds(
506 int source, const std::pair<int, int>& target_range,
507 const std::vector<std::vector<int>>& outputs,
508 std::vector<DistanceFromSource>* longest_distance) {
509 std::deque<int> queue;
510 queue.emplace_front(source);
511 while (!queue.empty()) {
512 int node = queue.front();
513 queue.pop_front();
514 for (int fanout : outputs[node]) {
515 // 1) Only nodes in the target range can be on paths from source to one of
516 // its control outputs.
517 // 2) Since we only need a lower bound on the longest distance, we can
518 // skip nodes for which we have already proven have a path of
519 // length > 1 from the source.
520 if (fanout >= target_range.first && fanout <= target_range.second &&
521 (*longest_distance)[fanout] != TWO_OR_GREATER) {
522 (*longest_distance)[fanout] =
523 (*longest_distance)[fanout] == ZERO ? ONE : TWO_OR_GREATER;
524 queue.emplace_front(fanout);
525 }
526 }
527 }
528 }
529
530 } // namespace
531
TransitiveReduction()532 Status DependencyOptimizer::TransitiveReduction() {
533 // PRECONDITION: optimized_graph_ must be sorted topologically.
534 const int num_nodes = optimized_graph_->node_size();
535 // Set up a compressed version of the graph to save a constant factor in the
536 // expensive algorithm below. Also cache the set of control outputs and the
537 // highest index of a target of any control output from each node.
538 int num_controls = 0;
539 std::vector<std::vector<int>> outputs(num_nodes);
540 std::vector<gtl::InlinedVector<std::pair<int, int>, 2>> control_outputs(
541 num_nodes);
542 // target_range[i] contains the range of node indices for which to compute
543 // longest paths starting from node i.
544 std::vector<std::pair<int, int>> target_range(num_nodes, {num_nodes, -1});
545 for (int node_idx = 0; node_idx < num_nodes; ++node_idx) {
546 const NodeDef& node = optimized_graph_->node(node_idx);
547 if (ModifiesFrameInfo(node) || !HasOpDef(node)) {
548 // Ignore function nodes and nodes that modify frame info.
549 continue;
550 }
551 for (int input_slot = 0; input_slot < node.input_size(); ++input_slot) {
552 const string& input = node.input(input_slot);
553 const NodeDef* input_node = node_map_->GetNode(input);
554 if (ModifiesFrameInfo(*input_node) || IsMerge(*input_node)) {
555 // Ignore edges from nodes that modify frame info and from Merge nodes,
556 // because we cannot know which of it's input paths executes.
557 continue;
558 }
559 const int input_node_idx = node_to_idx_[input_node];
560 outputs[input_node_idx].push_back(node_idx);
561 target_range[input_node_idx].first =
562 std::min(target_range[input_node_idx].first, node_idx);
563 if (IsControlInput(input)) {
564 ++num_controls;
565 control_outputs[input_node_idx].emplace_back(node_idx, input_slot);
566 target_range[input_node_idx].second =
567 std::max(target_range[input_node_idx].second, node_idx);
568 }
569 }
570 }
571
572 // Run the longest path in DAG algorithm for each source node that has control
573 // outputs. If, for any target node of a control output, there exists a path
574 // of length > 1, we can drop that control dependency.
575 int num_controls_removed = 0;
576 std::vector<DistanceFromSource> longest_distance(num_nodes);
577 // Map from target_index -> set of (input_slot, source_index), representing
578 // the control edges to remove. We sort them in reverse order by input slot,
579 // such that when we swap them out so we don't clobber the
580 // node(target).input() repeated field.
581 typedef std::pair<int, int> InputSlotAndSource;
582 absl::flat_hash_map<
583 int, std::set<InputSlotAndSource, std::greater<InputSlotAndSource>>>
584 control_edges_to_remove;
585 for (int source = 0; source < num_nodes; ++source) {
586 if (target_range[source].first >= target_range[source].second ||
587 target_range[source].second <= source) {
588 continue;
589 }
590 // Compute the set of nodes in the transitive fanout of source with
591 // topological sort index in [target_range.first : target_range.second]]
592 // to which there exists a path of length 2 or more from source.
593 std::fill(longest_distance.begin() + target_range[source].first,
594 longest_distance.begin() + target_range[source].second + 1, ZERO);
595 LongestPathsLowerBounds(source, target_range[source], outputs,
596 &longest_distance);
597
598 // If the longest path from source to target of a control dependency is
599 // longer than 1, there exists an alternate path, and we can eliminate the
600 // redundant direct control dependency.
601 for (const auto& control_output : control_outputs[source]) {
602 const int target = control_output.first;
603 if (longest_distance[target] == TWO_OR_GREATER) {
604 const int input_slot = control_output.second;
605 control_edges_to_remove[target].emplace(input_slot, source);
606 }
607 }
608 }
609 for (const auto& it : control_edges_to_remove) {
610 const int target = it.first;
611 NodeDef* target_node = optimized_graph_->mutable_node(target);
612 for (const InputSlotAndSource& slot_and_source : it.second) {
613 const int input_slot = slot_and_source.first;
614 const int source = slot_and_source.second;
615 const NodeDef& source_node = optimized_graph_->node(source);
616 CHECK_LT(input_slot, target_node->input_size());
617 target_node->mutable_input()->SwapElements(input_slot,
618 target_node->input_size() - 1);
619 node_map_->RemoveOutput(source_node.name(), target_node->name());
620 target_node->mutable_input()->RemoveLast();
621 ++num_controls_removed;
622 }
623 }
624 VLOG(1) << "Removed " << num_controls_removed << " out of " << num_controls
625 << " control dependencies";
626 return Status::OK();
627 }
628
BuildNodeToIdx()629 void DependencyOptimizer::BuildNodeToIdx() {
630 // Set up &node -> index map.
631 node_to_idx_.clear();
632 for (int i = 0; i < optimized_graph_->node_size(); ++i) {
633 const NodeDef& node = optimized_graph_->node(i);
634 node_to_idx_[&node] = i;
635 }
636 }
637
638 // Suppose there are cross-device control inputs to node C from multiple nodes
639 // that are located on another device, e.g., we have control edges:
640 // A->C, B->C
641 // where A and B are on device X and C is on device Y.
642 // We can reduce cross-device communication by introducing an intermediate
643 // NoOp node C' on device X and rewriting the control edges to:
644 // A->C', B->C', C' -> C
GroupCrossDeviceControlEdges(bool host_granularity)645 void DependencyOptimizer::GroupCrossDeviceControlEdges(bool host_granularity) {
646 VLOG(1)
647 << "DependencyOptimizer::GroupCrossDeviceControlEdges host_granularity="
648 << host_granularity;
649 const int num_nodes = optimized_graph_->node_size();
650 for (int i = 0; i < num_nodes; ++i) {
651 NodeDef* node = optimized_graph_->mutable_node(i);
652 if (node->device().empty()) continue;
653 string rest, node_device = node->device();
654 if (host_granularity) {
655 DeviceNameUtils::SplitDeviceName(node->device(), &node_device, &rest);
656 }
657
658 // Creates new noop nodes for devices on which multiple control inputs are
659 // located.
660
661 // Map keyed by device name to the newly introduced Noop node for that
662 // device. A nullptr value means that we have only seen a single node on
663 // that device.
664 std::map<string, NodeDef*> noops;
665 int num_noops = 0;
666 for (int j = 0; j < node->input_size(); ++j) {
667 if (IsControlInput(node->input(j))) {
668 const NodeDef* input = node_map_->GetNode(node->input(j));
669 if (input == nullptr || input->device().empty()) continue;
670 string input_device = input->device();
671 if (host_granularity) {
672 DeviceNameUtils::SplitDeviceName(input->device(), &input_device,
673 &rest);
674 }
675 if (input_device != node_device) {
676 VLOG(2) << "Cross-device " << node->name() << " " << input->device()
677 << " -> " << node->device();
678 auto emplace_result = noops.emplace(input_device, nullptr);
679 if (!emplace_result.second &&
680 emplace_result.first->second == nullptr) {
681 VLOG(2) << "Duplicate input device from " << node->name();
682 // This is the second cross-device control input from the same
683 // device. Creates an intermediate noop node on that device.
684 string group_name;
685 NodeDef* noop;
686 // Creates a fresh node name; there may be conflicting names from
687 // a previous iteration of the optimizer.
688 do {
689 group_name = AddPrefixToNodeName(
690 node->name(),
691 strings::StrCat("GroupCrossDeviceControlEdges_", num_noops));
692 noop = node_map_->GetNode(group_name);
693 ++num_noops;
694 } while (noop != nullptr);
695 noop = optimized_graph_->add_node();
696 noop->set_name(group_name);
697 noop->set_device(input->device());
698 noop->set_op("NoOp");
699 node_map_->AddNode(noop->name(), noop);
700 emplace_result.first->second = noop;
701 VLOG(1) << "GroupCrossDeviceControlEdges: Added "
702 << SummarizeNodeDef(*noop);
703 }
704 }
705 }
706 }
707
708 // Reroute existing control edges to go via the newly introduced NoOp nodes.
709 int pos = 0;
710 while (pos < node->input_size()) {
711 const string& input_name = node->input(pos);
712 if (IsControlInput(input_name)) {
713 NodeDef* input = node_map_->GetNode(input_name);
714 if (input == nullptr) {
715 ++pos;
716 } else {
717 string input_device = input->device();
718 if (host_granularity) {
719 DeviceNameUtils::SplitDeviceName(input->device(), &input_device,
720 &rest);
721 }
722 auto it = noops.find(input_device);
723 if (it == noops.end() || it->second == nullptr) {
724 ++pos;
725 } else {
726 VLOG(2) << "Rewriting input from " << input_name;
727 node->mutable_input()->SwapElements(pos, node->input_size() - 1);
728 node->mutable_input()->RemoveLast();
729 it->second->add_input(AsControlDependency(*input));
730 node_map_->UpdateOutput(input_name, node->name(),
731 it->second->name());
732 }
733 }
734 } else {
735 ++pos;
736 }
737 }
738 for (const auto& entry : noops) {
739 if (entry.second) {
740 node->add_input(AsControlDependency(*entry.second));
741 node_map_->AddOutput(entry.second->name(), node->name());
742 }
743 }
744 }
745 }
746
Optimize(Cluster * cluster,const GrapplerItem & item,GraphDef * optimized_graph)747 Status DependencyOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
748 GraphDef* optimized_graph) {
749 optimized_graph_ = optimized_graph;
750 *optimized_graph_ = item.graph;
751 nodes_to_preserve_ = item.NodesToPreserve();
752 fetch_nodes_known_ = !item.fetch.empty();
753 CleanControlInputs();
754
755 const int num_iterations = 2;
756 for (int iteration = 0; iteration < num_iterations; ++iteration) {
757 GRAPPLER_RETURN_IF_DEADLINE_EXCEEDED();
758 Status topo_sort_status;
759 // Perform topological sort to prepare the graph for transitive reduction.
760 topo_sort_status = TopologicalSort(optimized_graph_);
761 // Set up index-based graph datastructures to speed up analysis steps below.
762 node_map_.reset(new NodeMap(optimized_graph_));
763 BuildNodeToIdx();
764
765 if (topo_sort_status.ok()) {
766 // Remove redundant control dependencies.
767 TF_RETURN_IF_ERROR(TransitiveReduction());
768 } else {
769 LOG(ERROR) << "Iteration = " << iteration
770 << ", topological sort failed with message: "
771 << topo_sort_status.error_message();
772 }
773 // Turn nodes with only control outputs into NoOps, prune NoOp and Identity
774 // nodes.
775 TF_RETURN_IF_ERROR(OptimizeDependencies());
776
777 // Dedup control inputs.
778 CleanControlInputs();
779
780 // Merge multiple control edges from the same device.
781 GroupCrossDeviceControlEdges(/*host_granularity=*/false);
782
783 // Merge control edges from the same host to reduce RPC traffic.
784 GroupCrossDeviceControlEdges(/*host_granularity=*/true);
785 }
786
787 return Status::OK();
788 }
789
790 } // end namespace grappler
791 } // end namespace tensorflow
792