1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/grappler/optimizers/dependency_optimizer.h"
17
18 #include <unordered_map>
19 #include <unordered_set>
20
21 #include "tensorflow/core/framework/node_def.pb.h"
22 #include "tensorflow/core/framework/op.h"
23 #include "tensorflow/core/grappler/costs/graph_properties.h"
24 #include "tensorflow/core/grappler/grappler_item.h"
25 #include "tensorflow/core/grappler/op_types.h"
26 #include "tensorflow/core/grappler/optimizers/constant_folding.h"
27 #include "tensorflow/core/grappler/utils.h"
28 #include "tensorflow/core/grappler/utils/topological_sort.h"
29 #include "tensorflow/core/lib/core/errors.h"
30 #include "tensorflow/core/lib/core/stringpiece.h"
31 #include "tensorflow/core/lib/gtl/inlined_vector.h"
32 #include "tensorflow/core/lib/strings/str_util.h"
33 #include "tensorflow/core/lib/strings/strcat.h"
34 #include "tensorflow/core/util/device_name_utils.h"
35
36 namespace tensorflow {
37 namespace grappler {
38
39 namespace {
40
RemoveInput(NodeDef * node,const string & input,NodeMap * node_map)41 bool RemoveInput(NodeDef* node, const string& input, NodeMap* node_map) {
42 bool removed_input = false;
43 int pos = 0;
44 while (pos < node->input_size()) {
45 if (node->input(pos) == input) {
46 node->mutable_input()->SwapElements(pos, node->input_size() - 1);
47 node->mutable_input()->RemoveLast();
48 node_map->RemoveOutput(NodeName(input), node->name());
49 removed_input = true;
50 } else {
51 ++pos;
52 }
53 }
54 return removed_input;
55 }
56
57 } // namespace
58
SafeToRemoveIdentity(const NodeDef & node) const59 bool DependencyOptimizer::SafeToRemoveIdentity(const NodeDef& node) const {
60 if (!IsIdentity(node) && !IsIdentityN(node)) {
61 return true;
62 }
63
64 if (nodes_to_preserve_.find(node.name()) != nodes_to_preserve_.end()) {
65 return false;
66 }
67 if (!fetch_nodes_known_) {
68 // The output values of this node may be needed.
69 return false;
70 }
71 const NodeDef* input = node_map_->GetNode(NodeName(node.input(0)));
72 CHECK(input != nullptr) << "node = " << node.name()
73 << " input = " << node.input(0);
74 // Don't remove Identity nodes corresponding to Variable reads or following
75 // Recv.
76 if (IsVariable(*input) || IsRecv(*input)) {
77 return false;
78 }
79 for (const auto& consumer : node_map_->GetOutputs(node.name())) {
80 if (node.input_size() > 1 && IsMerge(*consumer)) {
81 return false;
82 }
83 if (IsSwitch(*input)) {
84 for (const string& consumer_input : consumer->input()) {
85 if (consumer_input == AsControlDependency(node.name())) {
86 return false;
87 }
88 }
89 }
90 }
91 return true;
92 }
93
SafeToConvertToNoOp(const NodeDef & node) const94 bool DependencyOptimizer::SafeToConvertToNoOp(const NodeDef& node) const {
95 if (!fetch_nodes_known_ ||
96 nodes_to_preserve_.find(node.name()) != nodes_to_preserve_.end()) {
97 return false;
98 }
99 if (IsMerge(node) || IsSwitch(node) || ModifiesFrameInfo(node) ||
100 !IsFreeOfSideEffect(node)) {
101 return false;
102 }
103 if (node.op().rfind("Submodel", 0) == 0) {
104 return false;
105 }
106 const OpDef* op_def = nullptr;
107 Status status = OpRegistry::Global()->LookUpOpDef(node.op(), &op_def);
108 if (!status.ok() || op_def->output_arg_size() == 0) {
109 return false;
110 }
111 const std::unordered_set<string> do_not_rewrite_ops{
112 "Assert", "CheckNumerics", "_Retval",
113 "_Arg", "_ParallelConcatUpdate", "TPUExecute",
114 "TPUCompile", "ControlTrigger"};
115 if (do_not_rewrite_ops.find(node.op()) != do_not_rewrite_ops.end()) {
116 return false;
117 }
118 if (!SafeToRemoveIdentity(node)) {
119 return false;
120 }
121 if (NumNonControlOutputs(node, *node_map_) > 0) {
122 // The output values of this node may be needed.
123 return false;
124 }
125 return true;
126 }
127
NumEdgesIfBypassed(const NodeDef & node,const std::vector<NodeDef * > & output_nodes) const128 int DependencyOptimizer::NumEdgesIfBypassed(
129 const NodeDef& node, const std::vector<NodeDef*>& output_nodes) const {
130 const bool is_multi_input_identity_n =
131 IsIdentityN(node) && !IsIdentityNSingleInput(node);
132 const int num_outputs = output_nodes.size();
133 const int num_inputs = node.input_size();
134
135 if (is_multi_input_identity_n) {
136 // multi-input identity_n with input/output control dependencies will likely
137 // increase number of edges after optimization.
138 int num_edges_if_bypassed(0);
139 for (string input_node_name : node.input()) {
140 if (IsControlInput(input_node_name)) {
141 num_edges_if_bypassed += num_outputs;
142 } else {
143 ++num_edges_if_bypassed;
144 }
145 }
146
147 for (auto consumer : output_nodes) {
148 for (int j = 0; j < consumer->input_size(); ++j) {
149 const TensorId consumer_input = ParseTensorName(consumer->input(j));
150 if (consumer_input.node() == node.name()) {
151 if (IsControlInput(consumer_input)) {
152 num_edges_if_bypassed += num_inputs;
153 } else {
154 ++num_edges_if_bypassed;
155 }
156 }
157 }
158 }
159 return num_edges_if_bypassed;
160 } else {
161 return num_inputs * num_outputs;
162 }
163 }
164
BypassingNodeIsBeneficial(const NodeDef & node,const std::vector<NodeDef * > & input_nodes,const std::vector<NodeDef * > & output_nodes) const165 bool DependencyOptimizer::BypassingNodeIsBeneficial(
166 const NodeDef& node, const std::vector<NodeDef*>& input_nodes,
167 const std::vector<NodeDef*>& output_nodes) const {
168 const bool is_identity = IsIdentity(node) || IsIdentityNSingleInput(node);
169 const bool is_multi_input_identity_n =
170 IsIdentityN(node) && !IsIdentityNSingleInput(node);
171 const int num_outputs = output_nodes.size();
172 const int num_inputs = node.input_size();
173
174 if (NumEdgesIfBypassed(node, output_nodes) > num_inputs + num_outputs) {
175 return false;
176 }
177
178 // Make sure that we don't increase the number of edges that cross
179 // device boundaries.
180 if ((num_inputs == 1 && num_outputs > 1 &&
181 input_nodes[0]->device() != node.device()) ||
182 (num_inputs > 1 && num_outputs == 1 &&
183 output_nodes[0]->device() != node.device())) {
184 return false;
185 }
186
187 // TODO(rmlarsen): Not all device crossings are equally expensive.
188 // Assign a cost to each based on device affinity and compute a
189 // cost before and after.
190 const string& node_dev = node.device();
191 int num_cross_in = 0;
192 for (NodeDef* input_node : input_nodes) {
193 num_cross_in += static_cast<int>(input_node->device() != node_dev);
194 }
195 int num_cross_out = 0;
196 for (NodeDef* output_node : output_nodes) {
197 num_cross_out += static_cast<int>(output_node->device() != node_dev);
198 }
199
200 // Make sure we do not increase the number of device crossings.
201 const int num_cross_before = num_cross_in + num_cross_out;
202 int num_cross_after = 0;
203 for (NodeDef* input_node : input_nodes) {
204 for (NodeDef* output_node : output_nodes) {
205 num_cross_after +=
206 static_cast<int>(input_node->device() != output_node->device());
207 }
208 }
209 if (num_cross_after > num_cross_before) {
210 return false;
211 }
212
213 if ((is_identity || is_multi_input_identity_n) && num_cross_in > 0 &&
214 num_cross_out > 0 && num_cross_after > 0) {
215 // This identity node follows a device crossing, so it might be
216 // following a _Recv node after partioning. Do not remove such nodes,
217 // unless they only have consumers on the same device as themselves.
218 return false;
219 }
220
221 return true;
222 }
223
OptimizeNode(int node_idx,SetVector<int> * nodes_to_simplify,std::set<int> * nodes_to_delete)224 void DependencyOptimizer::OptimizeNode(int node_idx,
225 SetVector<int>* nodes_to_simplify,
226 std::set<int>* nodes_to_delete) {
227 NodeDef* node = optimized_graph_->mutable_node(node_idx);
228 const bool is_noop = IsNoOp(*node);
229 const bool is_identity = IsIdentity(*node) || IsIdentityNSingleInput(*node);
230 const bool is_multi_input_identity =
231 IsIdentityN(*node) && !IsIdentityNSingleInput(*node);
232 const string node_name = node->name();
233 // Constant nodes with no input control dependency are always executed early,
234 // so we can prune all their output control dependencies.
235 if (IsConstant(*node) && node->input_size() == 0) {
236 const std::set<NodeDef*> output_nodes = node_map_->GetOutputs(node_name);
237 for (NodeDef* fanout : output_nodes) {
238 bool optimize_fanout = false;
239 bool data_connection = false;
240 for (int i = fanout->input_size() - 1; i >= 0; --i) {
241 const TensorId input_tensor = ParseTensorName(fanout->input(i));
242 if (input_tensor.node() == node_name) {
243 if (input_tensor.index() < 0) {
244 fanout->mutable_input()->SwapElements(i, fanout->input_size() - 1);
245 fanout->mutable_input()->RemoveLast();
246 optimize_fanout = true;
247 } else {
248 data_connection = true;
249 }
250 }
251 }
252 if (optimize_fanout) {
253 nodes_to_simplify->PushBack(node_to_idx_[fanout]);
254 if (!data_connection) {
255 node_map_->RemoveOutput(node_name, fanout->name());
256 }
257 }
258 }
259 if (node_map_->GetOutputs(node_name).empty() && fetch_nodes_known_ &&
260 nodes_to_preserve_.find(node_name) == nodes_to_preserve_.end()) {
261 // Mark the node for deletion.
262 nodes_to_delete->insert(node_to_idx_[node]);
263 }
264 return;
265 }
266
267 // Change ops that only have control dependencies as outputs to NoOps.
268 if (!is_noop && SafeToConvertToNoOp(*node)) {
269 VLOG(1) << "***** Replacing " << node_name << " (" << node->op()
270 << ") with NoOp.";
271 // The outputs of this node are not consumed. Replace its inputs with
272 // control dependencies and replace the op itself with the NoOp op.
273 std::unordered_set<string> ctrl_inputs;
274 int pos = 0;
275 while (pos < node->input_size()) {
276 const string old_input = node->input(pos);
277 if (IsControlInput(old_input)) {
278 if (!ctrl_inputs.insert(old_input).second) {
279 // We found a duplicate control input. Remove it.
280 node->mutable_input()->SwapElements(pos, node->input_size() - 1);
281 node->mutable_input()->RemoveLast();
282 } else {
283 ++pos;
284 }
285 continue;
286 }
287 // Replace a normal input with a control input.
288 const string ctrl_input = ConstantFolding::AddControlDependency(
289 old_input, optimized_graph_, node_map_.get());
290 ctrl_inputs.insert(ctrl_input);
291 node->set_input(pos, ctrl_input);
292 node_map_->UpdateInput(node_name, old_input, ctrl_input);
293 const NodeDef* old_input_node = node_map_->GetNode(old_input);
294 nodes_to_simplify->PushBack(node_to_idx_[old_input_node]);
295 ++pos;
296 }
297 node->set_op("NoOp");
298 node->clear_attr();
299 nodes_to_simplify->PushBack(node_to_idx_[node]);
300 return;
301 }
302
303 // Remove NoOp nodes if the product of their fan-in and fan-out is less than
304 // or equal to the sum of the fan-in and fan-out. The non-trivial rewrites
305 // take the following form:
306 //
307 // Case a)
308 // x --^> +------+ x --^> +---+
309 // y --^> | NoOp | --^> a ==> y --^> | a |
310 // ... | | ... | |
311 // z --^> +------+ z --^> +---+
312 //
313 // Case b)
314 // +------+ --^> a +---+ --^> a
315 // x --^> | NoOp | --^> b ==> | x | --^> b
316 // | | ... | | ...
317 // +------+ --^> c +---+ --^> c
318 // Case c)
319 // +------+ x ---^> a
320 // x --^> | NoOp | --^> a ==> \/
321 // y --^> | | --^> b /\
322 // +------+ y ---^> b
323 //
324 // We only apply this optimization if we don't increase the number of control
325 // edges across device boundaries, e.g. in cases a) and b) if NoOp and
326 // a and x, respectively, are on the same device. Control edges across device
327 // boundaries require inter-device communication (Send/Recv pairs to be
328 // inserted in the graph), which is very costly.
329 //
330 // We also remove identity nodes, subject to the same constraints on number of
331 // resulting control edges and device boundary crossings:
332 //
333 // Case a)
334 // +----------+ ---> a +---+ ---> a
335 // x --> | Identity | --^> b ==> | x | --^> b
336 // | | ... | | ...
337 // +----------+ --^> c +---+ --^> c
338 //
339 // Case b)
340 // x ---> +----------+ ---> a x ---> +---+
341 // y --^> | Identity | ==> y --^> | a |
342 // ... | | ... | |
343 // z --^> +----------+ z --^> +---+
344 //
345 // Case c)
346 // +----------+ x ---> +---+
347 // x ---> | Identity | ---> a ==> \--^> | a |
348 // y --^> | | --^> b /\ +---+
349 // +----------+ y --^> b
350
351 if (is_noop || ((is_identity || is_multi_input_identity) &&
352 SafeToRemoveIdentity(*node))) {
353 const auto& output_node_set = node_map_->GetOutputs(node_name);
354 const std::vector<NodeDef*> output_nodes(output_node_set.begin(),
355 output_node_set.end());
356 const int num_inputs = node->input_size();
357 std::vector<NodeDef*> input_nodes;
358 for (int i = 0; i < num_inputs; ++i) {
359 NodeDef* input_node = node_map_->GetNode(node->input(i));
360 if (input_node == nullptr) {
361 LOG(ERROR) << "Invalid input " << node->input(i);
362 return;
363 }
364 input_nodes.push_back(input_node);
365 }
366
367 if (!BypassingNodeIsBeneficial(*node, input_nodes, output_nodes)) {
368 return;
369 }
370
371 VLOG(1) << "***** Rerouting input around\n" << node->DebugString();
372 // Now remove the node and re-wire its inputs to its outputs.
373 for (auto consumer : output_nodes) {
374 bool updated_consumer = false;
375 VLOG(1) << "consumer before:\n" << consumer->DebugString();
376 for (int i = 0; i < num_inputs; ++i) {
377 const NodeDef* input = input_nodes[i];
378 // Forward dependency from input to consumer if it doesn't already
379 // depend on it.
380 if ((is_identity && i == 0) ||
381 (is_multi_input_identity && !IsControlInput(node->input(i)))) {
382 // Replace regular input from Identity node.
383 string new_input;
384 const string& input_to_forward = node->input(i);
385 CHECK(!IsControlInput(input_to_forward));
386 for (int j = 0; j < consumer->input_size(); ++j) {
387 const TensorId old_input = ParseTensorName(consumer->input(j));
388 if (old_input.node() == node_name) {
389 if (old_input.index() == i) {
390 // Regular input
391 new_input = input_to_forward;
392 node_map_->UpdateInput(consumer->name(), old_input.ToString(),
393 new_input);
394 consumer->set_input(j, new_input);
395 } else if (old_input.index() == -1) {
396 // Control dependency
397 new_input = AsControlDependency(NodeName(input_to_forward));
398 node_map_->UpdateInput(consumer->name(), old_input.ToString(),
399 new_input);
400 consumer->set_input(j, new_input);
401 }
402 }
403 }
404 updated_consumer = true;
405 } else {
406 // Forward dependency from input to consumer if it doesn't already
407 // depend on it.
408 if (node_map_->GetOutputs(input->name()).count(consumer) == 0) {
409 consumer->add_input(AsControlDependency(input->name()));
410 node_map_->AddOutput(input->name(), consumer->name());
411 nodes_to_simplify->PushBack(node_to_idx_[input]);
412 updated_consumer = true;
413 }
414 }
415 }
416 // Remove dependency on node from consumer.
417 updated_consumer |= RemoveInput(consumer, AsControlDependency(node_name),
418 node_map_.get());
419 if (updated_consumer) {
420 nodes_to_simplify->PushBack(node_to_idx_[consumer]);
421 }
422 VLOG(1) << "consumer after:\n" << consumer->DebugString();
423 }
424 node_map_->RemoveOutputs(node_name);
425 if (fetch_nodes_known_ &&
426 nodes_to_preserve_.find(node_name) == nodes_to_preserve_.end()) {
427 // Mark the node for deletion.
428 nodes_to_delete->insert(node_idx);
429
430 // Disconnect the node from its inputs to enable further optimizations.
431 node_map_->RemoveInputs(node_name);
432 node->clear_input();
433 }
434 }
435 }
436
CleanControlInputs()437 void DependencyOptimizer::CleanControlInputs() {
438 for (int i = 0; i < optimized_graph_->node_size(); ++i) {
439 DedupControlInputs(optimized_graph_->mutable_node(i));
440 }
441 }
442
OptimizeDependencies()443 Status DependencyOptimizer::OptimizeDependencies() {
444 SetVector<int> nodes_to_simplify;
445 std::set<int> nodes_to_delete;
446 for (int i = 0; i < optimized_graph_->node_size(); ++i) {
447 const NodeDef& node = optimized_graph_->node(i);
448 if (IsNoOp(node) || IsIdentity(node) || IsIdentityN(node) ||
449 IsConstant(node) || SafeToConvertToNoOp(node)) {
450 nodes_to_simplify.PushBack(i);
451 }
452 }
453 while (!nodes_to_simplify.Empty()) {
454 int node_to_simplify = nodes_to_simplify.PopBack();
455 // Discard nodes that were marked for deletion already.
456 while (nodes_to_delete.find(node_to_simplify) != nodes_to_delete.end()) {
457 node_to_simplify = nodes_to_simplify.PopBack();
458 }
459 OptimizeNode(node_to_simplify, &nodes_to_simplify, &nodes_to_delete);
460 }
461
462 if (fetch_nodes_known_) {
463 VLOG(1) << "Deleted " << nodes_to_delete.size() << " out of "
464 << optimized_graph_->node_size() << " nodes.";
465 EraseNodesFromGraph(nodes_to_delete, optimized_graph_);
466 node_map_.reset(new NodeMap(optimized_graph_));
467 BuildNodeToIdx();
468 }
469 return Status::OK();
470 }
471
TransitiveReduction()472 Status DependencyOptimizer::TransitiveReduction() {
473 // PRECONDITION: optimized_graph_ must be sorted topologically.
474 const int num_nodes = optimized_graph_->node_size();
475 // Set up a compressed version of the graph to save a constant factor in the
476 // expensive algorithm below. Also cache the set of control outputs and the
477 // highest index of a target of any control output from each node.
478 int num_controls = 0;
479 std::vector<gtl::InlinedVector<int, 4>> inputs(num_nodes);
480 std::vector<gtl::InlinedVector<std::pair<int, int>, 2>> control_outputs(
481 num_nodes);
482 for (int node_idx = 0; node_idx < num_nodes; ++node_idx) {
483 const NodeDef& node = optimized_graph_->node(node_idx);
484 if (ModifiesFrameInfo(node) || !HasOpDef(node)) {
485 // Ignore function nodes and nodes that modify frame info.
486 continue;
487 }
488 for (int input_slot = 0; input_slot < node.input_size(); ++input_slot) {
489 const string& input = node.input(input_slot);
490 const NodeDef* input_node = node_map_->GetNode(input);
491 if (ModifiesFrameInfo(*input_node) || IsMerge(*input_node)) {
492 // Ignore edges from nodes that modify frame info and from Merge nodes,
493 // because we cannot know which of it's input paths executes.
494 continue;
495 }
496 const int input_node_idx = node_to_idx_[input_node];
497 inputs[node_idx].push_back(input_node_idx);
498 if (IsControlInput(input)) {
499 ++num_controls;
500 control_outputs[input_node_idx].emplace_back(node_idx, input_slot);
501 }
502 }
503 }
504
505 // Run the longest path in DAG algorithm for each source node that has control
506 // outputs. If, for any target node of a control output, there exists a path
507 // of length > 1, we can drop that control dependency.
508 int num_controls_removed = 0;
509 std::vector<int> longest_distance(num_nodes);
510 // Map from target_index -> set of (input_slot, source_index), representing
511 // the control edges to remove. We sort them in reverse order by input slot,
512 // such that when we swap them out so we don't clobber the
513 // node(target).input() repeated field.
514 typedef std::pair<int, int> InputSlotAndSource;
515 std::unordered_map<
516 int, std::set<InputSlotAndSource, std::greater<InputSlotAndSource>>>
517 control_edges_to_remove;
518 for (int source = 0; source < num_nodes; ++source) {
519 int highest_control_target = -1;
520 for (const auto& control_output : control_outputs[source]) {
521 if (control_output.first > highest_control_target) {
522 highest_control_target = control_output.first;
523 }
524 }
525 if (highest_control_target <= source) {
526 continue;
527 }
528 std::fill(longest_distance.begin() + source,
529 longest_distance.begin() + highest_control_target + 1, 0);
530 for (int target = source + 1; target <= highest_control_target; ++target) {
531 for (int input : inputs[target]) {
532 // If the input node is before source in the topo order, no path
533 // source -> input -> target can exits and we can skip it.
534 // Also only extend a path from the source itself or from nodes that
535 // have a path from source, indicated by longest_distance[input] > 0.
536 if (input == source ||
537 (input > source && longest_distance[input] > 0)) {
538 // If source -> input -> target is longer than the longest
539 // path so far from source -> target, update the longest_distance.
540 int candidate_longest_distance = longest_distance[input] + 1;
541 if (candidate_longest_distance > longest_distance[target]) {
542 longest_distance[target] = candidate_longest_distance;
543 }
544 }
545 }
546 }
547
548 // If the longest path from source to target of a control dependency is
549 // longer than 1, there exists an alternate path, and we can eliminate the
550 // redundant direct control dependency.
551 for (const auto& control_output : control_outputs[source]) {
552 const int target = control_output.first;
553 if (longest_distance[target] > 1) {
554 const int input_slot = control_output.second;
555 control_edges_to_remove[target].emplace(input_slot, source);
556 }
557 }
558 }
559
560 for (const auto& it : control_edges_to_remove) {
561 const int target = it.first;
562 NodeDef* target_node = optimized_graph_->mutable_node(target);
563 for (const InputSlotAndSource& slot_and_source : it.second) {
564 const int input_slot = slot_and_source.first;
565 const int source = slot_and_source.second;
566 const NodeDef& source_node = optimized_graph_->node(source);
567 CHECK_LT(input_slot, target_node->input_size());
568 target_node->mutable_input()->SwapElements(input_slot,
569 target_node->input_size() - 1);
570 node_map_->RemoveOutput(source_node.name(), target_node->name());
571 target_node->mutable_input()->RemoveLast();
572 ++num_controls_removed;
573 }
574 }
575 VLOG(1) << "Removed " << num_controls_removed << " out of " << num_controls
576 << " control dependencies";
577 return Status::OK();
578 }
579
BuildNodeToIdx()580 void DependencyOptimizer::BuildNodeToIdx() {
581 // Set up &node -> index map.
582 node_to_idx_.clear();
583 for (int i = 0; i < optimized_graph_->node_size(); ++i) {
584 const NodeDef& node = optimized_graph_->node(i);
585 node_to_idx_[&node] = i;
586 }
587 }
588
589 // Suppose there are cross-device control inputs to node C from multiple nodes
590 // that are located on another device, e.g., we have control edges:
591 // A->C, B->C
592 // where A and B are on device X and C is on device Y.
593 // We can reduce cross-device communication by introducing an intermediate
594 // NoOp node C' on device X and rewriting the control edges to:
595 // A->C', B->C', C' -> C
GroupCrossDeviceControlEdges()596 void DependencyOptimizer::GroupCrossDeviceControlEdges() {
597 const int num_nodes = optimized_graph_->node_size();
598 for (int i = 0; i < num_nodes; ++i) {
599 NodeDef* node = optimized_graph_->mutable_node(i);
600 if (node->device().empty()) continue;
601
602 // Creates new noop nodes for devices on which multiple control inputs are
603 // located.
604
605 // Map keyed by device name to the newly introduced Noop node for that
606 // device. A nullptr value means that we have only seen a single node on
607 // that device.
608 std::map<string, NodeDef*> noops;
609 int num_noops = 0;
610 for (int j = 0; j < node->input_size(); ++j) {
611 if (IsControlInput(node->input(j))) {
612 const NodeDef* input = node_map_->GetNode(node->input(j));
613 if (input != nullptr && !input->device().empty() &&
614 input->device() != node->device()) {
615 auto emplace_result = noops.emplace(input->device(), nullptr);
616 if (!emplace_result.second &&
617 emplace_result.first->second == nullptr) {
618 // This is the second cross-device control input from the same
619 // device. Creates an intermediate noop node on that device.
620 string group_name;
621 NodeDef* noop;
622 // Creates a fresh node name; there may be conflicting names from
623 // a previous iteration of the optimizer.
624 do {
625 group_name = AddPrefixToNodeName(
626 node->name(),
627 strings::StrCat("GroupCrossDeviceControlEdges_", num_noops));
628 noop = node_map_->GetNode(group_name);
629 ++num_noops;
630 } while (noop != nullptr);
631 noop = optimized_graph_->add_node();
632 noop->set_name(group_name);
633 noop->set_device(input->device());
634 noop->set_op("NoOp");
635 node_map_->AddNode(noop->name(), noop);
636 emplace_result.first->second = noop;
637 }
638 }
639 }
640 }
641
642 // Reroute existing control edges to go via the newly introduced NoOp nodes.
643 int pos = 0;
644 while (pos < node->input_size()) {
645 const string& input_name = node->input(pos);
646 if (IsControlInput(input_name)) {
647 NodeDef* input = node_map_->GetNode(input_name);
648 if (input == nullptr) {
649 ++pos;
650 } else {
651 auto it = noops.find(input->device());
652 if (it == noops.end() || it->second == nullptr) {
653 ++pos;
654 } else {
655 node->mutable_input()->SwapElements(pos, node->input_size() - 1);
656 node->mutable_input()->RemoveLast();
657 it->second->add_input(AsControlDependency(*input));
658 node_map_->UpdateOutput(input_name, node->name(),
659 it->second->name());
660 }
661 }
662 } else {
663 ++pos;
664 }
665 }
666 for (const auto& entry : noops) {
667 if (entry.second) {
668 node->add_input(AsControlDependency(*entry.second));
669 node_map_->AddOutput(entry.second->name(), node->name());
670 }
671 }
672 }
673 }
674
Optimize(Cluster * cluster,const GrapplerItem & item,GraphDef * optimized_graph)675 Status DependencyOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
676 GraphDef* optimized_graph) {
677 optimized_graph_ = optimized_graph;
678 *optimized_graph_ = item.graph;
679 nodes_to_preserve_ = item.NodesToPreserve();
680 fetch_nodes_known_ = !item.fetch.empty();
681 CleanControlInputs();
682
683 const int num_iterations = 2;
684 for (int iteration = 0; iteration < num_iterations; ++iteration) {
685 GRAPPLER_RETURN_IF_DEADLINE_EXCEEDED();
686 Status topo_sort_status;
687 // Perform topological sort to prepare the graph for transitive reduction.
688 topo_sort_status = TopologicalSort(optimized_graph_);
689 // Set up index-based graph datastructures to speed up analysis steps below.
690 node_map_.reset(new NodeMap(optimized_graph_));
691 BuildNodeToIdx();
692
693 if (topo_sort_status.ok()) {
694 // Remove redundant control dependencies.
695 TF_RETURN_IF_ERROR(TransitiveReduction());
696 } else {
697 LOG(ERROR) << "Iteration = " << iteration
698 << ", topological sort failed with message: "
699 << topo_sort_status.error_message();
700 }
701 // Turn nodes with only control outputs into NoOps, prune NoOp and Identity
702 // nodes.
703 TF_RETURN_IF_ERROR(OptimizeDependencies());
704
705 // Dedup control inputs.
706 CleanControlInputs();
707
708 GroupCrossDeviceControlEdges();
709 }
710
711 return Status::OK();
712 }
713
Feedback(Cluster *,const GrapplerItem &,const GraphDef &,double)714 void DependencyOptimizer::Feedback(Cluster* /*cluster*/,
715 const GrapplerItem& /*item*/,
716 const GraphDef& /*optimized_graph*/,
717 double /*result*/) {
718 // Nothing to do for DependencyOptimizer.
719 }
720
721 } // end namespace grappler
722 } // end namespace tensorflow
723