• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/optimizers/dependency_optimizer.h"
17 
18 #include <unordered_map>
19 #include <unordered_set>
20 
21 #include "tensorflow/core/framework/node_def.pb.h"
22 #include "tensorflow/core/framework/op.h"
23 #include "tensorflow/core/grappler/costs/graph_properties.h"
24 #include "tensorflow/core/grappler/grappler_item.h"
25 #include "tensorflow/core/grappler/op_types.h"
26 #include "tensorflow/core/grappler/optimizers/constant_folding.h"
27 #include "tensorflow/core/grappler/utils.h"
28 #include "tensorflow/core/grappler/utils/topological_sort.h"
29 #include "tensorflow/core/lib/core/errors.h"
30 #include "tensorflow/core/lib/core/stringpiece.h"
31 #include "tensorflow/core/lib/gtl/inlined_vector.h"
32 #include "tensorflow/core/lib/strings/str_util.h"
33 #include "tensorflow/core/lib/strings/strcat.h"
34 #include "tensorflow/core/util/device_name_utils.h"
35 
36 namespace tensorflow {
37 namespace grappler {
38 
39 namespace {
40 
RemoveInput(NodeDef * node,const string & input,NodeMap * node_map)41 bool RemoveInput(NodeDef* node, const string& input, NodeMap* node_map) {
42   bool removed_input = false;
43   int pos = 0;
44   while (pos < node->input_size()) {
45     if (node->input(pos) == input) {
46       node->mutable_input()->SwapElements(pos, node->input_size() - 1);
47       node->mutable_input()->RemoveLast();
48       node_map->RemoveOutput(NodeName(input), node->name());
49       removed_input = true;
50     } else {
51       ++pos;
52     }
53   }
54   return removed_input;
55 }
56 
57 }  // namespace
58 
SafeToRemoveIdentity(const NodeDef & node) const59 bool DependencyOptimizer::SafeToRemoveIdentity(const NodeDef& node) const {
60   if (!IsIdentity(node) && !IsIdentityN(node)) {
61     return true;
62   }
63 
64   if (nodes_to_preserve_.find(node.name()) != nodes_to_preserve_.end()) {
65     return false;
66   }
67   if (!fetch_nodes_known_) {
68     // The output values of this node may be needed.
69     return false;
70   }
71   const NodeDef* input = node_map_->GetNode(NodeName(node.input(0)));
72   CHECK(input != nullptr) << "node = " << node.name()
73                           << " input = " << node.input(0);
74   // Don't remove Identity nodes corresponding to Variable reads or following
75   // Recv.
76   if (IsVariable(*input) || IsRecv(*input)) {
77     return false;
78   }
79   for (const auto& consumer : node_map_->GetOutputs(node.name())) {
80     if (node.input_size() > 1 && IsMerge(*consumer)) {
81       return false;
82     }
83     if (IsSwitch(*input)) {
84       for (const string& consumer_input : consumer->input()) {
85         if (consumer_input == AsControlDependency(node.name())) {
86           return false;
87         }
88       }
89     }
90   }
91   return true;
92 }
93 
SafeToConvertToNoOp(const NodeDef & node) const94 bool DependencyOptimizer::SafeToConvertToNoOp(const NodeDef& node) const {
95   if (!fetch_nodes_known_ ||
96       nodes_to_preserve_.find(node.name()) != nodes_to_preserve_.end()) {
97     return false;
98   }
99   if (IsMerge(node) || IsSwitch(node) || ModifiesFrameInfo(node) ||
100       !IsFreeOfSideEffect(node)) {
101     return false;
102   }
103   if (node.op().rfind("Submodel", 0) == 0) {
104     return false;
105   }
106   const OpDef* op_def = nullptr;
107   Status status = OpRegistry::Global()->LookUpOpDef(node.op(), &op_def);
108   if (!status.ok() || op_def->output_arg_size() == 0) {
109     return false;
110   }
111   const std::unordered_set<string> do_not_rewrite_ops{
112       "Assert",     "CheckNumerics",         "_Retval",
113       "_Arg",       "_ParallelConcatUpdate", "TPUExecute",
114       "TPUCompile", "ControlTrigger"};
115   if (do_not_rewrite_ops.find(node.op()) != do_not_rewrite_ops.end()) {
116     return false;
117   }
118   if (!SafeToRemoveIdentity(node)) {
119     return false;
120   }
121   if (NumNonControlOutputs(node, *node_map_) > 0) {
122     // The output values of this node may be needed.
123     return false;
124   }
125   return true;
126 }
127 
NumEdgesIfBypassed(const NodeDef & node,const std::vector<NodeDef * > & output_nodes) const128 int DependencyOptimizer::NumEdgesIfBypassed(
129     const NodeDef& node, const std::vector<NodeDef*>& output_nodes) const {
130   const bool is_multi_input_identity_n =
131       IsIdentityN(node) && !IsIdentityNSingleInput(node);
132   const int num_outputs = output_nodes.size();
133   const int num_inputs = node.input_size();
134 
135   if (is_multi_input_identity_n) {
136     // multi-input identity_n with input/output control dependencies will likely
137     // increase number of edges after optimization.
138     int num_edges_if_bypassed(0);
139     for (string input_node_name : node.input()) {
140       if (IsControlInput(input_node_name)) {
141         num_edges_if_bypassed += num_outputs;
142       } else {
143         ++num_edges_if_bypassed;
144       }
145     }
146 
147     for (auto consumer : output_nodes) {
148       for (int j = 0; j < consumer->input_size(); ++j) {
149         const TensorId consumer_input = ParseTensorName(consumer->input(j));
150         if (consumer_input.node() == node.name()) {
151           if (IsControlInput(consumer_input)) {
152             num_edges_if_bypassed += num_inputs;
153           } else {
154             ++num_edges_if_bypassed;
155           }
156         }
157       }
158     }
159     return num_edges_if_bypassed;
160   } else {
161     return num_inputs * num_outputs;
162   }
163 }
164 
BypassingNodeIsBeneficial(const NodeDef & node,const std::vector<NodeDef * > & input_nodes,const std::vector<NodeDef * > & output_nodes) const165 bool DependencyOptimizer::BypassingNodeIsBeneficial(
166     const NodeDef& node, const std::vector<NodeDef*>& input_nodes,
167     const std::vector<NodeDef*>& output_nodes) const {
168   const bool is_identity = IsIdentity(node) || IsIdentityNSingleInput(node);
169   const bool is_multi_input_identity_n =
170       IsIdentityN(node) && !IsIdentityNSingleInput(node);
171   const int num_outputs = output_nodes.size();
172   const int num_inputs = node.input_size();
173 
174   if (NumEdgesIfBypassed(node, output_nodes) > num_inputs + num_outputs) {
175     return false;
176   }
177 
178   // Make sure that we don't increase the number of edges that cross
179   // device boundaries.
180   if ((num_inputs == 1 && num_outputs > 1 &&
181        input_nodes[0]->device() != node.device()) ||
182       (num_inputs > 1 && num_outputs == 1 &&
183        output_nodes[0]->device() != node.device())) {
184     return false;
185   }
186 
187   // TODO(rmlarsen): Not all device crossings are equally expensive.
188   // Assign a cost to each based on device affinity and compute a
189   // cost before and after.
190   const string& node_dev = node.device();
191   int num_cross_in = 0;
192   for (NodeDef* input_node : input_nodes) {
193     num_cross_in += static_cast<int>(input_node->device() != node_dev);
194   }
195   int num_cross_out = 0;
196   for (NodeDef* output_node : output_nodes) {
197     num_cross_out += static_cast<int>(output_node->device() != node_dev);
198   }
199 
200   // Make sure we do not increase the number of device crossings.
201   const int num_cross_before = num_cross_in + num_cross_out;
202   int num_cross_after = 0;
203   for (NodeDef* input_node : input_nodes) {
204     for (NodeDef* output_node : output_nodes) {
205       num_cross_after +=
206           static_cast<int>(input_node->device() != output_node->device());
207     }
208   }
209   if (num_cross_after > num_cross_before) {
210     return false;
211   }
212 
213   if ((is_identity || is_multi_input_identity_n) && num_cross_in > 0 &&
214       num_cross_out > 0 && num_cross_after > 0) {
215     // This identity node follows a device crossing, so it might be
216     // following a _Recv node after partioning. Do not remove such nodes,
217     // unless they only have consumers on the same device as themselves.
218     return false;
219   }
220 
221   return true;
222 }
223 
OptimizeNode(int node_idx,SetVector<int> * nodes_to_simplify,std::set<int> * nodes_to_delete)224 void DependencyOptimizer::OptimizeNode(int node_idx,
225                                        SetVector<int>* nodes_to_simplify,
226                                        std::set<int>* nodes_to_delete) {
227   NodeDef* node = optimized_graph_->mutable_node(node_idx);
228   const bool is_noop = IsNoOp(*node);
229   const bool is_identity = IsIdentity(*node) || IsIdentityNSingleInput(*node);
230   const bool is_multi_input_identity =
231       IsIdentityN(*node) && !IsIdentityNSingleInput(*node);
232   const string node_name = node->name();
233   // Constant nodes with no input control dependency are always executed early,
234   // so we can prune all their output control dependencies.
235   if (IsConstant(*node) && node->input_size() == 0) {
236     const std::set<NodeDef*> output_nodes = node_map_->GetOutputs(node_name);
237     for (NodeDef* fanout : output_nodes) {
238       bool optimize_fanout = false;
239       bool data_connection = false;
240       for (int i = fanout->input_size() - 1; i >= 0; --i) {
241         const TensorId input_tensor = ParseTensorName(fanout->input(i));
242         if (input_tensor.node() == node_name) {
243           if (input_tensor.index() < 0) {
244             fanout->mutable_input()->SwapElements(i, fanout->input_size() - 1);
245             fanout->mutable_input()->RemoveLast();
246             optimize_fanout = true;
247           } else {
248             data_connection = true;
249           }
250         }
251       }
252       if (optimize_fanout) {
253         nodes_to_simplify->PushBack(node_to_idx_[fanout]);
254         if (!data_connection) {
255           node_map_->RemoveOutput(node_name, fanout->name());
256         }
257       }
258     }
259     if (node_map_->GetOutputs(node_name).empty() && fetch_nodes_known_ &&
260         nodes_to_preserve_.find(node_name) == nodes_to_preserve_.end()) {
261       // Mark the node for deletion.
262       nodes_to_delete->insert(node_to_idx_[node]);
263     }
264     return;
265   }
266 
267   // Change ops that only have control dependencies as outputs to NoOps.
268   if (!is_noop && SafeToConvertToNoOp(*node)) {
269     VLOG(1) << "***** Replacing  " << node_name << " (" << node->op()
270             << ") with NoOp.";
271     // The outputs of this node are not consumed. Replace its inputs with
272     // control dependencies and replace the op itself with the NoOp op.
273     std::unordered_set<string> ctrl_inputs;
274     int pos = 0;
275     while (pos < node->input_size()) {
276       const string old_input = node->input(pos);
277       if (IsControlInput(old_input)) {
278         if (!ctrl_inputs.insert(old_input).second) {
279           // We found a duplicate control input. Remove it.
280           node->mutable_input()->SwapElements(pos, node->input_size() - 1);
281           node->mutable_input()->RemoveLast();
282         } else {
283           ++pos;
284         }
285         continue;
286       }
287       // Replace a normal input with a control input.
288       const string ctrl_input = ConstantFolding::AddControlDependency(
289           old_input, optimized_graph_, node_map_.get());
290       ctrl_inputs.insert(ctrl_input);
291       node->set_input(pos, ctrl_input);
292       node_map_->UpdateInput(node_name, old_input, ctrl_input);
293       const NodeDef* old_input_node = node_map_->GetNode(old_input);
294       nodes_to_simplify->PushBack(node_to_idx_[old_input_node]);
295       ++pos;
296     }
297     node->set_op("NoOp");
298     node->clear_attr();
299     nodes_to_simplify->PushBack(node_to_idx_[node]);
300     return;
301   }
302 
303   // Remove NoOp nodes if the product of their fan-in and fan-out is less than
304   // or equal to the sum of the fan-in and fan-out. The non-trivial rewrites
305   // take the following form:
306   //
307   // Case a)
308   //    x --^> +------+                x --^> +---+
309   //    y --^> | NoOp | --^> a   ==>   y --^> | a |
310   //    ...    |      |                  ...  |   |
311   //    z --^> +------+                z --^> +---+
312   //
313   // Case b)
314   //           +------+ --^> a         +---+ --^> a
315   //    x --^> | NoOp | --^> b  ==>    | x | --^> b
316   //           |      | ...            |   | ...
317   //           +------+ --^> c         +---+ --^> c
318   // Case c)
319   //           +------+                x ---^> a
320   //    x --^> | NoOp | --^> a  ==>      \/
321   //    y --^> |      | --^> b           /\
322   //           +------+                y ---^> b
323   //
324   // We only apply this optimization if we don't increase the number of control
325   // edges across device boundaries, e.g. in cases a) and b) if NoOp and
326   // a and x, respectively, are on the same device. Control edges across device
327   // boundaries require inter-device communication (Send/Recv pairs to be
328   // inserted in the graph), which is very costly.
329   //
330   // We also remove identity nodes, subject to the same constraints on number of
331   // resulting control edges and device boundary crossings:
332   //
333   // Case a)
334   //          +----------+ ---> a       +---+ ---> a
335   //    x --> | Identity | --^> b  ==>  | x | --^> b
336   //          |          | ...          |   | ...
337   //          +----------+ --^> c       +---+ --^> c
338   //
339   // Case b)
340   //    x ---> +----------+ ---> a      x ---> +---+
341   //    y --^> | Identity |        ==>  y --^> | a |
342   //    ...    |          |               ...  |   |
343   //    z --^> +----------+             z --^> +---+
344   //
345   // Case c)
346   //           +----------+             x ---> +---+
347   //    x ---> | Identity | ---> a ==>   \--^> | a |
348   //    y --^> |          | --^> b       /\    +---+
349   //           +----------+             y --^> b
350 
351   if (is_noop || ((is_identity || is_multi_input_identity) &&
352                   SafeToRemoveIdentity(*node))) {
353     const auto& output_node_set = node_map_->GetOutputs(node_name);
354     const std::vector<NodeDef*> output_nodes(output_node_set.begin(),
355                                              output_node_set.end());
356     const int num_inputs = node->input_size();
357     std::vector<NodeDef*> input_nodes;
358     for (int i = 0; i < num_inputs; ++i) {
359       NodeDef* input_node = node_map_->GetNode(node->input(i));
360       if (input_node == nullptr) {
361         LOG(ERROR) << "Invalid input " << node->input(i);
362         return;
363       }
364       input_nodes.push_back(input_node);
365     }
366 
367     if (!BypassingNodeIsBeneficial(*node, input_nodes, output_nodes)) {
368       return;
369     }
370 
371     VLOG(1) << "***** Rerouting input around\n" << node->DebugString();
372     // Now remove the node and re-wire its inputs to its outputs.
373     for (auto consumer : output_nodes) {
374       bool updated_consumer = false;
375       VLOG(1) << "consumer before:\n" << consumer->DebugString();
376       for (int i = 0; i < num_inputs; ++i) {
377         const NodeDef* input = input_nodes[i];
378         // Forward dependency from input to consumer if it doesn't already
379         // depend on it.
380         if ((is_identity && i == 0) ||
381             (is_multi_input_identity && !IsControlInput(node->input(i)))) {
382           // Replace regular input from Identity node.
383           string new_input;
384           const string& input_to_forward = node->input(i);
385           CHECK(!IsControlInput(input_to_forward));
386           for (int j = 0; j < consumer->input_size(); ++j) {
387             const TensorId old_input = ParseTensorName(consumer->input(j));
388             if (old_input.node() == node_name) {
389               if (old_input.index() == i) {
390                 // Regular input
391                 new_input = input_to_forward;
392                 node_map_->UpdateInput(consumer->name(), old_input.ToString(),
393                                        new_input);
394                 consumer->set_input(j, new_input);
395               } else if (old_input.index() == -1) {
396                 // Control dependency
397                 new_input = AsControlDependency(NodeName(input_to_forward));
398                 node_map_->UpdateInput(consumer->name(), old_input.ToString(),
399                                        new_input);
400                 consumer->set_input(j, new_input);
401               }
402             }
403           }
404           updated_consumer = true;
405         } else {
406           // Forward dependency from input to consumer if it doesn't already
407           // depend on it.
408           if (node_map_->GetOutputs(input->name()).count(consumer) == 0) {
409             consumer->add_input(AsControlDependency(input->name()));
410             node_map_->AddOutput(input->name(), consumer->name());
411             nodes_to_simplify->PushBack(node_to_idx_[input]);
412             updated_consumer = true;
413           }
414         }
415       }
416       // Remove dependency on node from consumer.
417       updated_consumer |= RemoveInput(consumer, AsControlDependency(node_name),
418                                       node_map_.get());
419       if (updated_consumer) {
420         nodes_to_simplify->PushBack(node_to_idx_[consumer]);
421       }
422       VLOG(1) << "consumer after:\n" << consumer->DebugString();
423     }
424     node_map_->RemoveOutputs(node_name);
425     if (fetch_nodes_known_ &&
426         nodes_to_preserve_.find(node_name) == nodes_to_preserve_.end()) {
427       // Mark the node for deletion.
428       nodes_to_delete->insert(node_idx);
429 
430       // Disconnect the node from its inputs to enable further optimizations.
431       node_map_->RemoveInputs(node_name);
432       node->clear_input();
433     }
434   }
435 }
436 
CleanControlInputs()437 void DependencyOptimizer::CleanControlInputs() {
438   for (int i = 0; i < optimized_graph_->node_size(); ++i) {
439     DedupControlInputs(optimized_graph_->mutable_node(i));
440   }
441 }
442 
OptimizeDependencies()443 Status DependencyOptimizer::OptimizeDependencies() {
444   SetVector<int> nodes_to_simplify;
445   std::set<int> nodes_to_delete;
446   for (int i = 0; i < optimized_graph_->node_size(); ++i) {
447     const NodeDef& node = optimized_graph_->node(i);
448     if (IsNoOp(node) || IsIdentity(node) || IsIdentityN(node) ||
449         IsConstant(node) || SafeToConvertToNoOp(node)) {
450       nodes_to_simplify.PushBack(i);
451     }
452   }
453   while (!nodes_to_simplify.Empty()) {
454     int node_to_simplify = nodes_to_simplify.PopBack();
455     // Discard nodes that were marked for deletion already.
456     while (nodes_to_delete.find(node_to_simplify) != nodes_to_delete.end()) {
457       node_to_simplify = nodes_to_simplify.PopBack();
458     }
459     OptimizeNode(node_to_simplify, &nodes_to_simplify, &nodes_to_delete);
460   }
461 
462   if (fetch_nodes_known_) {
463     VLOG(1) << "Deleted " << nodes_to_delete.size() << " out of "
464             << optimized_graph_->node_size() << " nodes.";
465     EraseNodesFromGraph(nodes_to_delete, optimized_graph_);
466     node_map_.reset(new NodeMap(optimized_graph_));
467     BuildNodeToIdx();
468   }
469   return Status::OK();
470 }
471 
TransitiveReduction()472 Status DependencyOptimizer::TransitiveReduction() {
473   // PRECONDITION: optimized_graph_ must be sorted topologically.
474   const int num_nodes = optimized_graph_->node_size();
475   // Set up a compressed version of the graph to save a constant factor in the
476   // expensive algorithm below. Also cache the set of control outputs and the
477   // highest index of a target of any control output from each node.
478   int num_controls = 0;
479   std::vector<gtl::InlinedVector<int, 4>> inputs(num_nodes);
480   std::vector<gtl::InlinedVector<std::pair<int, int>, 2>> control_outputs(
481       num_nodes);
482   for (int node_idx = 0; node_idx < num_nodes; ++node_idx) {
483     const NodeDef& node = optimized_graph_->node(node_idx);
484     if (ModifiesFrameInfo(node) || !HasOpDef(node)) {
485       // Ignore function nodes and nodes that modify frame info.
486       continue;
487     }
488     for (int input_slot = 0; input_slot < node.input_size(); ++input_slot) {
489       const string& input = node.input(input_slot);
490       const NodeDef* input_node = node_map_->GetNode(input);
491       if (ModifiesFrameInfo(*input_node) || IsMerge(*input_node)) {
492         // Ignore edges from nodes that modify frame info and from Merge nodes,
493         // because we cannot know which of it's input paths executes.
494         continue;
495       }
496       const int input_node_idx = node_to_idx_[input_node];
497       inputs[node_idx].push_back(input_node_idx);
498       if (IsControlInput(input)) {
499         ++num_controls;
500         control_outputs[input_node_idx].emplace_back(node_idx, input_slot);
501       }
502     }
503   }
504 
505   // Run the longest path in DAG algorithm for each source node that has control
506   // outputs. If, for any target node of a control output, there exists a path
507   // of length > 1, we can drop that control dependency.
508   int num_controls_removed = 0;
509   std::vector<int> longest_distance(num_nodes);
510   // Map from target_index -> set of (input_slot, source_index), representing
511   // the control edges to remove. We sort them in reverse order by input slot,
512   // such that when we swap them out so we don't clobber the
513   // node(target).input() repeated field.
514   typedef std::pair<int, int> InputSlotAndSource;
515   std::unordered_map<
516       int, std::set<InputSlotAndSource, std::greater<InputSlotAndSource>>>
517       control_edges_to_remove;
518   for (int source = 0; source < num_nodes; ++source) {
519     int highest_control_target = -1;
520     for (const auto& control_output : control_outputs[source]) {
521       if (control_output.first > highest_control_target) {
522         highest_control_target = control_output.first;
523       }
524     }
525     if (highest_control_target <= source) {
526       continue;
527     }
528     std::fill(longest_distance.begin() + source,
529               longest_distance.begin() + highest_control_target + 1, 0);
530     for (int target = source + 1; target <= highest_control_target; ++target) {
531       for (int input : inputs[target]) {
532         // If the input node is before source in the topo order, no path
533         // source -> input -> target can exits and we can skip it.
534         // Also only extend a path from the source itself or from nodes that
535         // have a path from source, indicated by longest_distance[input] > 0.
536         if (input == source ||
537             (input > source && longest_distance[input] > 0)) {
538           // If source -> input -> target is longer than the longest
539           // path so far from source -> target, update the longest_distance.
540           int candidate_longest_distance = longest_distance[input] + 1;
541           if (candidate_longest_distance > longest_distance[target]) {
542             longest_distance[target] = candidate_longest_distance;
543           }
544         }
545       }
546     }
547 
548     // If the longest path from source to target of a control dependency is
549     // longer than 1, there exists an alternate path, and we can eliminate the
550     // redundant direct control dependency.
551     for (const auto& control_output : control_outputs[source]) {
552       const int target = control_output.first;
553       if (longest_distance[target] > 1) {
554         const int input_slot = control_output.second;
555         control_edges_to_remove[target].emplace(input_slot, source);
556       }
557     }
558   }
559 
560   for (const auto& it : control_edges_to_remove) {
561     const int target = it.first;
562     NodeDef* target_node = optimized_graph_->mutable_node(target);
563     for (const InputSlotAndSource& slot_and_source : it.second) {
564       const int input_slot = slot_and_source.first;
565       const int source = slot_and_source.second;
566       const NodeDef& source_node = optimized_graph_->node(source);
567       CHECK_LT(input_slot, target_node->input_size());
568       target_node->mutable_input()->SwapElements(input_slot,
569                                                  target_node->input_size() - 1);
570       node_map_->RemoveOutput(source_node.name(), target_node->name());
571       target_node->mutable_input()->RemoveLast();
572       ++num_controls_removed;
573     }
574   }
575   VLOG(1) << "Removed " << num_controls_removed << " out of " << num_controls
576           << " control dependencies";
577   return Status::OK();
578 }
579 
BuildNodeToIdx()580 void DependencyOptimizer::BuildNodeToIdx() {
581   // Set up &node -> index map.
582   node_to_idx_.clear();
583   for (int i = 0; i < optimized_graph_->node_size(); ++i) {
584     const NodeDef& node = optimized_graph_->node(i);
585     node_to_idx_[&node] = i;
586   }
587 }
588 
589 // Suppose there are cross-device control inputs to node C from multiple nodes
590 // that are located on another device, e.g., we have control edges:
591 // A->C, B->C
592 // where A and B are on device X and C is on device Y.
593 // We can reduce cross-device communication by introducing an intermediate
594 // NoOp node C' on device X and rewriting the control edges to:
595 // A->C', B->C', C' -> C
GroupCrossDeviceControlEdges()596 void DependencyOptimizer::GroupCrossDeviceControlEdges() {
597   const int num_nodes = optimized_graph_->node_size();
598   for (int i = 0; i < num_nodes; ++i) {
599     NodeDef* node = optimized_graph_->mutable_node(i);
600     if (node->device().empty()) continue;
601 
602     // Creates new noop nodes for devices on which multiple control inputs are
603     // located.
604 
605     // Map keyed by device name to the newly introduced Noop node for that
606     // device. A nullptr value means that we have only seen a single node on
607     // that device.
608     std::map<string, NodeDef*> noops;
609     int num_noops = 0;
610     for (int j = 0; j < node->input_size(); ++j) {
611       if (IsControlInput(node->input(j))) {
612         const NodeDef* input = node_map_->GetNode(node->input(j));
613         if (input != nullptr && !input->device().empty() &&
614             input->device() != node->device()) {
615           auto emplace_result = noops.emplace(input->device(), nullptr);
616           if (!emplace_result.second &&
617               emplace_result.first->second == nullptr) {
618             // This is the second cross-device control input from the same
619             // device. Creates an intermediate noop node on that device.
620             string group_name;
621             NodeDef* noop;
622             // Creates a fresh node name; there may be conflicting names from
623             // a previous iteration of the optimizer.
624             do {
625               group_name = AddPrefixToNodeName(
626                   node->name(),
627                   strings::StrCat("GroupCrossDeviceControlEdges_", num_noops));
628               noop = node_map_->GetNode(group_name);
629               ++num_noops;
630             } while (noop != nullptr);
631             noop = optimized_graph_->add_node();
632             noop->set_name(group_name);
633             noop->set_device(input->device());
634             noop->set_op("NoOp");
635             node_map_->AddNode(noop->name(), noop);
636             emplace_result.first->second = noop;
637           }
638         }
639       }
640     }
641 
642     // Reroute existing control edges to go via the newly introduced NoOp nodes.
643     int pos = 0;
644     while (pos < node->input_size()) {
645       const string& input_name = node->input(pos);
646       if (IsControlInput(input_name)) {
647         NodeDef* input = node_map_->GetNode(input_name);
648         if (input == nullptr) {
649           ++pos;
650         } else {
651           auto it = noops.find(input->device());
652           if (it == noops.end() || it->second == nullptr) {
653             ++pos;
654           } else {
655             node->mutable_input()->SwapElements(pos, node->input_size() - 1);
656             node->mutable_input()->RemoveLast();
657             it->second->add_input(AsControlDependency(*input));
658             node_map_->UpdateOutput(input_name, node->name(),
659                                     it->second->name());
660           }
661         }
662       } else {
663         ++pos;
664       }
665     }
666     for (const auto& entry : noops) {
667       if (entry.second) {
668         node->add_input(AsControlDependency(*entry.second));
669         node_map_->AddOutput(entry.second->name(), node->name());
670       }
671     }
672   }
673 }
674 
Optimize(Cluster * cluster,const GrapplerItem & item,GraphDef * optimized_graph)675 Status DependencyOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
676                                      GraphDef* optimized_graph) {
677   optimized_graph_ = optimized_graph;
678   *optimized_graph_ = item.graph;
679   nodes_to_preserve_ = item.NodesToPreserve();
680   fetch_nodes_known_ = !item.fetch.empty();
681   CleanControlInputs();
682 
683   const int num_iterations = 2;
684   for (int iteration = 0; iteration < num_iterations; ++iteration) {
685     GRAPPLER_RETURN_IF_DEADLINE_EXCEEDED();
686     Status topo_sort_status;
687     // Perform topological sort to prepare the graph for transitive reduction.
688     topo_sort_status = TopologicalSort(optimized_graph_);
689     // Set up index-based graph datastructures to speed up analysis steps below.
690     node_map_.reset(new NodeMap(optimized_graph_));
691     BuildNodeToIdx();
692 
693     if (topo_sort_status.ok()) {
694       // Remove redundant control dependencies.
695       TF_RETURN_IF_ERROR(TransitiveReduction());
696     } else {
697       LOG(ERROR) << "Iteration = " << iteration
698                  << ", topological sort failed with message: "
699                  << topo_sort_status.error_message();
700     }
701     // Turn nodes with only control outputs into NoOps, prune NoOp and Identity
702     // nodes.
703     TF_RETURN_IF_ERROR(OptimizeDependencies());
704 
705     // Dedup control inputs.
706     CleanControlInputs();
707 
708     GroupCrossDeviceControlEdges();
709   }
710 
711   return Status::OK();
712 }
713 
Feedback(Cluster *,const GrapplerItem &,const GraphDef &,double)714 void DependencyOptimizer::Feedback(Cluster* /*cluster*/,
715                                    const GrapplerItem& /*item*/,
716                                    const GraphDef& /*optimized_graph*/,
717                                    double /*result*/) {
718   // Nothing to do for DependencyOptimizer.
719 }
720 
721 }  // end namespace grappler
722 }  // end namespace tensorflow
723