1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/tf2xla/functionalize_while.h"
17
18 #include <algorithm>
19 #include <deque>
20 #include <stack>
21 #include <unordered_set>
22 #include <vector>
23
24 #include "absl/memory/memory.h"
25 #include "absl/strings/match.h"
26 #include "absl/types/optional.h"
27 #include "tensorflow/compiler/tf2xla/frontend_attributes_util.h"
28 #include "tensorflow/compiler/tf2xla/functionalize_cond.h"
29 #include "tensorflow/compiler/tf2xla/tf2xla_util.h"
30 #include "tensorflow/compiler/xla/status_macros.h"
31 #include "tensorflow/compiler/xla/union_find.h"
32 #include "tensorflow/core/common_runtime/function.h"
33 #include "tensorflow/core/framework/graph_to_functiondef.h"
34 #include "tensorflow/core/framework/node_def_builder.h"
35 #include "tensorflow/core/graph/algorithm.h"
36 #include "tensorflow/core/graph/control_flow.h"
37 #include "tensorflow/core/graph/node_builder.h"
38 #include "tensorflow/core/lib/strings/strcat.h"
39 #include "tensorflow/core/util/dump_graph.h"
40
41 namespace tensorflow {
42 namespace {
43
44 using xla::StatusOr;
45
46 // Copies a subgraph from `graph` to `output` by performing a reverse DFS
47 // starting at nodes in vector `stack`.
48 // `node_map` is a vector indexed by source node ID to dest nodes.
49 // Does not traverse into nodes in `node_map`, so by adding nodes to `node_map`
50 // before the traversal clients can cut the graph. If a frame is provided (frame
51 // != nullptr), then this functions will return an error if the
52 // traversal leaves 'frame'; the client must add enough nodes to `node_map` to
53 // cut the graph and prevent the traversal from escaping.
54 //
55 // `squash_src_outputs` contains a bool for each source node ID. If true, then
56 // the source output on that node will be replaced by zero when copied. This is
57 // used when replacing a Switch node with an _Arg node. The output we are
58 // taking from the Switch node was not necessarily the first output, but _Arg
59 // nodes only have one output. By adding the Switch node to `squash_src_outputs`
60 // we rewrite the src_output of the corresponding edge to be 0.
CopySubgraph(const Graph & graph,const WhileLoopFrame * frame,std::vector<Node * > stack,const std::vector<bool> & squash_src_outputs,std::vector<Node * > * node_map,Graph * output)61 Status CopySubgraph(const Graph& graph, const WhileLoopFrame* frame,
62 std::vector<Node*> stack,
63 const std::vector<bool>& squash_src_outputs,
64 std::vector<Node*>* node_map, Graph* output) {
65 VLOG(3) << "Stack: " << NodesToString(stack);
66 std::vector<bool> visited(graph.num_node_ids(), false);
67 while (!stack.empty()) {
68 Node* n = stack.back();
69 stack.pop_back();
70
71 VLOG(5) << "Copying node " << n->name();
72
73 if (visited[n->id()]) continue;
74 visited[n->id()] = true;
75
76 // Sort "n->in_edges()" to make sure nodes are copied in a deterministic
77 // order.
78 std::vector<const Edge*> sorted_edges(n->in_edges().begin(),
79 n->in_edges().end());
80 std::sort(sorted_edges.begin(), sorted_edges.end(),
81 [](const Edge* a, const Edge* b) {
82 int a_src_output = a->src_output(),
83 b_src_output = b->src_output();
84 StringPiece a_name(a->src()->name()), b_name(b->src()->name());
85 return std::tie(a_src_output, a_name) <
86 std::tie(b_src_output, b_name);
87 });
88 for (const Edge* e : sorted_edges) {
89 Node* src = e->src();
90 if (frame != nullptr && frame->nodes.find(src) == frame->nodes.end()) {
91 // We traversed out of the loop frame, without encountering a cut node.
92 return errors::Internal("Graph traversal of loop frame ", frame->name,
93 " escaped frame at ", src->name(),
94 " without encountering an argument node.");
95 }
96 if ((*node_map)[src->id()] == nullptr) {
97 (*node_map)[src->id()] = output->CopyNode(src);
98 stack.push_back(src);
99 }
100 Node* src_copy = (*node_map)[e->src()->id()];
101 int src_output = squash_src_outputs[e->src()->id()] && !e->IsControlEdge()
102 ? 0
103 : e->src_output();
104 Node* dst_copy = (*node_map)[e->dst()->id()];
105 output->AddEdge(src_copy, src_output, dst_copy, e->dst_input());
106 }
107 }
108 return Status::OK();
109 }
110
BuildArgNode(Graph * graph,DataType type,int index)111 StatusOr<Node*> BuildArgNode(Graph* graph, DataType type, int index) {
112 const char* const kArgOp = "_Arg";
113 NodeDef arg_def;
114 NodeDefBuilder builder(absl::StrCat(kArgOp, index), kArgOp);
115 builder.Attr("T", type);
116 builder.Attr("index", index);
117 TF_RETURN_IF_ERROR(builder.Finalize(&arg_def));
118 return AddNodeDefToGraph(arg_def, graph);
119 }
120
121 // Builds a graph for the loop condition.
BuildLoopCondition(const Graph & graph,WhileLoopFrame * frame,std::unique_ptr<Graph> * cond_output)122 Status BuildLoopCondition(const Graph& graph, WhileLoopFrame* frame,
123 std::unique_ptr<Graph>* cond_output) {
124 VLOG(2) << "Building loop condition for " << frame->name;
125 *cond_output = absl::make_unique<Graph>(graph.op_registry());
126 Graph* output = cond_output->get();
127
128 // Map from nodes in the original graph to the condition graph.
129 std::vector<Node*> node_map(graph.num_node_ids(), nullptr);
130 std::vector<bool> squash_src_outputs(graph.num_node_ids(), false);
131
132 // Build one _Arg node for each Enter node.
133 for (int i = 0, end = frame->args.size(); i < end; ++i) {
134 const WhileLoopArg& arg = frame->args[i];
135
136 TF_ASSIGN_OR_RETURN(Node * arg_node,
137 BuildArgNode(output, arg.enter->input_type(0), i));
138 if (arg.is_loop_invariant) {
139 node_map[arg.enter->id()] = arg_node;
140 } else {
141 node_map[arg.merge->id()] = arg_node;
142 }
143 }
144
145 // Build a Retval node for the loop condition. The LoopCond nodes are always
146 // boolean because of the type constraints on the LoopCond op.
147 TF_ASSIGN_OR_RETURN(node_map[frame->loop_cond->id()],
148 BuildRetvalNode(output, DT_BOOL, 0));
149
150 // Performs a reverse DFS, copying nodes and edges to the output graph.
151 // The _Arg and _Retval nodes were added unconditionally above, so we are
152 // guaranteed to get the correct function signature.
153 return CopySubgraph(graph, frame, {frame->loop_cond}, squash_src_outputs,
154 &node_map, output);
155 }
156
157 // Builds a graph for the loop body.
BuildLoopBody(const Graph & graph,WhileLoopFrame * frame,DataTypeVector * arg_types,std::unique_ptr<Graph> * body_output)158 Status BuildLoopBody(const Graph& graph, WhileLoopFrame* frame,
159 DataTypeVector* arg_types,
160 std::unique_ptr<Graph>* body_output) {
161 VLOG(2) << "Building loop body for " << frame->name;
162 *body_output = absl::make_unique<Graph>(graph.op_registry());
163 Graph* output = body_output->get();
164
165 // Map from nodes in the original graph to the body graph.
166 std::vector<Node*> node_map(graph.num_node_ids(), nullptr);
167 std::vector<bool> squash_src_outputs(graph.num_node_ids(), false);
168
169 // Build one _Arg node for each Enter node.
170 std::vector<Node*> next_iterations;
171 next_iterations.reserve(frame->args.size());
172 arg_types->reserve(frame->args.size());
173 for (int i = 0, end = frame->args.size(); i < end; ++i) {
174 const WhileLoopArg& arg = frame->args[i];
175
176 DataType dtype = arg.enter->input_type(0);
177 arg_types->push_back(dtype);
178
179 TF_ASSIGN_OR_RETURN(Node * arg_node, BuildArgNode(output, dtype, i));
180 TF_ASSIGN_OR_RETURN(Node * retval_node, BuildRetvalNode(output, dtype, i));
181 if (arg.is_loop_invariant) {
182 // Argument is loop-invariant. Forward it from the Arg to the Retval.
183 node_map[arg.enter->id()] = arg_node;
184 output->AddEdge(arg_node, 0, retval_node, 0);
185 } else {
186 // Argument is loop-varying.
187 if (dtype == DT_RESOURCE) {
188 // DT_RESOURCE arguments should always be loop-invariant in the graphs
189 // generated from TF.
190 return errors::Unimplemented("Loop-varying DT_RESOURCE Enter node ",
191 arg.enter->name(), " is currently not",
192 " supported.");
193 }
194 node_map[arg.switch_node->id()] = arg_node;
195 // The Switch node has two outputs, but _Arg only has one. This tells
196 // the CopySubgraph function to rewrite the output number of edges from
197 // the _Arg node to be 0 rather than copying the output number from the
198 // Switch node.
199 squash_src_outputs[arg.switch_node->id()] = true;
200 node_map[arg.next_iteration->id()] = retval_node;
201 next_iterations.push_back(arg.next_iteration);
202 }
203 }
204
205 // Performs a reverse DFS, copying nodes and edges to the output graph.
206 // The _Arg and _Retval nodes were added unconditionally above, so we are
207 // guaranteed to get the correct function signature.
208 TF_RETURN_IF_ERROR(CopySubgraph(graph, frame, std::move(next_iterations),
209 squash_src_outputs, &node_map, output));
210
211 return Status::OK();
212 }
213
FunctionalizeLoop(Graph * graph,WhileLoopFrame * frame,FunctionLibraryDefinition * library,const NodeFilter & node_filter)214 Status FunctionalizeLoop(Graph* graph, WhileLoopFrame* frame,
215 FunctionLibraryDefinition* library,
216 const NodeFilter& node_filter) {
217 if (node_filter && !frame->should_be_functionalized) {
218 VLOG(2) << "Skipping functionalization for frame " << frame->name
219 << " because it has control flow nodes that are filtered out by "
220 "the specified node filter.";
221 return Status::OK();
222 }
223 VLOG(2) << "Frame " << frame->name << " before: "
224 << DumpGraphToFile("functionalize_before", *graph, library);
225
226 // Split loop-varying Enter nodes with multiple successors. If the same
227 // Tensor is fed as input to multiple loop arguments, we may end up with a
228 // shared Enter node. We clone Enter nodes with multiple successors to
229 // maintain the invariant of a unique Enter node per argument of the final
230 // loop.
231 std::vector<WhileLoopArg> args;
232 for (const WhileLoopArg& arg : frame->args) {
233 if (arg.is_loop_invariant) {
234 args.push_back(arg);
235 } else {
236 std::vector<const Edge*> edges(arg.enter->out_edges().begin(),
237 arg.enter->out_edges().end());
238 for (int i = 0, end = edges.size(); i < end; ++i) {
239 if (edges[i]->IsControlEdge() && edges[i]->dst()->IsSink()) {
240 continue;
241 }
242 TF_RET_CHECK(!edges[i]->IsControlEdge()) << edges[i]->src()->name();
243 WhileLoopArg new_arg;
244 new_arg.is_loop_invariant = false;
245 if (i == 0) {
246 new_arg.enter = arg.enter;
247 } else {
248 new_arg.enter = graph->CopyNode(arg.enter);
249 frame->nodes.insert(new_arg.enter);
250 for (Edge const* e : arg.enter->in_edges()) {
251 graph->AddEdge(e->src(), e->src_output(), new_arg.enter,
252 e->IsControlEdge() ? Graph::kControlSlot : 0);
253 }
254 Node* dst = edges[i]->dst();
255 int dst_input = edges[i]->dst_input();
256 graph->RemoveEdge(edges[i]);
257 graph->AddEdge(new_arg.enter, 0, dst, dst_input);
258 }
259 args.push_back(new_arg);
260 }
261 }
262 }
263 frame->args = std::move(args);
264
265 std::sort(frame->args.begin(), frame->args.end(),
266 [](const WhileLoopArg& a, const WhileLoopArg& b) {
267 return NodeCmpByNameResourcesLast()(a.enter, b.enter);
268 });
269
270 if (frame->loop_cond == nullptr) {
271 return errors::InvalidArgument("Loop ", frame->name,
272 " has no LoopCond node");
273 }
274
275 // Find the set of Switch nodes that are successors of the LoopCond.
276 std::unordered_set<Node*> switches;
277 for (const Edge* edge : frame->loop_cond->out_edges()) {
278 if (!edge->IsControlEdge() && IsSwitch(edge->dst()) &&
279 edge->dst_input() == 1) {
280 switches.insert(edge->dst());
281 }
282 }
283
284 // For each non-constant argument, looks for the following pattern of nodes:
285 // Enter ----> Merge --------> Switch --> Exit
286 // ^ ^
287 // | |
288 // NextIteration LoopCond
289 // ^ ^
290 // | |
291 // ... ...
292 for (WhileLoopArg& arg : frame->args) {
293 if (!arg.is_loop_invariant) {
294 // Follow the edge from the Enter to Merge.
295 const Edge* enter_merge = nullptr;
296 for (const Edge* e : arg.enter->out_edges()) {
297 // Ignore control-edges to the sink node. These are allowed by the
298 // graph invariants, although probably they should have been stripped
299 // off earlier.
300 if (e->IsControlEdge() && e->dst()->IsSink()) {
301 continue;
302 }
303 if (enter_merge != nullptr) {
304 return errors::Internal("Enter node for loop-varying argument ",
305 FormatNodeForError(*arg.enter),
306 " has multiple successors: ",
307 FormatNodeForError(*enter_merge->dst()),
308 " and ", FormatNodeForError(*e->dst()));
309 }
310 enter_merge = e;
311 }
312 if (enter_merge == nullptr) {
313 return errors::Internal("Enter node for loop-varying argument ",
314 FormatNodeForError(*arg.enter),
315 " has zero successors");
316 }
317 arg.merge = enter_merge->dst();
318 if (!IsMerge(arg.merge)) {
319 return errors::InvalidArgument(
320 "Successor of Enter node for loop-varying argument ",
321 FormatNodeForError(*arg.merge),
322 " is not a Merge node; got: ", arg.merge->type_string());
323 }
324
325 // Find the NextIteration from the merge. There should be two inputs to
326 // the Merge and the NextIteration should be the other input.
327 if (arg.merge->input_types().size() != 2) {
328 return errors::InvalidArgument(
329 "Unexpected number of inputs to Merge node for loop-varying "
330 "argument ",
331 FormatNodeForError(*arg.merge), "; expected 2, got ",
332 arg.merge->input_types().size());
333 }
334 TF_RETURN_IF_ERROR(arg.merge->input_node(1 - enter_merge->dst_input(),
335 &arg.next_iteration));
336 if (!IsNextIteration(arg.next_iteration)) {
337 return errors::InvalidArgument(
338 "Expected NextIteration node as input to Merge node; got node ",
339 FormatNodeForError(*arg.next_iteration), " with kind ",
340 arg.next_iteration->type_string());
341 }
342
343 // Find the Switch successor of the Merge. There should be exactly one
344 // Switch node that is a successor of both the Merge and the LoopCond.
345 for (const Edge* edge : arg.merge->out_edges()) {
346 if (edge->dst_input() == 0 && IsSwitch(edge->dst()) &&
347 switches.find(edge->dst()) != switches.end()) {
348 if (arg.switch_node != nullptr) {
349 return errors::InvalidArgument("Duplicate Switch successors to ",
350 FormatNodeForError(*arg.merge));
351 }
352 arg.switch_node = edge->dst();
353 }
354 }
355 if (arg.switch_node == nullptr) {
356 return errors::InvalidArgument("Missing Switch successor to ",
357 FormatNodeForError(*arg.merge));
358 }
359 // Loop over the switch node's output to:
360 // - Find the Exit successor.
361 // - Set the sharding on all Identity outputs of the switch. These
362 // identity nodes are values used by the loop body or condition.
363 // The Identity node may have the wrong device so copy the device from
364 // one of its outputs instead.
365 std::deque<const Edge*> possible_exit;
366 for (const Edge* edge : arg.switch_node->out_edges()) {
367 if (edge->src_output() == 0) {
368 possible_exit.push_back(edge);
369 }
370 if (IsIdentity(edge->dst())) {
371 TF_RETURN_IF_ERROR(
372 SetNodeShardingFromNeighbors(edge->dst(), /*out_edges=*/true));
373 }
374 }
375 // TODO(b/67425339): Allow general graph between switch and exit.
376 while (!possible_exit.empty()) {
377 const Edge* edge = possible_exit.front();
378 possible_exit.pop_front();
379 if (IsExit(edge->dst())) {
380 if (arg.exit != nullptr) {
381 return errors::InvalidArgument(
382 "Duplicate Exit successors to ",
383 FormatNodeForError(*arg.switch_node));
384 }
385 arg.exit = edge->dst();
386 } else {
387 if (!IsIdentity(edge->dst())) {
388 return errors::Unimplemented("General graph between switch (",
389 FormatNodeForError(*arg.switch_node),
390 ") and exit node of frame ",
391 frame->name, " not supported yet.");
392 }
393 for (const Edge* out : edge->dst()->out_edges()) {
394 possible_exit.push_back(out);
395 }
396 }
397 }
398 }
399 }
400
401 // Builds the condition and body functions. Notice that we call
402 // FunctionalizeCond() on cond_graph and body_graph because we might have
403 // unfunctionalized "if" in cond_graph and body_graph. Functionalize them
404 // before they are encapsulated in FunctionDef.
405 std::unique_ptr<Graph> cond_graph;
406 TF_RETURN_IF_ERROR(BuildLoopCondition(*graph, frame, &cond_graph));
407 FixupSourceAndSinkEdges(cond_graph.get());
408 TF_RETURN_IF_ERROR(FunctionalizeCond(cond_graph.get(), library, node_filter));
409 DataTypeVector arg_types;
410 std::unique_ptr<Graph> body_graph;
411 TF_RETURN_IF_ERROR(BuildLoopBody(*graph, frame, &arg_types, &body_graph));
412 FixupSourceAndSinkEdges(body_graph.get());
413 TF_RETURN_IF_ERROR(FunctionalizeCond(body_graph.get(), library, node_filter));
414
415 VLOG(2) << "Frame " << frame->name << " condition: "
416 << DumpGraphToFile("loop_condition", *cond_graph, library)
417 << " body: " << DumpGraphToFile("loop_body", *body_graph);
418
419 NameAttrList cond_name;
420 cond_name.set_name(library->UniqueFunctionName("_functionalize_cond_"));
421 NameAttrList body_name;
422 body_name.set_name(library->UniqueFunctionName("_functionalize_body_"));
423 FunctionDef cond_fdef;
424 TF_RETURN_IF_ERROR(
425 GraphToFunctionDef(*cond_graph, cond_name.name(), &cond_fdef));
426 FunctionDef body_fdef;
427 TF_RETURN_IF_ERROR(
428 GraphToFunctionDef(*body_graph, body_name.name(), &body_fdef));
429
430 TF_RETURN_IF_ERROR(library->AddFunctionDef(cond_fdef));
431 TF_RETURN_IF_ERROR(library->AddFunctionDef(body_fdef));
432
433 // Builds a While operator.
434 NodeDef while_def;
435 NodeDefBuilder builder(frame->loop_cond->name(), "While", library);
436 builder.Attr("T", arg_types);
437 builder.Attr("cond", cond_name);
438 builder.Attr("body", body_name);
439 // Add some internal attributes which need to be propagated.
440 // TODO(b/160275126): attributes shouldn't be hard-coded here
441 for (const char* attr_name :
442 {kXlaFrontendAttributesAttrName, kXlaOutsideCompilationAttrName,
443 kTpuReplicateAttrName}) {
444 string attr_val;
445 if (GetNodeAttr(frame->loop_cond->def(), attr_name, &attr_val).ok()) {
446 builder.Attr(attr_name, attr_val);
447 }
448 }
449 std::vector<NodeDefBuilder::NodeOut> inputs;
450 for (int i = 0, end = frame->args.size(); i < end; ++i) {
451 const WhileLoopArg& arg = frame->args[i];
452 const Edge* in_edge;
453 TF_RETURN_IF_ERROR(arg.enter->input_edge(0, &in_edge));
454 if (in_edge->IsControlEdge()) {
455 builder.ControlInput(in_edge->src()->name());
456 } else {
457 inputs.push_back(NodeDefBuilder::NodeOut(
458 in_edge->src()->name(), in_edge->src_output(), arg_types[i]));
459 }
460 }
461 builder.Input(inputs);
462 TF_RETURN_IF_ERROR(builder.Finalize(&while_def));
463 TF_ASSIGN_OR_RETURN(Node * while_node, AddNodeDefToGraph(while_def, graph));
464
465 // Copies edges to the Enter nodes and from the Exit nodes onto the While.
466 for (int i = 0, end = frame->args.size(); i < end; ++i) {
467 const WhileLoopArg& arg = frame->args[i];
468 const Edge* in_edge;
469 TF_RETURN_IF_ERROR(arg.enter->input_edge(0, &in_edge));
470 if (in_edge->IsControlEdge()) {
471 graph->AddControlEdge(in_edge->src(), while_node);
472 } else {
473 graph->AddEdge(in_edge->src(), in_edge->src_output(), while_node, i);
474 }
475
476 if (!arg.is_loop_invariant) {
477 // Add output edges if the output of the loop is consumed.
478 if (arg.exit != nullptr) {
479 std::vector<const Edge*> edges(arg.exit->out_edges().begin(),
480 arg.exit->out_edges().end());
481 for (const Edge* edge : edges) {
482 Node* dst = edge->dst();
483 int dst_input = edge->dst_input();
484 graph->RemoveEdge(edge);
485
486 if (dst_input == Graph::kControlSlot) {
487 graph->AddControlEdge(while_node, dst);
488 } else {
489 graph->AddEdge(while_node, i, dst, dst_input);
490 }
491 }
492 }
493 }
494 }
495
496 // Remove the old nodes from the graph, and add the while node to the parent
497 // frame.
498 for (Node* node : frame->nodes) {
499 VLOG(2) << "Removing obsolete node " << node->name();
500 graph->RemoveNode(node);
501 }
502 frame->nodes.clear();
503 frame->parent->nodes.insert(while_node);
504
505 VLOG(2) << "Frame " << frame->name << " after: "
506 << DumpGraphToFile("functionalize_after", *graph, library);
507
508 return Status::OK();
509 }
510 } // namespace
511
FunctionalizeWhileLoop(Graph * graph,FunctionLibraryDefinition * library,const NodeFilter & node_filter)512 Status FunctionalizeWhileLoop(Graph* graph, FunctionLibraryDefinition* library,
513 const NodeFilter& node_filter) {
514 // Note: BuildControlFlowInfo() requires that the graph's source node is
515 // connected to all source nodes in the graph. Many graphs violate this
516 // invariant.
517 std::vector<ControlFlowInfo> cf_info;
518 std::vector<string> unreachable_nodes;
519 TF_RETURN_IF_ERROR(BuildControlFlowInfo(graph, &cf_info, &unreachable_nodes));
520 if (!unreachable_nodes.empty()) {
521 return errors::InvalidArgument(
522 "The following nodes are unreachable from the source in the graph: ",
523 errors::FormatNodeNamesForError(unreachable_nodes));
524 }
525
526 // Builds Frames, indexed by name.
527 std::unordered_map<string, WhileLoopFrame> frames;
528 TF_RETURN_IF_ERROR(
529 ExtractWhileLoopFrames(cf_info, graph, &frames, node_filter));
530
531 // Adds frames with no children (i.e., the innermost frames) to a worklist.
532 std::deque<WhileLoopFrame*> worklist;
533 for (auto& frame : frames) {
534 if (frame.second.num_children == 0) {
535 worklist.push_back(&frame.second);
536 }
537 }
538
539 // Eliminate loops from innermost to outermost. Note that the precondition for
540 // `node_filter` in `FunctionalizeControlFlow` makes sure that this approach
541 // works.
542 while (!worklist.empty()) {
543 WhileLoopFrame* frame = worklist.front();
544 worklist.pop_front();
545 if (frame->parent == frame) {
546 // Skip the root frame.
547 continue;
548 }
549
550 TF_RETURN_IF_ERROR(FunctionalizeLoop(graph, frame, library, node_filter));
551
552 // If the parent has no remaining children, add it to the worklist.
553 --frame->parent->num_children;
554 if (frame->parent->num_children == 0) {
555 worklist.push_back(frame->parent);
556 }
557 }
558
559 if (!node_filter) {
560 // There should be no cycle at this point, since while loops have been
561 // removed from graph. Check that the newly added While nodes don't feed
562 // into themselves.
563 for (const Node* node : graph->op_nodes()) {
564 if (node->def().op() == "While") {
565 TF_RETURN_WITH_CONTEXT_IF_ERROR(
566 CheckNodeNotInCycle(node, graph->num_node_ids()),
567 "Functionalizing loop failed.");
568 }
569 }
570 }
571
572 return Status::OK();
573 }
574
575 } // namespace tensorflow
576