1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/grappler/optimizers/loop_optimizer.h"
17
18 #include <algorithm>
19 #include <deque>
20 #include <limits>
21 #include <unordered_map>
22 #include <unordered_set>
23 #include <vector>
24
25 #include "absl/container/flat_hash_set.h"
26 #include "absl/strings/string_view.h"
27 #include "tensorflow/core/common_runtime/device.h"
28 #include "tensorflow/core/framework/allocator.h"
29 #include "tensorflow/core/framework/attr_value.pb.h"
30 #include "tensorflow/core/framework/node_def.pb.h"
31 #include "tensorflow/core/framework/op.h"
32 #include "tensorflow/core/framework/tensor.pb.h"
33 #include "tensorflow/core/framework/types.h"
34 #include "tensorflow/core/grappler/graph_topology_view.h"
35 #include "tensorflow/core/grappler/grappler_item.h"
36 #include "tensorflow/core/grappler/mutable_graph_view.h"
37 #include "tensorflow/core/grappler/op_types.h"
38 #include "tensorflow/core/grappler/optimizers/constant_folding.h"
39 #include "tensorflow/core/grappler/optimizers/evaluation_utils.h"
40 #include "tensorflow/core/grappler/utils/frame.h"
41 #include "tensorflow/core/grappler/utils/traversal.h"
42 #include "tensorflow/core/lib/core/errors.h"
43 #include "tensorflow/core/lib/core/stringpiece.h"
44 #include "tensorflow/core/lib/gtl/inlined_vector.h"
45 #include "tensorflow/core/lib/strings/strcat.h"
46 #include "tensorflow/core/platform/tensor_coding.h"
47 #include "tensorflow/core/public/version.h"
48 #include "tensorflow/core/util/device_name_utils.h"
49 #include "tensorflow/core/util/saved_tensor_slice_util.h"
50
51 using tensorflow::strings::StrCat;
52
53 namespace tensorflow {
54 namespace grappler {
55 namespace {
56
57 using TensorVector = gtl::InlinedVector<TensorValue, 4>;
58
59 class LoopInvariantNodeMotionOptimizer {
60 public:
LoopInvariantNodeMotionOptimizer(GraphDef * optimized_graph)61 explicit LoopInvariantNodeMotionOptimizer(GraphDef* optimized_graph)
62 : optimized_graph_(optimized_graph) {}
63 virtual ~LoopInvariantNodeMotionOptimizer() = default;
64 Status Optimize();
65
66 private:
67 Status FindInvariantNodes(NodeDef* node);
68 Status RevertInvariantNodes();
69 Status MoveInvariantNodes(const int frame_id);
70 Status HandleInvariantNode(NodeDef* node, const int num_outputs,
71 const int frame_id);
72 Status HandleConst(NodeDef* node, const int num_outputs, const int frame_id);
73 Status HandleInvariantEnter(NodeDef* node, const int num_outputs);
74
75 GraphDef* optimized_graph_; // Not owned.
76 std::unique_ptr<NodeMap> node_map_;
77 std::map<NodeDef*, int> invariant_nodes_;
78 std::set<int> empty_set_;
79 // TODO(rmlarsen): Use vector instead of map, since frames ids are dense.
80 std::map<int, std::set<int>> frame_children_;
81 std::map<int, int> frame_parent_;
82 std::map<int, const NodeDef*> loop_cond_;
83 std::map<int, std::vector<NodeDef*>> invariant_enters_;
84 int new_enter_id_;
85 };
86
HandleInvariantEnter(NodeDef * node,const int num_outputs)87 Status LoopInvariantNodeMotionOptimizer::HandleInvariantEnter(
88 NodeDef* node, const int num_outputs) {
89 auto consumers = node_map_->GetOutputs(node->name());
90 std::vector<string> enter_control_inputs;
91 string enter_input;
92 for (auto& input : node->input()) {
93 if (IsControlInput(input)) {
94 enter_control_inputs.push_back(input);
95 } else {
96 enter_input = input;
97 }
98 }
99 for (auto* consumer : consumers) {
100 if (invariant_nodes_.count(consumer)) {
101 for (int i = 0; i < consumer->input_size(); ++i) {
102 if (NodeName(consumer->input(i)) == node->name()) {
103 consumer->set_input(i, enter_input);
104 node_map_->AddOutput(NodeName(enter_input), consumer->name());
105 node_map_->RemoveOutput(node->name(), consumer->name());
106 }
107 }
108 for (auto& control_input : enter_control_inputs) {
109 consumer->add_input(control_input);
110 node_map_->AddOutput(NodeName(control_input), consumer->name());
111 }
112 }
113 }
114 return Status::OK();
115 }
116
HandleConst(NodeDef * node,const int num_outputs,const int frame_id)117 Status LoopInvariantNodeMotionOptimizer::HandleConst(NodeDef* node,
118 const int num_outputs,
119 const int frame_id) {
120 NodeDef* const_node = nullptr;
121 if (num_outputs == 0) {
122 // all successor nodes are invariant
123 // Remove the control inputs from this frame to the const node,
124 // when moving it out of the frame (in parent frame)
125 const_node = node;
126 node_map_->RemoveInputs(node->name());
127 node->clear_input();
128 } else {
129 // some successor nodes are variant
130 // Have to keep the const node in the frame,
131 // so create a new one outside the frame (in parent frame)
132 const string const_node_name =
133 AddPrefixToNodeName(node->name(), kLoopOptimizer);
134 const_node = node_map_->GetNode(const_node_name);
135 if (const_node == nullptr) {
136 const_node = optimized_graph_->add_node();
137 const_node->set_name(const_node_name);
138 const_node->set_op("Const");
139 const_node->set_device(node->device());
140 *const_node->mutable_attr() = node->attr();
141 node_map_->AddNode(const_node->name(), const_node);
142 }
143 auto consumers = node_map_->GetOutputs(node->name());
144 for (auto* consumer : consumers) {
145 if (invariant_nodes_.count(consumer)) {
146 for (int i = 0; i < consumer->input_size(); ++i) {
147 if (NodeName(consumer->input(i)) == node->name()) {
148 if (IsControlInput(consumer->input(i))) {
149 *consumer->mutable_input(i) = AsControlDependency(*const_node);
150 } else {
151 *consumer->mutable_input(i) = const_node->name();
152 }
153 node_map_->AddOutput(const_node->name(), consumer->name());
154 node_map_->RemoveOutput(node->name(), consumer->name());
155 }
156 }
157 }
158 }
159 }
160 // add a control input from the parent frame
161 auto parent_it = frame_parent_.find(frame_id);
162 if (parent_it != frame_parent_.end()) {
163 int parent_id = parent_it->second;
164 auto loop_cond_it = loop_cond_.find(parent_id);
165 if (loop_cond_it == loop_cond_.end()) {
166 return errors::InvalidArgument("Frame ", frame_id,
167 " doesn't have a LoopCond node");
168 }
169 auto& loop_cond_name = loop_cond_it->second->name();
170 NodeDef* switch_node = nullptr;
171 for (auto* node : node_map_->GetOutputs(loop_cond_name)) {
172 if (node->op() == "Switch") {
173 switch_node = node;
174 break;
175 }
176 }
177 if (!switch_node) {
178 return errors::InvalidArgument("LoopCond node of Frame ", frame_id,
179 " doesn't connect to any Switch node");
180 }
181 string switch_output = StrCat(switch_node->name(), ":1");
182 const string ctrl_dep = ConstantFolding::AddControlDependency(
183 switch_output, optimized_graph_, node_map_.get());
184 const_node->add_input(ctrl_dep);
185 node_map_->AddOutput(NodeName(ctrl_dep), const_node->name());
186 }
187 return Status::OK();
188 }
189
HandleInvariantNode(NodeDef * node,const int num_outputs,const int frame_id)190 Status LoopInvariantNodeMotionOptimizer::HandleInvariantNode(
191 NodeDef* node, const int num_outputs, const int frame_id) {
192 // have to remove control inputs to the invariant node from the same frame
193 // when moving this node out of this frame
194 for (int i = 0; i < node->input_size(); ++i) {
195 if (IsControlInput(node->input(i))) {
196 node->mutable_input()->SwapElements(i, node->input_size() - 1);
197 node->mutable_input()->RemoveLast();
198 }
199 }
200 if (num_outputs == 0) {
201 return Status::OK();
202 }
203
204 DataTypeVector input_types;
205 DataTypeVector output_types;
206 OpRegistryInterface* op_registry = OpRegistry::Global();
207 const OpRegistrationData* op_reg_data = nullptr;
208 TF_RETURN_IF_ERROR(op_registry->LookUp(node->op(), &op_reg_data));
209 TF_RETURN_IF_ERROR(InOutTypesForNode(*node, op_reg_data->op_def, &input_types,
210 &output_types));
211
212 auto consumers = node_map_->GetOutputs(node->name());
213 string fname = invariant_enters_[frame_id][0]->attr().at("frame_name").s();
214 int piterations =
215 invariant_enters_[frame_id][0]->attr().at("parallel_iterations").i();
216 for (auto* consumer : consumers) {
217 if (!invariant_nodes_.count(consumer)) {
218 for (int i = 0; i < consumer->input_size(); ++i) {
219 int port;
220 string node_name = ParseNodeName(consumer->input(i), &port);
221 if (node_name != node->name()) {
222 continue;
223 }
224 if (port < 0) {
225 return errors::InvalidArgument(
226 "Invariant node should not have control outputs "
227 "to variant node");
228 }
229 DataType output_type = output_types[port];
230 NodeDef* new_enter = optimized_graph_->add_node();
231 new_enter->set_op("Enter");
232 new_enter->set_device(node->device());
233 new_enter->set_name(AddPrefixToNodeName(
234 StrCat(fname, "_enter_", new_enter_id_++), kLoopOptimizer));
235 AttrValue data_type;
236 data_type.set_type(output_type);
237 new_enter->mutable_attr()->insert({"T", data_type});
238 AttrValue frame_name;
239 frame_name.set_s(fname);
240 new_enter->mutable_attr()->insert({"frame_name", frame_name});
241 AttrValue is_const;
242 is_const.set_b(true);
243 new_enter->mutable_attr()->insert({"is_constant", is_const});
244 AttrValue parallel_iterations;
245 parallel_iterations.set_i(piterations);
246 new_enter->mutable_attr()->insert(
247 {"parallel_iterations", parallel_iterations});
248 new_enter->add_input(consumer->input(i));
249 *consumer->mutable_input(i) = new_enter->name();
250 node_map_->AddNode(new_enter->name(), new_enter);
251 node_map_->AddOutput(node->name(), new_enter->name());
252 node_map_->AddOutput(new_enter->name(), consumer->name());
253 }
254 }
255 }
256 return Status::OK();
257 }
258
MoveInvariantNodes(const int frame_id)259 Status LoopInvariantNodeMotionOptimizer::MoveInvariantNodes(
260 const int frame_id) {
261 for (auto iter = invariant_nodes_.begin(); iter != invariant_nodes_.end();
262 ++iter) {
263 auto* invariant_node = iter->first;
264 const int num_outputs = iter->second;
265 if (IsEnter(*invariant_node)) {
266 TF_RETURN_IF_ERROR(HandleInvariantEnter(invariant_node, num_outputs));
267 } else if (IsConstant(*invariant_node)) {
268 TF_RETURN_IF_ERROR(HandleConst(invariant_node, num_outputs, frame_id));
269 } else {
270 TF_RETURN_IF_ERROR(
271 HandleInvariantNode(invariant_node, num_outputs, frame_id));
272 }
273 }
274 return Status::OK();
275 }
276
RevertInvariantNodes()277 Status LoopInvariantNodeMotionOptimizer::RevertInvariantNodes() {
278 std::deque<const NodeDef*> reverted_nodes;
279 for (auto iter = invariant_nodes_.begin(); iter != invariant_nodes_.end();) {
280 bool erased = false;
281 const auto* node = iter->first;
282 if (!IsConstant(*node) && !IsEnter(*node) && iter->second > 0) {
283 auto& consumers = node_map_->GetOutputs(node->name());
284 for (auto* consumer : consumers) {
285 if (!invariant_nodes_.count(consumer)) {
286 for (const auto& input : consumer->input()) {
287 if (IsControlInput(input) && NodeName(input) == node->name()) {
288 reverted_nodes.push_back(node);
289 invariant_nodes_.erase(iter++);
290 erased = true;
291 break;
292 }
293 }
294 if (erased) break;
295 }
296 }
297 }
298 if (!erased) ++iter;
299 }
300 while (!reverted_nodes.empty()) {
301 const auto* node = reverted_nodes.front();
302 reverted_nodes.pop_front();
303 std::set<NodeDef*> producers;
304 for (const auto& input : node->input()) {
305 auto* producer = node_map_->GetNode(input);
306 auto iter = invariant_nodes_.find(producer);
307 if (iter != invariant_nodes_.end()) {
308 if (IsControlInput(input) && !IsConstant(*producer) &&
309 !IsEnter(*producer)) {
310 reverted_nodes.push_back(producer);
311 invariant_nodes_.erase(iter);
312 } else {
313 producers.insert(producer);
314 }
315 }
316 }
317 for (auto* producer : producers) {
318 auto iter = invariant_nodes_.find(producer);
319 if (iter != invariant_nodes_.end()) {
320 ++iter->second;
321 }
322 }
323 for (auto* consumer : node_map_->GetOutputs(node->name())) {
324 auto iter = invariant_nodes_.find(consumer);
325 if (iter != invariant_nodes_.end()) {
326 reverted_nodes.push_back(consumer);
327 invariant_nodes_.erase(iter);
328 }
329 }
330 }
331 return Status::OK();
332 }
333
FindInvariantNodes(NodeDef * start_node)334 Status LoopInvariantNodeMotionOptimizer::FindInvariantNodes(
335 NodeDef* start_node) {
336 std::vector<NodeDef*> stack;
337 stack.reserve(32);
338 stack.push_back(start_node);
339 while (!stack.empty()) {
340 NodeDef* node = stack.back();
341 stack.pop_back();
342 auto consumers = node_map_->GetOutputs(node->name());
343 invariant_nodes_.emplace(node, consumers.size());
344 for (auto* consumer : consumers) {
345 if (invariant_nodes_.count(consumer) || ModifiesFrameInfo(*consumer)) {
346 continue;
347 }
348 bool is_invariant = true;
349 for (const auto& input : consumer->input()) {
350 if (!IsControlInput(input)) {
351 const string name = NodeName(input);
352 auto* producer = node_map_->GetNode(name);
353 if (!invariant_nodes_.count(producer)) {
354 if (IsConstant(*producer)) {
355 invariant_nodes_.insert(
356 std::make_pair(producer, node_map_->GetOutputs(name).size()));
357 } else {
358 is_invariant = false;
359 break;
360 }
361 }
362 }
363 }
364 if (is_invariant) {
365 std::set<NodeDef*> producers;
366 for (const auto& input : consumer->input()) {
367 auto* producer = node_map_->GetNode(input);
368 producers.insert(producer);
369 }
370 for (auto* producer : producers) {
371 auto iter = invariant_nodes_.find(producer);
372 if (iter != invariant_nodes_.end()) {
373 --iter->second;
374 }
375 }
376 stack.push_back(consumer);
377 }
378 }
379 }
380 return Status::OK();
381 }
382
Optimize()383 Status LoopInvariantNodeMotionOptimizer::Optimize() {
384 node_map_.reset(new NodeMap(optimized_graph_));
385 FrameView frame_view;
386 // TODO(ezhulenev): Use GraphView when migrated from NodeMap.
387 TF_RETURN_IF_ERROR(frame_view.InferFromGraph(*optimized_graph_));
388
389 std::deque<int> worklist;
390 for (const NodeDef& node : optimized_graph_->node()) {
391 const std::vector<int>& frame_ids = frame_view.Frames(node);
392
393 if (frame_ids.size() >= 3) {
394 for (unsigned int i = 1; i < frame_ids.size() - 1; ++i) {
395 frame_parent_[frame_ids[i]] = frame_ids[i - 1];
396 frame_children_[frame_ids[i]].insert(frame_ids[i + 1]);
397 }
398 }
399 if (frame_ids.size() >= 2) {
400 frame_children_[frame_ids[0]].insert(frame_ids[1]);
401 frame_parent_[frame_ids.back()] = frame_ids[frame_ids.size() - 2];
402 }
403 if (!frame_ids.empty()) {
404 frame_children_.insert(std::make_pair(frame_ids.back(), empty_set_));
405 if (node.op() == "LoopCond") {
406 if (loop_cond_.count(frame_ids.back())) {
407 return errors::InvalidArgument(
408 "Loop ", frame_ids.back(),
409 " has more than one LoopCond node: ", node.name(), " and ",
410 loop_cond_[frame_ids.back()]->name());
411 }
412 loop_cond_[frame_ids.back()] = &node;
413 }
414 if (IsEnter(node) && node.attr().at("is_constant").b()) {
415 invariant_enters_[frame_ids.back()].push_back(
416 const_cast<NodeDef*>(&node));
417 }
418 }
419 }
420
421 for (auto it = frame_children_.begin(); it != frame_children_.end(); ++it) {
422 if (it->second.empty()) {
423 worklist.push_back(it->first);
424 }
425 }
426
427 while (!worklist.empty()) {
428 int frame_id = worklist.front();
429 new_enter_id_ = 0;
430 worklist.pop_front();
431 auto parent_it = frame_parent_.find(frame_id);
432 if (parent_it != frame_parent_.end()) {
433 int parent_id = parent_it->second;
434 frame_children_[parent_id].erase(frame_id);
435 if (frame_children_[parent_id].empty()) {
436 worklist.push_back(parent_id);
437 }
438 }
439
440 if (invariant_enters_[frame_id].empty()) {
441 continue;
442 }
443 invariant_nodes_.clear();
444 for (auto* enter : invariant_enters_[frame_id]) {
445 TF_RETURN_IF_ERROR(FindInvariantNodes(enter));
446 }
447
448 // revert invariant nodes that have control outputs to variant nodes
449 TF_RETURN_IF_ERROR(RevertInvariantNodes());
450
451 TF_RETURN_IF_ERROR(MoveInvariantNodes(frame_id));
452 }
453 return Status::OK();
454 }
455
GetStackPushNodesToConvert(const GraphTopologyView & graph_view,const std::unordered_set<string> & nodes_to_preserve,int stack_node_idx)456 std::vector<int> GetStackPushNodesToConvert(
457 const GraphTopologyView& graph_view,
458 const std::unordered_set<string>& nodes_to_preserve, int stack_node_idx) {
459 VLOG(1) << "Stack node: " << graph_view.graph()->node(stack_node_idx).name();
460
461 const std::unordered_set<string> op_types_to_traverse(
462 {"Stack", "StackV2", "Enter", "RefEnter", "Switch", "RefSwitch",
463 "Identity", "RefIdentity"});
464 const auto is_op_to_traverse = [&](const NodeDef* node) -> bool {
465 return op_types_to_traverse.find(node->op()) != op_types_to_traverse.end();
466 };
467
468 std::vector<int> nodes_to_convert;
469 std::vector<int> fanouts;
470
471 DfsTraversal(graph_view, {graph_view.GetNode(stack_node_idx)},
472 TraversalDirection::kFollowOutputs,
473 DfsPredicates::Advance(is_op_to_traverse),
474 DfsCallbacks::PreOrder([&](const NodeDef* node) {
475 const absl::optional<int> idx = graph_view.GetNodeIndex(*node);
476 fanouts.push_back(idx.value());
477 }));
478
479 for (int fanout_idx : fanouts) {
480 const NodeDef& fanout_node = graph_view.graph()->node(fanout_idx);
481 VLOG(1) << "Fanout " << fanout_idx << " : " << fanout_node.name();
482 if (IsStackPushOp(fanout_node)) {
483 // Check that the stack itself is not a node we want to preserve. This can
484 // happen when the graph we have contains only the forward pass for a loop
485 // (as when the forward and backward passes are split across different
486 // functions).
487 if (graph_view.HasNode(fanout_node.input(0))) {
488 const NodeDef* stack_node = graph_view.GetNode(fanout_node.input(0));
489 while (stack_node->op() != "Stack" && stack_node->op() != "StackV2" &&
490 stack_node->input_size() > 0 &&
491 graph_view.HasNode(stack_node->input(0))) {
492 stack_node = graph_view.GetNode(stack_node->input(0));
493 }
494 if (nodes_to_preserve.find(stack_node->name()) ==
495 nodes_to_preserve.end()) {
496 nodes_to_convert.push_back(fanout_idx);
497 }
498 } else {
499 nodes_to_convert.push_back(fanout_idx);
500 }
501 } else if (IsStackOp(fanout_node) || IsStackCloseOp(fanout_node) ||
502 op_types_to_traverse.find(fanout_node.op()) !=
503 op_types_to_traverse.end()) {
504 continue;
505 } else if (!IsStackPopOp(fanout_node) ||
506 (!graph_view.GetFanout(fanout_idx).empty() ||
507 nodes_to_preserve.find(fanout_node.name()) !=
508 nodes_to_preserve.end())) {
509 // The node is either a stack pop with consumers or something unexpected
510 // so we leave the graph alone.
511 nodes_to_convert.clear();
512 break;
513 }
514 }
515
516 return nodes_to_convert;
517 }
518
RemoveStackOps(const std::unordered_set<string> & nodes_to_preserve,GraphDef * optimized_graph)519 Status RemoveStackOps(const std::unordered_set<string>& nodes_to_preserve,
520 GraphDef* optimized_graph) {
521 NodeMap node_map(optimized_graph);
522 GraphTopologyView graph_view;
523 TF_RETURN_IF_ERROR(graph_view.InitializeFromGraph(*optimized_graph));
524
525 for (int node_idx = 0; node_idx < optimized_graph->node_size(); ++node_idx) {
526 if (IsStackOp(optimized_graph->node(node_idx))) {
527 for (int push_node_idx : GetStackPushNodesToConvert(
528 graph_view, nodes_to_preserve, node_idx)) {
529 // We found push nodes without corresponding pops. Convert them to
530 // Identity passing the data through and add a control dependency from
531 // the op supplying the stack handle.
532 NodeDef* push_node = optimized_graph->mutable_node(push_node_idx);
533 VLOG(1) << "Converting " << push_node_idx << " : "
534 << push_node->DebugString();
535 if (push_node->attr().count("swap_memory") != 0) {
536 push_node->mutable_attr()->erase("swap_memory");
537 }
538 push_node->set_op("Identity");
539 push_node->mutable_input()->SwapElements(0, 1);
540 const string ctrl_dep = ConstantFolding::AddControlDependency(
541 push_node->input(1), optimized_graph, &node_map);
542 push_node->set_input(1, ctrl_dep);
543 VLOG(1) << "After converting: " << push_node->DebugString();
544 }
545 }
546 }
547 return Status::OK();
548 }
549
IsSimpleBinaryOperator(const NodeDef & node)550 bool IsSimpleBinaryOperator(const NodeDef& node) {
551 return (IsLess(node) || IsLessEqual(node) || IsGreater(node) ||
552 IsGreaterEqual(node) || IsEqual(node));
553 }
554
EvaluateBoolOpForConstantOperands(const NodeDef & op_node,const NodeDef & constant_operand_0,const NodeDef & constant_operand_1,DeviceBase * cpu_device,ResourceMgr * resource_mgr,bool * value)555 Status EvaluateBoolOpForConstantOperands(const NodeDef& op_node,
556 const NodeDef& constant_operand_0,
557 const NodeDef& constant_operand_1,
558 DeviceBase* cpu_device,
559 ResourceMgr* resource_mgr,
560 bool* value) {
561 TensorVector inputs;
562
563 const TensorProto& raw_val_0 = constant_operand_0.attr().at("value").tensor();
564 Tensor value_0(raw_val_0.dtype(), raw_val_0.tensor_shape());
565 CHECK(value_0.FromProto(raw_val_0));
566 inputs.emplace_back(&value_0);
567 const TensorProto& raw_val_1 = constant_operand_1.attr().at("value").tensor();
568 Tensor value_1(raw_val_1.dtype(), raw_val_1.tensor_shape());
569 CHECK(value_1.FromProto(raw_val_1));
570 inputs.emplace_back(&value_1);
571
572 TensorVector outputs;
573 TF_RETURN_IF_ERROR(
574 EvaluateNode(op_node, inputs, cpu_device, resource_mgr, &outputs));
575
576 if (outputs.size() != 1 || outputs[0].tensor == nullptr) {
577 return Status(error::INVALID_ARGUMENT, "Expected one output.");
578 }
579 *value = outputs[0].tensor->scalar<bool>()();
580 delete outputs[0].tensor;
581
582 return Status::OK();
583 }
584
585 // TODO(lyandy): Consolidate with ConstantFolding implementation.
IsReallyConstant(const NodeDef & node,const absl::flat_hash_set<string> & feed_nodes)586 bool IsReallyConstant(const NodeDef& node,
587 const absl::flat_hash_set<string>& feed_nodes) {
588 if (!IsConstant(node)) {
589 return false;
590 }
591 // If the node is fed it's not constant anymore.
592 return feed_nodes.find(node.name()) == feed_nodes.end();
593 }
594
CheckForDeadFanout(const MutableGraphView & view,const NodeDef & switch_node,const NodeMap & node_map,const absl::flat_hash_set<string> & feed_nodes,DeviceBase * cpu_device,ResourceMgr * resource_mgr,bool * has_dead_fanout,int * dead_fanout)595 Status CheckForDeadFanout(const MutableGraphView& view,
596 const NodeDef& switch_node, const NodeMap& node_map,
597 const absl::flat_hash_set<string>& feed_nodes,
598 DeviceBase* cpu_device, ResourceMgr* resource_mgr,
599 bool* has_dead_fanout, int* dead_fanout) {
600 *has_dead_fanout = false;
601 GraphView::InputPort switch_loopcond_port(&switch_node, 1);
602 const NodeDef* switch_predicate =
603 view.GetRegularFanin(switch_loopcond_port).node;
604
605 // CASE 1: Control is a constant.
606 if (IsReallyConstant(*switch_predicate, feed_nodes)) {
607 Tensor selector;
608 CHECK(selector.FromProto(switch_predicate->attr().at("value").tensor()));
609 *has_dead_fanout = true;
610 *dead_fanout = selector.scalar<bool>()() ? 0 : 1;
611 }
612
613 GraphView::InputPort switch_input_port(&switch_node, 0);
614 const NodeDef* switch_input = view.GetRegularFanin(switch_input_port).node;
615
616 // CASE 2: Zero-iteration while loop.
617 // We check if its a while loop such that the condition is a simple binary
618 // operator which returns false for the initialization value.
619 // TODO(srjoglekar): Improve to work with arbitrary predicate subgraphs.
620 if (!IsMerge(*switch_input)) {
621 return Status::OK();
622 }
623
624 // Find the boolean Op from predicate node.
625 NodeDef* switch_ctrl_node = nullptr;
626 for (int i = 0; i < switch_predicate->input().size(); ++i) {
627 NodeDef* node = node_map.GetNode(switch_predicate->input(i));
628 if (IsSimpleBinaryOperator(*node)) {
629 switch_ctrl_node = node;
630 }
631 }
632 if (switch_ctrl_node == nullptr) {
633 return Status::OK();
634 }
635 // Find the Merge node & the Constant Operand to the condition node, if
636 // available.
637 NodeDef* merge_node = nullptr;
638 NodeDef* constant_ctrl_input = nullptr;
639 int constant_index = 0;
640 for (int i = 0; i < switch_ctrl_node->input().size(); ++i) {
641 NodeDef* node = node_map.GetNode(switch_ctrl_node->input(i));
642 if (IsMerge(*node)) {
643 merge_node = node;
644 }
645 if (IsReallyConstant(*node, feed_nodes)) {
646 constant_ctrl_input = node;
647 constant_index = i;
648 }
649 }
650 if (merge_node == nullptr || constant_ctrl_input == nullptr) {
651 return Status::OK();
652 }
653 // Find the initialization constant (via Enter, if one exists).
654 NodeDef* enter_node = nullptr;
655 NodeDef* constant_init_node = nullptr;
656 for (const auto& input : merge_node->input()) {
657 NodeDef* node = node_map.GetNode(input);
658 if (IsEnter(*node)) {
659 enter_node = node;
660 }
661 if (IsReallyConstant(*node, feed_nodes)) {
662 constant_init_node = node;
663 }
664 }
665 if (enter_node != nullptr) {
666 if (constant_init_node != nullptr) return Status::OK();
667 for (const auto& input : enter_node->input()) {
668 NodeDef* node = node_map.GetNode(input);
669 if (IsReallyConstant(*node, feed_nodes)) {
670 constant_init_node = node;
671 }
672 }
673 }
674 if (constant_init_node == nullptr) {
675 return Status::OK();
676 }
677
678 // Check if there will be 0 iterations. This will only happen if the condition
679 // evaluates to false with respect to the initialization value.
680 NodeDef* operand_0 =
681 constant_index ? constant_init_node : constant_ctrl_input;
682 NodeDef* operand_1 =
683 constant_index ? constant_ctrl_input : constant_init_node;
684 bool constant_switch_value;
685 TF_RETURN_IF_ERROR(EvaluateBoolOpForConstantOperands(
686 *switch_ctrl_node, *operand_0, *operand_1, cpu_device, resource_mgr,
687 &constant_switch_value));
688 if (constant_switch_value == false) {
689 *has_dead_fanout = true;
690 *dead_fanout = 1;
691 }
692 return Status::OK();
693 }
694
695 } // namespace
696
LoopOptimizer()697 LoopOptimizer::LoopOptimizer()
698 : opt_level_(RewriterConfig::ON),
699 cpu_device_(nullptr),
700 options_(LoopOptimizerOptions::Default(RewriterConfig::ON)) {}
701
LoopOptimizer(RewriterConfig::Toggle opt_level,DeviceBase * cpu_device)702 LoopOptimizer::LoopOptimizer(RewriterConfig::Toggle opt_level,
703 DeviceBase* cpu_device)
704 : opt_level_(opt_level),
705 cpu_device_(cpu_device),
706 options_(LoopOptimizerOptions::Default(RewriterConfig::ON)) {
707 resource_mgr_.reset(new ResourceMgr());
708 }
709
Optimize(Cluster * cluster,const GrapplerItem & item,GraphDef * optimized_graph)710 Status LoopOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
711 GraphDef* optimized_graph) {
712 *optimized_graph = item.graph;
713 // Set up helper data structures.
714 if (options_.enable_loop_invariant_node_motion) {
715 LoopInvariantNodeMotionOptimizer linm_optimizer(optimized_graph);
716 TF_RETURN_IF_ERROR(linm_optimizer.Optimize());
717 }
718 if (options_.enable_stack_push_removal) {
719 TF_RETURN_IF_ERROR(RemoveStackOps(item.NodesToPreserve(), optimized_graph));
720 }
721 if (options_.enable_dead_branch_removal) {
722 // TODO(srjoglekar): Figure out if we can optimize NodeMap creations across
723 // optimizer passes.
724 NodeMap node_map(optimized_graph);
725 absl::flat_hash_set<string> feed_nodes;
726 for (const auto& feed : item.feed) {
727 feed_nodes.insert(NodeName(feed.first));
728 }
729 TF_RETURN_IF_ERROR(RemoveDeadBranches(item.NodesToPreserve(), node_map,
730 feed_nodes, optimized_graph));
731 }
732
733 return Status::OK();
734 }
735
RemoveDeadBranches(const std::unordered_set<string> & nodes_to_preserve,const NodeMap & node_map,const absl::flat_hash_set<string> & feed_nodes,GraphDef * optimized_graph)736 Status LoopOptimizer::RemoveDeadBranches(
737 const std::unordered_set<string>& nodes_to_preserve,
738 const NodeMap& node_map, const absl::flat_hash_set<string>& feed_nodes,
739 GraphDef* optimized_graph) {
740 std::unordered_set<const NodeDef*> dead_nodes;
741 std::unordered_map<NodeDef*, std::set<int>> dead_merge_inputs;
742 // TODO(bsteiner): also rewrite switches as identity. For now we just record
743 // them
744 absl::flat_hash_set<GraphView::OutputPort> identity_switches;
745
746 MutableGraphView view(optimized_graph);
747 for (const NodeDef& node : optimized_graph->node()) {
748 if (!IsSwitch(node)) {
749 continue;
750 }
751 if (nodes_to_preserve.find(node.name()) != nodes_to_preserve.end()) {
752 continue;
753 }
754
755 int dead_fanout;
756 bool has_dead_fanout;
757 TF_RETURN_IF_ERROR(CheckForDeadFanout(view, node, node_map, feed_nodes,
758 cpu_device_, resource_mgr_.get(),
759 &has_dead_fanout, &dead_fanout));
760 if (!has_dead_fanout) {
761 continue;
762 }
763 GraphView::OutputPort dead(&node, dead_fanout);
764 identity_switches.insert(dead);
765
766 SetVector<MutableGraphView::InputPort, absl::Hash<MutableGraphView::Port>>
767 zombie_inputs;
768 for (const MutableGraphView::InputPort& port : view.GetFanout(dead)) {
769 if (dead_nodes.find(port.node) == dead_nodes.end()) {
770 zombie_inputs.PushBack(port);
771 }
772 }
773 // If we encounter a single node that must be preserved in the fanout of the
774 // switch node we need to preserve the entire switch fanout: we therefore
775 // work on a local copy that only gets committed to the master copy once the
776 // whole fanout has been explored.
777 std::unordered_set<const NodeDef*> local_dead_nodes = dead_nodes;
778 std::unordered_map<NodeDef*, std::set<int>> local_dead_merge_inputs =
779 dead_merge_inputs;
780 bool found_node_to_preserve = false;
781 while (!found_node_to_preserve && !zombie_inputs.Empty()) {
782 MutableGraphView::InputPort dead = zombie_inputs.PopBack();
783 if (nodes_to_preserve.find(dead.node->name()) !=
784 nodes_to_preserve.end()) {
785 found_node_to_preserve = true;
786 break;
787 }
788
789 if (local_dead_nodes.find(dead.node) != local_dead_nodes.end()) {
790 continue;
791 }
792
793 if (IsMerge(*dead.node)) {
794 const int num_data_inputs = dead.node->attr().at("N").i();
795 if (num_data_inputs > 2) {
796 // This never happens in practice, so we'll just skip these to
797 // simplify the code for now.
798 found_node_to_preserve = true;
799 break;
800 }
801 MutableGraphView::OutputPort value_index(dead.node, 1);
802 const absl::flat_hash_set<MutableGraphView::InputPort>& index_fanout =
803 view.GetFanout(value_index);
804 if (!index_fanout.empty()) {
805 // The 2nd output (that indicates which input is propagated) is
806 // connected. This never happens in practice, so we'll just skip this
807 // case to simplify the code for now.
808 found_node_to_preserve = true;
809 break;
810 }
811
812 bool fully_dead = false;
813 // Merge node can become real dead only if all data inputs are dead.
814 // Merge always waits for all control edges, but they do not
815 // change the node deadness.
816 if (dead.port_id >= 0) {
817 local_dead_merge_inputs[dead.node].insert(dead.port_id);
818 if (local_dead_merge_inputs[dead.node].size() == num_data_inputs) {
819 fully_dead = true;
820 }
821 } else {
822 // Keep track of all Merge nodes, even if they do not have dead data
823 // inputs. We'll need to cleanup dead control edges for them later.
824 local_dead_merge_inputs.insert({dead.node, {}});
825 }
826 if (fully_dead) {
827 local_dead_merge_inputs.erase(dead.node);
828 local_dead_nodes.insert(dead.node);
829 for (const MutableGraphView::InputPort& port :
830 view.GetFanouts(*dead.node, true)) {
831 zombie_inputs.PushBack(port);
832 }
833 }
834 } else if (dead.node->op() == "ControlTrigger") {
835 // Control trigger have different semantic, so don't touch them
836 found_node_to_preserve = true;
837 break;
838 } else {
839 if (local_dead_nodes.insert(dead.node).second) {
840 for (const MutableGraphView::InputPort& dead_fanout :
841 view.GetFanouts(*dead.node, true)) {
842 zombie_inputs.PushBack(dead_fanout);
843 }
844 }
845 }
846 }
847 if (!found_node_to_preserve) {
848 std::swap(dead_nodes, local_dead_nodes);
849 std::swap(dead_merge_inputs, local_dead_merge_inputs);
850 }
851 }
852
853 std::vector<int> nodes_idx_to_delete;
854 nodes_idx_to_delete.reserve(dead_nodes.size());
855 for (int i = 0; i < optimized_graph->node_size(); ++i) {
856 if (dead_nodes.count(&optimized_graph->node(i)))
857 nodes_idx_to_delete.push_back(i);
858 }
859
860 // Names of the nodes that were removed from the graph.
861 absl::flat_hash_set<absl::string_view> dead_node_names;
862 dead_node_names.reserve(dead_nodes.size());
863 for (const NodeDef* dead_node : dead_nodes)
864 dead_node_names.insert(dead_node->name());
865
866 // Remove dead inputs from Merge nodes that were not pruned from the graph.
867 for (const auto& itr : dead_merge_inputs) {
868 NodeDef* dead_node = itr.first;
869 if (dead_nodes.find(dead_node) != dead_nodes.end()) {
870 // The node has been pruned since all its inputs are dead.
871 continue;
872 }
873 // Remove dead data input.
874 const std::set<int>& dead_inputs = itr.second;
875 for (int index : dead_inputs) {
876 dead_node->mutable_input()->DeleteSubrange(index, 1);
877 }
878 // Turn Merge into Identity only if we deleted data inputs.
879 if (!dead_inputs.empty()) {
880 dead_node->set_op("Identity");
881 dead_node->mutable_attr()->erase("N");
882 }
883 // Remove control inputs from dead nodes.
884 int pos = 0;
885 while (pos < dead_node->input_size()) {
886 TensorId tensor = ParseTensorName(dead_node->input(pos));
887 if (tensor.index() == Graph::kControlSlot &&
888 dead_node_names.contains(tensor.node())) {
889 auto* inputs = dead_node->mutable_input();
890 inputs->SwapElements(pos, dead_node->input_size() - 1);
891 inputs->RemoveLast();
892 } else {
893 ++pos;
894 }
895 }
896 }
897
898 EraseNodesFromGraph(std::move(nodes_idx_to_delete), optimized_graph);
899
900 return Status::OK();
901 }
902
Feedback(Cluster *,const GrapplerItem &,const GraphDef &,double)903 void LoopOptimizer::Feedback(Cluster* /*cluster*/, const GrapplerItem& /*item*/,
904 const GraphDef& /*optimized_graph*/,
905 double /*result*/) {
906 // Nothing to do for LoopOptimizer.
907 }
908
909 } // end namespace grappler
910 } // end namespace tensorflow
911