1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/common_runtime/lower_while_op.h"
17
18 #include "tensorflow/core/common_runtime/inline_function_utils.h"
19 #include "tensorflow/core/framework/node_def_builder.h"
20 #include "tensorflow/core/framework/types.pb.h"
21 #include "tensorflow/core/graph/graph.h"
22 #include "tensorflow/core/graph/node_builder.h"
23
24 namespace tensorflow {
25
26 namespace {
27
28 using NodeOut = NodeBuilder::NodeOut;
29
30 constexpr const char* const kLowerAsMultiDeviceFunctionAttr =
31 LowerFunctionalOpsConstants::kLowerAsMultiDeviceFunctionAttr;
32
33 // Helper to convert a functional While op to its lowered form.
34 //
35 // Example:
36 //
37 // Input graph:
38 //
39 // loop_var -> WhileOp<cond_func, body_func> -> consumer
40 //
41 // Output graph(top to down flow):
42 //
43 // loop_var
44 // |
45 // Enter
46 // |
47 // cond_func ---<--- Merge ---<--- NextIteration
48 // | | |
49 // V V ^
50 // | | |
51 // LoopCond --->--- Switch --->--- body_func
52 // |
53 // Exit
54 // |
55 // consumer
56 //
57 // DT_RESOURCE tensors are handled specially:
58 //
59 // resource_loop_var -> Enter[is_constant=True] -> cond_func and body_func
60 // |
61 // V
62 // consumer
63 class LowerWhileHelper {
64 public:
Run(Node * while_op,const NameAttrList & cond_fn,const NameAttrList & body_fn,int parallel_iterations,Graph * graph,bool keep_node_fetchable)65 static Status Run(Node* while_op, const NameAttrList& cond_fn,
66 const NameAttrList& body_fn, int parallel_iterations,
67 Graph* graph, bool keep_node_fetchable) {
68 LowerWhileHelper helper(while_op, cond_fn, body_fn, parallel_iterations,
69 graph, keep_node_fetchable);
70 return helper.RunInternal();
71 }
72
73 private:
74 // Create a LowerWhileHelper to create the lowering of While op that has cond
75 // and body functions named `cond_fn_name` and `body_fn_name` respectively in
76 // the given graph.
77 LowerWhileHelper(Node* while_op, const NameAttrList& cond_fn,
78 const NameAttrList& body_fn, int parallel_iterations,
79 Graph* graph, bool keep_node_fetchable);
80
81 Status RunInternal();
82
83 void InitializeInputOutputToLoweredNodeMap();
84
85 // Creates an Enter node for each `while_op_` input and adds them to
86 // `enter_nodes_`. If the `while_op_` has an incoming control edge from a
87 // `src` node we add a control edge from `src` to each Enter node.
88 Status CreateEnterNodes();
89
90 // Creates a Merge node for each Enter node and adds to `merge_nodes_`.
91 // Initially now both inputs of a Merge node are the Enter node. Input at
92 // index 1 is later updated to the output of NextIteration node in
93 // `UpdateMergeNodes`.
94 Status CreateMergeNodes();
95
96 // Creates the call node for cond func and stores in `cond_call_node_`.
97 Status CreateCondFuncCallNode();
98
99 // Creates a Switch node for each loop var and adds to `switch_nodes_`.
100 // Output at index 1(true) of a Switch node is fed into the loop body.
101 // Output at index 0(false) of a Switch node is fed into the Exit nodes.
102 Status CreateSwitchNodes();
103
104 // Creates the call node for body func and stores in `body_call_node_`.
105 Status CreateBodyFuncCallNode();
106
107 // Creates an Exit node for each loop var and adds to `exit_nodes_`. These
108 // are fed into the consumers of the `while_op_`.
109 Status CreateExitNodes();
110
111 // Creates an NextIteration node for each loop var and adds to
112 // `next_iteration_nodes_`.
113 Status CreateNextIterationNodes();
114
115 // Updates input at index 1 of each merge node created in `CreateMergeNodes`
116 // to use the output of NextIteration node created in
117 // `CreateNextIterationNodes` instead.
118 Status UpdateMergeNodes();
119
120 // Updates consumers of the original `while_op_` to instead use the outputs
121 // from the exit nodes in `exit_nodes_`. Also updates any outgoing control
122 // edges to depend on `lowered_while_executed_` instead.
123 Status UpdateConsumers();
124
125 // Returns unique name containing the name of the While op being rewritten
126 // (name_), infix and a suffix to ensure it is unique within the graph.
127 string NewName(const string& infix);
128
129 // Returns whether the While op's input/output at `index` is a `DT_RESOURCE`.
130 bool IsResource(int index);
131
132 // The original While op.
133 Node* while_op_;
134 // The call node for the cond branch.
135 Node* cond_call_node_;
136 // The LoopCond node specifying the loop termination condition.
137 Node* loop_cond_node_;
138 // The call node for the body branch.
139 Node* body_call_node_;
140 // The node with the same name as the original While op:
141 // (a) IdentityN node with same outputs if 'keep_node_fetchable_ == true'.
142 // (b) NoOp node with control edge from 'lowered_while_executed_' otherwise.
143 Node* lowered_while_output_;
144 // The NoOp node with control edges from all Exit nodes. This node will be
145 // used as a source of outgoing control edges from lowered While node.
146 Node* lowered_while_executed_;
147 Graph* graph_;
148 // Name of the `while_op_`.
149 string name_;
150 // Max number of parallel_iterations for the while loop.
151 const int parallel_iterations_;
152 bool keep_node_fetchable_;
153
154 NodeDebugInfo debug_info_;
155 NodeBuilder cond_call_builder_;
156 NodeBuilder body_call_builder_;
157
158 // `Enter` nodes, one per loop input/output.
159 // Note: `Enter` nodes with type `DT_RESOURCE` have attr `is_constant=True`.
160 std::vector<Node*> enter_nodes_;
161
162 // Merge/Switch/NextIteration/Exit nodes, one per non-resource loop
163 // input/output.
164 std::vector<Node*> merge_nodes_;
165 std::vector<Node*> switch_nodes_;
166 std::vector<Node*> exit_nodes_;
167 std::vector<Node*> next_iterations_nodes_;
168 // Maps from the loop input/output indices to their corresponding
169 // Merge/Switch/NextIteration/Exit node indices. For inputs/outputs of
170 // `DT_RESOURCE` type there are no Merge/Switch/NextIteration/Exit nodes
171 // in which case the mapping contains -1.
172 std::vector<int> op_input_output_to_lowered_node_;
173
174 size_t num_loop_inputs_;
175 };
176
LowerWhileHelper(Node * while_op,const NameAttrList & cond_fn,const NameAttrList & body_fn,int parallel_iterations,Graph * graph,bool keep_node_fetchable)177 LowerWhileHelper::LowerWhileHelper(Node* while_op, const NameAttrList& cond_fn,
178 const NameAttrList& body_fn,
179 int parallel_iterations, Graph* graph,
180 bool keep_node_fetchable)
181 : while_op_(while_op),
182 graph_(graph),
183 name_(while_op->name()),
184 parallel_iterations_(parallel_iterations),
185 keep_node_fetchable_(keep_node_fetchable),
186 debug_info_(*while_op_),
187 cond_call_builder_(NewName("cond"), cond_fn.name(), graph->op_registry(),
188 &debug_info_),
189 body_call_builder_(NewName("body"), body_fn.name(), graph->op_registry(),
190 &debug_info_),
191 num_loop_inputs_(while_op_->num_inputs()) {
192 cond_call_builder_.Attr(kLowerAsMultiDeviceFunctionAttr, true);
193 for (const auto& i : cond_fn.attr()) {
194 cond_call_builder_.Attr(i.first, i.second);
195 }
196 body_call_builder_.Attr(kLowerAsMultiDeviceFunctionAttr, true);
197 for (const auto& i : body_fn.attr()) {
198 body_call_builder_.Attr(i.first, i.second);
199 }
200 // We intentionally `resize` instead of `reserve` space in `enter_nodes_`
201 // because we need to set it's elements out of order in `CreateEnterNodes`.
202 enter_nodes_.resize(num_loop_inputs_);
203 merge_nodes_.reserve(num_loop_inputs_);
204 switch_nodes_.reserve(num_loop_inputs_);
205 exit_nodes_.reserve(num_loop_inputs_);
206 next_iterations_nodes_.reserve(num_loop_inputs_);
207 op_input_output_to_lowered_node_.resize(num_loop_inputs_, -1);
208 }
209
RunInternal()210 Status LowerWhileHelper::RunInternal() {
211 InitializeInputOutputToLoweredNodeMap();
212 TF_RETURN_IF_ERROR(CreateEnterNodes());
213 TF_RETURN_IF_ERROR(CreateMergeNodes());
214 TF_RETURN_IF_ERROR(CreateCondFuncCallNode());
215 TF_RETURN_IF_ERROR(CreateSwitchNodes());
216 TF_RETURN_IF_ERROR(CreateBodyFuncCallNode());
217 TF_RETURN_IF_ERROR(CreateExitNodes());
218 TF_RETURN_IF_ERROR(CreateNextIterationNodes());
219 TF_RETURN_IF_ERROR(UpdateMergeNodes());
220 TF_RETURN_IF_ERROR(UpdateConsumers());
221 return Status::OK();
222 }
223
InitializeInputOutputToLoweredNodeMap()224 void LowerWhileHelper::InitializeInputOutputToLoweredNodeMap() {
225 int counter = 0;
226 for (int i = 0; i < num_loop_inputs_; i++) {
227 if (!IsResource(i)) {
228 op_input_output_to_lowered_node_[i] = counter++;
229 }
230 }
231 }
232
CreateEnterNodes()233 Status LowerWhileHelper::CreateEnterNodes() {
234 // Note: `Node::input_edge` runs in O(num_inputs) so we use
235 // `Node::input_edges` instead so that below loop runs in O(num_inputs) time
236 // and not O(num_inputs^2).
237 std::vector<const Edge*> edges;
238 TF_RETURN_IF_ERROR(while_op_->input_edges(&edges));
239 for (const Edge* edge : edges) {
240 Node* enter_node;
241 NodeBuilder builder =
242 NodeBuilder(NewName("enter"), "Enter", graph_->op_registry(),
243 &debug_info_)
244 .Input(NodeOut(edge->src(), edge->src_output()))
245 .Attr("frame_name", name_)
246 .Attr("parallel_iterations", parallel_iterations_)
247 .Device(edge->src()->requested_device())
248 .AssignedDevice(edge->src()->assigned_device_name());
249 if (IsResource(edge->dst_input())) {
250 builder.Attr("is_constant", true);
251 }
252 TF_RETURN_IF_ERROR(builder.Finalize(graph_, &enter_node));
253 enter_nodes_[edge->dst_input()] = enter_node;
254 }
255 // Create a NoOp node that takes incoming control inputs of the original While
256 // op as control inputs and use it as a control input for all Enter nodes.
257 std::vector<Node*> control_inputs;
258 for (const Edge* e : while_op_->in_edges()) {
259 if (e->IsControlEdge()) {
260 control_inputs.push_back(e->src());
261 }
262 }
263 if (!control_inputs.empty()) {
264 Node* incoming_control_node;
265 TF_RETURN_IF_ERROR(NodeBuilder(NewName("LoopControlInputs"), "NoOp",
266 graph_->op_registry(), &debug_info_)
267 .ControlInputs(control_inputs)
268 .Device(while_op_->requested_device())
269 .Finalize(graph_, &incoming_control_node));
270 for (Node* n : enter_nodes_) {
271 graph_->AddControlEdge(incoming_control_node, n);
272 }
273 }
274 return Status::OK();
275 }
276
CreateMergeNodes()277 Status LowerWhileHelper::CreateMergeNodes() {
278 for (Node* enter_node : enter_nodes_) {
279 if (enter_node->output_type(0) == DT_RESOURCE) {
280 continue;
281 }
282 Node* merge_node;
283 TF_RETURN_IF_ERROR(
284 NodeBuilder(NewName("merge"), "Merge", graph_->op_registry(),
285 &debug_info_)
286 .Input({NodeOut(enter_node, 0), NodeOut(enter_node, 0)})
287 .Device(enter_node->requested_device())
288 .AssignedDevice(enter_node->assigned_device_name())
289 .Finalize(graph_, &merge_node));
290 merge_nodes_.emplace_back(merge_node);
291 }
292 return Status::OK();
293 }
294
CreateCondFuncCallNode()295 Status LowerWhileHelper::CreateCondFuncCallNode() {
296 for (int i = 0; i < num_loop_inputs_; i++) {
297 if (IsResource(i)) {
298 cond_call_builder_.Input(NodeOut(enter_nodes_[i], 0));
299 } else {
300 cond_call_builder_.Input(
301 NodeOut(merge_nodes_[op_input_output_to_lowered_node_[i]], 0));
302 }
303 }
304 cond_call_builder_.Device(while_op_->requested_device());
305 TF_RETURN_IF_ERROR(cond_call_builder_.Finalize(graph_, &cond_call_node_));
306 // Add a control edge to make sure the Const nodes in the cond function
307 // are in the same frame as the rest of the function, otherwise
308 // `BuildControlFlowInfo` throws an error.
309 graph_->AddControlEdge(merge_nodes_[0], cond_call_node_);
310 TF_RETURN_IF_ERROR(NodeBuilder(NewName("LoopCond"), "LoopCond",
311 graph_->op_registry(), &debug_info_)
312 .Input(NodeOut(cond_call_node_, 0))
313 .Device(while_op_->requested_device())
314 .Finalize(graph_, &loop_cond_node_));
315 return Status::OK();
316 }
317
CreateSwitchNodes()318 Status LowerWhileHelper::CreateSwitchNodes() {
319 for (int i = 0; i < num_loop_inputs_; i++) {
320 if (IsResource(i)) {
321 continue;
322 }
323 string op_name;
324 {
325 const Node* input_node;
326 TF_RETURN_IF_ERROR(while_op_->input_node(i, &input_node));
327 op_name = strings::StrCat(input_node->name(), "_switch");
328 }
329 Node* merge_node = merge_nodes_[op_input_output_to_lowered_node_[i]];
330 Node* switch_node;
331 string op_type = "Switch";
332 if (IsRefType(merge_node->output_type(0))) {
333 op_type = "RefSwitch";
334 }
335 TF_RETURN_IF_ERROR(NodeBuilder(NewName(op_name), op_type,
336 graph_->op_registry(), &debug_info_)
337 .Input(NodeOut(merge_node, 0))
338 .Input(NodeOut(loop_cond_node_, 0))
339 .Device(merge_node->requested_device())
340 .AssignedDevice(merge_node->assigned_device_name())
341 .Finalize(graph_, &switch_node));
342 switch_nodes_.emplace_back(switch_node);
343 }
344 return Status::OK();
345 }
346
CreateBodyFuncCallNode()347 Status LowerWhileHelper::CreateBodyFuncCallNode() {
348 for (int i = 0; i < num_loop_inputs_; i++) {
349 if (IsResource(i)) {
350 body_call_builder_.Input(NodeOut(enter_nodes_[i], 0));
351 } else {
352 body_call_builder_.Input(
353 NodeOut(switch_nodes_[op_input_output_to_lowered_node_[i]], 1));
354 }
355 }
356 body_call_builder_.Device(while_op_->requested_device());
357 TF_RETURN_IF_ERROR(body_call_builder_.Finalize(graph_, &body_call_node_));
358 // Add a control edge to make sure the Const nodes in the body function
359 // are in the same frame as the rest of the function, otherwise
360 // `BuildControlFlowInfo` throws an error.
361 // TODO(srbs): The choice of input at index 0 seems arbitrary(is it?) however
362 // this is how tf.while_loop does it. Can this affect performance if the 0th
363 // node is not the first one to be ready? Can we speed that case up using some
364 // sort of multi-input Merge?
365 Node* body_control_node_;
366 string op_type = "Identity";
367 if (IsRefType(switch_nodes_[0]->output_type(1))) {
368 op_type = "RefIdentity";
369 }
370 TF_RETURN_IF_ERROR(NodeBuilder(NewName("loop_body_control"), op_type,
371 graph_->op_registry(), &debug_info_)
372 .Input(NodeOut(switch_nodes_[0], 1))
373 .Device(while_op_->requested_device())
374 .Finalize(graph_, &body_control_node_));
375 graph_->AddControlEdge(body_control_node_, body_call_node_);
376 return Status::OK();
377 }
378
CreateExitNodes()379 Status LowerWhileHelper::CreateExitNodes() {
380 std::vector<NodeOut> outputs;
381 outputs.reserve(num_loop_inputs_);
382 for (int i = 0; i < num_loop_inputs_; i++) {
383 if (IsResource(i)) {
384 // Note(srbs): A resource output of this While should never be used but we
385 // need this for the IdentityN node below.
386 OutputTensor resource_tensor;
387 TF_RETURN_IF_ERROR(enter_nodes_[i]->input_tensor(0, &resource_tensor));
388 outputs.emplace_back(resource_tensor);
389 } else {
390 Node* exit_node;
391 TF_RETURN_IF_ERROR(
392 NodeBuilder(NewName("exit"), "Exit", graph_->op_registry(),
393 &debug_info_)
394 .Input(NodeOut(switch_nodes_[op_input_output_to_lowered_node_[i]],
395 0))
396 .Device(switch_nodes_[op_input_output_to_lowered_node_[i]]
397 ->requested_device())
398 .AssignedDevice(switch_nodes_[op_input_output_to_lowered_node_[i]]
399 ->assigned_device_name())
400 .Finalize(graph_, &exit_node));
401 exit_nodes_.emplace_back(exit_node);
402 outputs.emplace_back(NodeOut(exit_node, 0));
403 }
404 }
405
406 // We split data and control outputs of lowered while op, because otherwise
407 // after lowering of multi-device loop body we might end up with DT_RESOURCE
408 // inputs from multiple devices coming into IdentityN.
409
410 // Add a NoOp node that has control edges from all Exit nodes. This node is
411 // used for rewriting control edges with the original while op as src.
412 TF_RETURN_IF_ERROR(NodeBuilder(NewName("LoopExecuted"), "NoOp",
413 OpRegistry::Global(), &debug_info_)
414 .ControlInputs(exit_nodes_)
415 .Device(while_op_->requested_device())
416 .Finalize(graph_, &lowered_while_executed_));
417
418 if (keep_node_fetchable_) {
419 // Add an IdentityN node that has the same outputs and same name as the
420 // original functional While op. This is used for fetching the output of the
421 // While node by name in calls to sess.run.
422 TF_RETURN_IF_ERROR(
423 NodeBuilder(name_, "IdentityN", OpRegistry::Global(), &debug_info_)
424 .Input(outputs)
425 .Device(while_op_->requested_device())
426 .Finalize(graph_, &lowered_while_output_));
427 } else {
428 // Even if we don't plan to fetch tensors from the lowered While op, we must
429 // keep it a valid source of control edges, because it might be a part of
430 // function control output set.
431 TF_RETURN_IF_ERROR(
432 NodeBuilder(name_, "NoOp", OpRegistry::Global(), &debug_info_)
433 .ControlInput(lowered_while_executed_)
434 .Device(while_op_->requested_device())
435 .Finalize(graph_, &lowered_while_output_));
436 }
437
438 return Status::OK();
439 }
440
CreateNextIterationNodes()441 Status LowerWhileHelper::CreateNextIterationNodes() {
442 for (int i = 0; i < num_loop_inputs_; i++) {
443 Node* next_iteration;
444 if (IsResource(i)) {
445 continue;
446 }
447 Node* merge_node = merge_nodes_[op_input_output_to_lowered_node_[i]];
448 TF_RETURN_IF_ERROR(NodeBuilder(NewName("next_iteration"), "NextIteration",
449 graph_->op_registry(), &debug_info_)
450 .Input(NodeOut(body_call_node_, i))
451 .ControlInput(body_call_node_)
452 .Device(merge_node->requested_device())
453 .AssignedDevice(merge_node->assigned_device_name())
454 .Finalize(graph_, &next_iteration));
455 next_iterations_nodes_.emplace_back(next_iteration);
456 }
457 return Status::OK();
458 }
459
UpdateMergeNodes()460 Status LowerWhileHelper::UpdateMergeNodes() {
461 for (int i = 0; i < merge_nodes_.size(); i++) {
462 TF_RETURN_IF_ERROR(
463 graph_->UpdateEdge(next_iterations_nodes_[i], 0, merge_nodes_[i], 1));
464 }
465 return Status::OK();
466 }
467
UpdateConsumers()468 Status LowerWhileHelper::UpdateConsumers() {
469 for (const Edge* e : while_op_->out_edges()) {
470 if (e->IsControlEdge()) {
471 graph_->AddControlEdge(lowered_while_executed_, e->dst());
472 } else {
473 if (IsResource(e->src_output())) {
474 OutputTensor resource;
475 TF_RETURN_IF_ERROR(
476 enter_nodes_[e->src_output()]->input_tensor(0, &resource));
477 graph_->AddEdge(resource.node, resource.index, e->dst(),
478 e->dst_input());
479 } else {
480 // Feed the outputs directly from the exit nodes so that downstream ops
481 // can start before all the outputs have been computed.
482 int exit_node_index = op_input_output_to_lowered_node_[e->src_output()];
483 if (exit_node_index < 0) {
484 return errors::Internal(
485 "Expecting an Exit node for a Resource tensor.");
486 }
487 graph_->AddEdge(exit_nodes_[exit_node_index], 0, e->dst(),
488 e->dst_input());
489 }
490 }
491 }
492 return Status::OK();
493 }
494
NewName(const string & infix)495 string LowerWhileHelper::NewName(const string& infix) {
496 return graph_->NewName(strings::StrCat(name_, "/", infix));
497 }
498
IsResource(int index)499 bool LowerWhileHelper::IsResource(int index) {
500 return while_op_->input_type(index) == DT_RESOURCE;
501 }
502
503 } // namespace
504
RewriteWhileNode(Node * n,Graph * g,bool keep_node_fetchable)505 Status RewriteWhileNode(Node* n, Graph* g,
506 bool keep_node_fetchable) {
507 VLOG(2) << "Lower While node (keep_node_fetchable=" << keep_node_fetchable
508 << "): " << SummarizeNode(*n);
509
510 const AttrValue* cond_attr = n->attrs().Find("cond");
511 if (cond_attr == nullptr) {
512 return errors::InvalidArgument("While cond function missing");
513 }
514 const AttrValue* body_attr = n->attrs().Find("body");
515 if (body_attr == nullptr) {
516 return errors::InvalidArgument("While body function missing");
517 }
518 const AttrValue* parallel_iterations_attr =
519 n->attrs().Find("parallel_iterations");
520 if (parallel_iterations_attr == nullptr) {
521 return errors::InvalidArgument("parallel_iterations attr missing");
522 }
523
524 TF_RETURN_IF_ERROR(LowerWhileHelper::Run(
525 n, cond_attr->func(), body_attr->func(), parallel_iterations_attr->i(), g,
526 keep_node_fetchable));
527 g->RemoveNode(n);
528
529 return Status::OK();
530 }
531
532 } // namespace tensorflow
533