1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/tf2xla/functionalize_while.h"
17
18 #include <algorithm>
19 #include <deque>
20 #include <stack>
21 #include <unordered_set>
22 #include <vector>
23
24 #include "absl/memory/memory.h"
25 #include "absl/strings/match.h"
26 #include "absl/types/optional.h"
27 #include "tensorflow/compiler/tf2xla/frontend_attributes_util.h"
28 #include "tensorflow/compiler/tf2xla/functionalize_cond.h"
29 #include "tensorflow/compiler/tf2xla/tf2xla_util.h"
30 #include "tensorflow/compiler/xla/status_macros.h"
31 #include "tensorflow/compiler/xla/union_find.h"
32 #include "tensorflow/core/common_runtime/function.h"
33 #include "tensorflow/core/framework/graph_to_functiondef.h"
34 #include "tensorflow/core/framework/node_def_builder.h"
35 #include "tensorflow/core/graph/algorithm.h"
36 #include "tensorflow/core/graph/control_flow.h"
37 #include "tensorflow/core/graph/node_builder.h"
38 #include "tensorflow/core/lib/strings/strcat.h"
39 #include "tensorflow/core/util/dump_graph.h"
40
41 namespace tensorflow {
42 namespace {
43
44 // Copies a subgraph from `graph` to `output` by performing a reverse DFS
45 // starting at nodes in vector `stack`.
46 // `node_map` is a vector indexed by source node ID to dest nodes.
47 // Does not traverse into nodes in `node_map`, so by adding nodes to `node_map`
48 // before the traversal clients can cut the graph. If a frame is provided (frame
49 // != nullptr), then this functions will return an error if the
50 // traversal leaves 'frame'; the client must add enough nodes to `node_map` to
51 // cut the graph and prevent the traversal from escaping.
52 //
53 // `squash_src_outputs` contains a bool for each source node ID. If true, then
54 // the source output on that node will be replaced by zero when copied. This is
55 // used when replacing a Switch node with an _Arg node. The output we are
56 // taking from the Switch node was not necessarily the first output, but _Arg
57 // nodes only have one output. By adding the Switch node to `squash_src_outputs`
58 // we rewrite the src_output of the corresponding edge to be 0.
CopySubgraph(const Graph & graph,const WhileLoopFrame * frame,std::vector<Node * > stack,const std::vector<bool> & squash_src_outputs,std::vector<Node * > * node_map,Graph * output)59 Status CopySubgraph(const Graph& graph, const WhileLoopFrame* frame,
60 std::vector<Node*> stack,
61 const std::vector<bool>& squash_src_outputs,
62 std::vector<Node*>* node_map, Graph* output) {
63 VLOG(3) << "Stack: " << NodesToString(stack);
64 std::vector<bool> visited(graph.num_node_ids(), false);
65 while (!stack.empty()) {
66 Node* n = stack.back();
67 stack.pop_back();
68
69 VLOG(5) << "Copying node " << n->name();
70
71 if (visited[n->id()]) continue;
72 visited[n->id()] = true;
73
74 // Sort "n->in_edges()" to make sure nodes are copied in a deterministic
75 // order.
76 std::vector<const Edge*> sorted_edges(n->in_edges().begin(),
77 n->in_edges().end());
78 std::sort(sorted_edges.begin(), sorted_edges.end(),
79 [](const Edge* a, const Edge* b) {
80 int a_src_output = a->src_output(),
81 b_src_output = b->src_output();
82 StringPiece a_name(a->src()->name()), b_name(b->src()->name());
83 return std::tie(a_src_output, a_name) <
84 std::tie(b_src_output, b_name);
85 });
86 for (const Edge* e : sorted_edges) {
87 Node* src = e->src();
88 if (frame != nullptr && frame->nodes.find(src) == frame->nodes.end()) {
89 // We traversed out of the loop frame, without encountering a cut node.
90 return errors::Internal("Graph traversal of loop frame ", frame->name,
91 " escaped frame at ", src->name(),
92 " without encountering an argument node.");
93 }
94 if ((*node_map)[src->id()] == nullptr) {
95 (*node_map)[src->id()] = output->CopyNode(src);
96 stack.push_back(src);
97 }
98 Node* src_copy = (*node_map)[e->src()->id()];
99 int src_output = squash_src_outputs[e->src()->id()] && !e->IsControlEdge()
100 ? 0
101 : e->src_output();
102 Node* dst_copy = (*node_map)[e->dst()->id()];
103 output->AddEdge(src_copy, src_output, dst_copy, e->dst_input());
104 }
105 }
106 return Status::OK();
107 }
108
BuildArgNode(Graph * graph,DataType type,int index)109 StatusOr<Node*> BuildArgNode(Graph* graph, DataType type, int index) {
110 const char* const kArgOp = "_Arg";
111 NodeDef arg_def;
112 NodeDefBuilder builder(absl::StrCat(kArgOp, index), kArgOp);
113 builder.Attr("T", type);
114 builder.Attr("index", index);
115 TF_RETURN_IF_ERROR(builder.Finalize(&arg_def));
116 return AddNodeDefToGraph(arg_def, graph);
117 }
118
119 // Builds a graph for the loop condition.
BuildLoopCondition(const Graph & graph,WhileLoopFrame * frame,std::unique_ptr<Graph> * cond_output)120 Status BuildLoopCondition(const Graph& graph, WhileLoopFrame* frame,
121 std::unique_ptr<Graph>* cond_output) {
122 VLOG(2) << "Building loop condition for " << frame->name;
123 *cond_output = absl::make_unique<Graph>(graph.op_registry());
124 Graph* output = cond_output->get();
125
126 // Map from nodes in the original graph to the condition graph.
127 std::vector<Node*> node_map(graph.num_node_ids(), nullptr);
128 std::vector<bool> squash_src_outputs(graph.num_node_ids(), false);
129
130 // Build one _Arg node for each Enter node.
131 for (int i = 0, end = frame->args.size(); i < end; ++i) {
132 const WhileLoopArg& arg = frame->args[i];
133
134 TF_ASSIGN_OR_RETURN(Node * arg_node,
135 BuildArgNode(output, arg.enter->input_type(0), i));
136 if (arg.is_loop_invariant) {
137 node_map[arg.enter->id()] = arg_node;
138 } else {
139 node_map[arg.merge->id()] = arg_node;
140 }
141 }
142
143 // Build a Retval node for the loop condition. The LoopCond nodes are always
144 // boolean because of the type constraints on the LoopCond op.
145 TF_ASSIGN_OR_RETURN(node_map[frame->loop_cond->id()],
146 BuildRetvalNode(output, DT_BOOL, 0));
147
148 // Performs a reverse DFS, copying nodes and edges to the output graph.
149 // The _Arg and _Retval nodes were added unconditionally above, so we are
150 // guaranteed to get the correct function signature.
151 return CopySubgraph(graph, frame, {frame->loop_cond}, squash_src_outputs,
152 &node_map, output);
153 }
154
155 // Builds a graph for the loop body.
BuildLoopBody(const Graph & graph,WhileLoopFrame * frame,DataTypeVector * arg_types,std::unique_ptr<Graph> * body_output)156 Status BuildLoopBody(const Graph& graph, WhileLoopFrame* frame,
157 DataTypeVector* arg_types,
158 std::unique_ptr<Graph>* body_output) {
159 VLOG(2) << "Building loop body for " << frame->name;
160 *body_output = absl::make_unique<Graph>(graph.op_registry());
161 Graph* output = body_output->get();
162
163 // Map from nodes in the original graph to the body graph.
164 std::vector<Node*> node_map(graph.num_node_ids(), nullptr);
165 std::vector<bool> squash_src_outputs(graph.num_node_ids(), false);
166
167 // Build one _Arg node for each Enter node.
168 std::vector<Node*> next_iterations;
169 next_iterations.reserve(frame->args.size());
170 arg_types->reserve(frame->args.size());
171 for (int i = 0, end = frame->args.size(); i < end; ++i) {
172 const WhileLoopArg& arg = frame->args[i];
173
174 DataType dtype = arg.enter->input_type(0);
175 arg_types->push_back(dtype);
176
177 TF_ASSIGN_OR_RETURN(Node * arg_node, BuildArgNode(output, dtype, i));
178 TF_ASSIGN_OR_RETURN(Node * retval_node, BuildRetvalNode(output, dtype, i));
179 if (arg.is_loop_invariant) {
180 // Argument is loop-invariant. Forward it from the Arg to the Retval.
181 node_map[arg.enter->id()] = arg_node;
182 output->AddEdge(arg_node, 0, retval_node, 0);
183 } else {
184 // Argument is loop-varying.
185 if (dtype == DT_RESOURCE) {
186 // DT_RESOURCE arguments should always be loop-invariant in the graphs
187 // generated from TF.
188 return errors::Unimplemented("Loop-varying DT_RESOURCE Enter node ",
189 arg.enter->name(), " is currently not",
190 " supported.");
191 }
192 node_map[arg.switch_node->id()] = arg_node;
193 // The Switch node has two outputs, but _Arg only has one. This tells
194 // the CopySubgraph function to rewrite the output number of edges from
195 // the _Arg node to be 0 rather than copying the output number from the
196 // Switch node.
197 squash_src_outputs[arg.switch_node->id()] = true;
198 node_map[arg.next_iteration->id()] = retval_node;
199 next_iterations.push_back(arg.next_iteration);
200 }
201 }
202
203 // Performs a reverse DFS, copying nodes and edges to the output graph.
204 // The _Arg and _Retval nodes were added unconditionally above, so we are
205 // guaranteed to get the correct function signature.
206 TF_RETURN_IF_ERROR(CopySubgraph(graph, frame, std::move(next_iterations),
207 squash_src_outputs, &node_map, output));
208
209 return Status::OK();
210 }
211
FunctionalizeLoop(Graph * graph,WhileLoopFrame * frame,FunctionLibraryDefinition * library,const NodeFilter & node_filter)212 Status FunctionalizeLoop(Graph* graph, WhileLoopFrame* frame,
213 FunctionLibraryDefinition* library,
214 const NodeFilter& node_filter) {
215 if (node_filter && !frame->should_be_functionalized) {
216 VLOG(2) << "Skipping functionalization for frame " << frame->name
217 << " because it has control flow nodes that are filtered out by "
218 "the specified node filter.";
219 return Status::OK();
220 }
221 VLOG(2) << "Frame " << frame->name << " before: "
222 << DumpGraphToFile("functionalize_before", *graph, library);
223
224 // Split loop-varying Enter nodes with multiple successors. If the same
225 // Tensor is fed as input to multiple loop arguments, we may end up with a
226 // shared Enter node. We clone Enter nodes with multiple successors to
227 // maintain the invariant of a unique Enter node per argument of the final
228 // loop.
229 std::vector<WhileLoopArg> args;
230 for (const WhileLoopArg& arg : frame->args) {
231 if (arg.is_loop_invariant) {
232 args.push_back(arg);
233 } else {
234 std::vector<const Edge*> edges(arg.enter->out_edges().begin(),
235 arg.enter->out_edges().end());
236 for (int i = 0, end = edges.size(); i < end; ++i) {
237 if (edges[i]->IsControlEdge() && edges[i]->dst()->IsSink()) {
238 continue;
239 }
240 TF_RET_CHECK(!edges[i]->IsControlEdge()) << edges[i]->src()->name();
241 WhileLoopArg new_arg;
242 new_arg.is_loop_invariant = false;
243 if (i == 0) {
244 new_arg.enter = arg.enter;
245 } else {
246 new_arg.enter = graph->CopyNode(arg.enter);
247 frame->nodes.insert(new_arg.enter);
248 for (Edge const* e : arg.enter->in_edges()) {
249 graph->AddEdge(e->src(), e->src_output(), new_arg.enter,
250 e->IsControlEdge() ? Graph::kControlSlot : 0);
251 }
252 Node* dst = edges[i]->dst();
253 int dst_input = edges[i]->dst_input();
254 graph->RemoveEdge(edges[i]);
255 graph->AddEdge(new_arg.enter, 0, dst, dst_input);
256 }
257 args.push_back(new_arg);
258 }
259 }
260 }
261 frame->args = std::move(args);
262
263 std::sort(frame->args.begin(), frame->args.end(),
264 [](const WhileLoopArg& a, const WhileLoopArg& b) {
265 return NodeCmpByNameResourcesLast()(a.enter, b.enter);
266 });
267
268 if (frame->loop_cond == nullptr) {
269 return errors::InvalidArgument("Loop ", frame->name,
270 " has no LoopCond node");
271 }
272
273 // Find the set of Switch nodes that are successors of the LoopCond.
274 std::unordered_set<Node*> switches;
275 for (const Edge* edge : frame->loop_cond->out_edges()) {
276 if (!edge->IsControlEdge() && IsSwitch(edge->dst()) &&
277 edge->dst_input() == 1) {
278 switches.insert(edge->dst());
279 }
280 }
281
282 // For each non-constant argument, looks for the following pattern of nodes:
283 // Enter ----> Merge --------> Switch --> Exit
284 // ^ ^
285 // | |
286 // NextIteration LoopCond
287 // ^ ^
288 // | |
289 // ... ...
290 for (WhileLoopArg& arg : frame->args) {
291 if (!arg.is_loop_invariant) {
292 // Follow the edge from the Enter to Merge.
293 const Edge* enter_merge = nullptr;
294 for (const Edge* e : arg.enter->out_edges()) {
295 // Ignore control-edges to the sink node. These are allowed by the
296 // graph invariants, although probably they should have been stripped
297 // off earlier.
298 if (e->IsControlEdge() && e->dst()->IsSink()) {
299 continue;
300 }
301 if (enter_merge != nullptr) {
302 return errors::Internal("Enter node for loop-varying argument ",
303 FormatNodeForError(*arg.enter),
304 " has multiple successors: ",
305 FormatNodeForError(*enter_merge->dst()),
306 " and ", FormatNodeForError(*e->dst()));
307 }
308 enter_merge = e;
309 }
310 if (enter_merge == nullptr) {
311 return errors::Internal("Enter node for loop-varying argument ",
312 FormatNodeForError(*arg.enter),
313 " has zero successors");
314 }
315 arg.merge = enter_merge->dst();
316 if (!IsMerge(arg.merge)) {
317 return errors::InvalidArgument(
318 "Successor of Enter node for loop-varying argument ",
319 FormatNodeForError(*arg.merge),
320 " is not a Merge node; got: ", arg.merge->type_string());
321 }
322
323 // Find the NextIteration from the merge. There should be two inputs to
324 // the Merge and the NextIteration should be the other input.
325 if (arg.merge->input_types().size() != 2) {
326 return errors::InvalidArgument(
327 "Unexpected number of inputs to Merge node for loop-varying "
328 "argument ",
329 FormatNodeForError(*arg.merge), "; expected 2, got ",
330 arg.merge->input_types().size());
331 }
332 TF_RETURN_IF_ERROR(arg.merge->input_node(1 - enter_merge->dst_input(),
333 &arg.next_iteration));
334 if (!IsNextIteration(arg.next_iteration)) {
335 return errors::InvalidArgument(
336 "Expected NextIteration node as input to Merge node; got node ",
337 FormatNodeForError(*arg.next_iteration), " with kind ",
338 arg.next_iteration->type_string());
339 }
340
341 // Find the Switch successor of the Merge. There should be exactly one
342 // Switch node that is a successor of both the Merge and the LoopCond.
343 for (const Edge* edge : arg.merge->out_edges()) {
344 if (edge->dst_input() == 0 && IsSwitch(edge->dst()) &&
345 switches.find(edge->dst()) != switches.end()) {
346 if (arg.switch_node != nullptr) {
347 return errors::InvalidArgument("Duplicate Switch successors to ",
348 FormatNodeForError(*arg.merge));
349 }
350 arg.switch_node = edge->dst();
351 }
352 }
353 if (arg.switch_node == nullptr) {
354 return errors::InvalidArgument("Missing Switch successor to ",
355 FormatNodeForError(*arg.merge));
356 }
357 // Loop over the switch node's output to:
358 // - Find the Exit successor.
359 // - Set the sharding on all Identity outputs of the switch. These
360 // identity nodes are values used by the loop body or condition.
361 // The Identity node may have the wrong device so copy the device from
362 // one of its outputs instead.
363 std::deque<const Edge*> possible_exit;
364 for (const Edge* edge : arg.switch_node->out_edges()) {
365 if (edge->src_output() == 0) {
366 possible_exit.push_back(edge);
367 }
368 if (IsIdentity(edge->dst())) {
369 TF_RETURN_IF_ERROR(
370 SetNodeShardingFromNeighbors(edge->dst(), /*out_edges=*/true));
371 }
372 }
373 // TODO(b/67425339): Allow general graph between switch and exit.
374 while (!possible_exit.empty()) {
375 const Edge* edge = possible_exit.front();
376 possible_exit.pop_front();
377 if (IsExit(edge->dst())) {
378 if (arg.exit != nullptr) {
379 return errors::InvalidArgument(
380 "Duplicate Exit successors to ",
381 FormatNodeForError(*arg.switch_node));
382 }
383 arg.exit = edge->dst();
384 } else {
385 if (!IsIdentity(edge->dst())) {
386 return errors::Unimplemented("General graph between switch (",
387 FormatNodeForError(*arg.switch_node),
388 ") and exit node of frame ",
389 frame->name, " not supported yet.");
390 }
391 for (const Edge* out : edge->dst()->out_edges()) {
392 possible_exit.push_back(out);
393 }
394 }
395 }
396 }
397 }
398
399 // Builds the condition and body functions. Notice that we call
400 // FunctionalizeCond() on cond_graph and body_graph because we might have
401 // unfunctionalized "if" in cond_graph and body_graph. Functionalize them
402 // before they are encapsulated in FunctionDef.
403 std::unique_ptr<Graph> cond_graph;
404 TF_RETURN_IF_ERROR(BuildLoopCondition(*graph, frame, &cond_graph));
405 FixupSourceAndSinkEdges(cond_graph.get());
406 TF_RETURN_IF_ERROR(FunctionalizeCond(cond_graph.get(), library, node_filter));
407 DataTypeVector arg_types;
408 std::unique_ptr<Graph> body_graph;
409 TF_RETURN_IF_ERROR(BuildLoopBody(*graph, frame, &arg_types, &body_graph));
410 FixupSourceAndSinkEdges(body_graph.get());
411 TF_RETURN_IF_ERROR(FunctionalizeCond(body_graph.get(), library, node_filter));
412
413 VLOG(2) << "Frame " << frame->name << " condition: "
414 << DumpGraphToFile("loop_condition", *cond_graph, library)
415 << " body: " << DumpGraphToFile("loop_body", *body_graph);
416
417 NameAttrList cond_name;
418 cond_name.set_name(library->UniqueFunctionName("_functionalize_cond_"));
419 NameAttrList body_name;
420 body_name.set_name(library->UniqueFunctionName("_functionalize_body_"));
421 FunctionDef cond_fdef;
422 TF_RETURN_IF_ERROR(
423 GraphToFunctionDef(*cond_graph, cond_name.name(), &cond_fdef));
424 FunctionDef body_fdef;
425 TF_RETURN_IF_ERROR(
426 GraphToFunctionDef(*body_graph, body_name.name(), &body_fdef));
427
428 TF_RETURN_IF_ERROR(library->AddFunctionDef(cond_fdef));
429 TF_RETURN_IF_ERROR(library->AddFunctionDef(body_fdef));
430
431 // Builds a While operator.
432 NodeDef while_def;
433 NodeDefBuilder builder(frame->loop_cond->name(), "While", library);
434 builder.Attr("T", arg_types);
435 builder.Attr("cond", cond_name);
436 builder.Attr("body", body_name);
437 // Add some internal attributes which need to be propagated.
438 // TODO(b/160275126): attributes shouldn't be hard-coded here
439 for (const char* attr_name :
440 {kXlaFrontendAttributesAttrName, kXlaOutsideCompilationAttrName,
441 kTpuReplicateAttrName}) {
442 string attr_val;
443 if (GetNodeAttr(frame->loop_cond->def(), attr_name, &attr_val).ok()) {
444 builder.Attr(attr_name, attr_val);
445 }
446 }
447 std::vector<NodeDefBuilder::NodeOut> inputs;
448 for (int i = 0, end = frame->args.size(); i < end; ++i) {
449 const WhileLoopArg& arg = frame->args[i];
450 const Edge* in_edge;
451 TF_RETURN_IF_ERROR(arg.enter->input_edge(0, &in_edge));
452 if (in_edge->IsControlEdge()) {
453 builder.ControlInput(in_edge->src()->name());
454 } else {
455 inputs.push_back(NodeDefBuilder::NodeOut(
456 in_edge->src()->name(), in_edge->src_output(), arg_types[i]));
457 }
458 }
459 builder.Input(inputs);
460 TF_RETURN_IF_ERROR(builder.Finalize(&while_def));
461 TF_ASSIGN_OR_RETURN(Node * while_node, AddNodeDefToGraph(while_def, graph));
462
463 // Copies edges to the Enter nodes and from the Exit nodes onto the While.
464 for (int i = 0, end = frame->args.size(); i < end; ++i) {
465 const WhileLoopArg& arg = frame->args[i];
466 const Edge* in_edge;
467 TF_RETURN_IF_ERROR(arg.enter->input_edge(0, &in_edge));
468 if (in_edge->IsControlEdge()) {
469 graph->AddControlEdge(in_edge->src(), while_node);
470 } else {
471 graph->AddEdge(in_edge->src(), in_edge->src_output(), while_node, i);
472 }
473
474 if (!arg.is_loop_invariant) {
475 // Add output edges if the output of the loop is consumed.
476 if (arg.exit != nullptr) {
477 std::vector<const Edge*> edges(arg.exit->out_edges().begin(),
478 arg.exit->out_edges().end());
479 for (const Edge* edge : edges) {
480 Node* dst = edge->dst();
481 int dst_input = edge->dst_input();
482 graph->RemoveEdge(edge);
483
484 if (dst_input == Graph::kControlSlot) {
485 graph->AddControlEdge(while_node, dst);
486 } else {
487 graph->AddEdge(while_node, i, dst, dst_input);
488 }
489 }
490 }
491 }
492 }
493
494 // Remove the old nodes from the graph, and add the while node to the parent
495 // frame.
496 for (Node* node : frame->nodes) {
497 VLOG(2) << "Removing obsolete node " << node->name();
498 graph->RemoveNode(node);
499 }
500 frame->nodes.clear();
501 frame->parent->nodes.insert(while_node);
502
503 VLOG(2) << "Frame " << frame->name << " after: "
504 << DumpGraphToFile("functionalize_after", *graph, library);
505
506 return Status::OK();
507 }
508 } // namespace
509
FunctionalizeWhileLoop(Graph * graph,FunctionLibraryDefinition * library,const NodeFilter & node_filter)510 Status FunctionalizeWhileLoop(Graph* graph, FunctionLibraryDefinition* library,
511 const NodeFilter& node_filter) {
512 // Note: BuildControlFlowInfo() requires that the graph's source node is
513 // connected to all source nodes in the graph. Many graphs violate this
514 // invariant.
515 std::vector<ControlFlowInfo> cf_info;
516 std::vector<string> unreachable_nodes;
517 TF_RETURN_IF_ERROR(BuildControlFlowInfo(graph, &cf_info, &unreachable_nodes));
518 if (!unreachable_nodes.empty()) {
519 return errors::InvalidArgument(
520 "The following nodes are unreachable from the source in the graph: ",
521 errors::FormatNodeNamesForError(unreachable_nodes));
522 }
523
524 // Builds Frames, indexed by name.
525 std::unordered_map<string, WhileLoopFrame> frames;
526 TF_RETURN_IF_ERROR(
527 ExtractWhileLoopFrames(cf_info, graph, &frames, node_filter));
528
529 // Adds frames with no children (i.e., the innermost frames) to a worklist.
530 std::deque<WhileLoopFrame*> worklist;
531 for (auto& frame : frames) {
532 if (frame.second.num_children == 0) {
533 worklist.push_back(&frame.second);
534 }
535 }
536
537 // Eliminate loops from innermost to outermost. Note that the precondition for
538 // `node_filter` in `FunctionalizeControlFlow` makes sure that this approach
539 // works.
540 while (!worklist.empty()) {
541 WhileLoopFrame* frame = worklist.front();
542 worklist.pop_front();
543 if (frame->parent == frame) {
544 // Skip the root frame.
545 continue;
546 }
547
548 TF_RETURN_IF_ERROR(FunctionalizeLoop(graph, frame, library, node_filter));
549
550 // If the parent has no remaining children, add it to the worklist.
551 --frame->parent->num_children;
552 if (frame->parent->num_children == 0) {
553 worklist.push_back(frame->parent);
554 }
555 }
556
557 if (!node_filter) {
558 // There should be no cycle at this point, since while loops have been
559 // removed from graph. Check that the newly added While nodes don't feed
560 // into themselves.
561 for (const Node* node : graph->op_nodes()) {
562 if (node->def().op() == "While") {
563 TF_RETURN_WITH_CONTEXT_IF_ERROR(
564 CheckNodeNotInCycle(node, graph->num_node_ids()),
565 "Functionalizing loop failed.");
566 }
567 }
568 }
569
570 return Status::OK();
571 }
572
573 } // namespace tensorflow
574