• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/tf2xla/functionalize_while.h"
17 
18 #include <algorithm>
19 #include <deque>
20 #include <stack>
21 #include <unordered_set>
22 #include <vector>
23 
24 #include "absl/memory/memory.h"
25 #include "absl/strings/match.h"
26 #include "absl/types/optional.h"
27 #include "tensorflow/compiler/tf2xla/frontend_attributes_util.h"
28 #include "tensorflow/compiler/tf2xla/functionalize_cond.h"
29 #include "tensorflow/compiler/tf2xla/tf2xla_util.h"
30 #include "tensorflow/compiler/xla/status_macros.h"
31 #include "tensorflow/compiler/xla/union_find.h"
32 #include "tensorflow/core/common_runtime/function.h"
33 #include "tensorflow/core/framework/graph_to_functiondef.h"
34 #include "tensorflow/core/framework/node_def_builder.h"
35 #include "tensorflow/core/graph/algorithm.h"
36 #include "tensorflow/core/graph/control_flow.h"
37 #include "tensorflow/core/graph/node_builder.h"
38 #include "tensorflow/core/lib/strings/strcat.h"
39 #include "tensorflow/core/util/dump_graph.h"
40 
41 namespace tensorflow {
42 namespace {
43 
44 using xla::StatusOr;
45 
46 // Copies a subgraph from `graph` to `output` by performing a reverse DFS
47 // starting at nodes in vector `stack`.
48 // `node_map` is a vector indexed by source node ID to dest nodes.
49 // Does not traverse into nodes in `node_map`, so by adding nodes to `node_map`
50 // before the traversal clients can cut the graph. If a frame is provided (frame
51 // != nullptr), then this functions will return an error if the
52 // traversal leaves 'frame'; the client must add enough nodes to `node_map` to
53 // cut the graph and prevent the traversal from escaping.
54 //
55 // `squash_src_outputs` contains a bool for each source node ID. If true, then
56 // the source output on that node will be replaced by zero when copied. This is
57 // used when replacing a Switch node with an _Arg node. The output we are
58 // taking from the Switch node was not necessarily the first output, but _Arg
59 // nodes only have one output. By adding the Switch node to `squash_src_outputs`
60 // we rewrite the src_output of the corresponding edge to be 0.
CopySubgraph(const Graph & graph,const WhileLoopFrame * frame,std::vector<Node * > stack,const std::vector<bool> & squash_src_outputs,std::vector<Node * > * node_map,Graph * output)61 Status CopySubgraph(const Graph& graph, const WhileLoopFrame* frame,
62                     std::vector<Node*> stack,
63                     const std::vector<bool>& squash_src_outputs,
64                     std::vector<Node*>* node_map, Graph* output) {
65   VLOG(3) << "Stack: " << NodesToString(stack);
66   std::vector<bool> visited(graph.num_node_ids(), false);
67   while (!stack.empty()) {
68     Node* n = stack.back();
69     stack.pop_back();
70 
71     VLOG(5) << "Copying node " << n->name();
72 
73     if (visited[n->id()]) continue;
74     visited[n->id()] = true;
75 
76     // Sort "n->in_edges()" to make sure nodes are copied in a deterministic
77     // order.
78     std::vector<const Edge*> sorted_edges(n->in_edges().begin(),
79                                           n->in_edges().end());
80     std::sort(sorted_edges.begin(), sorted_edges.end(),
81               [](const Edge* a, const Edge* b) {
82                 int a_src_output = a->src_output(),
83                     b_src_output = b->src_output();
84                 StringPiece a_name(a->src()->name()), b_name(b->src()->name());
85                 return std::tie(a_src_output, a_name) <
86                        std::tie(b_src_output, b_name);
87               });
88     for (const Edge* e : sorted_edges) {
89       Node* src = e->src();
90       if (frame != nullptr && frame->nodes.find(src) == frame->nodes.end()) {
91         // We traversed out of the loop frame, without encountering a cut node.
92         return errors::Internal("Graph traversal of loop frame ", frame->name,
93                                 " escaped frame at ", src->name(),
94                                 " without encountering an argument node.");
95       }
96       if ((*node_map)[src->id()] == nullptr) {
97         (*node_map)[src->id()] = output->CopyNode(src);
98         stack.push_back(src);
99       }
100       Node* src_copy = (*node_map)[e->src()->id()];
101       int src_output = squash_src_outputs[e->src()->id()] && !e->IsControlEdge()
102                            ? 0
103                            : e->src_output();
104       Node* dst_copy = (*node_map)[e->dst()->id()];
105       output->AddEdge(src_copy, src_output, dst_copy, e->dst_input());
106     }
107   }
108   return Status::OK();
109 }
110 
BuildArgNode(Graph * graph,DataType type,int index)111 StatusOr<Node*> BuildArgNode(Graph* graph, DataType type, int index) {
112   const char* const kArgOp = "_Arg";
113   NodeDef arg_def;
114   NodeDefBuilder builder(absl::StrCat(kArgOp, index), kArgOp);
115   builder.Attr("T", type);
116   builder.Attr("index", index);
117   TF_RETURN_IF_ERROR(builder.Finalize(&arg_def));
118   return AddNodeDefToGraph(arg_def, graph);
119 }
120 
121 // Builds a graph for the loop condition.
BuildLoopCondition(const Graph & graph,WhileLoopFrame * frame,std::unique_ptr<Graph> * cond_output)122 Status BuildLoopCondition(const Graph& graph, WhileLoopFrame* frame,
123                           std::unique_ptr<Graph>* cond_output) {
124   VLOG(2) << "Building loop condition for " << frame->name;
125   *cond_output = absl::make_unique<Graph>(graph.op_registry());
126   Graph* output = cond_output->get();
127 
128   // Map from nodes in the original graph to the condition graph.
129   std::vector<Node*> node_map(graph.num_node_ids(), nullptr);
130   std::vector<bool> squash_src_outputs(graph.num_node_ids(), false);
131 
132   // Build one _Arg node for each Enter node.
133   for (int i = 0, end = frame->args.size(); i < end; ++i) {
134     const WhileLoopArg& arg = frame->args[i];
135 
136     TF_ASSIGN_OR_RETURN(Node * arg_node,
137                         BuildArgNode(output, arg.enter->input_type(0), i));
138     if (arg.is_loop_invariant) {
139       node_map[arg.enter->id()] = arg_node;
140     } else {
141       node_map[arg.merge->id()] = arg_node;
142     }
143   }
144 
145   // Build a Retval node for the loop condition. The LoopCond nodes are always
146   // boolean because of the type constraints on the LoopCond op.
147   TF_ASSIGN_OR_RETURN(node_map[frame->loop_cond->id()],
148                       BuildRetvalNode(output, DT_BOOL, 0));
149 
150   // Performs a reverse DFS, copying nodes and edges to the output graph.
151   // The _Arg and _Retval nodes were added unconditionally above, so we are
152   // guaranteed to get the correct function signature.
153   return CopySubgraph(graph, frame, {frame->loop_cond}, squash_src_outputs,
154                       &node_map, output);
155 }
156 
157 // Builds a graph for the loop body.
BuildLoopBody(const Graph & graph,WhileLoopFrame * frame,DataTypeVector * arg_types,std::unique_ptr<Graph> * body_output)158 Status BuildLoopBody(const Graph& graph, WhileLoopFrame* frame,
159                      DataTypeVector* arg_types,
160                      std::unique_ptr<Graph>* body_output) {
161   VLOG(2) << "Building loop body for " << frame->name;
162   *body_output = absl::make_unique<Graph>(graph.op_registry());
163   Graph* output = body_output->get();
164 
165   // Map from nodes in the original graph to the body graph.
166   std::vector<Node*> node_map(graph.num_node_ids(), nullptr);
167   std::vector<bool> squash_src_outputs(graph.num_node_ids(), false);
168 
169   // Build one _Arg node for each Enter node.
170   std::vector<Node*> next_iterations;
171   next_iterations.reserve(frame->args.size());
172   arg_types->reserve(frame->args.size());
173   for (int i = 0, end = frame->args.size(); i < end; ++i) {
174     const WhileLoopArg& arg = frame->args[i];
175 
176     DataType dtype = arg.enter->input_type(0);
177     arg_types->push_back(dtype);
178 
179     TF_ASSIGN_OR_RETURN(Node * arg_node, BuildArgNode(output, dtype, i));
180     TF_ASSIGN_OR_RETURN(Node * retval_node, BuildRetvalNode(output, dtype, i));
181     if (arg.is_loop_invariant) {
182       // Argument is loop-invariant. Forward it from the Arg to the Retval.
183       node_map[arg.enter->id()] = arg_node;
184       output->AddEdge(arg_node, 0, retval_node, 0);
185     } else {
186       // Argument is loop-varying.
187       if (dtype == DT_RESOURCE) {
188         // DT_RESOURCE arguments should always be loop-invariant in the graphs
189         // generated from TF.
190         return errors::Unimplemented("Loop-varying DT_RESOURCE Enter node ",
191                                      arg.enter->name(), " is currently not",
192                                      " supported.");
193       }
194       node_map[arg.switch_node->id()] = arg_node;
195       // The Switch node has two outputs, but _Arg only has one. This tells
196       // the CopySubgraph function to rewrite the output number of edges from
197       // the _Arg node to be 0 rather than copying the output number from the
198       // Switch node.
199       squash_src_outputs[arg.switch_node->id()] = true;
200       node_map[arg.next_iteration->id()] = retval_node;
201       next_iterations.push_back(arg.next_iteration);
202     }
203   }
204 
205   // Performs a reverse DFS, copying nodes and edges to the output graph.
206   // The _Arg and _Retval nodes were added unconditionally above, so we are
207   // guaranteed to get the correct function signature.
208   TF_RETURN_IF_ERROR(CopySubgraph(graph, frame, std::move(next_iterations),
209                                   squash_src_outputs, &node_map, output));
210 
211   return Status::OK();
212 }
213 
FunctionalizeLoop(Graph * graph,WhileLoopFrame * frame,FunctionLibraryDefinition * library,const NodeFilter & node_filter)214 Status FunctionalizeLoop(Graph* graph, WhileLoopFrame* frame,
215                          FunctionLibraryDefinition* library,
216                          const NodeFilter& node_filter) {
217   if (node_filter && !frame->should_be_functionalized) {
218     VLOG(2) << "Skipping functionalization for frame " << frame->name
219             << " because it has control flow nodes that are filtered out by "
220                "the specified node filter.";
221     return Status::OK();
222   }
223   VLOG(2) << "Frame " << frame->name << " before: "
224           << DumpGraphToFile("functionalize_before", *graph, library);
225 
226   // Split loop-varying Enter nodes with multiple successors. If the same
227   // Tensor is fed as input to multiple loop arguments, we may end up with a
228   // shared Enter node. We clone Enter nodes with multiple successors to
229   // maintain the invariant of a unique Enter node per argument of the final
230   // loop.
231   std::vector<WhileLoopArg> args;
232   for (const WhileLoopArg& arg : frame->args) {
233     if (arg.is_loop_invariant) {
234       args.push_back(arg);
235     } else {
236       std::vector<const Edge*> edges(arg.enter->out_edges().begin(),
237                                      arg.enter->out_edges().end());
238       for (int i = 0, end = edges.size(); i < end; ++i) {
239         if (edges[i]->IsControlEdge() && edges[i]->dst()->IsSink()) {
240           continue;
241         }
242         TF_RET_CHECK(!edges[i]->IsControlEdge()) << edges[i]->src()->name();
243         WhileLoopArg new_arg;
244         new_arg.is_loop_invariant = false;
245         if (i == 0) {
246           new_arg.enter = arg.enter;
247         } else {
248           new_arg.enter = graph->CopyNode(arg.enter);
249           frame->nodes.insert(new_arg.enter);
250           for (Edge const* e : arg.enter->in_edges()) {
251             graph->AddEdge(e->src(), e->src_output(), new_arg.enter,
252                            e->IsControlEdge() ? Graph::kControlSlot : 0);
253           }
254           Node* dst = edges[i]->dst();
255           int dst_input = edges[i]->dst_input();
256           graph->RemoveEdge(edges[i]);
257           graph->AddEdge(new_arg.enter, 0, dst, dst_input);
258         }
259         args.push_back(new_arg);
260       }
261     }
262   }
263   frame->args = std::move(args);
264 
265   std::sort(frame->args.begin(), frame->args.end(),
266             [](const WhileLoopArg& a, const WhileLoopArg& b) {
267               return NodeCmpByNameResourcesLast()(a.enter, b.enter);
268             });
269 
270   if (frame->loop_cond == nullptr) {
271     return errors::InvalidArgument("Loop ", frame->name,
272                                    " has no LoopCond node");
273   }
274 
275   // Find the set of Switch nodes that are successors of the LoopCond.
276   std::unordered_set<Node*> switches;
277   for (const Edge* edge : frame->loop_cond->out_edges()) {
278     if (!edge->IsControlEdge() && IsSwitch(edge->dst()) &&
279         edge->dst_input() == 1) {
280       switches.insert(edge->dst());
281     }
282   }
283 
284   // For each non-constant argument, looks for the following pattern of nodes:
285   // Enter ----> Merge  -------->  Switch  --> Exit
286   //               ^                  ^
287   //               |                  |
288   //         NextIteration         LoopCond
289   //               ^                  ^
290   //               |                  |
291   //              ...                ...
292   for (WhileLoopArg& arg : frame->args) {
293     if (!arg.is_loop_invariant) {
294       // Follow the edge from the Enter to Merge.
295       const Edge* enter_merge = nullptr;
296       for (const Edge* e : arg.enter->out_edges()) {
297         // Ignore control-edges to the sink node. These are allowed by the
298         // graph invariants, although probably they should have been stripped
299         // off earlier.
300         if (e->IsControlEdge() && e->dst()->IsSink()) {
301           continue;
302         }
303         if (enter_merge != nullptr) {
304           return errors::Internal("Enter node for loop-varying argument ",
305                                   FormatNodeForError(*arg.enter),
306                                   " has multiple successors: ",
307                                   FormatNodeForError(*enter_merge->dst()),
308                                   " and ", FormatNodeForError(*e->dst()));
309         }
310         enter_merge = e;
311       }
312       if (enter_merge == nullptr) {
313         return errors::Internal("Enter node for loop-varying argument ",
314                                 FormatNodeForError(*arg.enter),
315                                 " has zero successors");
316       }
317       arg.merge = enter_merge->dst();
318       if (!IsMerge(arg.merge)) {
319         return errors::InvalidArgument(
320             "Successor of Enter node for loop-varying argument ",
321             FormatNodeForError(*arg.merge),
322             " is not a Merge node; got: ", arg.merge->type_string());
323       }
324 
325       // Find the NextIteration from the merge. There should be two inputs to
326       // the Merge and the NextIteration should be the other input.
327       if (arg.merge->input_types().size() != 2) {
328         return errors::InvalidArgument(
329             "Unexpected number of inputs to Merge node for loop-varying "
330             "argument ",
331             FormatNodeForError(*arg.merge), "; expected 2, got ",
332             arg.merge->input_types().size());
333       }
334       TF_RETURN_IF_ERROR(arg.merge->input_node(1 - enter_merge->dst_input(),
335                                                &arg.next_iteration));
336       if (!IsNextIteration(arg.next_iteration)) {
337         return errors::InvalidArgument(
338             "Expected NextIteration node as input to Merge node; got node ",
339             FormatNodeForError(*arg.next_iteration), " with kind ",
340             arg.next_iteration->type_string());
341       }
342 
343       // Find the Switch successor of the Merge. There should be exactly one
344       // Switch node that is a successor of both the Merge and the LoopCond.
345       for (const Edge* edge : arg.merge->out_edges()) {
346         if (edge->dst_input() == 0 && IsSwitch(edge->dst()) &&
347             switches.find(edge->dst()) != switches.end()) {
348           if (arg.switch_node != nullptr) {
349             return errors::InvalidArgument("Duplicate Switch successors to ",
350                                            FormatNodeForError(*arg.merge));
351           }
352           arg.switch_node = edge->dst();
353         }
354       }
355       if (arg.switch_node == nullptr) {
356         return errors::InvalidArgument("Missing Switch successor to ",
357                                        FormatNodeForError(*arg.merge));
358       }
359       // Loop over the switch node's output to:
360       // - Find the Exit successor.
361       // - Set the sharding on all Identity outputs of the switch. These
362       //   identity nodes are values used by the loop body or condition.
363       //   The Identity node may have the wrong device so copy the device from
364       //   one of its outputs instead.
365       std::deque<const Edge*> possible_exit;
366       for (const Edge* edge : arg.switch_node->out_edges()) {
367         if (edge->src_output() == 0) {
368           possible_exit.push_back(edge);
369         }
370         if (IsIdentity(edge->dst())) {
371           TF_RETURN_IF_ERROR(
372               SetNodeShardingFromNeighbors(edge->dst(), /*out_edges=*/true));
373         }
374       }
375       // TODO(b/67425339): Allow general graph between switch and exit.
376       while (!possible_exit.empty()) {
377         const Edge* edge = possible_exit.front();
378         possible_exit.pop_front();
379         if (IsExit(edge->dst())) {
380           if (arg.exit != nullptr) {
381             return errors::InvalidArgument(
382                 "Duplicate Exit successors to ",
383                 FormatNodeForError(*arg.switch_node));
384           }
385           arg.exit = edge->dst();
386         } else {
387           if (!IsIdentity(edge->dst())) {
388             return errors::Unimplemented("General graph between switch (",
389                                          FormatNodeForError(*arg.switch_node),
390                                          ") and exit node of frame ",
391                                          frame->name, " not supported yet.");
392           }
393           for (const Edge* out : edge->dst()->out_edges()) {
394             possible_exit.push_back(out);
395           }
396         }
397       }
398     }
399   }
400 
401   // Builds the condition and body functions. Notice that we call
402   // FunctionalizeCond() on cond_graph and body_graph because we might have
403   // unfunctionalized "if" in cond_graph and body_graph. Functionalize them
404   // before they are encapsulated in FunctionDef.
405   std::unique_ptr<Graph> cond_graph;
406   TF_RETURN_IF_ERROR(BuildLoopCondition(*graph, frame, &cond_graph));
407   FixupSourceAndSinkEdges(cond_graph.get());
408   TF_RETURN_IF_ERROR(FunctionalizeCond(cond_graph.get(), library, node_filter));
409   DataTypeVector arg_types;
410   std::unique_ptr<Graph> body_graph;
411   TF_RETURN_IF_ERROR(BuildLoopBody(*graph, frame, &arg_types, &body_graph));
412   FixupSourceAndSinkEdges(body_graph.get());
413   TF_RETURN_IF_ERROR(FunctionalizeCond(body_graph.get(), library, node_filter));
414 
415   VLOG(2) << "Frame " << frame->name << " condition: "
416           << DumpGraphToFile("loop_condition", *cond_graph, library)
417           << " body: " << DumpGraphToFile("loop_body", *body_graph);
418 
419   NameAttrList cond_name;
420   cond_name.set_name(library->UniqueFunctionName("_functionalize_cond_"));
421   NameAttrList body_name;
422   body_name.set_name(library->UniqueFunctionName("_functionalize_body_"));
423   FunctionDef cond_fdef;
424   TF_RETURN_IF_ERROR(
425       GraphToFunctionDef(*cond_graph, cond_name.name(), &cond_fdef));
426   FunctionDef body_fdef;
427   TF_RETURN_IF_ERROR(
428       GraphToFunctionDef(*body_graph, body_name.name(), &body_fdef));
429 
430   TF_RETURN_IF_ERROR(library->AddFunctionDef(cond_fdef));
431   TF_RETURN_IF_ERROR(library->AddFunctionDef(body_fdef));
432 
433   // Builds a While operator.
434   NodeDef while_def;
435   NodeDefBuilder builder(frame->loop_cond->name(), "While", library);
436   builder.Attr("T", arg_types);
437   builder.Attr("cond", cond_name);
438   builder.Attr("body", body_name);
439   // Add some internal attributes which need to be propagated.
440   // TODO(b/160275126): attributes shouldn't be hard-coded here
441   for (const char* attr_name :
442        {kXlaFrontendAttributesAttrName, kXlaOutsideCompilationAttrName,
443         kTpuReplicateAttrName}) {
444     string attr_val;
445     if (GetNodeAttr(frame->loop_cond->def(), attr_name, &attr_val).ok()) {
446       builder.Attr(attr_name, attr_val);
447     }
448   }
449   std::vector<NodeDefBuilder::NodeOut> inputs;
450   for (int i = 0, end = frame->args.size(); i < end; ++i) {
451     const WhileLoopArg& arg = frame->args[i];
452     const Edge* in_edge;
453     TF_RETURN_IF_ERROR(arg.enter->input_edge(0, &in_edge));
454     if (in_edge->IsControlEdge()) {
455       builder.ControlInput(in_edge->src()->name());
456     } else {
457       inputs.push_back(NodeDefBuilder::NodeOut(
458           in_edge->src()->name(), in_edge->src_output(), arg_types[i]));
459     }
460   }
461   builder.Input(inputs);
462   TF_RETURN_IF_ERROR(builder.Finalize(&while_def));
463   TF_ASSIGN_OR_RETURN(Node * while_node, AddNodeDefToGraph(while_def, graph));
464 
465   // Copies edges to the Enter nodes and from the Exit nodes onto the While.
466   for (int i = 0, end = frame->args.size(); i < end; ++i) {
467     const WhileLoopArg& arg = frame->args[i];
468     const Edge* in_edge;
469     TF_RETURN_IF_ERROR(arg.enter->input_edge(0, &in_edge));
470     if (in_edge->IsControlEdge()) {
471       graph->AddControlEdge(in_edge->src(), while_node);
472     } else {
473       graph->AddEdge(in_edge->src(), in_edge->src_output(), while_node, i);
474     }
475 
476     if (!arg.is_loop_invariant) {
477       // Add output edges if the output of the loop is consumed.
478       if (arg.exit != nullptr) {
479         std::vector<const Edge*> edges(arg.exit->out_edges().begin(),
480                                        arg.exit->out_edges().end());
481         for (const Edge* edge : edges) {
482           Node* dst = edge->dst();
483           int dst_input = edge->dst_input();
484           graph->RemoveEdge(edge);
485 
486           if (dst_input == Graph::kControlSlot) {
487             graph->AddControlEdge(while_node, dst);
488           } else {
489             graph->AddEdge(while_node, i, dst, dst_input);
490           }
491         }
492       }
493     }
494   }
495 
496   // Remove the old nodes from the graph, and add the while node to the parent
497   // frame.
498   for (Node* node : frame->nodes) {
499     VLOG(2) << "Removing obsolete node " << node->name();
500     graph->RemoveNode(node);
501   }
502   frame->nodes.clear();
503   frame->parent->nodes.insert(while_node);
504 
505   VLOG(2) << "Frame " << frame->name << " after: "
506           << DumpGraphToFile("functionalize_after", *graph, library);
507 
508   return Status::OK();
509 }
510 }  // namespace
511 
FunctionalizeWhileLoop(Graph * graph,FunctionLibraryDefinition * library,const NodeFilter & node_filter)512 Status FunctionalizeWhileLoop(Graph* graph, FunctionLibraryDefinition* library,
513                               const NodeFilter& node_filter) {
514   // Note: BuildControlFlowInfo() requires that the graph's source node is
515   // connected to all source nodes in the graph. Many graphs violate this
516   // invariant.
517   std::vector<ControlFlowInfo> cf_info;
518   std::vector<string> unreachable_nodes;
519   TF_RETURN_IF_ERROR(BuildControlFlowInfo(graph, &cf_info, &unreachable_nodes));
520   if (!unreachable_nodes.empty()) {
521     return errors::InvalidArgument(
522         "The following nodes are unreachable from the source in the graph: ",
523         errors::FormatNodeNamesForError(unreachable_nodes));
524   }
525 
526   // Builds Frames, indexed by name.
527   std::unordered_map<string, WhileLoopFrame> frames;
528   TF_RETURN_IF_ERROR(
529       ExtractWhileLoopFrames(cf_info, graph, &frames, node_filter));
530 
531   // Adds frames with no children (i.e., the innermost frames) to a worklist.
532   std::deque<WhileLoopFrame*> worklist;
533   for (auto& frame : frames) {
534     if (frame.second.num_children == 0) {
535       worklist.push_back(&frame.second);
536     }
537   }
538 
539   // Eliminate loops from innermost to outermost. Note that the precondition for
540   // `node_filter` in `FunctionalizeControlFlow` makes sure that this approach
541   // works.
542   while (!worklist.empty()) {
543     WhileLoopFrame* frame = worklist.front();
544     worklist.pop_front();
545     if (frame->parent == frame) {
546       // Skip the root frame.
547       continue;
548     }
549 
550     TF_RETURN_IF_ERROR(FunctionalizeLoop(graph, frame, library, node_filter));
551 
552     // If the parent has no remaining children, add it to the worklist.
553     --frame->parent->num_children;
554     if (frame->parent->num_children == 0) {
555       worklist.push_back(frame->parent);
556     }
557   }
558 
559   if (!node_filter) {
560     // There should be no cycle at this point, since while loops have been
561     // removed from graph. Check that the newly added While nodes don't feed
562     // into themselves.
563     for (const Node* node : graph->op_nodes()) {
564       if (node->def().op() == "While") {
565         TF_RETURN_WITH_CONTEXT_IF_ERROR(
566             CheckNodeNotInCycle(node, graph->num_node_ids()),
567             "Functionalizing loop failed.");
568       }
569     }
570   }
571 
572   return Status::OK();
573 }
574 
575 }  // namespace tensorflow
576