• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/tf2xla/functionalize_control_flow.h"
17 
18 #include "tensorflow/cc/framework/ops.h"
19 #include "tensorflow/cc/ops/control_flow_ops_internal.h"
20 #include "tensorflow/cc/ops/function_ops.h"
21 #include "tensorflow/cc/ops/functional_ops.h"
22 #include "tensorflow/cc/ops/resource_variable_ops.h"
23 #include "tensorflow/cc/ops/standard_ops.h"
24 #include "tensorflow/compiler/tf2xla/cc/ops/xla_ops.h"
25 #include "tensorflow/compiler/tf2xla/test_util.h"
26 #include "tensorflow/compiler/xla/status_macros.h"
27 #include "tensorflow/core/common_runtime/function.h"
28 #include "tensorflow/core/common_runtime/graph_constructor.h"
29 #include "tensorflow/core/framework/function.h"
30 #include "tensorflow/core/framework/graph_to_functiondef.h"
31 #include "tensorflow/core/framework/node_def_util.h"
32 #include "tensorflow/core/framework/op.h"
33 #include "tensorflow/core/graph/graph_def_builder.h"
34 #include "tensorflow/core/graph/validate.h"
35 #include "tensorflow/core/lib/core/status_test_util.h"
36 #include "tensorflow/core/platform/test.h"
37 #include "tensorflow/core/public/version.h"
38 #include "tensorflow/core/util/dump_graph.h"
39 #include "tensorflow/core/util/equal_graph_def.h"
40 
41 namespace tensorflow {
42 namespace {
43 
44 // Returns the names of the "then" and "else" functions for the If node in a
45 // graph.
FindIfThenAndElse(const GraphDef & graph,string * op_name,NameAttrList * then_fn,NameAttrList * else_fn)46 Status FindIfThenAndElse(const GraphDef& graph, string* op_name,
47                          NameAttrList* then_fn, NameAttrList* else_fn) {
48   for (const NodeDef& node : graph.node()) {
49     if (node.op() == "If") {
50       *op_name = node.name();
51       const NameAttrList* result;
52       TF_RETURN_IF_ERROR(GetNodeAttr(node, "then_branch", &result));
53       *then_fn = *result;
54       TF_RETURN_IF_ERROR(GetNodeAttr(node, "else_branch", &result));
55       *else_fn = *result;
56       return Status::OK();
57     }
58   }
59   return errors::NotFound("No If node found in graph");
60 }
61 
62 // Graph:
63 // x = array_ops.placeholder(dtypes.int32)
64 // y = array_ops.placeholder(dtypes.int32)
65 // z = control_flow_ops.cond(
66 //     math_ops.less(y, x), lambda: math_ops.multiply(y, 17),
67 //     lambda: math_ops.add(x, 23))
68 //
69 // Tests different node filters and functionalization inside of a function.
70 class ConditionalTestFixture
71     : public ::testing::TestWithParam<std::tuple<bool, bool>> {
72  protected:
SetUp()73   void SetUp() override {
74     restrict_to_tpu_nodes_ = std::get<0>(GetParam());
75     wrap_condition_in_function_ = std::get<1>(GetParam());
76   }
77   void RunTest();
78 
79  private:
80   void BuildCondGraph(Graph* cond_graph);
81   void CheckGraphDef(const GraphDef& graph_def,
82                      const FunctionLibraryDefinition& library);
83 
84   bool restrict_to_tpu_nodes_ = false;
85   bool wrap_condition_in_function_ = false;
86 };
87 
TEST_P(ConditionalTestFixture,ConditionalTests)88 TEST_P(ConditionalTestFixture, ConditionalTests) { RunTest(); }
89 
90 INSTANTIATE_TEST_SUITE_P(
91     FunctionalizeControlFlow, ConditionalTestFixture,
92     ::testing::Combine(::testing::Bool(), ::testing::Bool()),
93     [](const ::testing::TestParamInfo<ConditionalTestFixture::ParamType>&
__anon7e28f50b0202(const ::testing::TestParamInfo<ConditionalTestFixture::ParamType>& info) 94            info) {
95       bool restrict_to_tpu_nodes = std::get<0>(info.param);
96       bool wrap_cond_in_function = std::get<1>(info.param);
97       string name =
98           absl::StrCat(restrict_to_tpu_nodes ? "with_filter" : "without_filter",
99                        wrap_cond_in_function ? "_in_function" : "_in_graph");
100       return name;
101     });
102 
BuildCondGraph(Graph * cond_graph)103 void ConditionalTestFixture::BuildCondGraph(Graph* cond_graph) {
104   {
105     Scope scope = Scope::NewRootScope().ExitOnError();
106 
107     auto x = ops::Placeholder(scope.WithOpName("x"), DT_INT32);
108     auto y = ops::Placeholder(scope.WithOpName("y"), DT_INT32);
109     auto less = ops::Less(scope.WithOpName("cond/Less"), y, x);
110     auto switch_1 = ops::Switch(scope.WithOpName("cond/Switch"), less, less);
111 
112     auto identity_t =
113         ops::Identity(scope.WithOpName("cond/Identity"), switch_1.output_true);
114     auto seventeen = ops::Const<int32>(
115         scope.WithOpName("cond").WithControlDependencies(identity_t), 17);
116     auto switch_2 = ops::Switch(scope.WithOpName("cond/Switch"), y, less);
117     auto mul = ops::Multiply(scope.WithOpName("cond/Mul"), switch_2.output_true,
118                              seventeen);
119 
120     auto identity_f =
121         ops::Identity(scope.WithOpName("cond/Identity"), switch_1.output_false);
122     auto twenty_three = ops::Const<int32>(
123         scope.WithOpName("cond").WithControlDependencies(identity_f), 23);
124     auto switch_3 = ops::Switch(scope.WithOpName("cond/Switch"), x, less);
125     auto add = ops::Add(scope.WithOpName("cond/false/add"),
126                         switch_3.output_false, twenty_three);
127 
128     auto merge = ops::Merge(scope.WithOpName("cond/Merge"),
129                             std::initializer_list<Input>{add, mul});
130 
131     TF_EXPECT_OK(scope.ToGraph(cond_graph));
132 
133     // Set `_tpu_replicate` attribute for all nodes.
134     for (Node* n : cond_graph->nodes()) {
135       n->AddAttr("_tpu_replicate", "cluster");
136     }
137   }
138 }
139 
CheckGraphDef(const GraphDef & graph_def,const FunctionLibraryDefinition & library)140 void ConditionalTestFixture::CheckGraphDef(
141     const GraphDef& graph_def, const FunctionLibraryDefinition& library) {
142   string op_name;
143   NameAttrList then_fn;
144   NameAttrList else_fn;
145   TF_EXPECT_OK(FindIfThenAndElse(graph_def, &op_name, &then_fn, &else_fn));
146   InstantiationResultForTest else_result;
147   TF_EXPECT_OK(
148       InstantiateFunctionForTest(else_fn.name(), library, &else_result));
149 
150   // Outer graph
151   {
152     Scope scope = Scope::NewRootScope().ExitOnError();
153     auto y = ops::Placeholder(scope.WithOpName("y"), DT_INT32);
154     auto x = ops::Placeholder(scope.WithOpName("x"), DT_INT32);
155     auto less = ops::Less(scope.WithOpName("cond/Less"), y, x);
156     auto if_op =
157         ops::If(scope.WithOpName(op_name), less,
158                 std::initializer_list<Input>{less, y, x}, {DT_INT32}, then_fn,
159                 else_fn, ops::If::OutputShapes({PartialTensorShape()}));
160     auto id = ops::Identity(scope.WithOpName("cond/Merge"), if_op.output[0]);
161     GraphDef expected;
162     TF_EXPECT_OK(scope.ToGraphDef(&expected));
163     TF_EXPECT_GRAPH_EQ(expected, graph_def);
164   }
165 
166   // then body.
167   {
168     Scope scope = Scope::NewRootScope().ExitOnError();
169     auto arg_0 = ops::_Arg(scope.WithOpName("arg0"), DT_BOOL, 0);
170     auto arg_1 = ops::_Arg(scope.WithOpName("arg1"), DT_INT32, 1);
171     auto arg_2 = ops::_Arg(scope.WithOpName("arg2"), DT_INT32, 2);
172     auto identity = ops::Identity(scope.WithOpName("cond/Identity"), arg_0);
173     auto cond = ops::Const(
174         scope.WithOpName("cond").WithControlDependencies(identity), 17);
175     auto mul = ops::Mul(scope.WithOpName("cond/Mul"), arg_1, cond);
176     auto retval0 = ops::_Retval(scope.WithOpName("retval0_RetVal"), mul, 0);
177 
178     GraphDef expected;
179     TF_EXPECT_OK(scope.ToGraphDef(&expected));
180 
181     InstantiationResultForTest result;
182     TF_EXPECT_OK(InstantiateFunctionForTest(then_fn.name(), library, &result));
183 
184     EXPECT_EQ(DataTypeVector{DT_INT32}, result.ret_types);
185     EXPECT_EQ((DataTypeVector{DT_BOOL, DT_INT32, DT_INT32}), result.arg_types);
186     TF_EXPECT_GRAPH_EQ(expected, result.gdef);
187   }
188 
189   // else body.
190   {
191     Scope scope = Scope::NewRootScope().ExitOnError();
192     auto arg_0 = ops::_Arg(scope.WithOpName("arg0"), DT_BOOL, 0);
193     auto arg_1 = ops::_Arg(scope.WithOpName("arg1"), DT_INT32, 1);
194     auto arg_2 = ops::_Arg(scope.WithOpName("arg2"), DT_INT32, 2);
195     auto identity = ops::Identity(scope.WithOpName("cond/Identity_1"), arg_0);
196     auto cond_1 = ops::Const(
197         scope.WithOpName("cond_1").WithControlDependencies(identity), 23);
198     auto add = ops::Add(scope.WithOpName("cond/false/add"), arg_2, cond_1);
199     auto retval0 = ops::_Retval(scope.WithOpName("retval0_RetVal"), add, 0);
200 
201     GraphDef expected;
202     TF_EXPECT_OK(scope.ToGraphDef(&expected));
203 
204     InstantiationResultForTest result;
205     TF_EXPECT_OK(InstantiateFunctionForTest(else_fn.name(), library, &result));
206 
207     EXPECT_EQ(DataTypeVector{DT_INT32}, result.ret_types);
208     EXPECT_EQ((DataTypeVector{DT_BOOL, DT_INT32, DT_INT32}), result.arg_types);
209     TF_EXPECT_GRAPH_EQ(expected, result.gdef);
210   }
211 }
212 
RunTest()213 void ConditionalTestFixture::RunTest() {
214   Graph graph(OpRegistry::Global());
215   if (wrap_condition_in_function_) {
216     // Wrap condition in a function which is called from `graph`.
217     Scope scope = Scope::NewRootScope().ExitOnError();
218     auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32);
219 
220     Graph cond_graph(OpRegistry::Global());
221     BuildCondGraph(&cond_graph);
222 
223     FunctionDef cond_fdef;
224     TF_ASSERT_OK(GraphToFunctionDef(cond_graph, "cond_fn", &cond_fdef));
225 
226     FunctionDefLibrary fdef_lib;
227     *(fdef_lib.add_function()) = cond_fdef;
228     TF_ASSERT_OK(scope.graph()->AddFunctionLibrary(fdef_lib));
229     NodeDef cond_fn;
230     cond_fn.set_name("cond_node");
231     cond_fn.set_op("cond_fn");
232     *(cond_fn.add_input()) = "source";
233     Status status;
234     scope.graph()->AddNode(cond_fn, &status);
235     TF_ASSERT_OK(status);
236     TF_ASSERT_OK(scope.ToGraph(&graph));
237   } else {
238     // Build condition in `graph`.
239     BuildCondGraph(&graph);
240   }
241   FunctionLibraryDefinition library(graph.flib_def());
242   // If `restrict_to_tpu_nodes_` is true let filter function return true for
243   // `_tpu_replicate` nodes.
244   NodeFilter node_filter =
245       restrict_to_tpu_nodes_
246           ? [](const Node* n) { return n->attrs().Find("_tpu_replicate"); }
247           : NodeFilter{};
248 
249   GraphDef optimized_graph_def;
250   graph.ToGraphDef(&optimized_graph_def);
251   TF_ASSERT_OK(FunctionalizeControlFlowForGraphDef(
252       &optimized_graph_def, &library, node_filter,
253       /*include_functions=*/wrap_condition_in_function_));
254   TF_ASSERT_OK(FunctionalizeControlFlow(
255       &graph, &library, node_filter,
256       /*include_functions=*/wrap_condition_in_function_));
257 
258   if (wrap_condition_in_function_) {
259     // Check if function body was functionalized.
260     auto pflr = absl::make_unique<ProcessFunctionLibraryRuntime>(
261         /*device_mgr=*/nullptr, tensorflow::Env::Default(),
262         /*config=*/nullptr, TF_GRAPH_DEF_VERSION, &library,
263         tensorflow::OptimizerOptions());
264     FunctionLibraryRuntime* flr =
265         pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice);
266     FunctionLibraryRuntime::Handle handle;
267 
268     // Functionalized function name is the type string of `cond_node`.
269     string func_name;
270     for (Node* n : graph.nodes()) {
271       if (n->name() == "cond_node") {
272         func_name = n->type_string();
273         break;
274       }
275     }
276     TF_ASSERT_OK(flr->Instantiate(func_name, AttrSlice(), &handle));
277     const FunctionBody* body = flr->GetFunctionBody(handle);
278     GraphDef graph_def;
279     body->graph->ToGraphDef(&graph_def);
280     CheckGraphDef(graph_def, library);
281   } else {
282     // Check if graphs were functionalized.
283     CheckGraphDef(optimized_graph_def, library);
284     GraphDef converted_graph_def;
285     graph.ToGraphDef(&converted_graph_def);
286     CheckGraphDef(converted_graph_def, library);
287   }
288 }
289 
290 // Returns the names of the "cond" and "body" functions for the While node
291 // in a graph.
FindWhileCondAndBody(const GraphDef & graph,NameAttrList * cond,NameAttrList * body)292 Status FindWhileCondAndBody(const GraphDef& graph, NameAttrList* cond,
293                             NameAttrList* body) {
294   for (const NodeDef& node : graph.node()) {
295     if (node.op() == "While") {
296       const NameAttrList* result;
297       TF_RETURN_IF_ERROR(GetNodeAttr(node, "cond", &result));
298       *cond = *result;
299       TF_RETURN_IF_ERROR(GetNodeAttr(node, "body", &result));
300       *body = *result;
301       return Status::OK();
302     }
303   }
304   return errors::NotFound("No While node found in graph");
305 }
306 
307 // Graph:
308 // x = array_ops.placeholder(dtypes.int32)
309 // y = control_flow_ops.while_loop(lambda i: i < 10, lambda i: i + 1, [x])
TEST(FunctionalizeControlFlow,OneLoopVar)310 TEST(FunctionalizeControlFlow, OneLoopVar) {
311   Graph graph(OpRegistry::Global());
312   {
313     Scope scope = Scope::NewRootScope().ExitOnError();
314 
315     auto dummy = ops::Placeholder(scope.WithOpName("Dummy"), DT_INT32);
316 
317     auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32);
318     auto enter =
319         ops::internal::Enter(scope.WithOpName("while/Enter"), source, "aloop");
320     // Add an unused Enter node. These should be ignored.
321     auto enter2 =
322         ops::internal::Enter(scope.WithOpName("while/Enter2"), source, "aloop");
323     auto merge = ops::Merge(scope.WithOpName("while/Merge"),
324                             std::initializer_list<Input>{enter, dummy});
325     auto ten = ops::Const<int32>(
326         scope.WithOpName("while/Less/y").WithControlDependencies(merge.output),
327         10);
328     auto less = ops::Less(scope.WithOpName("while/Less"), merge.output, ten);
329     auto loop_cond = ops::LoopCond(scope.WithOpName("while/LoopCond"), less);
330     auto switch_ =
331         ops::Switch(scope.WithOpName("while/Switch"), merge.output, loop_cond);
332     auto exit = ops::internal::Exit(scope.WithOpName("while/Exit"),
333                                     switch_.output_false);
334     auto identity =
335         ops::Identity(scope.WithOpName("while/Identity"), switch_.output_true);
336     auto one = ops::Const<int32>(
337         scope.WithOpName("while/add/y").WithControlDependencies(identity), 1);
338     auto add = ops::Add(scope.WithOpName("while/add"), identity, one);
339     auto next_iteration =
340         ops::NextIteration(scope.WithOpName("while/NextIteration"), add);
341 
342     auto sink = ops::Identity(scope.WithOpName("sink"), exit);
343 
344     // Remove the dummy node and add the loop backedge.
345     scope.graph()->RemoveNode(dummy.node());
346     scope.graph()->AddEdge(next_iteration.node(), 0, merge.output.node(), 1);
347 
348     TF_EXPECT_OK(scope.ToGraph(&graph));
349   }
350 
351   // Regression test: control edges from an Enter node to the graph sink should
352   // be ignored.
353   for (Node* n : graph.nodes()) {
354     if (n->name() == "while/Enter") {
355       graph.AddControlEdge(n, graph.sink_node());
356     }
357   }
358 
359   FunctionLibraryDefinition library(OpRegistry::Global(), {});
360   GraphDef optimized_graph_def;
361   graph.ToGraphDef(&optimized_graph_def);
362   TF_ASSERT_OK(
363       FunctionalizeControlFlowForGraphDef(&optimized_graph_def, &library));
364   TF_ASSERT_OK(FunctionalizeControlFlow(&graph, &library));
365   GraphDef converted_graph_def;
366   graph.ToGraphDef(&converted_graph_def);
367 
368   for (const GraphDef& graph_def : {optimized_graph_def, converted_graph_def}) {
369     NameAttrList cond_fn, body_fn;
370     TF_EXPECT_OK(FindWhileCondAndBody(graph_def, &cond_fn, &body_fn));
371 
372     // Outer graph
373     {
374       Scope scope = Scope::NewRootScope().ExitOnError();
375       auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32);
376       auto while_op =
377           ops::While(scope.WithOpName("while/LoopCond"),
378                      std::initializer_list<Input>{source}, cond_fn, body_fn);
379       auto sink = ops::Identity(scope.WithOpName("sink"), while_op[0]);
380       GraphDef expected;
381       TF_EXPECT_OK(scope.ToGraphDef(&expected));
382       TF_EXPECT_GRAPH_EQ(expected, graph_def);
383     }
384 
385     // Condition graph
386     {
387       Scope scope = Scope::NewRootScope().ExitOnError();
388       auto arg = ops::_Arg(scope.WithOpName("arg0"), DT_INT32, 0);
389       auto ten = ops::Const<int32>(
390           scope.WithOpName("while/Less/y").WithControlDependencies(arg), 10);
391       auto less = ops::Less(scope.WithOpName("while/Less"), arg, ten);
392       auto retval = ops::_Retval(scope.WithOpName("retval0_RetVal"), less, 0);
393 
394       GraphDef expected;
395       TF_EXPECT_OK(scope.ToGraphDef(&expected));
396 
397       InstantiationResultForTest result;
398       TF_EXPECT_OK(
399           InstantiateFunctionForTest(cond_fn.name(), library, &result));
400 
401       EXPECT_EQ(DataTypeVector{DT_INT32}, result.arg_types);
402       EXPECT_EQ(DataTypeVector{DT_BOOL}, result.ret_types);
403       TF_EXPECT_GRAPH_EQ(expected, result.gdef);
404     }
405 
406     // Body graph.
407     {
408       Scope scope = Scope::NewRootScope().ExitOnError();
409       auto arg = ops::_Arg(scope.WithOpName("arg0"), DT_INT32, 0);
410       auto identity = ops::Identity(scope.WithOpName("while/Identity"), arg);
411       auto one = ops::Const<int32>(
412           scope.WithOpName("while/add/y").WithControlDependencies(identity), 1);
413       auto add = ops::Add(scope.WithOpName("while/add"), identity, one);
414       auto retval = ops::_Retval(scope.WithOpName("retval0_RetVal"), add, 0);
415 
416       GraphDef expected;
417       TF_EXPECT_OK(scope.ToGraphDef(&expected));
418 
419       InstantiationResultForTest result;
420       TF_EXPECT_OK(
421           InstantiateFunctionForTest(body_fn.name(), library, &result));
422 
423       EXPECT_EQ(DataTypeVector{DT_INT32}, result.arg_types);
424       EXPECT_EQ(DataTypeVector{DT_INT32}, result.ret_types);
425       TF_EXPECT_GRAPH_EQ(expected, result.gdef);
426     }
427   }
428 }
429 
GetNoinlineFunctionDef()430 FunctionDef GetNoinlineFunctionDef() {
431   FunctionDef fdef = FunctionDefHelper::Create(
432       "increment_fn", {"x:int32"}, {"add:int32"}, {},
433       {
434           {{"add/y"}, "Const", {}, {{"dtype", DT_INT32}}},
435           {{"add_0"}, "Add", {"x", "add/y:output:0"}, {{"T", DT_INT32}}},
436       },
437       {{"add", "add_0:z:0"}});
438   (*fdef.mutable_attr())["_noinline"].set_b(true);
439   return fdef;
440 }
441 
442 // @function.Defun(noinline=True)
443 // def increment_fn(x):
444 //   return [x + 1]
445 // Define the above function, and add it to the given graph. It's used as the
446 // while loop body in NoinlineLoopBody test.
AddNoinlineFunctionToGraph(const string & node_name,Graph * graph)447 Status AddNoinlineFunctionToGraph(const string& node_name, Graph* graph) {
448   FunctionDefLibrary fdef_lib;
449   *(fdef_lib.add_function()) = GetNoinlineFunctionDef();
450   TF_RETURN_IF_ERROR(graph->AddFunctionLibrary(fdef_lib));
451   NodeDef increment_fn;
452   increment_fn.set_name(node_name);
453   increment_fn.set_op("increment_fn");
454   *increment_fn.add_input() = "while/Identity";
455   *increment_fn.add_input() = "^while/Identity";
456   Status status;
457   graph->AddNode(increment_fn, &status);
458   return status;
459 }
460 
461 // Graph:
462 // x = array_ops.placeholder(dtypes.int32)
463 // y = control_flow_ops.while_loop(lambda i: i < 10, increment_fn, [x])
TEST(FunctionalizeControlFlow,NoinlineLoopBody)464 TEST(FunctionalizeControlFlow, NoinlineLoopBody) {
465   const string& noinline_node_name = "while/increment_fn";
466   Graph graph(OpRegistry::Global());
467   {
468     Scope scope = Scope::NewRootScope().ExitOnError();
469     auto dummy = ops::Placeholder(scope.WithOpName("Dummy"), DT_INT32);
470     auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32);
471     auto enter = ops::internal::Enter(scope.WithOpName("while/Enter"), source,
472                                       "while/while_context");
473     auto merge = ops::Merge(scope.WithOpName("while/Merge"),
474                             std::initializer_list<Input>{enter, dummy});
475     auto ten = ops::Const<int32>(
476         scope.WithOpName("while/Less/y").WithControlDependencies(merge.output),
477         10);
478     auto less = ops::Less(scope.WithOpName("while/Less"), merge.output, ten);
479     auto loop_cond = ops::LoopCond(scope.WithOpName("while/LoopCond"), less);
480     auto switch_ =
481         ops::Switch(scope.WithOpName("while/Switch"), merge.output, loop_cond);
482     auto exit = ops::internal::Exit(scope.WithOpName("while/Exit"),
483                                     switch_.output_false);
484     auto identity =
485         ops::Identity(scope.WithOpName("while/Identity"), switch_.output_true);
486 
487     TF_ASSERT_OK(AddNoinlineFunctionToGraph(noinline_node_name, scope.graph()));
488 
489     NodeDef next_iter;
490     next_iter.set_name("while/NextIteration");
491     next_iter.set_op("NextIteration");
492     *next_iter.add_input() = noinline_node_name;
493     (*next_iter.mutable_attr())["T"].set_type(DT_INT32);
494 
495     Status status;
496     Node* n = scope.graph()->AddNode(next_iter, &status);
497     TF_ASSERT_OK(status);
498 
499     // Remove the dummy node and add the loop backedge.
500     scope.graph()->RemoveNode(dummy.node());
501     scope.graph()->AddEdge(n, 0, merge.output.node(), 1);
502     TF_ASSERT_OK(scope.ToGraph(&graph));
503   }
504 
505   FunctionLibraryDefinition library(graph.flib_def());
506   GraphDef optimized_graph_def;
507   graph.ToGraphDef(&optimized_graph_def);
508 
509   *(optimized_graph_def.mutable_library()->add_function()) =
510       GetNoinlineFunctionDef();
511 
512   TF_ASSERT_OK(
513       FunctionalizeControlFlowForGraphDef(&optimized_graph_def, &library));
514   TF_ASSERT_OK(FunctionalizeControlFlow(&graph, &library));
515   GraphDef converted_graph_def;
516   graph.ToGraphDef(&converted_graph_def);
517 
518   for (const GraphDef& graph_def : {optimized_graph_def, converted_graph_def}) {
519     NameAttrList cond_fn, body_fn;
520     TF_ASSERT_OK(FindWhileCondAndBody(graph_def, &cond_fn, &body_fn));
521 
522     // Outer graph
523     {
524       Scope scope = Scope::NewRootScope().ExitOnError();
525       auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32);
526       auto while_op =
527           ops::While(scope.WithOpName("while/LoopCond"),
528                      std::initializer_list<Input>{source}, cond_fn, body_fn);
529       GraphDef expected;
530       TF_ASSERT_OK(scope.ToGraphDef(&expected));
531       TF_EXPECT_GRAPH_EQ(expected, graph_def);
532     }
533 
534     // Body graph.
535     {
536       Scope scope = Scope::NewRootScope().ExitOnError();
537       auto arg = ops::_Arg(scope.WithOpName("arg0"), DT_INT32, 0);
538       TF_ASSERT_OK(
539           AddNoinlineFunctionToGraph(noinline_node_name, scope.graph()));
540       auto identity = ops::Identity(scope.WithOpName("while/Identity"), arg);
541       NodeDef retval;
542       retval.set_name("retval0_RetVal");
543       retval.set_op(FunctionLibraryDefinition::kRetOp);
544       *retval.add_input() = noinline_node_name;
545       (*retval.mutable_attr())["T"].set_type(DT_INT32);
546       (*retval.mutable_attr())["index"].set_i(0);
547       Status status;
548       scope.graph()->AddNode(retval, &status);
549       TF_ASSERT_OK(status);
550 
551       GraphDef expected;
552       TF_ASSERT_OK(scope.ToGraphDef(&expected));
553 
554       InstantiationResultForTest result;
555       // Verify that increment_fn has been copied to library.
556       TF_EXPECT_OK(
557           InstantiateFunctionForTest(body_fn.name(), library, &result));
558 
559       EXPECT_EQ(DataTypeVector{DT_INT32}, result.arg_types);
560       EXPECT_EQ(DataTypeVector{DT_INT32}, result.ret_types);
561       // Ignore the function library when comparing the graphs.
562       expected.clear_library();
563       TF_EXPECT_GRAPH_EQ(expected, result.gdef);
564     }
565   }
566 }
567 
TEST(FunctionalizeControlFlow,MissingFunctionDefInLibrary)568 TEST(FunctionalizeControlFlow, MissingFunctionDefInLibrary) {
569   const string& noinline_node_name = "while/increment_fn";
570   Graph graph(OpRegistry::Global());
571   {
572     Scope scope = Scope::NewRootScope().ExitOnError();
573     auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32);
574     auto identity = ops::Identity(scope.WithOpName("while/Identity"), source);
575     TF_ASSERT_OK(AddNoinlineFunctionToGraph(noinline_node_name, scope.graph()));
576     TF_ASSERT_OK(scope.ToGraph(&graph));
577   }
578 
579   FunctionLibraryDefinition library(graph.flib_def());
580   GraphDef graph_def;
581   graph.ToGraphDef(&graph_def);
582   graph_def.clear_library();
583 
584   Status status = FunctionalizeControlFlowForGraphDef(&graph_def, &library);
585   EXPECT_EQ(tensorflow::error::NOT_FOUND, status.code());
586 }
587 
588 // Tests functionalizing OneLoopVar where the loop value is not used post the
589 // loop.
590 // Graph:
591 // x = array_ops.placeholder(dtypes.int32)
592 // control_flow_ops.while_loop(lambda i: i < 10, lambda i: i + 1, [x])
TEST(FunctionalizeControlFlow,OneLoopVarWithoutExit)593 TEST(FunctionalizeControlFlow, OneLoopVarWithoutExit) {
594   Graph graph(OpRegistry::Global());
595   {
596     Scope scope = Scope::NewRootScope().ExitOnError();
597 
598     auto dummy = ops::Placeholder(scope.WithOpName("Dummy"), DT_INT32);
599 
600     auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32);
601     auto enter =
602         ops::internal::Enter(scope.WithOpName("while/Enter"), source, "aloop");
603     auto merge = ops::Merge(scope.WithOpName("while/Merge"),
604                             std::initializer_list<Input>{enter, dummy});
605     auto ten = ops::Const<int32>(
606         scope.WithOpName("while/Less/y").WithControlDependencies(merge.output),
607         10);
608     auto less = ops::Less(scope.WithOpName("while/Less"), merge.output, ten);
609     auto loop_cond = ops::LoopCond(scope.WithOpName("while/LoopCond"), less);
610     auto switch_ =
611         ops::Switch(scope.WithOpName("while/Switch"), merge.output, loop_cond);
612     auto identity =
613         ops::Identity(scope.WithOpName("while/Identity"), switch_.output_true);
614     auto one = ops::Const<int32>(
615         scope.WithOpName("while/add/y").WithControlDependencies(identity), 1);
616     auto add = ops::Add(scope.WithOpName("while/add"), identity, one);
617     auto next_iteration =
618         ops::NextIteration(scope.WithOpName("while/NextIteration"), add);
619 
620     // Remove the dummy node and add the loop backedge.
621     scope.graph()->RemoveNode(dummy.node());
622     scope.graph()->AddEdge(next_iteration.node(), 0, merge.output.node(), 1);
623 
624     TF_EXPECT_OK(scope.ToGraph(&graph));
625   }
626 
627   FunctionLibraryDefinition library(OpRegistry::Global(), {});
628   GraphDef optimized_graph_def;
629   graph.ToGraphDef(&optimized_graph_def);
630   TF_ASSERT_OK(
631       FunctionalizeControlFlowForGraphDef(&optimized_graph_def, &library));
632   TF_ASSERT_OK(FunctionalizeControlFlow(&graph, &library));
633   GraphDef converted_graph_def;
634   graph.ToGraphDef(&converted_graph_def);
635 
636   for (const GraphDef& graph_def : {optimized_graph_def, converted_graph_def}) {
637     NameAttrList cond_fn, body_fn;
638     TF_EXPECT_OK(FindWhileCondAndBody(graph_def, &cond_fn, &body_fn));
639 
640     // Outer graph
641     {
642       Scope scope = Scope::NewRootScope().ExitOnError();
643       auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32);
644       auto while_op =
645           ops::While(scope.WithOpName("while/LoopCond"),
646                      std::initializer_list<Input>{source}, cond_fn, body_fn);
647       GraphDef expected;
648       TF_EXPECT_OK(scope.ToGraphDef(&expected));
649       TF_EXPECT_GRAPH_EQ(expected, graph_def);
650     }
651 
652     // Condition graph
653     {
654       Scope scope = Scope::NewRootScope().ExitOnError();
655       auto arg = ops::_Arg(scope.WithOpName("arg0"), DT_INT32, 0);
656       auto ten = ops::Const<int32>(
657           scope.WithOpName("while/Less/y").WithControlDependencies(arg), 10);
658       auto less = ops::Less(scope.WithOpName("while/Less"), arg, ten);
659       auto retval = ops::_Retval(scope.WithOpName("retval0_RetVal"), less, 0);
660 
661       GraphDef expected;
662       TF_EXPECT_OK(scope.ToGraphDef(&expected));
663 
664       InstantiationResultForTest result;
665       TF_EXPECT_OK(
666           InstantiateFunctionForTest(cond_fn.name(), library, &result));
667 
668       EXPECT_EQ(DataTypeVector{DT_INT32}, result.arg_types);
669       EXPECT_EQ(DataTypeVector{DT_BOOL}, result.ret_types);
670       TF_EXPECT_GRAPH_EQ(expected, result.gdef);
671     }
672 
673     // Body graph.
674     {
675       Scope scope = Scope::NewRootScope().ExitOnError();
676       auto arg = ops::_Arg(scope.WithOpName("arg0"), DT_INT32, 0);
677       auto identity = ops::Identity(scope.WithOpName("while/Identity"), arg);
678       auto one = ops::Const<int32>(
679           scope.WithOpName("while/add/y").WithControlDependencies(identity), 1);
680       auto add = ops::Add(scope.WithOpName("while/add"), identity, one);
681       auto retval = ops::_Retval(scope.WithOpName("retval0_RetVal"), add, 0);
682 
683       GraphDef expected;
684       TF_EXPECT_OK(scope.ToGraphDef(&expected));
685 
686       InstantiationResultForTest result;
687       TF_EXPECT_OK(
688           InstantiateFunctionForTest(body_fn.name(), library, &result));
689 
690       EXPECT_EQ(DataTypeVector{DT_INT32}, result.arg_types);
691       EXPECT_EQ(DataTypeVector{DT_INT32}, result.ret_types);
692       TF_EXPECT_GRAPH_EQ(expected, result.gdef);
693     }
694   }
695 }
696 
697 // Graph:
698 // x = array_ops.placeholder(dtypes.int32)
699 // y = array_ops.placeholder(dtypes.int32)
700 // cond = lambda (i, j): i + 3 < 10
701 // body = lambda (i, j): (i < 10, j * 2)
702 // z = control_flow_ops.while_loop(cond, body, [x, y])
TEST(FunctionalizeControlFlow,TwoLoopVars)703 TEST(FunctionalizeControlFlow, TwoLoopVars) {
704   Graph graph(OpRegistry::Global());
705   {
706     Scope scope = Scope::NewRootScope().ExitOnError();
707 
708     auto dummy = ops::Placeholder(scope.WithOpName("Dummy"), DT_INT32);
709 
710     auto x = ops::Placeholder(scope.WithOpName("Placeholder/x"), DT_INT32);
711     auto y = ops::Placeholder(scope.WithOpName("Placeholder/y"), DT_INT32);
712     auto enter_x =
713         ops::internal::Enter(scope.WithOpName("while/Enter/x"), x, "aloop");
714     auto enter_y =
715         ops::internal::Enter(scope.WithOpName("while/Enter/y"), y, "aloop");
716     auto merge_x = ops::Merge(scope.WithOpName("while/Merge/x"),
717                               std::initializer_list<Input>{enter_x, dummy});
718     auto merge_y = ops::Merge(scope.WithOpName("while/Merge/y"),
719                               std::initializer_list<Input>{enter_y, dummy});
720 
721     // Loop condition
722     auto three = ops::Const<int32>(scope.WithOpName("while/cond/three")
723                                        .WithControlDependencies(merge_x.output),
724                                    3);
725     auto cond_add =
726         ops::Add(scope.WithOpName("while/cond/Add"), merge_x.output, three);
727     auto ten = ops::Const<int32>(scope.WithOpName("while/cond/ten")
728                                      .WithControlDependencies(merge_x.output),
729                                  10);
730     auto less = ops::Less(scope.WithOpName("while/cond/Less"), cond_add, ten);
731     auto loop_cond = ops::LoopCond(scope.WithOpName("while/LoopCond"), less);
732 
733     auto switch_x = ops::Switch(scope.WithOpName("while/Switch/x"),
734                                 merge_x.output, loop_cond);
735     auto switch_y = ops::Switch(scope.WithOpName("while/Switch/y"),
736                                 merge_y.output, loop_cond);
737 
738     auto exit_x = ops::internal::Exit(scope.WithOpName("while/Exit/x"),
739                                       switch_x.output_false);
740     auto exit_y = ops::internal::Exit(scope.WithOpName("while/Exit/y"),
741                                       switch_y.output_false);
742 
743     auto identity_x = ops::Identity(scope.WithOpName("while/Identity/x"),
744                                     switch_x.output_true);
745     auto identity_y = ops::Identity(scope.WithOpName("while/Identity/y"),
746                                     switch_y.output_true);
747 
748     auto one = ops::Const<int32>(
749         scope.WithOpName("while/add/one").WithControlDependencies(identity_x),
750         1);
751     auto two = ops::Const<int32>(
752         scope.WithOpName("while/mul/two").WithControlDependencies(identity_x),
753         2);
754 
755     auto add = ops::Add(scope.WithOpName("while/add"), identity_x, one);
756     auto mul = ops::Add(scope.WithOpName("while/mul"), identity_y, two);
757     auto next_iteration_x =
758         ops::NextIteration(scope.WithOpName("while/NextIteration/x"), add);
759     auto next_iteration_y =
760         ops::NextIteration(scope.WithOpName("while/NextIteration/y"), mul);
761 
762     auto sink_x = ops::Identity(scope.WithOpName("sink_x"), exit_x);
763     auto sink_y = ops::Identity(scope.WithOpName("sink_y"), exit_y);
764 
765     // Remove the dummy node and add the loop backedges.
766     scope.graph()->RemoveNode(dummy.node());
767     scope.graph()->AddEdge(next_iteration_x.node(), 0, merge_x.output.node(),
768                            1);
769     scope.graph()->AddEdge(next_iteration_y.node(), 0, merge_y.output.node(),
770                            1);
771 
772     TF_EXPECT_OK(scope.ToGraph(&graph));
773   }
774 
775   FunctionLibraryDefinition library(OpRegistry::Global(), {});
776   GraphDef optimized_graph_def;
777   graph.ToGraphDef(&optimized_graph_def);
778   TF_ASSERT_OK(
779       FunctionalizeControlFlowForGraphDef(&optimized_graph_def, &library));
780   TF_ASSERT_OK(FunctionalizeControlFlow(&graph, &library));
781   GraphDef converted_graph_def;
782   graph.ToGraphDef(&converted_graph_def);
783 
784   for (const GraphDef& graph_def : {optimized_graph_def, converted_graph_def}) {
785     NameAttrList cond_fn, body_fn;
786     TF_EXPECT_OK(FindWhileCondAndBody(graph_def, &cond_fn, &body_fn));
787 
788     // Outer graph.
789     {
790       Scope scope = Scope::NewRootScope().ExitOnError();
791       auto x = ops::Placeholder(scope.WithOpName("Placeholder/x"), DT_INT32);
792       auto y = ops::Placeholder(scope.WithOpName("Placeholder/y"), DT_INT32);
793       auto while_op =
794           ops::While(scope.WithOpName("while/LoopCond"),
795                      std::initializer_list<Input>{x, y}, cond_fn, body_fn);
796       auto sink_x = ops::Identity(scope.WithOpName("sink_x"), while_op[0]);
797       auto sink_y = ops::Identity(scope.WithOpName("sink_y"), while_op[1]);
798       GraphDef expected;
799       TF_EXPECT_OK(scope.ToGraphDef(&expected));
800       TF_EXPECT_GRAPH_EQ(expected, graph_def);
801     }
802 
803     // Condition graph.
804     {
805       Scope scope = Scope::NewRootScope().ExitOnError();
806       auto arg0 = ops::_Arg(scope.WithOpName("arg0"), DT_INT32, 0);
807       auto arg1 = ops::_Arg(scope.WithOpName("arg1"), DT_INT32, 1);
808       auto three = ops::Const<int32>(scope.WithOpName("while/cond/three")
809                                          .WithControlDependencies(arg0.output),
810                                      3);
811       auto cond_add =
812           ops::Add(scope.WithOpName("while/cond/Add"), arg0.output, three);
813       auto ten = ops::Const<int32>(scope.WithOpName("while/cond/ten")
814                                        .WithControlDependencies(arg0.output),
815                                    10);
816       auto less = ops::Less(scope.WithOpName("while/cond/Less"), cond_add, ten);
817       auto retval = ops::_Retval(scope.WithOpName("retval0_RetVal"), less, 0);
818 
819       GraphDef expected;
820       TF_EXPECT_OK(scope.ToGraphDef(&expected));
821 
822       InstantiationResultForTest result;
823       TF_EXPECT_OK(
824           InstantiateFunctionForTest(cond_fn.name(), library, &result));
825 
826       EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32}), result.arg_types);
827       EXPECT_EQ(DataTypeVector{DT_BOOL}, result.ret_types);
828       TF_EXPECT_GRAPH_EQ(expected, result.gdef);
829     }
830 
831     // Body graph.
832     {
833       Scope scope = Scope::NewRootScope().ExitOnError();
834       auto arg0 = ops::_Arg(scope.WithOpName("arg0"), DT_INT32, 0);
835       auto arg1 = ops::_Arg(scope.WithOpName("arg1"), DT_INT32, 1);
836 
837       auto identity_x =
838           ops::Identity(scope.WithOpName("while/Identity/x"), arg0);
839       auto identity_y =
840           ops::Identity(scope.WithOpName("while/Identity/y"), arg1);
841 
842       auto one = ops::Const<int32>(
843           scope.WithOpName("while/add/one").WithControlDependencies(identity_x),
844           1);
845       auto two = ops::Const<int32>(
846           scope.WithOpName("while/mul/two").WithControlDependencies(identity_x),
847           2);
848 
849       auto add = ops::Add(scope.WithOpName("while/add"), identity_x, one);
850       auto mul = ops::Add(scope.WithOpName("while/mul"), identity_y, two);
851       auto retval0 = ops::_Retval(scope.WithOpName("retval0_RetVal"), add, 0);
852       auto retval1 = ops::_Retval(scope.WithOpName("retval1_RetVal"), mul, 1);
853 
854       GraphDef expected;
855       TF_EXPECT_OK(scope.ToGraphDef(&expected));
856 
857       InstantiationResultForTest result;
858       TF_EXPECT_OK(
859           InstantiateFunctionForTest(body_fn.name(), library, &result));
860 
861       EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32}), result.arg_types);
862       EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32}), result.ret_types);
863       TF_EXPECT_GRAPH_EQ(expected, result.gdef);
864     }
865   }
866 }
867 
868 // More complex example with nesting, loop-invariant arguments, and resource
869 // variables. Used for multiple tests with different node filters.
870 class ComplexTestFixture
871     : public ::testing::TestWithParam<std::tuple<bool, bool, bool>> {
872  protected:
SetUp()873   void SetUp() override {
874     restrict_to_tpu_nodes_ = std::get<0>(GetParam());
875     mark_inner_loop_tpu_ = std::get<1>(GetParam());
876     mark_outer_loop_tpu_ = std::get<2>(GetParam());
877   }
878   void RunTest();
879 
880  private:
881   void CheckOuterNodesFunctionalized(const GraphDef& graph_def,
882                                      const FunctionLibraryDefinition& library,
883                                      NameAttrList& inner_cond_fn,
884                                      NameAttrList& inner_body_fn);
885   void CheckInnerNodesFunctionalized(const GraphDef& graph_def,
886                                      const FunctionLibraryDefinition& library,
887                                      const NameAttrList& inner_cond_fn,
888                                      const NameAttrList& inner_body_fn);
889 
890   bool restrict_to_tpu_nodes_ = false;
891   bool mark_inner_loop_tpu_ = false;
892   bool mark_outer_loop_tpu_ = false;
893 };
894 
TEST_P(ComplexTestFixture,ComplexTests)895 TEST_P(ComplexTestFixture, ComplexTests) { RunTest(); }
896 
897 INSTANTIATE_TEST_SUITE_P(
898     FunctionalizeControlFlow, ComplexTestFixture,
899     ::testing::Combine(::testing::Bool(), ::testing::Bool(), ::testing::Bool()),
__anon7e28f50b0402(const ::testing::TestParamInfo<ComplexTestFixture::ParamType>& info) 900     [](const ::testing::TestParamInfo<ComplexTestFixture::ParamType>& info) {
901       bool restrict_to_tpu_nodes = std::get<0>(info.param);
902       bool mark_inner_loop_tpu = std::get<1>(info.param);
903       bool mark_outer_loop_tpu = std::get<2>(info.param);
904 
905       string node_string;
906       if (mark_inner_loop_tpu && mark_outer_loop_tpu)
907         node_string = "both_loops_tpu";
908       else if (!mark_inner_loop_tpu && !mark_outer_loop_tpu)
909         node_string = "no_loop_tpu";
910       else
911         node_string = mark_inner_loop_tpu ? "inner_loop_tpu" : "outer_loop_tpu";
912 
913       string name = absl::StrCat(
914           restrict_to_tpu_nodes ? "restricted_" : "unrestricted_", node_string);
915       return name;
916     });
917 
RunTest()918 void ComplexTestFixture::RunTest() {
919   // Graph:
920   //
921   // accum = resource_variable_ops.ResourceVariable(1)
922   // x = array_ops.placeholder(2, dtype=dtypes.int32)
923   // y = 3 + x
924   //
925   // def inner_body(j, k):
926   //   add = state_ops.assign_add(accum, k * j + x)
927   //   with ops.control_dependencies([add]):
928   //     return [j + 1, k]
929   //
930   // def body(i):
931   //   m = control_flow_ops.while_loop(lambda j, k: j < 5, inner_body,
932   //                                   [1, y], name="inner")
933   //   with ops.control_dependencies(m):
934   //     return [i + 1]
935   //
936   // z = control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="outer")
937   Graph graph(OpRegistry::Global());
938   {
939     Scope scope = Scope::NewRootScope().ExitOnError();
940 
941     auto dummy = ops::Placeholder(scope.WithOpName("Dummy"), DT_INT32);
942 
943     auto x = ops::Placeholder(scope.WithOpName("x"), DT_INT32);
944     auto three = ops::Const<int32>(scope.WithOpName("three"), 3);
945     auto y = ops::Add(scope.WithOpName("y"), x, three);
946 
947     auto var = ops::VarHandleOp(scope.WithOpName("Variable"), DT_INT32,
948                                 TensorShape({}));
949 
950     // Outer loop
951     auto zero = ops::Const<int32>(scope.WithOpName("outer/Const"), 0);
952     auto enter_i =
953         ops::internal::Enter(scope.WithOpName("outer/Enter_i"), zero, "outer");
954     auto merge_i = ops::Merge(scope.WithOpName("outer/Merge_i"),
955                               std::initializer_list<Input>{enter_i, dummy});
956     auto ten = ops::Const<int32>(scope.WithOpName("outer/Less/y")
957                                      .WithControlDependencies(merge_i.output),
958                                  10);
959     auto less_i =
960         ops::Less(scope.WithOpName("outer/Less_i"), merge_i.output, ten);
961     auto outer_loop_cond =
962         ops::LoopCond(scope.WithOpName("outer/LoopCond"), less_i);
963     auto switch_i = ops::Switch(scope.WithOpName("outer/Switch"),
964                                 merge_i.output, outer_loop_cond);
965     auto exit_i = ops::internal::Exit(scope.WithOpName("outer/Exit"),
966                                       switch_i.output_false);
967     auto identity_i =
968         ops::Identity(scope.WithOpName("outer/Identity"), switch_i.output_true);
969 
970     auto enter_x_outer =
971         ops::internal::Enter(scope.WithOpName("outer/Enter_x"), x, "outer",
972                              ops::internal::Enter::Attrs().IsConstant(true));
973     auto enter_k_outer =
974         ops::internal::Enter(scope.WithOpName("outer/Enter_k"), y, "outer",
975                              ops::internal::Enter::Attrs().IsConstant(true));
976     auto enter_var_outer =
977         ops::internal::Enter(scope.WithOpName("outer/Enter_var"), var, "outer",
978                              ops::internal::Enter::Attrs().IsConstant(true));
979 
980     // Inner loop
981     auto one_j = ops::Const<int32>(
982         scope.WithOpName("outer/j").WithControlDependencies(identity_i), 1);
983     auto enter_j = ops::internal::Enter(scope.WithOpName("outer/inner/Enter_j"),
984                                         one_j, "inner");
985     auto enter_k =
986         ops::internal::Enter(scope.WithOpName("outer/inner/Enter_k")
987                                  .WithControlDependencies(identity_i),
988                              enter_k_outer, "inner");
989     auto enter_x = ops::internal::Enter(
990         scope.WithOpName("outer/inner/Enter_x"), enter_x_outer, "inner",
991         ops::internal::Enter::Attrs().IsConstant(true));
992     auto enter_var = ops::internal::Enter(
993         scope.WithOpName("outer/inner/Enter_var"), enter_var_outer, "inner",
994         ops::internal::Enter::Attrs().IsConstant(true));
995 
996     auto merge_j = ops::Merge(scope.WithOpName("outer/inner/Merge_j"),
997                               std::initializer_list<Input>{enter_j, dummy});
998     auto merge_k = ops::Merge(scope.WithOpName("outer/inner/Merge_k"),
999                               std::initializer_list<Input>{enter_k, dummy});
1000 
1001     auto five = ops::Const<int32>(scope.WithOpName("outer/inner/Five")
1002                                       .WithControlDependencies(merge_j.output),
1003                                   5);
1004     auto less_j =
1005         ops::Less(scope.WithOpName("outer/inner/Less_j"), merge_j.output, five);
1006     auto loop_cond =
1007         ops::LoopCond(scope.WithOpName("outer/inner/LoopCond"), less_j);
1008 
1009     auto switch_j = ops::Switch(scope.WithOpName("outer/inner/Switch_j"),
1010                                 merge_j.output, loop_cond);
1011     auto switch_k = ops::Switch(scope.WithOpName("outer/inner/Switch_k"),
1012                                 merge_k.output, loop_cond);
1013     auto exit_j = ops::internal::Exit(scope.WithOpName("outer/inner/Exit_j"),
1014                                       switch_j.output_false);
1015     auto exit_k = ops::internal::Exit(scope.WithOpName("outer/inner/Exit_k"),
1016                                       switch_k.output_false);
1017     auto identity_j = ops::Identity(scope.WithOpName("outer/inner/Identity_j"),
1018                                     switch_j.output_true);
1019     auto identity_k = ops::Identity(scope.WithOpName("outer/inner/Identity_k"),
1020                                     switch_k.output_true);
1021 
1022     // Variable update
1023     auto mul_jk =
1024         ops::Mul(scope.WithOpName("outer/inner/mul"), identity_j, identity_k);
1025     auto add_jkx =
1026         ops::Add(scope.WithOpName("outer/inner/add"), mul_jk, enter_x);
1027     auto assign = ops::AssignAddVariableOp(
1028         scope.WithOpName("outer/inner/assign_add"), enter_var, add_jkx);
1029 
1030     auto one = ops::Const<int32>(
1031         scope.WithOpName("outer/inner/One")
1032             .WithControlDependencies(
1033                 absl::Span<const Operation>{assign.operation}),
1034         1);
1035     auto add_j =
1036         ops::Add(scope.WithOpName("outer/inner/add_j"), identity_j, one);
1037 
1038     auto next_iteration_j = ops::NextIteration(
1039         scope.WithOpName("outer/inner/NextIteration_j"), add_j);
1040     auto next_iteration_k = ops::NextIteration(
1041         scope.WithOpName("outer/inner/NextIteration_k"), identity_k);
1042 
1043     // Body and backedge for outer loop.
1044     auto one_outer = ops::Const<int32>(
1045         scope.WithOpName("outer/add/y").WithControlDependencies(identity_i), 1);
1046     auto add_i =
1047         ops::Add(scope.WithOpName("outer/add")
1048                      .WithControlDependencies(absl::Span<const Operation>{
1049                          exit_j.output.op(), exit_k.output.op()}),
1050                  identity_i, one_outer);
1051     auto next_iteration_i =
1052         ops::NextIteration(scope.WithOpName("outer/NextIteration"), add_i);
1053 
1054     auto sink = ops::Identity(scope.WithOpName("sink"), exit_i);
1055 
1056     // Remove the dummy node and add the loop backedge.
1057     scope.graph()->RemoveNode(dummy.node());
1058     scope.graph()->AddEdge(next_iteration_i.node(), 0, merge_i.output.node(),
1059                            1);
1060     scope.graph()->AddEdge(next_iteration_j.node(), 0, merge_j.output.node(),
1061                            1);
1062     scope.graph()->AddEdge(next_iteration_k.node(), 0, merge_k.output.node(),
1063                            1);
1064 
1065     TF_EXPECT_OK(scope.ToGraph(&graph));
1066   }
1067   // Add '_tpu_replicate' attributes as specified.
1068   for (Node* n : graph.nodes()) {
1069     string name = n->name();
1070     bool is_inner_node = name.find("outer/inner/") != string::npos;
1071     bool is_outer_node = !is_inner_node && name.find("outer/") != string::npos;
1072     if ((is_inner_node && mark_inner_loop_tpu_) ||
1073         (is_outer_node && mark_outer_loop_tpu_)) {
1074       n->AddAttr("_tpu_replicate", "cluster");
1075     }
1076   }
1077 
1078   FunctionLibraryDefinition library(OpRegistry::Global(), {});
1079   GraphDef orig_graph_def, optimized_graph_def;
1080   graph.ToGraphDef(&orig_graph_def);
1081   optimized_graph_def = orig_graph_def;
1082   // If `restrict_to_tpu_nodes_` is true let filter function return true for
1083   // `_tpu_replicate` nodes, otherwise don't set filter.
1084   NodeFilter node_filter =
1085       restrict_to_tpu_nodes_
1086           ? [](const Node* n) { return n->attrs().Find("_tpu_replicate"); }
1087           : NodeFilter{};
1088 
1089   Status status1 = FunctionalizeControlFlowForGraphDef(&optimized_graph_def,
1090                                                        &library, node_filter);
1091   Status status2 = FunctionalizeControlFlow(&graph, &library, node_filter);
1092   ASSERT_EQ(status1, status2);
1093   if (restrict_to_tpu_nodes_ && mark_outer_loop_tpu_ && !mark_inner_loop_tpu_) {
1094     // This case violates the precondition of `FunctionalizeControlFlow`, we
1095     // expect an internal error.
1096     ASSERT_EQ(errors::IsInternal(status1), true);
1097     return;
1098   } else {
1099     // Supported cases, no error expected.
1100     TF_ASSERT_OK(status1);
1101   }
1102 
1103   GraphDef optimized_converted_graph_def;
1104   graph.ToGraphDef(&optimized_converted_graph_def);
1105   for (const GraphDef& graph_def :
1106        {optimized_graph_def, optimized_converted_graph_def}) {
1107     NameAttrList inner_cond_fn, inner_body_fn;
1108     if (!restrict_to_tpu_nodes_ ||
1109         (restrict_to_tpu_nodes_ && mark_outer_loop_tpu_ &&
1110          mark_inner_loop_tpu_)) {
1111       // We expect that both inner and outer nodes have been functionalized.
1112       CheckOuterNodesFunctionalized(graph_def, library, inner_cond_fn,
1113                                     inner_body_fn);
1114       CheckInnerNodesFunctionalized(graph_def, library, inner_cond_fn,
1115                                     inner_body_fn);
1116     } else /*restrict_to_tpu_nodes_ == true*/ {
1117       if (!mark_outer_loop_tpu_ && !mark_inner_loop_tpu_) {
1118         // Graph has no TPU nodes so we expect no functionalization.
1119         TF_EXPECT_GRAPH_EQ(orig_graph_def, graph_def);
1120       } else if (!mark_outer_loop_tpu_ && mark_inner_loop_tpu_) {
1121         // We expect that only inner nodes have been functionalized.
1122         TF_EXPECT_OK(
1123             FindWhileCondAndBody(graph_def, &inner_cond_fn, &inner_body_fn));
1124         CheckInnerNodesFunctionalized(graph_def, library, inner_cond_fn,
1125                                       inner_body_fn);
1126       }
1127     }
1128   }
1129 }
1130 
CheckOuterNodesFunctionalized(const GraphDef & graph_def,const FunctionLibraryDefinition & library,NameAttrList & inner_cond_fn,NameAttrList & inner_body_fn)1131 void ComplexTestFixture::CheckOuterNodesFunctionalized(
1132     const GraphDef& graph_def, const FunctionLibraryDefinition& library,
1133     NameAttrList& inner_cond_fn, NameAttrList& inner_body_fn) {
1134   NameAttrList outer_cond_fn, outer_body_fn;
1135   TF_EXPECT_OK(FindWhileCondAndBody(graph_def, &outer_cond_fn, &outer_body_fn));
1136 
1137   // Outer graph.
1138   {
1139     Scope scope = Scope::NewRootScope().ExitOnError();
1140     auto x = ops::Placeholder(scope.WithOpName("x"), DT_INT32);
1141     auto three = ops::Const<int32>(scope.WithOpName("three"), 3);
1142     auto y = ops::Add(scope.WithOpName("y"), x, three);
1143 
1144     auto var = ops::VarHandleOp(scope.WithOpName("Variable"), DT_INT32,
1145                                 TensorShape({}));
1146 
1147     auto zero = ops::Const<int32>(scope.WithOpName("outer/Const"), 0);
1148 
1149     auto while_op = ops::While(scope.WithOpName("outer/LoopCond"),
1150                                std::initializer_list<Input>{zero, y, x, var},
1151                                outer_cond_fn, outer_body_fn);
1152     auto sink = ops::Identity(scope.WithOpName("sink"), while_op[0]);
1153     GraphDef expected;
1154     TF_EXPECT_OK(scope.ToGraphDef(&expected));
1155     TF_EXPECT_GRAPH_EQ(expected, graph_def);
1156   }
1157 
1158   // Outer condition graph.
1159   {
1160     Scope scope = Scope::NewRootScope().ExitOnError();
1161     auto arg0 = ops::_Arg(scope.WithOpName("arg0"), DT_INT32, 0);
1162     auto arg1 = ops::_Arg(scope.WithOpName("arg1"), DT_INT32, 1);
1163     auto arg2 = ops::_Arg(scope.WithOpName("arg2"), DT_INT32, 2);
1164     auto arg3 = ops::_Arg(scope.WithOpName("arg3"), DT_RESOURCE, 3);
1165 
1166     auto ten = ops::Const<int32>(
1167         scope.WithOpName("outer/Less/y").WithControlDependencies(arg0.output),
1168         10);
1169     auto less = ops::Less(scope.WithOpName("outer/Less_i"), arg0, ten);
1170     auto retval = ops::_Retval(scope.WithOpName("retval0_RetVal"), less, 0);
1171 
1172     GraphDef expected;
1173     TF_EXPECT_OK(scope.ToGraphDef(&expected));
1174 
1175     InstantiationResultForTest result;
1176     TF_EXPECT_OK(
1177         InstantiateFunctionForTest(outer_cond_fn.name(), library, &result));
1178 
1179     EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32, DT_RESOURCE}),
1180               result.arg_types);
1181     EXPECT_EQ(DataTypeVector{DT_BOOL}, result.ret_types);
1182     TF_EXPECT_GRAPH_EQ(expected, result.gdef);
1183   }
1184 
1185   // Outer body graph.
1186   {
1187     InstantiationResultForTest result;
1188     TF_EXPECT_OK(
1189         InstantiateFunctionForTest(outer_body_fn.name(), library, &result));
1190 
1191     // Find the inner condition and body names.
1192     TF_EXPECT_OK(
1193         FindWhileCondAndBody(result.gdef, &inner_cond_fn, &inner_body_fn));
1194 
1195     Scope scope = Scope::NewRootScope().ExitOnError();
1196     auto arg0 = ops::_Arg(scope.WithOpName("arg0"), DT_INT32, 0);
1197     auto arg1 = ops::_Arg(scope.WithOpName("arg1"), DT_INT32, 1);
1198     auto arg2 = ops::_Arg(scope.WithOpName("arg2"), DT_INT32, 2);
1199     auto arg3 = ops::_Arg(scope.WithOpName("arg3"), DT_RESOURCE, 3);
1200 
1201     auto identity_i = ops::Identity(scope.WithOpName("outer/Identity"), arg0);
1202     auto one_j = ops::Const<int32>(
1203         scope.WithOpName("outer/j").WithControlDependencies(identity_i), 1);
1204     auto while_op =
1205         ops::While(scope.WithOpName("outer/inner/LoopCond"),
1206                    std::initializer_list<Input>{one_j, arg1, arg2, arg3},
1207                    inner_cond_fn, inner_body_fn);
1208 
1209     auto one_outer = ops::Const<int32>(
1210         scope.WithOpName("outer/add/y").WithControlDependencies(identity_i), 1);
1211     auto add_i =
1212         ops::Add(scope.WithOpName("outer/add")
1213                      .WithControlDependencies(absl::Span<const Operation>{
1214                          while_op[0].op(), while_op[1].op()}),
1215                  identity_i, one_outer);
1216 
1217     auto retval0 = ops::_Retval(scope.WithOpName("retval0_RetVal"), add_i, 0);
1218     auto retval1 = ops::_Retval(scope.WithOpName("retval1_RetVal"), arg1, 1);
1219     auto retval2 = ops::_Retval(scope.WithOpName("retval2_RetVal"), arg2, 2);
1220     auto retval3 = ops::_Retval(scope.WithOpName("retval3_RetVal"), arg3, 3);
1221 
1222     GraphDef expected;
1223     TF_EXPECT_OK(scope.ToGraphDef(&expected));
1224 
1225     EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32, DT_RESOURCE}),
1226               result.arg_types);
1227     EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32, DT_RESOURCE}),
1228               result.ret_types);
1229     TF_EXPECT_GRAPH_EQ(expected, result.gdef);
1230   }
1231 }
1232 
CheckInnerNodesFunctionalized(const GraphDef & graph_def,const FunctionLibraryDefinition & library,const NameAttrList & inner_cond_fn,const NameAttrList & inner_body_fn)1233 void ComplexTestFixture::CheckInnerNodesFunctionalized(
1234     const GraphDef& graph_def, const FunctionLibraryDefinition& library,
1235     const NameAttrList& inner_cond_fn, const NameAttrList& inner_body_fn) {
1236   // Inner condition graph.
1237   {
1238     Scope scope = Scope::NewRootScope().ExitOnError();
1239     auto arg0 = ops::_Arg(scope.WithOpName("arg0"), DT_INT32, 0);
1240     auto arg1 = ops::_Arg(scope.WithOpName("arg1"), DT_INT32, 1);
1241     auto arg2 = ops::_Arg(scope.WithOpName("arg2"), DT_INT32, 2);
1242     auto arg3 = ops::_Arg(scope.WithOpName("arg3"), DT_RESOURCE, 3);
1243 
1244     auto five = ops::Const<int32>(
1245         scope.WithOpName("outer/inner/Five").WithControlDependencies(arg0), 5);
1246     auto less_j = ops::Less(scope.WithOpName("outer/inner/Less_j"), arg0, five);
1247     auto retval = ops::_Retval(scope.WithOpName("retval0_RetVal"), less_j, 0);
1248 
1249     GraphDef expected;
1250     TF_EXPECT_OK(scope.ToGraphDef(&expected));
1251 
1252     InstantiationResultForTest result;
1253     TF_EXPECT_OK(
1254         InstantiateFunctionForTest(inner_cond_fn.name(), library, &result));
1255 
1256     EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32, DT_RESOURCE}),
1257               result.arg_types);
1258     EXPECT_EQ(DataTypeVector{DT_BOOL}, result.ret_types);
1259     TF_EXPECT_GRAPH_EQ(expected, result.gdef);
1260   }
1261 
1262   // Inner body graph.
1263   {
1264     Scope scope = Scope::NewRootScope().ExitOnError();
1265     auto arg0 = ops::_Arg(scope.WithOpName("arg0"), DT_INT32, 0);
1266     auto arg1 = ops::_Arg(scope.WithOpName("arg1"), DT_INT32, 1);
1267     auto arg2 = ops::_Arg(scope.WithOpName("arg2"), DT_INT32, 2);
1268     auto arg3 = ops::_Arg(scope.WithOpName("arg3"), DT_RESOURCE, 3);
1269 
1270     auto identity_j =
1271         ops::Identity(scope.WithOpName("outer/inner/Identity_j"), arg0);
1272     auto identity_k =
1273         ops::Identity(scope.WithOpName("outer/inner/Identity_k"), arg1);
1274 
1275     auto mul_jk =
1276         ops::Mul(scope.WithOpName("outer/inner/mul"), identity_j, identity_k);
1277     auto add_jkx = ops::Add(scope.WithOpName("outer/inner/add"), mul_jk, arg2);
1278     auto assign = ops::AssignAddVariableOp(
1279         scope.WithOpName("outer/inner/assign_add"), arg3, add_jkx);
1280 
1281     auto one = ops::Const<int32>(
1282         scope.WithOpName("outer/inner/One")
1283             .WithControlDependencies(
1284                 absl::Span<const Operation>{assign.operation}),
1285         1);
1286     auto add_j =
1287         ops::Add(scope.WithOpName("outer/inner/add_j"), identity_j, one);
1288 
1289     auto retval0 = ops::_Retval(scope.WithOpName("retval0_RetVal"), add_j, 0);
1290     auto retval1 =
1291         ops::_Retval(scope.WithOpName("retval1_RetVal"), identity_k, 1);
1292     auto retval2 = ops::_Retval(scope.WithOpName("retval2_RetVal"), arg2, 2);
1293     auto retval3 = ops::_Retval(scope.WithOpName("retval3_RetVal"), arg3, 3);
1294 
1295     GraphDef expected;
1296     TF_EXPECT_OK(scope.ToGraphDef(&expected));
1297 
1298     InstantiationResultForTest result;
1299     TF_EXPECT_OK(
1300         InstantiateFunctionForTest(inner_body_fn.name(), library, &result));
1301 
1302     EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32, DT_RESOURCE}),
1303               result.arg_types);
1304     EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32, DT_RESOURCE}),
1305               result.ret_types);
1306     TF_EXPECT_GRAPH_EQ(expected, result.gdef);
1307   }
1308 }
1309 
1310 }  // namespace
1311 }  // namespace tensorflow
1312