• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/tf2xla/functionalize_while.h"
17 
18 #include <algorithm>
19 #include <deque>
20 #include <stack>
21 #include <unordered_set>
22 #include <vector>
23 
24 #include "absl/memory/memory.h"
25 #include "absl/strings/match.h"
26 #include "absl/types/optional.h"
27 #include "tensorflow/compiler/tf2xla/frontend_attributes_util.h"
28 #include "tensorflow/compiler/tf2xla/functionalize_cond.h"
29 #include "tensorflow/compiler/tf2xla/tf2xla_util.h"
30 #include "tensorflow/compiler/xla/status_macros.h"
31 #include "tensorflow/compiler/xla/union_find.h"
32 #include "tensorflow/core/common_runtime/function.h"
33 #include "tensorflow/core/framework/graph_to_functiondef.h"
34 #include "tensorflow/core/framework/node_def_builder.h"
35 #include "tensorflow/core/graph/algorithm.h"
36 #include "tensorflow/core/graph/control_flow.h"
37 #include "tensorflow/core/graph/node_builder.h"
38 #include "tensorflow/core/lib/strings/strcat.h"
39 #include "tensorflow/core/util/dump_graph.h"
40 
41 namespace tensorflow {
42 namespace {
43 
44 // Copies a subgraph from `graph` to `output` by performing a reverse DFS
45 // starting at nodes in vector `stack`.
46 // `node_map` is a vector indexed by source node ID to dest nodes.
47 // Does not traverse into nodes in `node_map`, so by adding nodes to `node_map`
48 // before the traversal clients can cut the graph. If a frame is provided (frame
49 // != nullptr), then this functions will return an error if the
50 // traversal leaves 'frame'; the client must add enough nodes to `node_map` to
51 // cut the graph and prevent the traversal from escaping.
52 //
53 // `squash_src_outputs` contains a bool for each source node ID. If true, then
54 // the source output on that node will be replaced by zero when copied. This is
55 // used when replacing a Switch node with an _Arg node. The output we are
56 // taking from the Switch node was not necessarily the first output, but _Arg
57 // nodes only have one output. By adding the Switch node to `squash_src_outputs`
58 // we rewrite the src_output of the corresponding edge to be 0.
CopySubgraph(const Graph & graph,const WhileLoopFrame * frame,std::vector<Node * > stack,const std::vector<bool> & squash_src_outputs,std::vector<Node * > * node_map,Graph * output)59 Status CopySubgraph(const Graph& graph, const WhileLoopFrame* frame,
60                     std::vector<Node*> stack,
61                     const std::vector<bool>& squash_src_outputs,
62                     std::vector<Node*>* node_map, Graph* output) {
63   VLOG(3) << "Stack: " << NodesToString(stack);
64   std::vector<bool> visited(graph.num_node_ids(), false);
65   while (!stack.empty()) {
66     Node* n = stack.back();
67     stack.pop_back();
68 
69     VLOG(5) << "Copying node " << n->name();
70 
71     if (visited[n->id()]) continue;
72     visited[n->id()] = true;
73 
74     // Sort "n->in_edges()" to make sure nodes are copied in a deterministic
75     // order.
76     std::vector<const Edge*> sorted_edges(n->in_edges().begin(),
77                                           n->in_edges().end());
78     std::sort(sorted_edges.begin(), sorted_edges.end(),
79               [](const Edge* a, const Edge* b) {
80                 int a_src_output = a->src_output(),
81                     b_src_output = b->src_output();
82                 StringPiece a_name(a->src()->name()), b_name(b->src()->name());
83                 return std::tie(a_src_output, a_name) <
84                        std::tie(b_src_output, b_name);
85               });
86     for (const Edge* e : sorted_edges) {
87       Node* src = e->src();
88       if (frame != nullptr && frame->nodes.find(src) == frame->nodes.end()) {
89         // We traversed out of the loop frame, without encountering a cut node.
90         return errors::Internal("Graph traversal of loop frame ", frame->name,
91                                 " escaped frame at ", src->name(),
92                                 " without encountering an argument node.");
93       }
94       if ((*node_map)[src->id()] == nullptr) {
95         (*node_map)[src->id()] = output->CopyNode(src);
96         stack.push_back(src);
97       }
98       Node* src_copy = (*node_map)[e->src()->id()];
99       int src_output = squash_src_outputs[e->src()->id()] && !e->IsControlEdge()
100                            ? 0
101                            : e->src_output();
102       Node* dst_copy = (*node_map)[e->dst()->id()];
103       output->AddEdge(src_copy, src_output, dst_copy, e->dst_input());
104     }
105   }
106   return Status::OK();
107 }
108 
BuildArgNode(Graph * graph,DataType type,int index)109 StatusOr<Node*> BuildArgNode(Graph* graph, DataType type, int index) {
110   const char* const kArgOp = "_Arg";
111   NodeDef arg_def;
112   NodeDefBuilder builder(absl::StrCat(kArgOp, index), kArgOp);
113   builder.Attr("T", type);
114   builder.Attr("index", index);
115   TF_RETURN_IF_ERROR(builder.Finalize(&arg_def));
116   return AddNodeDefToGraph(arg_def, graph);
117 }
118 
119 // Builds a graph for the loop condition.
BuildLoopCondition(const Graph & graph,WhileLoopFrame * frame,std::unique_ptr<Graph> * cond_output)120 Status BuildLoopCondition(const Graph& graph, WhileLoopFrame* frame,
121                           std::unique_ptr<Graph>* cond_output) {
122   VLOG(2) << "Building loop condition for " << frame->name;
123   *cond_output = absl::make_unique<Graph>(graph.op_registry());
124   Graph* output = cond_output->get();
125 
126   // Map from nodes in the original graph to the condition graph.
127   std::vector<Node*> node_map(graph.num_node_ids(), nullptr);
128   std::vector<bool> squash_src_outputs(graph.num_node_ids(), false);
129 
130   // Build one _Arg node for each Enter node.
131   for (int i = 0, end = frame->args.size(); i < end; ++i) {
132     const WhileLoopArg& arg = frame->args[i];
133 
134     TF_ASSIGN_OR_RETURN(Node * arg_node,
135                         BuildArgNode(output, arg.enter->input_type(0), i));
136     if (arg.is_loop_invariant) {
137       node_map[arg.enter->id()] = arg_node;
138     } else {
139       node_map[arg.merge->id()] = arg_node;
140     }
141   }
142 
143   // Build a Retval node for the loop condition. The LoopCond nodes are always
144   // boolean because of the type constraints on the LoopCond op.
145   TF_ASSIGN_OR_RETURN(node_map[frame->loop_cond->id()],
146                       BuildRetvalNode(output, DT_BOOL, 0));
147 
148   // Performs a reverse DFS, copying nodes and edges to the output graph.
149   // The _Arg and _Retval nodes were added unconditionally above, so we are
150   // guaranteed to get the correct function signature.
151   return CopySubgraph(graph, frame, {frame->loop_cond}, squash_src_outputs,
152                       &node_map, output);
153 }
154 
155 // Builds a graph for the loop body.
BuildLoopBody(const Graph & graph,WhileLoopFrame * frame,DataTypeVector * arg_types,std::unique_ptr<Graph> * body_output)156 Status BuildLoopBody(const Graph& graph, WhileLoopFrame* frame,
157                      DataTypeVector* arg_types,
158                      std::unique_ptr<Graph>* body_output) {
159   VLOG(2) << "Building loop body for " << frame->name;
160   *body_output = absl::make_unique<Graph>(graph.op_registry());
161   Graph* output = body_output->get();
162 
163   // Map from nodes in the original graph to the body graph.
164   std::vector<Node*> node_map(graph.num_node_ids(), nullptr);
165   std::vector<bool> squash_src_outputs(graph.num_node_ids(), false);
166 
167   // Build one _Arg node for each Enter node.
168   std::vector<Node*> next_iterations;
169   next_iterations.reserve(frame->args.size());
170   arg_types->reserve(frame->args.size());
171   for (int i = 0, end = frame->args.size(); i < end; ++i) {
172     const WhileLoopArg& arg = frame->args[i];
173 
174     DataType dtype = arg.enter->input_type(0);
175     arg_types->push_back(dtype);
176 
177     TF_ASSIGN_OR_RETURN(Node * arg_node, BuildArgNode(output, dtype, i));
178     TF_ASSIGN_OR_RETURN(Node * retval_node, BuildRetvalNode(output, dtype, i));
179     if (arg.is_loop_invariant) {
180       // Argument is loop-invariant. Forward it from the Arg to the Retval.
181       node_map[arg.enter->id()] = arg_node;
182       output->AddEdge(arg_node, 0, retval_node, 0);
183     } else {
184       // Argument is loop-varying.
185       if (dtype == DT_RESOURCE) {
186         // DT_RESOURCE arguments should always be loop-invariant in the graphs
187         // generated from TF.
188         return errors::Unimplemented("Loop-varying DT_RESOURCE Enter node ",
189                                      arg.enter->name(), " is currently not",
190                                      " supported.");
191       }
192       node_map[arg.switch_node->id()] = arg_node;
193       // The Switch node has two outputs, but _Arg only has one. This tells
194       // the CopySubgraph function to rewrite the output number of edges from
195       // the _Arg node to be 0 rather than copying the output number from the
196       // Switch node.
197       squash_src_outputs[arg.switch_node->id()] = true;
198       node_map[arg.next_iteration->id()] = retval_node;
199       next_iterations.push_back(arg.next_iteration);
200     }
201   }
202 
203   // Performs a reverse DFS, copying nodes and edges to the output graph.
204   // The _Arg and _Retval nodes were added unconditionally above, so we are
205   // guaranteed to get the correct function signature.
206   TF_RETURN_IF_ERROR(CopySubgraph(graph, frame, std::move(next_iterations),
207                                   squash_src_outputs, &node_map, output));
208 
209   return Status::OK();
210 }
211 
FunctionalizeLoop(Graph * graph,WhileLoopFrame * frame,FunctionLibraryDefinition * library,const NodeFilter & node_filter)212 Status FunctionalizeLoop(Graph* graph, WhileLoopFrame* frame,
213                          FunctionLibraryDefinition* library,
214                          const NodeFilter& node_filter) {
215   if (node_filter && !frame->should_be_functionalized) {
216     VLOG(2) << "Skipping functionalization for frame " << frame->name
217             << " because it has control flow nodes that are filtered out by "
218                "the specified node filter.";
219     return Status::OK();
220   }
221   VLOG(2) << "Frame " << frame->name << " before: "
222           << DumpGraphToFile("functionalize_before", *graph, library);
223 
224   // Split loop-varying Enter nodes with multiple successors. If the same
225   // Tensor is fed as input to multiple loop arguments, we may end up with a
226   // shared Enter node. We clone Enter nodes with multiple successors to
227   // maintain the invariant of a unique Enter node per argument of the final
228   // loop.
229   std::vector<WhileLoopArg> args;
230   for (const WhileLoopArg& arg : frame->args) {
231     if (arg.is_loop_invariant) {
232       args.push_back(arg);
233     } else {
234       std::vector<const Edge*> edges(arg.enter->out_edges().begin(),
235                                      arg.enter->out_edges().end());
236       for (int i = 0, end = edges.size(); i < end; ++i) {
237         if (edges[i]->IsControlEdge() && edges[i]->dst()->IsSink()) {
238           continue;
239         }
240         TF_RET_CHECK(!edges[i]->IsControlEdge()) << edges[i]->src()->name();
241         WhileLoopArg new_arg;
242         new_arg.is_loop_invariant = false;
243         if (i == 0) {
244           new_arg.enter = arg.enter;
245         } else {
246           new_arg.enter = graph->CopyNode(arg.enter);
247           frame->nodes.insert(new_arg.enter);
248           for (Edge const* e : arg.enter->in_edges()) {
249             graph->AddEdge(e->src(), e->src_output(), new_arg.enter,
250                            e->IsControlEdge() ? Graph::kControlSlot : 0);
251           }
252           Node* dst = edges[i]->dst();
253           int dst_input = edges[i]->dst_input();
254           graph->RemoveEdge(edges[i]);
255           graph->AddEdge(new_arg.enter, 0, dst, dst_input);
256         }
257         args.push_back(new_arg);
258       }
259     }
260   }
261   frame->args = std::move(args);
262 
263   std::sort(frame->args.begin(), frame->args.end(),
264             [](const WhileLoopArg& a, const WhileLoopArg& b) {
265               return NodeCmpByNameResourcesLast()(a.enter, b.enter);
266             });
267 
268   if (frame->loop_cond == nullptr) {
269     return errors::InvalidArgument("Loop ", frame->name,
270                                    " has no LoopCond node");
271   }
272 
273   // Find the set of Switch nodes that are successors of the LoopCond.
274   std::unordered_set<Node*> switches;
275   for (const Edge* edge : frame->loop_cond->out_edges()) {
276     if (!edge->IsControlEdge() && IsSwitch(edge->dst()) &&
277         edge->dst_input() == 1) {
278       switches.insert(edge->dst());
279     }
280   }
281 
282   // For each non-constant argument, looks for the following pattern of nodes:
283   // Enter ----> Merge  -------->  Switch  --> Exit
284   //               ^                  ^
285   //               |                  |
286   //         NextIteration         LoopCond
287   //               ^                  ^
288   //               |                  |
289   //              ...                ...
290   for (WhileLoopArg& arg : frame->args) {
291     if (!arg.is_loop_invariant) {
292       // Follow the edge from the Enter to Merge.
293       const Edge* enter_merge = nullptr;
294       for (const Edge* e : arg.enter->out_edges()) {
295         // Ignore control-edges to the sink node. These are allowed by the
296         // graph invariants, although probably they should have been stripped
297         // off earlier.
298         if (e->IsControlEdge() && e->dst()->IsSink()) {
299           continue;
300         }
301         if (enter_merge != nullptr) {
302           return errors::Internal("Enter node for loop-varying argument ",
303                                   FormatNodeForError(*arg.enter),
304                                   " has multiple successors: ",
305                                   FormatNodeForError(*enter_merge->dst()),
306                                   " and ", FormatNodeForError(*e->dst()));
307         }
308         enter_merge = e;
309       }
310       if (enter_merge == nullptr) {
311         return errors::Internal("Enter node for loop-varying argument ",
312                                 FormatNodeForError(*arg.enter),
313                                 " has zero successors");
314       }
315       arg.merge = enter_merge->dst();
316       if (!IsMerge(arg.merge)) {
317         return errors::InvalidArgument(
318             "Successor of Enter node for loop-varying argument ",
319             FormatNodeForError(*arg.merge),
320             " is not a Merge node; got: ", arg.merge->type_string());
321       }
322 
323       // Find the NextIteration from the merge. There should be two inputs to
324       // the Merge and the NextIteration should be the other input.
325       if (arg.merge->input_types().size() != 2) {
326         return errors::InvalidArgument(
327             "Unexpected number of inputs to Merge node for loop-varying "
328             "argument ",
329             FormatNodeForError(*arg.merge), "; expected 2, got ",
330             arg.merge->input_types().size());
331       }
332       TF_RETURN_IF_ERROR(arg.merge->input_node(1 - enter_merge->dst_input(),
333                                                &arg.next_iteration));
334       if (!IsNextIteration(arg.next_iteration)) {
335         return errors::InvalidArgument(
336             "Expected NextIteration node as input to Merge node; got node ",
337             FormatNodeForError(*arg.next_iteration), " with kind ",
338             arg.next_iteration->type_string());
339       }
340 
341       // Find the Switch successor of the Merge. There should be exactly one
342       // Switch node that is a successor of both the Merge and the LoopCond.
343       for (const Edge* edge : arg.merge->out_edges()) {
344         if (edge->dst_input() == 0 && IsSwitch(edge->dst()) &&
345             switches.find(edge->dst()) != switches.end()) {
346           if (arg.switch_node != nullptr) {
347             return errors::InvalidArgument("Duplicate Switch successors to ",
348                                            FormatNodeForError(*arg.merge));
349           }
350           arg.switch_node = edge->dst();
351         }
352       }
353       if (arg.switch_node == nullptr) {
354         return errors::InvalidArgument("Missing Switch successor to ",
355                                        FormatNodeForError(*arg.merge));
356       }
357       // Loop over the switch node's output to:
358       // - Find the Exit successor.
359       // - Set the sharding on all Identity outputs of the switch. These
360       //   identity nodes are values used by the loop body or condition.
361       //   The Identity node may have the wrong device so copy the device from
362       //   one of its outputs instead.
363       std::deque<const Edge*> possible_exit;
364       for (const Edge* edge : arg.switch_node->out_edges()) {
365         if (edge->src_output() == 0) {
366           possible_exit.push_back(edge);
367         }
368         if (IsIdentity(edge->dst())) {
369           TF_RETURN_IF_ERROR(
370               SetNodeShardingFromNeighbors(edge->dst(), /*out_edges=*/true));
371         }
372       }
373       // TODO(b/67425339): Allow general graph between switch and exit.
374       while (!possible_exit.empty()) {
375         const Edge* edge = possible_exit.front();
376         possible_exit.pop_front();
377         if (IsExit(edge->dst())) {
378           if (arg.exit != nullptr) {
379             return errors::InvalidArgument(
380                 "Duplicate Exit successors to ",
381                 FormatNodeForError(*arg.switch_node));
382           }
383           arg.exit = edge->dst();
384         } else {
385           if (!IsIdentity(edge->dst())) {
386             return errors::Unimplemented("General graph between switch (",
387                                          FormatNodeForError(*arg.switch_node),
388                                          ") and exit node of frame ",
389                                          frame->name, " not supported yet.");
390           }
391           for (const Edge* out : edge->dst()->out_edges()) {
392             possible_exit.push_back(out);
393           }
394         }
395       }
396     }
397   }
398 
399   // Builds the condition and body functions. Notice that we call
400   // FunctionalizeCond() on cond_graph and body_graph because we might have
401   // unfunctionalized "if" in cond_graph and body_graph. Functionalize them
402   // before they are encapsulated in FunctionDef.
403   std::unique_ptr<Graph> cond_graph;
404   TF_RETURN_IF_ERROR(BuildLoopCondition(*graph, frame, &cond_graph));
405   FixupSourceAndSinkEdges(cond_graph.get());
406   TF_RETURN_IF_ERROR(FunctionalizeCond(cond_graph.get(), library, node_filter));
407   DataTypeVector arg_types;
408   std::unique_ptr<Graph> body_graph;
409   TF_RETURN_IF_ERROR(BuildLoopBody(*graph, frame, &arg_types, &body_graph));
410   FixupSourceAndSinkEdges(body_graph.get());
411   TF_RETURN_IF_ERROR(FunctionalizeCond(body_graph.get(), library, node_filter));
412 
413   VLOG(2) << "Frame " << frame->name << " condition: "
414           << DumpGraphToFile("loop_condition", *cond_graph, library)
415           << " body: " << DumpGraphToFile("loop_body", *body_graph);
416 
417   NameAttrList cond_name;
418   cond_name.set_name(library->UniqueFunctionName("_functionalize_cond_"));
419   NameAttrList body_name;
420   body_name.set_name(library->UniqueFunctionName("_functionalize_body_"));
421   FunctionDef cond_fdef;
422   TF_RETURN_IF_ERROR(
423       GraphToFunctionDef(*cond_graph, cond_name.name(), &cond_fdef));
424   FunctionDef body_fdef;
425   TF_RETURN_IF_ERROR(
426       GraphToFunctionDef(*body_graph, body_name.name(), &body_fdef));
427 
428   TF_RETURN_IF_ERROR(library->AddFunctionDef(cond_fdef));
429   TF_RETURN_IF_ERROR(library->AddFunctionDef(body_fdef));
430 
431   // Builds a While operator.
432   NodeDef while_def;
433   NodeDefBuilder builder(frame->loop_cond->name(), "While", library);
434   builder.Attr("T", arg_types);
435   builder.Attr("cond", cond_name);
436   builder.Attr("body", body_name);
437   // Add some internal attributes which need to be propagated.
438   // TODO(b/160275126): attributes shouldn't be hard-coded here
439   for (const char* attr_name :
440        {kXlaFrontendAttributesAttrName, kXlaOutsideCompilationAttrName,
441         kTpuReplicateAttrName}) {
442     string attr_val;
443     if (GetNodeAttr(frame->loop_cond->def(), attr_name, &attr_val).ok()) {
444       builder.Attr(attr_name, attr_val);
445     }
446   }
447   std::vector<NodeDefBuilder::NodeOut> inputs;
448   for (int i = 0, end = frame->args.size(); i < end; ++i) {
449     const WhileLoopArg& arg = frame->args[i];
450     const Edge* in_edge;
451     TF_RETURN_IF_ERROR(arg.enter->input_edge(0, &in_edge));
452     if (in_edge->IsControlEdge()) {
453       builder.ControlInput(in_edge->src()->name());
454     } else {
455       inputs.push_back(NodeDefBuilder::NodeOut(
456           in_edge->src()->name(), in_edge->src_output(), arg_types[i]));
457     }
458   }
459   builder.Input(inputs);
460   TF_RETURN_IF_ERROR(builder.Finalize(&while_def));
461   TF_ASSIGN_OR_RETURN(Node * while_node, AddNodeDefToGraph(while_def, graph));
462 
463   // Copies edges to the Enter nodes and from the Exit nodes onto the While.
464   for (int i = 0, end = frame->args.size(); i < end; ++i) {
465     const WhileLoopArg& arg = frame->args[i];
466     const Edge* in_edge;
467     TF_RETURN_IF_ERROR(arg.enter->input_edge(0, &in_edge));
468     if (in_edge->IsControlEdge()) {
469       graph->AddControlEdge(in_edge->src(), while_node);
470     } else {
471       graph->AddEdge(in_edge->src(), in_edge->src_output(), while_node, i);
472     }
473 
474     if (!arg.is_loop_invariant) {
475       // Add output edges if the output of the loop is consumed.
476       if (arg.exit != nullptr) {
477         std::vector<const Edge*> edges(arg.exit->out_edges().begin(),
478                                        arg.exit->out_edges().end());
479         for (const Edge* edge : edges) {
480           Node* dst = edge->dst();
481           int dst_input = edge->dst_input();
482           graph->RemoveEdge(edge);
483 
484           if (dst_input == Graph::kControlSlot) {
485             graph->AddControlEdge(while_node, dst);
486           } else {
487             graph->AddEdge(while_node, i, dst, dst_input);
488           }
489         }
490       }
491     }
492   }
493 
494   // Remove the old nodes from the graph, and add the while node to the parent
495   // frame.
496   for (Node* node : frame->nodes) {
497     VLOG(2) << "Removing obsolete node " << node->name();
498     graph->RemoveNode(node);
499   }
500   frame->nodes.clear();
501   frame->parent->nodes.insert(while_node);
502 
503   VLOG(2) << "Frame " << frame->name << " after: "
504           << DumpGraphToFile("functionalize_after", *graph, library);
505 
506   return Status::OK();
507 }
508 }  // namespace
509 
FunctionalizeWhileLoop(Graph * graph,FunctionLibraryDefinition * library,const NodeFilter & node_filter)510 Status FunctionalizeWhileLoop(Graph* graph, FunctionLibraryDefinition* library,
511                               const NodeFilter& node_filter) {
512   // Note: BuildControlFlowInfo() requires that the graph's source node is
513   // connected to all source nodes in the graph. Many graphs violate this
514   // invariant.
515   std::vector<ControlFlowInfo> cf_info;
516   std::vector<string> unreachable_nodes;
517   TF_RETURN_IF_ERROR(BuildControlFlowInfo(graph, &cf_info, &unreachable_nodes));
518   if (!unreachable_nodes.empty()) {
519     return errors::InvalidArgument(
520         "The following nodes are unreachable from the source in the graph: ",
521         errors::FormatNodeNamesForError(unreachable_nodes));
522   }
523 
524   // Builds Frames, indexed by name.
525   std::unordered_map<string, WhileLoopFrame> frames;
526   TF_RETURN_IF_ERROR(
527       ExtractWhileLoopFrames(cf_info, graph, &frames, node_filter));
528 
529   // Adds frames with no children (i.e., the innermost frames) to a worklist.
530   std::deque<WhileLoopFrame*> worklist;
531   for (auto& frame : frames) {
532     if (frame.second.num_children == 0) {
533       worklist.push_back(&frame.second);
534     }
535   }
536 
537   // Eliminate loops from innermost to outermost. Note that the precondition for
538   // `node_filter` in `FunctionalizeControlFlow` makes sure that this approach
539   // works.
540   while (!worklist.empty()) {
541     WhileLoopFrame* frame = worklist.front();
542     worklist.pop_front();
543     if (frame->parent == frame) {
544       // Skip the root frame.
545       continue;
546     }
547 
548     TF_RETURN_IF_ERROR(FunctionalizeLoop(graph, frame, library, node_filter));
549 
550     // If the parent has no remaining children, add it to the worklist.
551     --frame->parent->num_children;
552     if (frame->parent->num_children == 0) {
553       worklist.push_back(frame->parent);
554     }
555   }
556 
557   if (!node_filter) {
558     // There should be no cycle at this point, since while loops have been
559     // removed from graph. Check that the newly added While nodes don't feed
560     // into themselves.
561     for (const Node* node : graph->op_nodes()) {
562       if (node->def().op() == "While") {
563         TF_RETURN_WITH_CONTEXT_IF_ERROR(
564             CheckNodeNotInCycle(node, graph->num_node_ids()),
565             "Functionalizing loop failed.");
566       }
567     }
568   }
569 
570   return Status::OK();
571 }
572 
573 }  // namespace tensorflow
574