• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/optimizers/dependency_optimizer.h"
17 
18 #include <unordered_set>
19 
20 #include "absl/container/flat_hash_map.h"
21 #include "tensorflow/core/framework/node_def.pb.h"
22 #include "tensorflow/core/framework/node_def_util.h"
23 #include "tensorflow/core/framework/op.h"
24 #include "tensorflow/core/grappler/costs/graph_properties.h"
25 #include "tensorflow/core/grappler/grappler_item.h"
26 #include "tensorflow/core/grappler/op_types.h"
27 #include "tensorflow/core/grappler/optimizers/constant_folding.h"
28 #include "tensorflow/core/grappler/utils.h"
29 #include "tensorflow/core/grappler/utils/topological_sort.h"
30 #include "tensorflow/core/lib/core/errors.h"
31 #include "tensorflow/core/lib/core/stringpiece.h"
32 #include "tensorflow/core/lib/gtl/inlined_vector.h"
33 #include "tensorflow/core/lib/strings/str_util.h"
34 #include "tensorflow/core/lib/strings/strcat.h"
35 #include "tensorflow/core/util/device_name_utils.h"
36 
37 namespace tensorflow {
38 namespace grappler {
39 
40 namespace {
41 
RemoveControlInput(NodeDef * node,const string & control_input_to_remove,NodeMap * node_map)42 bool RemoveControlInput(NodeDef* node, const string& control_input_to_remove,
43                         NodeMap* node_map) {
44   for (int pos = node->input_size() - 1; pos >= 0; --pos) {
45     const string& input = node->input(pos);
46     if (input[0] != '^') break;
47     if (input == control_input_to_remove) {
48       node->mutable_input()->SwapElements(pos, node->input_size() - 1);
49       node->mutable_input()->RemoveLast();
50       node_map->RemoveOutput(NodeName(input), node->name());
51       return true;
52     }
53   }
54   return false;
55 }
56 
57 }  // namespace
58 
SafeToRemoveIdentity(const NodeDef & node) const59 bool DependencyOptimizer::SafeToRemoveIdentity(const NodeDef& node) const {
60   if (!IsIdentity(node) && !IsIdentityN(node)) {
61     return true;
62   }
63 
64   if (nodes_to_preserve_.find(node.name()) != nodes_to_preserve_.end()) {
65     return false;
66   }
67   if (!fetch_nodes_known_) {
68     // The output values of this node may be needed.
69     return false;
70   }
71 
72   if (node.input_size() < 1) {
73     // Node lacks input, is invalid
74     return false;
75   }
76 
77   const NodeDef* input = node_map_->GetNode(NodeName(node.input(0)));
78   CHECK(input != nullptr) << "node = " << node.name()
79                           << " input = " << node.input(0);
80   // Don't remove Identity nodes corresponding to Variable reads or following
81   // Recv.
82   if (IsVariable(*input) || IsRecv(*input)) {
83     return false;
84   }
85   for (const auto& consumer : node_map_->GetOutputs(node.name())) {
86     if (node.input_size() > 1 && (IsRetval(*consumer) || IsMerge(*consumer))) {
87       return false;
88     }
89     if (IsSwitch(*input)) {
90       for (const string& consumer_input : consumer->input()) {
91         if (consumer_input == AsControlDependency(node.name())) {
92           return false;
93         }
94       }
95     }
96   }
97   return true;
98 }
99 
SafeToConvertToNoOp(const NodeDef & node) const100 bool DependencyOptimizer::SafeToConvertToNoOp(const NodeDef& node) const {
101   if (HasRegularOutputs(node, *node_map_)) {
102     // The output values of this node may be needed.
103     VLOG(3) << "Not safe to convert '" << node.name()
104             << " to NoOp. Node has outputs.";
105     return false;
106   }
107   if (!fetch_nodes_known_) {
108     VLOG(3) << "Not safe to convert '" << node.name()
109             << " to NoOp. Fetches unknown.";
110     return false;
111   }
112   if (nodes_to_preserve_.find(node.name()) != nodes_to_preserve_.end()) {
113     VLOG(3) << "Not safe to convert to NoOp: " << node.name()
114             << " is in preserve set.";
115     return false;
116   }
117   if (IsMerge(node) || IsSwitch(node) || ModifiesFrameInfo(node)) {
118     VLOG(3) << "Not safe to convert '" << node.name()
119             << " to NoOp. Node modifies frame info.";
120     return false;
121   }
122   // Ops reading variables are marked as stateful, but are safe to remove if
123   // redundant.
124   static const absl::flat_hash_set<string>* gather_ops =
125       new absl::flat_hash_set<string>{"Gather", "GatherV2", "GatherNd",
126                                       "ResourceGather", "ResourceGatherNd"};
127   const bool is_variable_read =
128       IsReadVariableOp(node) || IsReadVariablesOp(node) ||
129       gather_ops->find(node.op()) != gather_ops->end();
130   if (!is_variable_read && !IsFreeOfSideEffect(node)) {
131     VLOG(3) << "Not safe to convert '" << node.name()
132             << " to NoOp. Node has side effect.";
133     return false;
134   }
135   if (node.op().rfind("Submodel", 0) == 0) {
136     return false;
137   }
138   const OpDef* op_def = nullptr;
139   Status status = OpRegistry::Global()->LookUpOpDef(node.op(), &op_def);
140   if (!status.ok() || op_def->output_arg_size() == 0) {
141     return false;
142   }
143   const std::unordered_set<string> do_not_rewrite_ops{
144       "Assert",     "CheckNumerics",         "_Retval",
145       "_Arg",       "_ParallelConcatUpdate", "TPUExecute",
146       "TPUCompile", "ControlTrigger"};
147   if (do_not_rewrite_ops.find(node.op()) != do_not_rewrite_ops.end()) {
148     return false;
149   }
150   if (!SafeToRemoveIdentity(node)) {
151     return false;
152   }
153   return true;
154 }
155 
NumEdgesIfBypassed(const NodeDef & node,const std::vector<NodeDef * > & output_nodes) const156 int DependencyOptimizer::NumEdgesIfBypassed(
157     const NodeDef& node, const std::vector<NodeDef*>& output_nodes) const {
158   const bool is_multi_input_identity_n =
159       IsIdentityN(node) && !IsIdentityNSingleInput(node);
160   const int num_outputs = output_nodes.size();
161   const int num_inputs = node.input_size();
162 
163   if (is_multi_input_identity_n) {
164     // multi-input identity_n with input/output control dependencies will likely
165     // increase number of edges after optimization.
166     int num_edges_if_bypassed(0);
167     for (const string& input_node_name : node.input()) {
168       if (IsControlInput(input_node_name)) {
169         num_edges_if_bypassed += num_outputs;
170       } else {
171         ++num_edges_if_bypassed;
172       }
173     }
174 
175     for (auto consumer : output_nodes) {
176       for (int j = 0; j < consumer->input_size(); ++j) {
177         const TensorId consumer_input = ParseTensorName(consumer->input(j));
178         if (consumer_input.node() == node.name()) {
179           if (IsControlInput(consumer_input)) {
180             num_edges_if_bypassed += num_inputs;
181           } else {
182             ++num_edges_if_bypassed;
183           }
184         }
185       }
186     }
187     return num_edges_if_bypassed;
188   } else {
189     return num_inputs * num_outputs;
190   }
191 }
192 
BypassingNodeIsBeneficial(const NodeDef & node,const std::vector<NodeDef * > & input_nodes,const std::vector<NodeDef * > & output_nodes) const193 bool DependencyOptimizer::BypassingNodeIsBeneficial(
194     const NodeDef& node, const std::vector<NodeDef*>& input_nodes,
195     const std::vector<NodeDef*>& output_nodes) const {
196   const bool is_identity = IsIdentity(node) || IsIdentityNSingleInput(node);
197   const bool is_multi_input_identity_n =
198       IsIdentityN(node) && !IsIdentityNSingleInput(node);
199   const int num_outputs = output_nodes.size();
200   const int num_inputs = node.input_size();
201 
202   if (NumEdgesIfBypassed(node, output_nodes) > num_inputs + num_outputs) {
203     return false;
204   }
205 
206   // Make sure that we don't increase the number of edges that cross
207   // device boundaries.
208   if ((num_inputs == 1 && num_outputs > 1 &&
209        input_nodes[0]->device() != node.device()) ||
210       (num_inputs > 1 && num_outputs == 1 &&
211        output_nodes[0]->device() != node.device())) {
212     return false;
213   }
214 
215   // TODO(rmlarsen): Not all device crossings are equally expensive.
216   // Assign a cost to each based on device affinity and compute a
217   // cost before and after.
218   const string& node_dev = node.device();
219   int num_cross_in = 0;
220   for (NodeDef* input_node : input_nodes) {
221     num_cross_in += static_cast<int>(input_node->device() != node_dev);
222   }
223   int num_cross_out = 0;
224   for (NodeDef* output_node : output_nodes) {
225     num_cross_out += static_cast<int>(output_node->device() != node_dev);
226   }
227 
228   // Make sure we do not increase the number of device crossings.
229   const int num_cross_before = num_cross_in + num_cross_out;
230   int num_cross_after = 0;
231   for (NodeDef* input_node : input_nodes) {
232     for (NodeDef* output_node : output_nodes) {
233       num_cross_after +=
234           static_cast<int>(input_node->device() != output_node->device());
235     }
236   }
237   if (num_cross_after > num_cross_before) {
238     return false;
239   }
240 
241   if ((is_identity || is_multi_input_identity_n) && num_cross_in > 0 &&
242       num_cross_out > 0 && num_cross_after > 0) {
243     // This identity node follows a device crossing, so it might be
244     // following a _Recv node after partitioning. Do not remove such nodes,
245     // unless they only have consumers on the same device as themselves.
246     return false;
247   }
248 
249   return true;
250 }
251 
OptimizeNode(int node_idx,SetVector<int> * nodes_to_simplify,std::set<int> * nodes_to_delete)252 void DependencyOptimizer::OptimizeNode(int node_idx,
253                                        SetVector<int>* nodes_to_simplify,
254                                        std::set<int>* nodes_to_delete) {
255   NodeDef* node = optimized_graph_->mutable_node(node_idx);
256   const bool is_noop = IsNoOp(*node);
257   const bool is_identity = IsIdentity(*node) || IsIdentityNSingleInput(*node);
258   const bool is_multi_input_identity =
259       IsIdentityN(*node) && !IsIdentityNSingleInput(*node);
260   const string node_name = node->name();
261   // Constant nodes with no input control dependency are always executed early,
262   // so we can prune all their output control dependencies.
263   if (IsConstant(*node) && node->input_size() == 0) {
264     const auto output_nodes = node_map_->GetOutputs(node_name);
265     for (NodeDef* fanout : output_nodes) {
266       bool optimize_fanout = false;
267       bool data_connection = false;
268       for (int i = fanout->input_size() - 1; i >= 0; --i) {
269         const TensorId input_tensor = ParseTensorName(fanout->input(i));
270         if (input_tensor.node() == node_name) {
271           if (input_tensor.index() < 0) {
272             fanout->mutable_input()->SwapElements(i, fanout->input_size() - 1);
273             fanout->mutable_input()->RemoveLast();
274             optimize_fanout = true;
275           } else {
276             data_connection = true;
277           }
278         }
279       }
280       if (optimize_fanout) {
281         nodes_to_simplify->PushBack(node_to_idx_[fanout]);
282         if (!data_connection) {
283           node_map_->RemoveOutput(node_name, fanout->name());
284         }
285       }
286     }
287     if (node_map_->GetOutputs(node_name).empty() && fetch_nodes_known_ &&
288         nodes_to_preserve_.find(node_name) == nodes_to_preserve_.end()) {
289       // Mark the node for deletion.
290       nodes_to_delete->insert(node_to_idx_[node]);
291     }
292     return;
293   }
294 
295   // Change ops that only have control dependencies as outputs to NoOps.
296   if (!is_noop && SafeToConvertToNoOp(*node)) {
297     VLOG(2) << "***** Replacing  " << node_name << " (" << node->op()
298             << ") with NoOp.";
299     // The outputs of this node are not consumed. Replace its inputs with
300     // control dependencies and replace the op itself with the NoOp op.
301     std::unordered_set<string> ctrl_inputs;
302     int pos = 0;
303     while (pos < node->input_size()) {
304       const string old_input = node->input(pos);
305       if (IsControlInput(old_input)) {
306         if (!ctrl_inputs.insert(old_input).second) {
307           // We found a duplicate control input. Remove it.
308           node->mutable_input()->SwapElements(pos, node->input_size() - 1);
309           node->mutable_input()->RemoveLast();
310         } else {
311           ++pos;
312         }
313         continue;
314       }
315       // Replace a normal input with a control input.
316       const string ctrl_input = ConstantFolding::AddControlDependency(
317           old_input, optimized_graph_, node_map_.get());
318       ctrl_inputs.insert(ctrl_input);
319       node->set_input(pos, ctrl_input);
320       node_map_->UpdateInput(node_name, old_input, ctrl_input);
321       const NodeDef* old_input_node = node_map_->GetNode(old_input);
322       nodes_to_simplify->PushBack(node_to_idx_[old_input_node]);
323       ++pos;
324     }
325     node->set_op("NoOp");
326     EraseRegularNodeAttributes(node);
327     DedupControlInputs(node);
328     nodes_to_simplify->PushBack(node_to_idx_[node]);
329     return;
330   }
331 
332   // Remove NoOp nodes if the product of their fan-in and fan-out is less than
333   // or equal to the sum of the fan-in and fan-out. The non-trivial rewrites
334   // take the following form:
335   //
336   // Case a)
337   //    x --^> +------+                x --^> +---+
338   //    y --^> | NoOp | --^> a   ==>   y --^> | a |
339   //    ...    |      |                  ...  |   |
340   //    z --^> +------+                z --^> +---+
341   //
342   // Case b)
343   //           +------+ --^> a         +---+ --^> a
344   //    x --^> | NoOp | --^> b  ==>    | x | --^> b
345   //           |      | ...            |   | ...
346   //           +------+ --^> c         +---+ --^> c
347   // Case c)
348   //           +------+                x ---^> a
349   //    x --^> | NoOp | --^> a  ==>      \/
350   //    y --^> |      | --^> b           /\
351   //           +------+                y ---^> b
352   //
353   // We only apply this optimization if we don't increase the number of control
354   // edges across device boundaries, e.g. in cases a) and b) if NoOp and
355   // a and x, respectively, are on the same device. Control edges across device
356   // boundaries require inter-device communication (Send/Recv pairs to be
357   // inserted in the graph), which is very costly.
358   //
359   // We also remove identity nodes, subject to the same constraints on number of
360   // resulting control edges and device boundary crossings:
361   //
362   // Case a)
363   //          +----------+ ---> a       +---+ ---> a
364   //    x --> | Identity | --^> b  ==>  | x | --^> b
365   //          |          | ...          |   | ...
366   //          +----------+ --^> c       +---+ --^> c
367   //
368   // Case b)
369   //    x ---> +----------+ ---> a      x ---> +---+
370   //    y --^> | Identity |        ==>  y --^> | a |
371   //    ...    |          |               ...  |   |
372   //    z --^> +----------+             z --^> +---+
373   //
374   // Case c)
375   //           +----------+             x ---> +---+
376   //    x ---> | Identity | ---> a ==>   \--^> | a |
377   //    y --^> |          | --^> b       /\    +---+
378   //           +----------+             y --^> b
379 
380   if (is_noop || ((is_identity || is_multi_input_identity) &&
381                   SafeToRemoveIdentity(*node))) {
382     const int num_inputs = node->input_size();
383     std::vector<NodeDef*> input_nodes;
384     for (int i = 0; i < num_inputs; ++i) {
385       NodeDef* input_node = node_map_->GetNode(node->input(i));
386       if (input_node == nullptr) {
387         LOG(ERROR) << "Invalid input " << node->input(i);
388         return;
389       }
390       input_nodes.push_back(input_node);
391     }
392     const auto& output_node_set = node_map_->GetOutputs(node_name);
393     const std::vector<NodeDef*> output_nodes(output_node_set.begin(),
394                                              output_node_set.end());
395 
396     if (!BypassingNodeIsBeneficial(*node, input_nodes, output_nodes)) {
397       return;
398     }
399 
400     VLOG(2) << "***** Rerouting input around\n" << node->DebugString();
401     // Now remove the node and re-wire its inputs to its outputs.
402     for (auto consumer : output_nodes) {
403       bool updated_consumer = false;
404       VLOG(2) << "consumer before:\n" << consumer->DebugString();
405       // Remove dependency on node from consumer.
406       for (int i = 0; i < num_inputs; ++i) {
407         const NodeDef* input = input_nodes[i];
408         // Forward dependency from input to consumer if it doesn't already
409         // depend on it.
410         if ((is_identity && i == 0) ||
411             (is_multi_input_identity && !IsControlInput(node->input(i)))) {
412           // Replace regular input from Identity node.
413           string new_input;
414           const string& input_to_forward = node->input(i);
415           CHECK(!IsControlInput(input_to_forward));
416           for (int j = 0; j < consumer->input_size(); ++j) {
417             const TensorId old_input = ParseTensorName(consumer->input(j));
418             if (old_input.node() == node_name) {
419               if (old_input.index() == i) {
420                 // Regular input
421                 new_input = input_to_forward;
422                 node_map_->UpdateInput(consumer->name(),
423                                        string(old_input.node()), new_input);
424                 consumer->set_input(j, new_input);
425               } else if (old_input.index() == -1) {
426                 // Control dependency
427                 new_input = AsControlDependency(NodeName(input_to_forward));
428                 node_map_->UpdateInput(consumer->name(),
429                                        string(old_input.node()), new_input);
430                 consumer->set_input(j, new_input);
431               }
432             }
433           }
434           updated_consumer = true;
435         } else {
436           // Forward dependency from input to consumer if it doesn't already
437           // depend on it.
438           if (node_map_->GetOutputs(input->name()).count(consumer) == 0) {
439             consumer->add_input(AsControlDependency(input->name()));
440             node_map_->AddOutput(input->name(), consumer->name());
441             nodes_to_simplify->PushBack(node_to_idx_[input]);
442             updated_consumer = true;
443           }
444         }
445       }
446       updated_consumer |= RemoveControlInput(
447           consumer, AsControlDependency(node_name), node_map_.get());
448       if (updated_consumer) {
449         nodes_to_simplify->PushBack(node_to_idx_[consumer]);
450       }
451       VLOG(2) << "consumer after:\n" << consumer->DebugString();
452     }
453     node_map_->RemoveOutputs(node_name);
454     if (fetch_nodes_known_ &&
455         nodes_to_preserve_.find(node_name) == nodes_to_preserve_.end()) {
456       // Mark the node for deletion.
457       nodes_to_delete->insert(node_idx);
458 
459       // Disconnect the node from its inputs to enable further optimizations.
460       node_map_->RemoveInputs(node_name);
461       node->clear_input();
462     }
463   }
464 }
465 
CleanControlInputs()466 void DependencyOptimizer::CleanControlInputs() {
467   for (int i = 0; i < optimized_graph_->node_size(); ++i) {
468     DedupControlInputs(optimized_graph_->mutable_node(i));
469   }
470 }
471 
OptimizeDependencies()472 Status DependencyOptimizer::OptimizeDependencies() {
473   SetVector<int> nodes_to_simplify;
474   std::set<int> nodes_to_delete;
475   for (int i = 0; i < optimized_graph_->node_size(); ++i) {
476     const NodeDef& node = optimized_graph_->node(i);
477     if (IsNoOp(node) || IsIdentity(node) || IsIdentityN(node) ||
478         IsConstant(node) || SafeToConvertToNoOp(node)) {
479       nodes_to_simplify.PushBack(i);
480     }
481   }
482   while (!nodes_to_simplify.Empty()) {
483     int node_to_simplify = nodes_to_simplify.PopBack();
484     // Discard nodes that were marked for deletion already.
485     while (nodes_to_delete.find(node_to_simplify) != nodes_to_delete.end()) {
486       node_to_simplify = nodes_to_simplify.PopBack();
487     }
488     OptimizeNode(node_to_simplify, &nodes_to_simplify, &nodes_to_delete);
489   }
490 
491   if (fetch_nodes_known_) {
492     VLOG(1) << "Deleted " << nodes_to_delete.size() << " out of "
493             << optimized_graph_->node_size() << " nodes.";
494     EraseNodesFromGraph(nodes_to_delete, optimized_graph_);
495     node_map_.reset(new NodeMap(optimized_graph_));
496     BuildNodeToIdx();
497   }
498   return Status::OK();
499 }
500 
501 namespace {
502 
503 enum DistanceFromSource : uint8 { ZERO = 0, ONE = 1, TWO_OR_GREATER = 2 };
504 
LongestPathsLowerBounds(int source,const std::pair<int,int> & target_range,const std::vector<std::vector<int>> & outputs,std::vector<DistanceFromSource> * longest_distance)505 void LongestPathsLowerBounds(
506     int source, const std::pair<int, int>& target_range,
507     const std::vector<std::vector<int>>& outputs,
508     std::vector<DistanceFromSource>* longest_distance) {
509   std::deque<int> queue;
510   queue.emplace_front(source);
511   while (!queue.empty()) {
512     int node = queue.front();
513     queue.pop_front();
514     for (int fanout : outputs[node]) {
515       // 1) Only nodes in the target range can be on paths from source to one of
516       //    its control outputs.
517       // 2) Since we only need a lower bound on the longest distance, we can
518       //    skip nodes for which we have already proven have a path of
519       //    length > 1 from the source.
520       if (fanout >= target_range.first && fanout <= target_range.second &&
521           (*longest_distance)[fanout] != TWO_OR_GREATER) {
522         (*longest_distance)[fanout] =
523             (*longest_distance)[fanout] == ZERO ? ONE : TWO_OR_GREATER;
524         queue.emplace_front(fanout);
525       }
526     }
527   }
528 }
529 
530 }  // namespace
531 
TransitiveReduction()532 Status DependencyOptimizer::TransitiveReduction() {
533   // PRECONDITION: optimized_graph_ must be sorted topologically.
534   const int num_nodes = optimized_graph_->node_size();
535   // Set up a compressed version of the graph to save a constant factor in the
536   // expensive algorithm below. Also cache the set of control outputs and the
537   // highest index of a target of any control output from each node.
538   int num_controls = 0;
539   std::vector<std::vector<int>> outputs(num_nodes);
540   std::vector<gtl::InlinedVector<std::pair<int, int>, 2>> control_outputs(
541       num_nodes);
542   // target_range[i] contains the range of node indices for which to compute
543   // longest paths starting from node i.
544   std::vector<std::pair<int, int>> target_range(num_nodes, {num_nodes, -1});
545   for (int node_idx = 0; node_idx < num_nodes; ++node_idx) {
546     const NodeDef& node = optimized_graph_->node(node_idx);
547     if (ModifiesFrameInfo(node) || !HasOpDef(node)) {
548       // Ignore function nodes and nodes that modify frame info.
549       continue;
550     }
551     for (int input_slot = 0; input_slot < node.input_size(); ++input_slot) {
552       const string& input = node.input(input_slot);
553       const NodeDef* input_node = node_map_->GetNode(input);
554       if (ModifiesFrameInfo(*input_node) || IsMerge(*input_node)) {
555         // Ignore edges from nodes that modify frame info and from Merge nodes,
556         // because we cannot know which of it's input paths executes.
557         continue;
558       }
559       const int input_node_idx = node_to_idx_[input_node];
560       outputs[input_node_idx].push_back(node_idx);
561       target_range[input_node_idx].first =
562           std::min(target_range[input_node_idx].first, node_idx);
563       if (IsControlInput(input)) {
564         ++num_controls;
565         control_outputs[input_node_idx].emplace_back(node_idx, input_slot);
566         target_range[input_node_idx].second =
567             std::max(target_range[input_node_idx].second, node_idx);
568       }
569     }
570   }
571 
572   // Run the longest path in DAG algorithm for each source node that has control
573   // outputs. If, for any target node of a control output, there exists a path
574   // of length > 1, we can drop that control dependency.
575   int num_controls_removed = 0;
576   std::vector<DistanceFromSource> longest_distance(num_nodes);
577   // Map from target_index -> set of (input_slot, source_index), representing
578   // the control edges to remove. We sort them in reverse order by input slot,
579   // such that when we swap them out so we don't clobber the
580   // node(target).input() repeated field.
581   typedef std::pair<int, int> InputSlotAndSource;
582   absl::flat_hash_map<
583       int, std::set<InputSlotAndSource, std::greater<InputSlotAndSource>>>
584       control_edges_to_remove;
585   for (int source = 0; source < num_nodes; ++source) {
586     if (target_range[source].first >= target_range[source].second ||
587         target_range[source].second <= source) {
588       continue;
589     }
590     // Compute the set of nodes in the transitive fanout of source with
591     // topological sort index in [target_range.first : target_range.second]]
592     // to which there exists a path of length 2 or more from source.
593     std::fill(longest_distance.begin() + target_range[source].first,
594               longest_distance.begin() + target_range[source].second + 1, ZERO);
595     LongestPathsLowerBounds(source, target_range[source], outputs,
596                             &longest_distance);
597 
598     // If the longest path from source to target of a control dependency is
599     // longer than 1, there exists an alternate path, and we can eliminate the
600     // redundant direct control dependency.
601     for (const auto& control_output : control_outputs[source]) {
602       const int target = control_output.first;
603       if (longest_distance[target] == TWO_OR_GREATER) {
604         const int input_slot = control_output.second;
605         control_edges_to_remove[target].emplace(input_slot, source);
606       }
607     }
608   }
609   for (const auto& it : control_edges_to_remove) {
610     const int target = it.first;
611     NodeDef* target_node = optimized_graph_->mutable_node(target);
612     for (const InputSlotAndSource& slot_and_source : it.second) {
613       const int input_slot = slot_and_source.first;
614       const int source = slot_and_source.second;
615       const NodeDef& source_node = optimized_graph_->node(source);
616       CHECK_LT(input_slot, target_node->input_size());
617       target_node->mutable_input()->SwapElements(input_slot,
618                                                  target_node->input_size() - 1);
619       node_map_->RemoveOutput(source_node.name(), target_node->name());
620       target_node->mutable_input()->RemoveLast();
621       ++num_controls_removed;
622     }
623   }
624   VLOG(1) << "Removed " << num_controls_removed << " out of " << num_controls
625           << " control dependencies";
626   return Status::OK();
627 }
628 
BuildNodeToIdx()629 void DependencyOptimizer::BuildNodeToIdx() {
630   // Set up &node -> index map.
631   node_to_idx_.clear();
632   for (int i = 0; i < optimized_graph_->node_size(); ++i) {
633     const NodeDef& node = optimized_graph_->node(i);
634     node_to_idx_[&node] = i;
635   }
636 }
637 
638 // Suppose there are cross-device control inputs to node C from multiple nodes
639 // that are located on another device, e.g., we have control edges:
640 // A->C, B->C
641 // where A and B are on device X and C is on device Y.
642 // We can reduce cross-device communication by introducing an intermediate
643 // NoOp node C' on device X and rewriting the control edges to:
644 // A->C', B->C', C' -> C
GroupCrossDeviceControlEdges(bool host_granularity)645 void DependencyOptimizer::GroupCrossDeviceControlEdges(bool host_granularity) {
646   VLOG(1)
647       << "DependencyOptimizer::GroupCrossDeviceControlEdges host_granularity="
648       << host_granularity;
649   const int num_nodes = optimized_graph_->node_size();
650   for (int i = 0; i < num_nodes; ++i) {
651     NodeDef* node = optimized_graph_->mutable_node(i);
652     if (node->device().empty()) continue;
653     string rest, node_device = node->device();
654     if (host_granularity) {
655       DeviceNameUtils::SplitDeviceName(node->device(), &node_device, &rest);
656     }
657 
658     // Creates new noop nodes for devices on which multiple control inputs are
659     // located.
660 
661     // Map keyed by device name to the newly introduced Noop node for that
662     // device. A nullptr value means that we have only seen a single node on
663     // that device.
664     std::map<string, NodeDef*> noops;
665     int num_noops = 0;
666     for (int j = 0; j < node->input_size(); ++j) {
667       if (IsControlInput(node->input(j))) {
668         const NodeDef* input = node_map_->GetNode(node->input(j));
669         if (input == nullptr || input->device().empty()) continue;
670         string input_device = input->device();
671         if (host_granularity) {
672           DeviceNameUtils::SplitDeviceName(input->device(), &input_device,
673                                            &rest);
674         }
675         if (input_device != node_device) {
676           VLOG(2) << "Cross-device " << node->name() << " " << input->device()
677                   << " -> " << node->device();
678           auto emplace_result = noops.emplace(input_device, nullptr);
679           if (!emplace_result.second &&
680               emplace_result.first->second == nullptr) {
681             VLOG(2) << "Duplicate input device from " << node->name();
682             // This is the second cross-device control input from the same
683             // device. Creates an intermediate noop node on that device.
684             string group_name;
685             NodeDef* noop;
686             // Creates a fresh node name; there may be conflicting names from
687             // a previous iteration of the optimizer.
688             do {
689               group_name = AddPrefixToNodeName(
690                   node->name(),
691                   strings::StrCat("GroupCrossDeviceControlEdges_", num_noops));
692               noop = node_map_->GetNode(group_name);
693               ++num_noops;
694             } while (noop != nullptr);
695             noop = optimized_graph_->add_node();
696             noop->set_name(group_name);
697             noop->set_device(input->device());
698             noop->set_op("NoOp");
699             node_map_->AddNode(noop->name(), noop);
700             emplace_result.first->second = noop;
701             VLOG(1) << "GroupCrossDeviceControlEdges: Added "
702                     << SummarizeNodeDef(*noop);
703           }
704         }
705       }
706     }
707 
708     // Reroute existing control edges to go via the newly introduced NoOp nodes.
709     int pos = 0;
710     while (pos < node->input_size()) {
711       const string& input_name = node->input(pos);
712       if (IsControlInput(input_name)) {
713         NodeDef* input = node_map_->GetNode(input_name);
714         if (input == nullptr) {
715           ++pos;
716         } else {
717           string input_device = input->device();
718           if (host_granularity) {
719             DeviceNameUtils::SplitDeviceName(input->device(), &input_device,
720                                              &rest);
721           }
722           auto it = noops.find(input_device);
723           if (it == noops.end() || it->second == nullptr) {
724             ++pos;
725           } else {
726             VLOG(2) << "Rewriting input from " << input_name;
727             node->mutable_input()->SwapElements(pos, node->input_size() - 1);
728             node->mutable_input()->RemoveLast();
729             it->second->add_input(AsControlDependency(*input));
730             node_map_->UpdateOutput(input_name, node->name(),
731                                     it->second->name());
732           }
733         }
734       } else {
735         ++pos;
736       }
737     }
738     for (const auto& entry : noops) {
739       if (entry.second) {
740         node->add_input(AsControlDependency(*entry.second));
741         node_map_->AddOutput(entry.second->name(), node->name());
742       }
743     }
744   }
745 }
746 
Optimize(Cluster * cluster,const GrapplerItem & item,GraphDef * optimized_graph)747 Status DependencyOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
748                                      GraphDef* optimized_graph) {
749   optimized_graph_ = optimized_graph;
750   *optimized_graph_ = item.graph;
751   nodes_to_preserve_ = item.NodesToPreserve();
752   fetch_nodes_known_ = !item.fetch.empty();
753   CleanControlInputs();
754 
755   const int num_iterations = 2;
756   for (int iteration = 0; iteration < num_iterations; ++iteration) {
757     GRAPPLER_RETURN_IF_DEADLINE_EXCEEDED();
758     Status topo_sort_status;
759     // Perform topological sort to prepare the graph for transitive reduction.
760     topo_sort_status = TopologicalSort(optimized_graph_);
761     // Set up index-based graph datastructures to speed up analysis steps below.
762     node_map_.reset(new NodeMap(optimized_graph_));
763     BuildNodeToIdx();
764 
765     if (topo_sort_status.ok()) {
766       // Remove redundant control dependencies.
767       TF_RETURN_IF_ERROR(TransitiveReduction());
768     } else {
769       LOG(ERROR) << "Iteration = " << iteration
770                  << ", topological sort failed with message: "
771                  << topo_sort_status.error_message();
772     }
773     // Turn nodes with only control outputs into NoOps, prune NoOp and Identity
774     // nodes.
775     TF_RETURN_IF_ERROR(OptimizeDependencies());
776 
777     // Dedup control inputs.
778     CleanControlInputs();
779 
780     // Merge multiple control edges from the same device.
781     GroupCrossDeviceControlEdges(/*host_granularity=*/false);
782 
783     // Merge control edges from the same host to reduce RPC traffic.
784     GroupCrossDeviceControlEdges(/*host_granularity=*/true);
785   }
786 
787   return Status::OK();
788 }
789 
790 }  // end namespace grappler
791 }  // end namespace tensorflow
792