• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/common_runtime/lower_while_op.h"
17 
18 #include "tensorflow/core/common_runtime/inline_function_utils.h"
19 #include "tensorflow/core/framework/node_def_builder.h"
20 #include "tensorflow/core/framework/types.pb.h"
21 #include "tensorflow/core/graph/graph.h"
22 #include "tensorflow/core/graph/node_builder.h"
23 
24 namespace tensorflow {
25 
26 namespace {
27 
28 using NodeOut = NodeBuilder::NodeOut;
29 
30 constexpr const char* const kLowerAsMultiDeviceFunctionAttr =
31     LowerFunctionalOpsConstants::kLowerAsMultiDeviceFunctionAttr;
32 
33 // Helper to convert a functional While op to its lowered form.
34 //
35 // Example:
36 //
37 // Input graph:
38 //
39 // loop_var -> WhileOp<cond_func, body_func> -> consumer
40 //
41 // Output graph(top to down flow):
42 //
43 //                   loop_var
44 //                      |
45 //                    Enter
46 //                      |
47 //  cond_func ---<--- Merge  ---<--- NextIteration
48 //      |               |                |
49 //      V               V                ^
50 //      |               |                |
51 //  LoopCond  --->--- Switch --->--- body_func
52 //                      |
53 //                     Exit
54 //                      |
55 //                   consumer
56 //
57 // DT_RESOURCE tensors are handled specially:
58 //
59 // resource_loop_var -> Enter[is_constant=True] -> cond_func and body_func
60 //      |
61 //      V
62 //   consumer
63 class LowerWhileHelper {
64  public:
Run(Node * while_op,const NameAttrList & cond_fn,const NameAttrList & body_fn,int parallel_iterations,Graph * graph,bool keep_node_fetchable)65   static Status Run(Node* while_op, const NameAttrList& cond_fn,
66                     const NameAttrList& body_fn, int parallel_iterations,
67                     Graph* graph, bool keep_node_fetchable) {
68     LowerWhileHelper helper(while_op, cond_fn, body_fn, parallel_iterations,
69                             graph, keep_node_fetchable);
70     return helper.RunInternal();
71   }
72 
73  private:
74   // Create a LowerWhileHelper to create the lowering of While op that has cond
75   // and body functions named `cond_fn_name` and `body_fn_name` respectively in
76   // the given graph.
77   LowerWhileHelper(Node* while_op, const NameAttrList& cond_fn,
78                    const NameAttrList& body_fn, int parallel_iterations,
79                    Graph* graph, bool keep_node_fetchable);
80 
81   Status RunInternal();
82 
83   void InitializeInputOutputToLoweredNodeMap();
84 
85   // Creates an Enter node for each `while_op_` input and adds them to
86   // `enter_nodes_`. If the `while_op_` has an incoming control edge from a
87   // `src` node we add a control edge from `src` to each Enter node.
88   Status CreateEnterNodes();
89 
90   // Creates a Merge node for each Enter node and adds to `merge_nodes_`.
91   // Initially now both inputs of a Merge node are the Enter node. Input at
92   // index 1 is later updated to the output of NextIteration node in
93   // `UpdateMergeNodes`.
94   Status CreateMergeNodes();
95 
96   // Creates the call node for cond func and stores in `cond_call_node_`.
97   Status CreateCondFuncCallNode();
98 
99   // Creates a Switch node for each loop var and adds to `switch_nodes_`.
100   // Output at index 1(true) of a Switch node is fed into the loop body.
101   // Output at index 0(false) of a Switch node is fed into the Exit nodes.
102   Status CreateSwitchNodes();
103 
104   // Creates the call node for body func and stores in `body_call_node_`.
105   Status CreateBodyFuncCallNode();
106 
107   // Creates an Exit node for each loop var and adds to `exit_nodes_`. These
108   // are fed into the consumers of the `while_op_`.
109   Status CreateExitNodes();
110 
111   // Creates an NextIteration node for each loop var and adds to
112   // `next_iteration_nodes_`.
113   Status CreateNextIterationNodes();
114 
115   // Updates input at index 1 of each merge node created in `CreateMergeNodes`
116   // to use the output of NextIteration node created in
117   // `CreateNextIterationNodes` instead.
118   Status UpdateMergeNodes();
119 
120   // Updates consumers of the original `while_op_` to instead use the outputs
121   // from the exit nodes in `exit_nodes_`. Also updates any outgoing control
122   // edges to depend on `lowered_while_executed_` instead.
123   Status UpdateConsumers();
124 
125   // Returns unique name containing the name of the While op being rewritten
126   // (name_), infix and a suffix to ensure it is unique within the graph.
127   string NewName(const string& infix);
128 
129   // Returns whether the While op's input/output at `index` is a `DT_RESOURCE`.
130   bool IsResource(int index);
131 
132   // The original While op.
133   Node* while_op_;
134   // The call node for the cond branch.
135   Node* cond_call_node_;
136   // The LoopCond node specifying the loop termination condition.
137   Node* loop_cond_node_;
138   // The call node for the body branch.
139   Node* body_call_node_;
140   // The node with the same name as the original While op:
141   //   (a) IdentityN node with same outputs if 'keep_node_fetchable_ == true'.
142   //   (b) NoOp node with control edge from 'lowered_while_executed_' otherwise.
143   Node* lowered_while_output_;
144   // The NoOp node with control edges from all Exit nodes. This node will be
145   // used as a source of outgoing control edges from lowered While node.
146   Node* lowered_while_executed_;
147   Graph* graph_;
148   // Name of the `while_op_`.
149   string name_;
150   // Max number of parallel_iterations for the while loop.
151   const int parallel_iterations_;
152   bool keep_node_fetchable_;
153 
154   NodeDebugInfo debug_info_;
155   NodeBuilder cond_call_builder_;
156   NodeBuilder body_call_builder_;
157 
158   // `Enter` nodes, one per loop input/output.
159   // Note: `Enter` nodes with type `DT_RESOURCE` have attr `is_constant=True`.
160   std::vector<Node*> enter_nodes_;
161 
162   // Merge/Switch/NextIteration/Exit nodes, one per non-resource loop
163   // input/output.
164   std::vector<Node*> merge_nodes_;
165   std::vector<Node*> switch_nodes_;
166   std::vector<Node*> exit_nodes_;
167   std::vector<Node*> next_iterations_nodes_;
168   // Maps from the loop input/output indices to their corresponding
169   // Merge/Switch/NextIteration/Exit node indices. For inputs/outputs of
170   // `DT_RESOURCE` type there are no Merge/Switch/NextIteration/Exit nodes
171   // in which case the mapping contains -1.
172   std::vector<int> op_input_output_to_lowered_node_;
173 
174   size_t num_loop_inputs_;
175 };
176 
LowerWhileHelper(Node * while_op,const NameAttrList & cond_fn,const NameAttrList & body_fn,int parallel_iterations,Graph * graph,bool keep_node_fetchable)177 LowerWhileHelper::LowerWhileHelper(Node* while_op, const NameAttrList& cond_fn,
178                                    const NameAttrList& body_fn,
179                                    int parallel_iterations, Graph* graph,
180                                    bool keep_node_fetchable)
181     : while_op_(while_op),
182       graph_(graph),
183       name_(while_op->name()),
184       parallel_iterations_(parallel_iterations),
185       keep_node_fetchable_(keep_node_fetchable),
186       debug_info_(*while_op_),
187       cond_call_builder_(NewName("cond"), cond_fn.name(), graph->op_registry(),
188                          &debug_info_),
189       body_call_builder_(NewName("body"), body_fn.name(), graph->op_registry(),
190                          &debug_info_),
191       num_loop_inputs_(while_op_->num_inputs()) {
192   cond_call_builder_.Attr(kLowerAsMultiDeviceFunctionAttr, true);
193   for (const auto& i : cond_fn.attr()) {
194     cond_call_builder_.Attr(i.first, i.second);
195   }
196   body_call_builder_.Attr(kLowerAsMultiDeviceFunctionAttr, true);
197   for (const auto& i : body_fn.attr()) {
198     body_call_builder_.Attr(i.first, i.second);
199   }
200   // We intentionally `resize` instead of `reserve` space in `enter_nodes_`
201   // because we need to set it's elements out of order in `CreateEnterNodes`.
202   enter_nodes_.resize(num_loop_inputs_);
203   merge_nodes_.reserve(num_loop_inputs_);
204   switch_nodes_.reserve(num_loop_inputs_);
205   exit_nodes_.reserve(num_loop_inputs_);
206   next_iterations_nodes_.reserve(num_loop_inputs_);
207   op_input_output_to_lowered_node_.resize(num_loop_inputs_, -1);
208 }
209 
RunInternal()210 Status LowerWhileHelper::RunInternal() {
211   InitializeInputOutputToLoweredNodeMap();
212   TF_RETURN_IF_ERROR(CreateEnterNodes());
213   TF_RETURN_IF_ERROR(CreateMergeNodes());
214   TF_RETURN_IF_ERROR(CreateCondFuncCallNode());
215   TF_RETURN_IF_ERROR(CreateSwitchNodes());
216   TF_RETURN_IF_ERROR(CreateBodyFuncCallNode());
217   TF_RETURN_IF_ERROR(CreateExitNodes());
218   TF_RETURN_IF_ERROR(CreateNextIterationNodes());
219   TF_RETURN_IF_ERROR(UpdateMergeNodes());
220   TF_RETURN_IF_ERROR(UpdateConsumers());
221   return Status::OK();
222 }
223 
InitializeInputOutputToLoweredNodeMap()224 void LowerWhileHelper::InitializeInputOutputToLoweredNodeMap() {
225   int counter = 0;
226   for (int i = 0; i < num_loop_inputs_; i++) {
227     if (!IsResource(i)) {
228       op_input_output_to_lowered_node_[i] = counter++;
229     }
230   }
231 }
232 
CreateEnterNodes()233 Status LowerWhileHelper::CreateEnterNodes() {
234   // Note: `Node::input_edge` runs in  O(num_inputs) so we use
235   // `Node::input_edges` instead so that below loop runs in O(num_inputs) time
236   // and not O(num_inputs^2).
237   std::vector<const Edge*> edges;
238   TF_RETURN_IF_ERROR(while_op_->input_edges(&edges));
239   for (const Edge* edge : edges) {
240     Node* enter_node;
241     NodeBuilder builder =
242         NodeBuilder(NewName("enter"), "Enter", graph_->op_registry(),
243                     &debug_info_)
244             .Input(NodeOut(edge->src(), edge->src_output()))
245             .Attr("frame_name", name_)
246             .Attr("parallel_iterations", parallel_iterations_)
247             .Device(edge->src()->requested_device())
248             .AssignedDevice(edge->src()->assigned_device_name());
249     if (IsResource(edge->dst_input())) {
250       builder.Attr("is_constant", true);
251     }
252     TF_RETURN_IF_ERROR(builder.Finalize(graph_, &enter_node));
253     enter_nodes_[edge->dst_input()] = enter_node;
254   }
255   // Create a NoOp node that takes incoming control inputs of the original While
256   // op as control inputs and use it as a control input for all Enter nodes.
257   std::vector<Node*> control_inputs;
258   for (const Edge* e : while_op_->in_edges()) {
259     if (e->IsControlEdge()) {
260       control_inputs.push_back(e->src());
261     }
262   }
263   if (!control_inputs.empty()) {
264     Node* incoming_control_node;
265     TF_RETURN_IF_ERROR(NodeBuilder(NewName("LoopControlInputs"), "NoOp",
266                                    graph_->op_registry(), &debug_info_)
267                            .ControlInputs(control_inputs)
268                            .Device(while_op_->requested_device())
269                            .Finalize(graph_, &incoming_control_node));
270     for (Node* n : enter_nodes_) {
271       graph_->AddControlEdge(incoming_control_node, n);
272     }
273   }
274   return Status::OK();
275 }
276 
CreateMergeNodes()277 Status LowerWhileHelper::CreateMergeNodes() {
278   for (Node* enter_node : enter_nodes_) {
279     if (enter_node->output_type(0) == DT_RESOURCE) {
280       continue;
281     }
282     Node* merge_node;
283     TF_RETURN_IF_ERROR(
284         NodeBuilder(NewName("merge"), "Merge", graph_->op_registry(),
285                     &debug_info_)
286             .Input({NodeOut(enter_node, 0), NodeOut(enter_node, 0)})
287             .Device(enter_node->requested_device())
288             .AssignedDevice(enter_node->assigned_device_name())
289             .Finalize(graph_, &merge_node));
290     merge_nodes_.emplace_back(merge_node);
291   }
292   return Status::OK();
293 }
294 
CreateCondFuncCallNode()295 Status LowerWhileHelper::CreateCondFuncCallNode() {
296   for (int i = 0; i < num_loop_inputs_; i++) {
297     if (IsResource(i)) {
298       cond_call_builder_.Input(NodeOut(enter_nodes_[i], 0));
299     } else {
300       cond_call_builder_.Input(
301           NodeOut(merge_nodes_[op_input_output_to_lowered_node_[i]], 0));
302     }
303   }
304   cond_call_builder_.Device(while_op_->requested_device());
305   TF_RETURN_IF_ERROR(cond_call_builder_.Finalize(graph_, &cond_call_node_));
306   // Add a control edge to make sure the Const nodes in the cond function
307   // are in the same frame as the rest of the function, otherwise
308   // `BuildControlFlowInfo` throws an error.
309   graph_->AddControlEdge(merge_nodes_[0], cond_call_node_);
310   TF_RETURN_IF_ERROR(NodeBuilder(NewName("LoopCond"), "LoopCond",
311                                  graph_->op_registry(), &debug_info_)
312                          .Input(NodeOut(cond_call_node_, 0))
313                          .Device(while_op_->requested_device())
314                          .Finalize(graph_, &loop_cond_node_));
315   return Status::OK();
316 }
317 
CreateSwitchNodes()318 Status LowerWhileHelper::CreateSwitchNodes() {
319   for (int i = 0; i < num_loop_inputs_; i++) {
320     if (IsResource(i)) {
321       continue;
322     }
323     string op_name;
324     {
325       const Node* input_node;
326       TF_RETURN_IF_ERROR(while_op_->input_node(i, &input_node));
327       op_name = strings::StrCat(input_node->name(), "_switch");
328     }
329     Node* merge_node = merge_nodes_[op_input_output_to_lowered_node_[i]];
330     Node* switch_node;
331     string op_type = "Switch";
332     if (IsRefType(merge_node->output_type(0))) {
333       op_type = "RefSwitch";
334     }
335     TF_RETURN_IF_ERROR(NodeBuilder(NewName(op_name), op_type,
336                                    graph_->op_registry(), &debug_info_)
337                            .Input(NodeOut(merge_node, 0))
338                            .Input(NodeOut(loop_cond_node_, 0))
339                            .Device(merge_node->requested_device())
340                            .AssignedDevice(merge_node->assigned_device_name())
341                            .Finalize(graph_, &switch_node));
342     switch_nodes_.emplace_back(switch_node);
343   }
344   return Status::OK();
345 }
346 
CreateBodyFuncCallNode()347 Status LowerWhileHelper::CreateBodyFuncCallNode() {
348   for (int i = 0; i < num_loop_inputs_; i++) {
349     if (IsResource(i)) {
350       body_call_builder_.Input(NodeOut(enter_nodes_[i], 0));
351     } else {
352       body_call_builder_.Input(
353           NodeOut(switch_nodes_[op_input_output_to_lowered_node_[i]], 1));
354     }
355   }
356   body_call_builder_.Device(while_op_->requested_device());
357   TF_RETURN_IF_ERROR(body_call_builder_.Finalize(graph_, &body_call_node_));
358   // Add a control edge to make sure the Const nodes in the body function
359   // are in the same frame as the rest of the function, otherwise
360   // `BuildControlFlowInfo` throws an error.
361   // TODO(srbs): The choice of input at index 0 seems arbitrary(is it?) however
362   // this is how tf.while_loop does it. Can this affect performance if the 0th
363   // node is not the first one to be ready? Can we speed that case up using some
364   // sort of multi-input Merge?
365   Node* body_control_node_;
366   string op_type = "Identity";
367   if (IsRefType(switch_nodes_[0]->output_type(1))) {
368     op_type = "RefIdentity";
369   }
370   TF_RETURN_IF_ERROR(NodeBuilder(NewName("loop_body_control"), op_type,
371                                  graph_->op_registry(), &debug_info_)
372                          .Input(NodeOut(switch_nodes_[0], 1))
373                          .Device(while_op_->requested_device())
374                          .Finalize(graph_, &body_control_node_));
375   graph_->AddControlEdge(body_control_node_, body_call_node_);
376   return Status::OK();
377 }
378 
CreateExitNodes()379 Status LowerWhileHelper::CreateExitNodes() {
380   std::vector<NodeOut> outputs;
381   outputs.reserve(num_loop_inputs_);
382   for (int i = 0; i < num_loop_inputs_; i++) {
383     if (IsResource(i)) {
384       // Note(srbs): A resource output of this While should never be used but we
385       // need this for the IdentityN node below.
386       OutputTensor resource_tensor;
387       TF_RETURN_IF_ERROR(enter_nodes_[i]->input_tensor(0, &resource_tensor));
388       outputs.emplace_back(resource_tensor);
389     } else {
390       Node* exit_node;
391       TF_RETURN_IF_ERROR(
392           NodeBuilder(NewName("exit"), "Exit", graph_->op_registry(),
393                       &debug_info_)
394               .Input(NodeOut(switch_nodes_[op_input_output_to_lowered_node_[i]],
395                              0))
396               .Device(switch_nodes_[op_input_output_to_lowered_node_[i]]
397                           ->requested_device())
398               .AssignedDevice(switch_nodes_[op_input_output_to_lowered_node_[i]]
399                                   ->assigned_device_name())
400               .Finalize(graph_, &exit_node));
401       exit_nodes_.emplace_back(exit_node);
402       outputs.emplace_back(NodeOut(exit_node, 0));
403     }
404   }
405 
406   // We split data and control outputs of lowered while op, because otherwise
407   // after lowering of multi-device loop body we might end up with DT_RESOURCE
408   // inputs from multiple devices coming into IdentityN.
409 
410   // Add a NoOp node that has control edges from all Exit nodes. This node is
411   // used for rewriting control edges with the original while op as src.
412   TF_RETURN_IF_ERROR(NodeBuilder(NewName("LoopExecuted"), "NoOp",
413                                  OpRegistry::Global(), &debug_info_)
414                          .ControlInputs(exit_nodes_)
415                          .Device(while_op_->requested_device())
416                          .Finalize(graph_, &lowered_while_executed_));
417 
418   if (keep_node_fetchable_) {
419     // Add an IdentityN node that has the same outputs and same name as the
420     // original functional While op. This is used for fetching the output of the
421     // While node by name in calls to sess.run.
422     TF_RETURN_IF_ERROR(
423         NodeBuilder(name_, "IdentityN", OpRegistry::Global(), &debug_info_)
424             .Input(outputs)
425             .Device(while_op_->requested_device())
426             .Finalize(graph_, &lowered_while_output_));
427   } else {
428     // Even if we don't plan to fetch tensors from the lowered While op, we must
429     // keep it a valid source of control edges, because it might be a part of
430     // function control output set.
431     TF_RETURN_IF_ERROR(
432         NodeBuilder(name_, "NoOp", OpRegistry::Global(), &debug_info_)
433             .ControlInput(lowered_while_executed_)
434             .Device(while_op_->requested_device())
435             .Finalize(graph_, &lowered_while_output_));
436   }
437 
438   return Status::OK();
439 }
440 
CreateNextIterationNodes()441 Status LowerWhileHelper::CreateNextIterationNodes() {
442   for (int i = 0; i < num_loop_inputs_; i++) {
443     Node* next_iteration;
444     if (IsResource(i)) {
445       continue;
446     }
447     Node* merge_node = merge_nodes_[op_input_output_to_lowered_node_[i]];
448     TF_RETURN_IF_ERROR(NodeBuilder(NewName("next_iteration"), "NextIteration",
449                                    graph_->op_registry(), &debug_info_)
450                            .Input(NodeOut(body_call_node_, i))
451                            .ControlInput(body_call_node_)
452                            .Device(merge_node->requested_device())
453                            .AssignedDevice(merge_node->assigned_device_name())
454                            .Finalize(graph_, &next_iteration));
455     next_iterations_nodes_.emplace_back(next_iteration);
456   }
457   return Status::OK();
458 }
459 
UpdateMergeNodes()460 Status LowerWhileHelper::UpdateMergeNodes() {
461   for (int i = 0; i < merge_nodes_.size(); i++) {
462     TF_RETURN_IF_ERROR(
463         graph_->UpdateEdge(next_iterations_nodes_[i], 0, merge_nodes_[i], 1));
464   }
465   return Status::OK();
466 }
467 
UpdateConsumers()468 Status LowerWhileHelper::UpdateConsumers() {
469   for (const Edge* e : while_op_->out_edges()) {
470     if (e->IsControlEdge()) {
471       graph_->AddControlEdge(lowered_while_executed_, e->dst());
472     } else {
473       if (IsResource(e->src_output())) {
474         OutputTensor resource;
475         TF_RETURN_IF_ERROR(
476             enter_nodes_[e->src_output()]->input_tensor(0, &resource));
477         graph_->AddEdge(resource.node, resource.index, e->dst(),
478                         e->dst_input());
479       } else {
480         // Feed the outputs directly from the exit nodes so that downstream ops
481         // can start before all the outputs have been computed.
482         int exit_node_index = op_input_output_to_lowered_node_[e->src_output()];
483         if (exit_node_index < 0) {
484           return errors::Internal(
485               "Expecting an Exit node for a Resource tensor.");
486         }
487         graph_->AddEdge(exit_nodes_[exit_node_index], 0, e->dst(),
488                         e->dst_input());
489       }
490     }
491   }
492   return Status::OK();
493 }
494 
NewName(const string & infix)495 string LowerWhileHelper::NewName(const string& infix) {
496   return graph_->NewName(strings::StrCat(name_, "/", infix));
497 }
498 
IsResource(int index)499 bool LowerWhileHelper::IsResource(int index) {
500   return while_op_->input_type(index) == DT_RESOURCE;
501 }
502 
503 }  // namespace
504 
RewriteWhileNode(Node * n,Graph * g,bool keep_node_fetchable)505 Status RewriteWhileNode(Node* n, Graph* g,
506                         bool keep_node_fetchable) {
507   VLOG(2) << "Lower While node (keep_node_fetchable=" << keep_node_fetchable
508           << "): " << SummarizeNode(*n);
509 
510   const AttrValue* cond_attr = n->attrs().Find("cond");
511   if (cond_attr == nullptr) {
512     return errors::InvalidArgument("While cond function missing");
513   }
514   const AttrValue* body_attr = n->attrs().Find("body");
515   if (body_attr == nullptr) {
516     return errors::InvalidArgument("While body function missing");
517   }
518   const AttrValue* parallel_iterations_attr =
519       n->attrs().Find("parallel_iterations");
520   if (parallel_iterations_attr == nullptr) {
521     return errors::InvalidArgument("parallel_iterations attr missing");
522   }
523 
524   TF_RETURN_IF_ERROR(LowerWhileHelper::Run(
525       n, cond_attr->func(), body_attr->func(), parallel_iterations_attr->i(), g,
526       keep_node_fetchable));
527   g->RemoveNode(n);
528 
529   return Status::OK();
530 }
531 
532 }  // namespace tensorflow
533