• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/tf2xla/functionalize_control_flow.h"
17 
18 #include <string>
19 
20 #include "tensorflow/cc/framework/ops.h"
21 #include "tensorflow/cc/ops/control_flow_ops_internal.h"
22 #include "tensorflow/cc/ops/function_ops.h"
23 #include "tensorflow/cc/ops/functional_ops.h"
24 #include "tensorflow/cc/ops/resource_variable_ops.h"
25 #include "tensorflow/cc/ops/standard_ops.h"
26 #include "tensorflow/compiler/tf2xla/cc/ops/xla_ops.h"
27 #include "tensorflow/compiler/tf2xla/test_util.h"
28 #include "tensorflow/compiler/tf2xla/tf2xla_util.h"
29 #include "tensorflow/compiler/xla/status_macros.h"
30 #include "tensorflow/core/common_runtime/function.h"
31 #include "tensorflow/core/common_runtime/graph_constructor.h"
32 #include "tensorflow/core/framework/function.h"
33 #include "tensorflow/core/framework/graph_to_functiondef.h"
34 #include "tensorflow/core/framework/node_def_util.h"
35 #include "tensorflow/core/framework/op.h"
36 #include "tensorflow/core/graph/graph_def_builder.h"
37 #include "tensorflow/core/graph/validate.h"
38 #include "tensorflow/core/lib/core/status_test_util.h"
39 #include "tensorflow/core/platform/test.h"
40 #include "tensorflow/core/public/version.h"
41 #include "tensorflow/core/util/dump_graph.h"
42 #include "tensorflow/core/util/equal_graph_def.h"
43 
44 namespace tensorflow {
45 namespace {
46 
47 // Returns the names of the "then" and "else" functions for the If node in a
48 // graph.
FindIfThenAndElse(const GraphDef & graph,string * op_name,NameAttrList * then_fn,NameAttrList * else_fn)49 Status FindIfThenAndElse(const GraphDef& graph, string* op_name,
50                          NameAttrList* then_fn, NameAttrList* else_fn) {
51   for (const NodeDef& node : graph.node()) {
52     if (node.op() == "If") {
53       *op_name = node.name();
54       const NameAttrList* result;
55       TF_RETURN_IF_ERROR(GetNodeAttr(node, "then_branch", &result));
56       *then_fn = *result;
57       TF_RETURN_IF_ERROR(GetNodeAttr(node, "else_branch", &result));
58       *else_fn = *result;
59       return OkStatus();
60     }
61   }
62   return errors::NotFound("No If node found in graph");
63 }
64 
65 // Graph:
66 // x = array_ops.placeholder(dtypes.int32)
67 // y = array_ops.placeholder(dtypes.int32)
68 // z = control_flow_ops.cond(
69 //     math_ops.less(y, x), lambda: math_ops.multiply(y, 17),
70 //     lambda: math_ops.add(x, 23))
71 //
72 // Tests different node filters and functionalization inside of a function.
73 class ConditionalTestFixture
74     : public ::testing::TestWithParam<std::tuple<bool, bool>> {
75  protected:
SetUp()76   void SetUp() override {
77     restrict_to_tpu_nodes_ = std::get<0>(GetParam());
78     wrap_condition_in_function_ = std::get<1>(GetParam());
79   }
80   void RunTest();
81 
82  private:
83   void BuildCondGraph(Graph* cond_graph);
84   void CheckGraphDef(const GraphDef& graph_def,
85                      const FunctionLibraryDefinition& library);
86 
87   bool restrict_to_tpu_nodes_ = false;
88   bool wrap_condition_in_function_ = false;
89 };
90 
TEST_P(ConditionalTestFixture,ConditionalTests)91 TEST_P(ConditionalTestFixture, ConditionalTests) { RunTest(); }
92 
93 INSTANTIATE_TEST_SUITE_P(
94     FunctionalizeControlFlow, ConditionalTestFixture,
95     ::testing::Combine(::testing::Bool(), ::testing::Bool()),
96     [](const ::testing::TestParamInfo<ConditionalTestFixture::ParamType>&
__anon8d20b2690202(const ::testing::TestParamInfo<ConditionalTestFixture::ParamType>& info) 97            info) {
98       bool restrict_to_tpu_nodes = std::get<0>(info.param);
99       bool wrap_cond_in_function = std::get<1>(info.param);
100       string name =
101           absl::StrCat(restrict_to_tpu_nodes ? "with_filter" : "without_filter",
102                        wrap_cond_in_function ? "_in_function" : "_in_graph");
103       return name;
104     });
105 
BuildCondGraph(Graph * cond_graph)106 void ConditionalTestFixture::BuildCondGraph(Graph* cond_graph) {
107   {
108     Scope scope = Scope::NewRootScope().ExitOnError();
109 
110     auto x = ops::Placeholder(scope.WithOpName("x"), DT_INT32);
111     auto y = ops::Placeholder(scope.WithOpName("y"), DT_INT32);
112     auto less = ops::Less(scope.WithOpName("cond/Less"), y, x);
113     auto switch_1 = ops::Switch(scope.WithOpName("cond/Switch"), less, less);
114 
115     auto identity_t =
116         ops::Identity(scope.WithOpName("cond/Identity"), switch_1.output_true);
117     auto seventeen = ops::Const<int32>(
118         scope.WithOpName("cond").WithControlDependencies(identity_t), 17);
119     auto switch_2 = ops::Switch(scope.WithOpName("cond/Switch"), y, less);
120     auto mul = ops::Multiply(scope.WithOpName("cond/Mul"), switch_2.output_true,
121                              seventeen);
122 
123     auto identity_f =
124         ops::Identity(scope.WithOpName("cond/Identity"), switch_1.output_false);
125     auto twenty_three = ops::Const<int32>(
126         scope.WithOpName("cond").WithControlDependencies(identity_f), 23);
127     auto switch_3 = ops::Switch(scope.WithOpName("cond/Switch"), x, less);
128     auto add = ops::Add(scope.WithOpName("cond/false/add"),
129                         switch_3.output_false, twenty_three);
130 
131     auto merge = ops::Merge(scope.WithOpName("cond/Merge"),
132                             std::initializer_list<Input>{add, mul});
133 
134     TF_EXPECT_OK(scope.ToGraph(cond_graph));
135 
136     // Set all attributes that need propagation for all nodes. This is to test
137     // if propagation works. Note that this includes `_tpu_replicate`.
138     for (Node* n : cond_graph->nodes()) {
139       std::string dummy_value = "value";
140       for (absl::string_view attr_name : kAttrsToPropagate) {
141         n->AddAttr(std::string(attr_name), dummy_value);
142       }
143     }
144   }
145 }
146 
CheckGraphDef(const GraphDef & graph_def,const FunctionLibraryDefinition & library)147 void ConditionalTestFixture::CheckGraphDef(
148     const GraphDef& graph_def, const FunctionLibraryDefinition& library) {
149   string op_name;
150   NameAttrList then_fn;
151   NameAttrList else_fn;
152   TF_EXPECT_OK(FindIfThenAndElse(graph_def, &op_name, &then_fn, &else_fn));
153   InstantiationResultForTest else_result;
154   TF_EXPECT_OK(
155       InstantiateFunctionForTest(else_fn.name(), library, &else_result));
156 
157   // Outer graph
158   {
159     Scope scope = Scope::NewRootScope().ExitOnError();
160     auto y = ops::Placeholder(scope.WithOpName("y"), DT_INT32);
161     auto x = ops::Placeholder(scope.WithOpName("x"), DT_INT32);
162     auto less = ops::Less(scope.WithOpName("cond/Less"), y, x);
163     auto if_op =
164         ops::If(scope.WithOpName(op_name), less,
165                 std::initializer_list<Input>{less, y, x}, {DT_INT32}, then_fn,
166                 else_fn, ops::If::OutputShapes({PartialTensorShape()}));
167     auto id = ops::Identity(scope.WithOpName("cond/Merge"), if_op.output[0]);
168     GraphDef expected;
169     TF_EXPECT_OK(scope.ToGraphDef(&expected));
170     TF_EXPECT_GRAPH_EQ(expected, graph_def);
171   }
172 
173   // then body.
174   {
175     Scope scope = Scope::NewRootScope().ExitOnError();
176     auto arg_0 = ops::_Arg(scope.WithOpName("arg0"), DT_BOOL, 0);
177     auto arg_1 = ops::_Arg(scope.WithOpName("arg1"), DT_INT32, 1);
178     auto arg_2 = ops::_Arg(scope.WithOpName("arg2"), DT_INT32, 2);
179     auto identity = ops::Identity(scope.WithOpName("cond/Identity"), arg_0);
180     auto cond = ops::Const(
181         scope.WithOpName("cond").WithControlDependencies(identity), 17);
182     auto mul = ops::Mul(scope.WithOpName("cond/Mul"), arg_1, cond);
183     auto retval0 = ops::_Retval(scope.WithOpName("retval0_RetVal"), mul, 0);
184 
185     GraphDef expected;
186     TF_EXPECT_OK(scope.ToGraphDef(&expected));
187 
188     InstantiationResultForTest result;
189     TF_EXPECT_OK(InstantiateFunctionForTest(then_fn.name(), library, &result));
190 
191     EXPECT_EQ(DataTypeVector{DT_INT32}, result.ret_types);
192     EXPECT_EQ((DataTypeVector{DT_BOOL, DT_INT32, DT_INT32}), result.arg_types);
193     TF_EXPECT_GRAPH_EQ(expected, result.gdef);
194   }
195 
196   // else body.
197   {
198     Scope scope = Scope::NewRootScope().ExitOnError();
199     auto arg_0 = ops::_Arg(scope.WithOpName("arg0"), DT_BOOL, 0);
200     auto arg_1 = ops::_Arg(scope.WithOpName("arg1"), DT_INT32, 1);
201     auto arg_2 = ops::_Arg(scope.WithOpName("arg2"), DT_INT32, 2);
202     auto identity = ops::Identity(scope.WithOpName("cond/Identity_1"), arg_0);
203     auto cond_1 = ops::Const(
204         scope.WithOpName("cond_1").WithControlDependencies(identity), 23);
205     auto add = ops::Add(scope.WithOpName("cond/false/add"), arg_2, cond_1);
206     auto retval0 = ops::_Retval(scope.WithOpName("retval0_RetVal"), add, 0);
207 
208     GraphDef expected;
209     TF_EXPECT_OK(scope.ToGraphDef(&expected));
210 
211     InstantiationResultForTest result;
212     TF_EXPECT_OK(InstantiateFunctionForTest(else_fn.name(), library, &result));
213 
214     EXPECT_EQ(DataTypeVector{DT_INT32}, result.ret_types);
215     EXPECT_EQ((DataTypeVector{DT_BOOL, DT_INT32, DT_INT32}), result.arg_types);
216     TF_EXPECT_GRAPH_EQ(expected, result.gdef);
217 
218     // Check that internal attributes were correctly propagated to `If` node
219     // (such attributes are ignored in the above graph equality check).
220     for (const NodeDef& node : graph_def.node()) {
221       if (node.op() == "If") {
222         for (absl::string_view attr_name : kAttrsToPropagate) {
223           std::string attr_val;
224           TF_EXPECT_OK(GetNodeAttr(node, attr_name, &attr_val));
225           EXPECT_EQ(attr_val, "value");
226         }
227       }
228     }
229   }
230 }
231 
RunTest()232 void ConditionalTestFixture::RunTest() {
233   Graph graph(OpRegistry::Global());
234   if (wrap_condition_in_function_) {
235     // Wrap condition in a function which is called from `graph`.
236     Scope scope = Scope::NewRootScope().ExitOnError();
237     auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32);
238 
239     Graph cond_graph(OpRegistry::Global());
240     BuildCondGraph(&cond_graph);
241 
242     FunctionDef cond_fdef;
243     TF_ASSERT_OK(GraphToFunctionDef(cond_graph, "cond_fn", &cond_fdef));
244 
245     FunctionDefLibrary fdef_lib;
246     *(fdef_lib.add_function()) = cond_fdef;
247     TF_ASSERT_OK(scope.graph()->AddFunctionLibrary(fdef_lib));
248     NodeDef cond_fn;
249     cond_fn.set_name("cond_node");
250     cond_fn.set_op("cond_fn");
251     *(cond_fn.add_input()) = "source";
252     Status status;
253     scope.graph()->AddNode(cond_fn, &status);
254     TF_ASSERT_OK(status);
255     TF_ASSERT_OK(scope.ToGraph(&graph));
256   } else {
257     // Build condition in `graph`.
258     BuildCondGraph(&graph);
259   }
260   FunctionLibraryDefinition library(graph.flib_def());
261   // If `restrict_to_tpu_nodes_` is true let filter function return true for
262   // `_tpu_replicate` nodes.
263   NodeFilter node_filter =
264       restrict_to_tpu_nodes_
265           ? [](const Node* n) { return n->attrs().Find("_tpu_replicate"); }
266           : NodeFilter{};
267 
268   GraphDef optimized_graph_def;
269   graph.ToGraphDef(&optimized_graph_def);
270   TF_ASSERT_OK(FunctionalizeControlFlowForGraphDef(
271       &optimized_graph_def, &library, node_filter,
272       /*include_functions=*/wrap_condition_in_function_));
273   TF_ASSERT_OK(FunctionalizeControlFlow(
274       &graph, &library, node_filter,
275       /*include_functions=*/wrap_condition_in_function_));
276 
277   if (wrap_condition_in_function_) {
278     // Check if function body was functionalized.
279     auto pflr = std::make_unique<ProcessFunctionLibraryRuntime>(
280         /*device_mgr=*/nullptr, tensorflow::Env::Default(),
281         /*config=*/nullptr, TF_GRAPH_DEF_VERSION, &library,
282         tensorflow::OptimizerOptions());
283     FunctionLibraryRuntime* flr =
284         pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice);
285     FunctionLibraryRuntime::Handle handle;
286 
287     // Functionalized function name is the type string of `cond_node`.
288     string func_name;
289     for (Node* n : graph.nodes()) {
290       if (n->name() == "cond_node") {
291         func_name = n->type_string();
292         break;
293       }
294     }
295     TF_ASSERT_OK(flr->Instantiate(func_name, AttrSlice(), &handle));
296     const FunctionBody* body = flr->GetFunctionBody(handle);
297     GraphDef graph_def;
298     body->graph->ToGraphDef(&graph_def);
299     CheckGraphDef(graph_def, library);
300   } else {
301     // Check if graphs were functionalized.
302     CheckGraphDef(optimized_graph_def, library);
303     GraphDef converted_graph_def;
304     graph.ToGraphDef(&converted_graph_def);
305     CheckGraphDef(converted_graph_def, library);
306   }
307 }
308 
309 // Returns the names of the "cond" and "body" functions for the While node
310 // in a graph.
FindWhileCondAndBody(const GraphDef & graph,NameAttrList * cond,NameAttrList * body)311 Status FindWhileCondAndBody(const GraphDef& graph, NameAttrList* cond,
312                             NameAttrList* body) {
313   for (const NodeDef& node : graph.node()) {
314     if (node.op() == "While") {
315       const NameAttrList* result;
316       TF_RETURN_IF_ERROR(GetNodeAttr(node, "cond", &result));
317       *cond = *result;
318       TF_RETURN_IF_ERROR(GetNodeAttr(node, "body", &result));
319       *body = *result;
320       return OkStatus();
321     }
322   }
323   return errors::NotFound("No While node found in graph");
324 }
325 
326 // Graph:
327 // x = array_ops.placeholder(dtypes.int32)
328 // y = control_flow_ops.while_loop(lambda i: i < 10, lambda i: i + 1, [x])
TEST(FunctionalizeControlFlow,OneLoopVar)329 TEST(FunctionalizeControlFlow, OneLoopVar) {
330   Graph graph(OpRegistry::Global());
331   {
332     Scope scope = Scope::NewRootScope().ExitOnError();
333 
334     auto dummy = ops::Placeholder(scope.WithOpName("Dummy"), DT_INT32);
335 
336     auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32);
337     auto enter =
338         ops::internal::Enter(scope.WithOpName("while/Enter"), source, "aloop");
339     // Add an unused Enter node. These should be ignored.
340     auto enter2 =
341         ops::internal::Enter(scope.WithOpName("while/Enter2"), source, "aloop");
342     auto merge = ops::Merge(scope.WithOpName("while/Merge"),
343                             std::initializer_list<Input>{enter, dummy});
344     auto ten = ops::Const<int32>(
345         scope.WithOpName("while/Less/y").WithControlDependencies(merge.output),
346         10);
347     auto less = ops::Less(scope.WithOpName("while/Less"), merge.output, ten);
348     auto loop_cond = ops::LoopCond(scope.WithOpName("while/LoopCond"), less);
349     auto switch_ =
350         ops::Switch(scope.WithOpName("while/Switch"), merge.output, loop_cond);
351     auto exit = ops::internal::Exit(scope.WithOpName("while/Exit"),
352                                     switch_.output_false);
353     auto identity =
354         ops::Identity(scope.WithOpName("while/Identity"), switch_.output_true);
355     auto one = ops::Const<int32>(
356         scope.WithOpName("while/add/y").WithControlDependencies(identity), 1);
357     auto add = ops::Add(scope.WithOpName("while/add"), identity, one);
358     auto next_iteration =
359         ops::NextIteration(scope.WithOpName("while/NextIteration"), add);
360 
361     auto sink = ops::Identity(scope.WithOpName("sink"), exit);
362 
363     // Remove the dummy node and add the loop backedge.
364     scope.graph()->RemoveNode(dummy.node());
365     scope.graph()->AddEdge(next_iteration.node(), 0, merge.output.node(), 1);
366 
367     TF_EXPECT_OK(scope.ToGraph(&graph));
368   }
369 
370   // Regression test: control edges from an Enter node to the graph sink should
371   // be ignored.
372   for (Node* n : graph.nodes()) {
373     if (n->name() == "while/Enter") {
374       graph.AddControlEdge(n, graph.sink_node());
375     }
376   }
377 
378   FunctionLibraryDefinition library(OpRegistry::Global(), {});
379   GraphDef optimized_graph_def;
380   graph.ToGraphDef(&optimized_graph_def);
381   TF_ASSERT_OK(
382       FunctionalizeControlFlowForGraphDef(&optimized_graph_def, &library));
383   TF_ASSERT_OK(FunctionalizeControlFlow(&graph, &library));
384   GraphDef converted_graph_def;
385   graph.ToGraphDef(&converted_graph_def);
386 
387   for (const GraphDef& graph_def : {optimized_graph_def, converted_graph_def}) {
388     NameAttrList cond_fn, body_fn;
389     TF_EXPECT_OK(FindWhileCondAndBody(graph_def, &cond_fn, &body_fn));
390 
391     // Outer graph
392     {
393       Scope scope = Scope::NewRootScope().ExitOnError();
394       auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32);
395       auto while_op =
396           ops::While(scope.WithOpName("while/LoopCond"),
397                      std::initializer_list<Input>{source}, cond_fn, body_fn);
398       auto sink = ops::Identity(scope.WithOpName("sink"), while_op[0]);
399       GraphDef expected;
400       TF_EXPECT_OK(scope.ToGraphDef(&expected));
401       TF_EXPECT_GRAPH_EQ(expected, graph_def);
402     }
403 
404     // Condition graph
405     {
406       Scope scope = Scope::NewRootScope().ExitOnError();
407       auto arg = ops::_Arg(scope.WithOpName("arg0"), DT_INT32, 0);
408       auto ten = ops::Const<int32>(
409           scope.WithOpName("while/Less/y").WithControlDependencies(arg), 10);
410       auto less = ops::Less(scope.WithOpName("while/Less"), arg, ten);
411       auto retval = ops::_Retval(scope.WithOpName("retval0_RetVal"), less, 0);
412 
413       GraphDef expected;
414       TF_EXPECT_OK(scope.ToGraphDef(&expected));
415 
416       InstantiationResultForTest result;
417       TF_EXPECT_OK(
418           InstantiateFunctionForTest(cond_fn.name(), library, &result));
419 
420       EXPECT_EQ(DataTypeVector{DT_INT32}, result.arg_types);
421       EXPECT_EQ(DataTypeVector{DT_BOOL}, result.ret_types);
422       TF_EXPECT_GRAPH_EQ(expected, result.gdef);
423     }
424 
425     // Body graph.
426     {
427       Scope scope = Scope::NewRootScope().ExitOnError();
428       auto arg = ops::_Arg(scope.WithOpName("arg0"), DT_INT32, 0);
429       auto identity = ops::Identity(scope.WithOpName("while/Identity"), arg);
430       auto one = ops::Const<int32>(
431           scope.WithOpName("while/add/y").WithControlDependencies(identity), 1);
432       auto add = ops::Add(scope.WithOpName("while/add"), identity, one);
433       auto retval = ops::_Retval(scope.WithOpName("retval0_RetVal"), add, 0);
434 
435       GraphDef expected;
436       TF_EXPECT_OK(scope.ToGraphDef(&expected));
437 
438       InstantiationResultForTest result;
439       TF_EXPECT_OK(
440           InstantiateFunctionForTest(body_fn.name(), library, &result));
441 
442       EXPECT_EQ(DataTypeVector{DT_INT32}, result.arg_types);
443       EXPECT_EQ(DataTypeVector{DT_INT32}, result.ret_types);
444       TF_EXPECT_GRAPH_EQ(expected, result.gdef);
445     }
446   }
447 }
448 
GetNoinlineFunctionDef()449 FunctionDef GetNoinlineFunctionDef() {
450   FunctionDef fdef = FunctionDefHelper::Create(
451       "increment_fn", {"x:int32"}, {"add:int32"}, {},
452       {
453           {{"add/y"}, "Const", {}, {{"dtype", DT_INT32}}},
454           {{"add_0"}, "Add", {"x", "add/y:output:0"}, {{"T", DT_INT32}}},
455       },
456       {{"add", "add_0:z:0"}});
457   (*fdef.mutable_attr())["_noinline"].set_b(true);
458   return fdef;
459 }
460 
461 // @function.Defun(noinline=True)
462 // def increment_fn(x):
463 //   return [x + 1]
464 // Define the above function, and add it to the given graph. It's used as the
465 // while loop body in NoinlineLoopBody test.
AddNoinlineFunctionToGraph(const string & node_name,Graph * graph)466 Status AddNoinlineFunctionToGraph(const string& node_name, Graph* graph) {
467   FunctionDefLibrary fdef_lib;
468   *(fdef_lib.add_function()) = GetNoinlineFunctionDef();
469   TF_RETURN_IF_ERROR(graph->AddFunctionLibrary(fdef_lib));
470   NodeDef increment_fn;
471   increment_fn.set_name(node_name);
472   increment_fn.set_op("increment_fn");
473   *increment_fn.add_input() = "while/Identity";
474   *increment_fn.add_input() = "^while/Identity";
475   Status status;
476   graph->AddNode(increment_fn, &status);
477   return status;
478 }
479 
480 // Graph:
481 // x = array_ops.placeholder(dtypes.int32)
482 // y = control_flow_ops.while_loop(lambda i: i < 10, increment_fn, [x])
TEST(FunctionalizeControlFlow,NoinlineLoopBody)483 TEST(FunctionalizeControlFlow, NoinlineLoopBody) {
484   const string& noinline_node_name = "while/increment_fn";
485   Graph graph(OpRegistry::Global());
486   {
487     Scope scope = Scope::NewRootScope().ExitOnError();
488     auto dummy = ops::Placeholder(scope.WithOpName("Dummy"), DT_INT32);
489     auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32);
490     auto enter = ops::internal::Enter(scope.WithOpName("while/Enter"), source,
491                                       "while/while_context");
492     auto merge = ops::Merge(scope.WithOpName("while/Merge"),
493                             std::initializer_list<Input>{enter, dummy});
494     auto ten = ops::Const<int32>(
495         scope.WithOpName("while/Less/y").WithControlDependencies(merge.output),
496         10);
497     auto less = ops::Less(scope.WithOpName("while/Less"), merge.output, ten);
498     auto loop_cond = ops::LoopCond(scope.WithOpName("while/LoopCond"), less);
499     auto switch_ =
500         ops::Switch(scope.WithOpName("while/Switch"), merge.output, loop_cond);
501     auto exit = ops::internal::Exit(scope.WithOpName("while/Exit"),
502                                     switch_.output_false);
503     auto identity =
504         ops::Identity(scope.WithOpName("while/Identity"), switch_.output_true);
505 
506     TF_ASSERT_OK(AddNoinlineFunctionToGraph(noinline_node_name, scope.graph()));
507 
508     NodeDef next_iter;
509     next_iter.set_name("while/NextIteration");
510     next_iter.set_op("NextIteration");
511     *next_iter.add_input() = noinline_node_name;
512     (*next_iter.mutable_attr())["T"].set_type(DT_INT32);
513 
514     Status status;
515     Node* n = scope.graph()->AddNode(next_iter, &status);
516     TF_ASSERT_OK(status);
517 
518     // Remove the dummy node and add the loop backedge.
519     scope.graph()->RemoveNode(dummy.node());
520     scope.graph()->AddEdge(n, 0, merge.output.node(), 1);
521     TF_ASSERT_OK(scope.ToGraph(&graph));
522   }
523 
524   FunctionLibraryDefinition library(graph.flib_def());
525   GraphDef optimized_graph_def;
526   graph.ToGraphDef(&optimized_graph_def);
527 
528   *(optimized_graph_def.mutable_library()->add_function()) =
529       GetNoinlineFunctionDef();
530 
531   TF_ASSERT_OK(
532       FunctionalizeControlFlowForGraphDef(&optimized_graph_def, &library));
533   TF_ASSERT_OK(FunctionalizeControlFlow(&graph, &library));
534   GraphDef converted_graph_def;
535   graph.ToGraphDef(&converted_graph_def);
536 
537   for (const GraphDef& graph_def : {optimized_graph_def, converted_graph_def}) {
538     NameAttrList cond_fn, body_fn;
539     TF_ASSERT_OK(FindWhileCondAndBody(graph_def, &cond_fn, &body_fn));
540 
541     // Outer graph
542     {
543       Scope scope = Scope::NewRootScope().ExitOnError();
544       auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32);
545       auto while_op =
546           ops::While(scope.WithOpName("while/LoopCond"),
547                      std::initializer_list<Input>{source}, cond_fn, body_fn);
548       GraphDef expected;
549       TF_ASSERT_OK(scope.ToGraphDef(&expected));
550       TF_EXPECT_GRAPH_EQ(expected, graph_def);
551     }
552 
553     // Body graph.
554     {
555       Scope scope = Scope::NewRootScope().ExitOnError();
556       auto arg = ops::_Arg(scope.WithOpName("arg0"), DT_INT32, 0);
557       TF_ASSERT_OK(
558           AddNoinlineFunctionToGraph(noinline_node_name, scope.graph()));
559       auto identity = ops::Identity(scope.WithOpName("while/Identity"), arg);
560       NodeDef retval;
561       retval.set_name("retval0_RetVal");
562       retval.set_op(FunctionLibraryDefinition::kRetOp);
563       *retval.add_input() = noinline_node_name;
564       (*retval.mutable_attr())["T"].set_type(DT_INT32);
565       (*retval.mutable_attr())["index"].set_i(0);
566       Status status;
567       scope.graph()->AddNode(retval, &status);
568       TF_ASSERT_OK(status);
569 
570       GraphDef expected;
571       TF_ASSERT_OK(scope.ToGraphDef(&expected));
572 
573       InstantiationResultForTest result;
574       // Verify that increment_fn has been copied to library.
575       TF_EXPECT_OK(
576           InstantiateFunctionForTest(body_fn.name(), library, &result));
577 
578       EXPECT_EQ(DataTypeVector{DT_INT32}, result.arg_types);
579       EXPECT_EQ(DataTypeVector{DT_INT32}, result.ret_types);
580       // Ignore the function library when comparing the graphs.
581       expected.clear_library();
582       TF_EXPECT_GRAPH_EQ(expected, result.gdef);
583     }
584   }
585 }
586 
TEST(FunctionalizeControlFlow,MissingFunctionDefInLibrary)587 TEST(FunctionalizeControlFlow, MissingFunctionDefInLibrary) {
588   const string& noinline_node_name = "while/increment_fn";
589   Graph graph(OpRegistry::Global());
590   {
591     Scope scope = Scope::NewRootScope().ExitOnError();
592     auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32);
593     auto identity = ops::Identity(scope.WithOpName("while/Identity"), source);
594     TF_ASSERT_OK(AddNoinlineFunctionToGraph(noinline_node_name, scope.graph()));
595     TF_ASSERT_OK(scope.ToGraph(&graph));
596   }
597 
598   FunctionLibraryDefinition library(graph.flib_def());
599   GraphDef graph_def;
600   graph.ToGraphDef(&graph_def);
601   graph_def.clear_library();
602 
603   Status status = FunctionalizeControlFlowForGraphDef(&graph_def, &library);
604   EXPECT_EQ(tensorflow::error::NOT_FOUND, status.code());
605 }
606 
607 // Tests functionalizing OneLoopVar where the loop value is not used post the
608 // loop.
609 // Graph:
610 // x = array_ops.placeholder(dtypes.int32)
611 // control_flow_ops.while_loop(lambda i: i < 10, lambda i: i + 1, [x])
TEST(FunctionalizeControlFlow,OneLoopVarWithoutExit)612 TEST(FunctionalizeControlFlow, OneLoopVarWithoutExit) {
613   Graph graph(OpRegistry::Global());
614   {
615     Scope scope = Scope::NewRootScope().ExitOnError();
616 
617     auto dummy = ops::Placeholder(scope.WithOpName("Dummy"), DT_INT32);
618 
619     auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32);
620     auto enter =
621         ops::internal::Enter(scope.WithOpName("while/Enter"), source, "aloop");
622     auto merge = ops::Merge(scope.WithOpName("while/Merge"),
623                             std::initializer_list<Input>{enter, dummy});
624     auto ten = ops::Const<int32>(
625         scope.WithOpName("while/Less/y").WithControlDependencies(merge.output),
626         10);
627     auto less = ops::Less(scope.WithOpName("while/Less"), merge.output, ten);
628     auto loop_cond = ops::LoopCond(scope.WithOpName("while/LoopCond"), less);
629     auto switch_ =
630         ops::Switch(scope.WithOpName("while/Switch"), merge.output, loop_cond);
631     auto identity =
632         ops::Identity(scope.WithOpName("while/Identity"), switch_.output_true);
633     auto one = ops::Const<int32>(
634         scope.WithOpName("while/add/y").WithControlDependencies(identity), 1);
635     auto add = ops::Add(scope.WithOpName("while/add"), identity, one);
636     auto next_iteration =
637         ops::NextIteration(scope.WithOpName("while/NextIteration"), add);
638 
639     // Remove the dummy node and add the loop backedge.
640     scope.graph()->RemoveNode(dummy.node());
641     scope.graph()->AddEdge(next_iteration.node(), 0, merge.output.node(), 1);
642 
643     TF_EXPECT_OK(scope.ToGraph(&graph));
644   }
645 
646   FunctionLibraryDefinition library(OpRegistry::Global(), {});
647   GraphDef optimized_graph_def;
648   graph.ToGraphDef(&optimized_graph_def);
649   TF_ASSERT_OK(
650       FunctionalizeControlFlowForGraphDef(&optimized_graph_def, &library));
651   TF_ASSERT_OK(FunctionalizeControlFlow(&graph, &library));
652   GraphDef converted_graph_def;
653   graph.ToGraphDef(&converted_graph_def);
654 
655   for (const GraphDef& graph_def : {optimized_graph_def, converted_graph_def}) {
656     NameAttrList cond_fn, body_fn;
657     TF_EXPECT_OK(FindWhileCondAndBody(graph_def, &cond_fn, &body_fn));
658 
659     // Outer graph
660     {
661       Scope scope = Scope::NewRootScope().ExitOnError();
662       auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32);
663       auto while_op =
664           ops::While(scope.WithOpName("while/LoopCond"),
665                      std::initializer_list<Input>{source}, cond_fn, body_fn);
666       GraphDef expected;
667       TF_EXPECT_OK(scope.ToGraphDef(&expected));
668       TF_EXPECT_GRAPH_EQ(expected, graph_def);
669     }
670 
671     // Condition graph
672     {
673       Scope scope = Scope::NewRootScope().ExitOnError();
674       auto arg = ops::_Arg(scope.WithOpName("arg0"), DT_INT32, 0);
675       auto ten = ops::Const<int32>(
676           scope.WithOpName("while/Less/y").WithControlDependencies(arg), 10);
677       auto less = ops::Less(scope.WithOpName("while/Less"), arg, ten);
678       auto retval = ops::_Retval(scope.WithOpName("retval0_RetVal"), less, 0);
679 
680       GraphDef expected;
681       TF_EXPECT_OK(scope.ToGraphDef(&expected));
682 
683       InstantiationResultForTest result;
684       TF_EXPECT_OK(
685           InstantiateFunctionForTest(cond_fn.name(), library, &result));
686 
687       EXPECT_EQ(DataTypeVector{DT_INT32}, result.arg_types);
688       EXPECT_EQ(DataTypeVector{DT_BOOL}, result.ret_types);
689       TF_EXPECT_GRAPH_EQ(expected, result.gdef);
690     }
691 
692     // Body graph.
693     {
694       Scope scope = Scope::NewRootScope().ExitOnError();
695       auto arg = ops::_Arg(scope.WithOpName("arg0"), DT_INT32, 0);
696       auto identity = ops::Identity(scope.WithOpName("while/Identity"), arg);
697       auto one = ops::Const<int32>(
698           scope.WithOpName("while/add/y").WithControlDependencies(identity), 1);
699       auto add = ops::Add(scope.WithOpName("while/add"), identity, one);
700       auto retval = ops::_Retval(scope.WithOpName("retval0_RetVal"), add, 0);
701 
702       GraphDef expected;
703       TF_EXPECT_OK(scope.ToGraphDef(&expected));
704 
705       InstantiationResultForTest result;
706       TF_EXPECT_OK(
707           InstantiateFunctionForTest(body_fn.name(), library, &result));
708 
709       EXPECT_EQ(DataTypeVector{DT_INT32}, result.arg_types);
710       EXPECT_EQ(DataTypeVector{DT_INT32}, result.ret_types);
711       TF_EXPECT_GRAPH_EQ(expected, result.gdef);
712     }
713   }
714 }
715 
716 // Graph:
717 // x = array_ops.placeholder(dtypes.int32)
718 // y = array_ops.placeholder(dtypes.int32)
719 // cond = lambda (i, j): i + 3 < 10
720 // body = lambda (i, j): (i < 10, j * 2)
721 // z = control_flow_ops.while_loop(cond, body, [x, y])
TEST(FunctionalizeControlFlow,TwoLoopVars)722 TEST(FunctionalizeControlFlow, TwoLoopVars) {
723   Graph graph(OpRegistry::Global());
724   {
725     Scope scope = Scope::NewRootScope().ExitOnError();
726 
727     auto dummy = ops::Placeholder(scope.WithOpName("Dummy"), DT_INT32);
728 
729     auto x = ops::Placeholder(scope.WithOpName("Placeholder/x"), DT_INT32);
730     auto y = ops::Placeholder(scope.WithOpName("Placeholder/y"), DT_INT32);
731     auto enter_x =
732         ops::internal::Enter(scope.WithOpName("while/Enter/x"), x, "aloop");
733     auto enter_y =
734         ops::internal::Enter(scope.WithOpName("while/Enter/y"), y, "aloop");
735     auto merge_x = ops::Merge(scope.WithOpName("while/Merge/x"),
736                               std::initializer_list<Input>{enter_x, dummy});
737     auto merge_y = ops::Merge(scope.WithOpName("while/Merge/y"),
738                               std::initializer_list<Input>{enter_y, dummy});
739 
740     // Loop condition
741     auto three = ops::Const<int32>(scope.WithOpName("while/cond/three")
742                                        .WithControlDependencies(merge_x.output),
743                                    3);
744     auto cond_add =
745         ops::Add(scope.WithOpName("while/cond/Add"), merge_x.output, three);
746     auto ten = ops::Const<int32>(scope.WithOpName("while/cond/ten")
747                                      .WithControlDependencies(merge_x.output),
748                                  10);
749     auto less = ops::Less(scope.WithOpName("while/cond/Less"), cond_add, ten);
750     auto loop_cond = ops::LoopCond(scope.WithOpName("while/LoopCond"), less);
751 
752     auto switch_x = ops::Switch(scope.WithOpName("while/Switch/x"),
753                                 merge_x.output, loop_cond);
754     auto switch_y = ops::Switch(scope.WithOpName("while/Switch/y"),
755                                 merge_y.output, loop_cond);
756 
757     auto exit_x = ops::internal::Exit(scope.WithOpName("while/Exit/x"),
758                                       switch_x.output_false);
759     auto exit_y = ops::internal::Exit(scope.WithOpName("while/Exit/y"),
760                                       switch_y.output_false);
761 
762     auto identity_x = ops::Identity(scope.WithOpName("while/Identity/x"),
763                                     switch_x.output_true);
764     auto identity_y = ops::Identity(scope.WithOpName("while/Identity/y"),
765                                     switch_y.output_true);
766 
767     auto one = ops::Const<int32>(
768         scope.WithOpName("while/add/one").WithControlDependencies(identity_x),
769         1);
770     auto two = ops::Const<int32>(
771         scope.WithOpName("while/mul/two").WithControlDependencies(identity_x),
772         2);
773 
774     auto add = ops::Add(scope.WithOpName("while/add"), identity_x, one);
775     auto mul = ops::Add(scope.WithOpName("while/mul"), identity_y, two);
776     auto next_iteration_x =
777         ops::NextIteration(scope.WithOpName("while/NextIteration/x"), add);
778     auto next_iteration_y =
779         ops::NextIteration(scope.WithOpName("while/NextIteration/y"), mul);
780 
781     auto sink_x = ops::Identity(scope.WithOpName("sink_x"), exit_x);
782     auto sink_y = ops::Identity(scope.WithOpName("sink_y"), exit_y);
783 
784     // Remove the dummy node and add the loop backedges.
785     scope.graph()->RemoveNode(dummy.node());
786     scope.graph()->AddEdge(next_iteration_x.node(), 0, merge_x.output.node(),
787                            1);
788     scope.graph()->AddEdge(next_iteration_y.node(), 0, merge_y.output.node(),
789                            1);
790 
791     TF_EXPECT_OK(scope.ToGraph(&graph));
792   }
793 
794   FunctionLibraryDefinition library(OpRegistry::Global(), {});
795   GraphDef optimized_graph_def;
796   graph.ToGraphDef(&optimized_graph_def);
797   TF_ASSERT_OK(
798       FunctionalizeControlFlowForGraphDef(&optimized_graph_def, &library));
799   TF_ASSERT_OK(FunctionalizeControlFlow(&graph, &library));
800   GraphDef converted_graph_def;
801   graph.ToGraphDef(&converted_graph_def);
802 
803   for (const GraphDef& graph_def : {optimized_graph_def, converted_graph_def}) {
804     NameAttrList cond_fn, body_fn;
805     TF_EXPECT_OK(FindWhileCondAndBody(graph_def, &cond_fn, &body_fn));
806 
807     // Outer graph.
808     {
809       Scope scope = Scope::NewRootScope().ExitOnError();
810       auto x = ops::Placeholder(scope.WithOpName("Placeholder/x"), DT_INT32);
811       auto y = ops::Placeholder(scope.WithOpName("Placeholder/y"), DT_INT32);
812       auto while_op =
813           ops::While(scope.WithOpName("while/LoopCond"),
814                      std::initializer_list<Input>{x, y}, cond_fn, body_fn);
815       auto sink_x = ops::Identity(scope.WithOpName("sink_x"), while_op[0]);
816       auto sink_y = ops::Identity(scope.WithOpName("sink_y"), while_op[1]);
817       GraphDef expected;
818       TF_EXPECT_OK(scope.ToGraphDef(&expected));
819       TF_EXPECT_GRAPH_EQ(expected, graph_def);
820     }
821 
822     // Condition graph.
823     {
824       Scope scope = Scope::NewRootScope().ExitOnError();
825       auto arg0 = ops::_Arg(scope.WithOpName("arg0"), DT_INT32, 0);
826       auto arg1 = ops::_Arg(scope.WithOpName("arg1"), DT_INT32, 1);
827       auto three = ops::Const<int32>(scope.WithOpName("while/cond/three")
828                                          .WithControlDependencies(arg0.output),
829                                      3);
830       auto cond_add =
831           ops::Add(scope.WithOpName("while/cond/Add"), arg0.output, three);
832       auto ten = ops::Const<int32>(scope.WithOpName("while/cond/ten")
833                                        .WithControlDependencies(arg0.output),
834                                    10);
835       auto less = ops::Less(scope.WithOpName("while/cond/Less"), cond_add, ten);
836       auto retval = ops::_Retval(scope.WithOpName("retval0_RetVal"), less, 0);
837 
838       GraphDef expected;
839       TF_EXPECT_OK(scope.ToGraphDef(&expected));
840 
841       InstantiationResultForTest result;
842       TF_EXPECT_OK(
843           InstantiateFunctionForTest(cond_fn.name(), library, &result));
844 
845       EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32}), result.arg_types);
846       EXPECT_EQ(DataTypeVector{DT_BOOL}, result.ret_types);
847       TF_EXPECT_GRAPH_EQ(expected, result.gdef);
848     }
849 
850     // Body graph.
851     {
852       Scope scope = Scope::NewRootScope().ExitOnError();
853       auto arg0 = ops::_Arg(scope.WithOpName("arg0"), DT_INT32, 0);
854       auto arg1 = ops::_Arg(scope.WithOpName("arg1"), DT_INT32, 1);
855 
856       auto identity_x =
857           ops::Identity(scope.WithOpName("while/Identity/x"), arg0);
858       auto identity_y =
859           ops::Identity(scope.WithOpName("while/Identity/y"), arg1);
860 
861       auto one = ops::Const<int32>(
862           scope.WithOpName("while/add/one").WithControlDependencies(identity_x),
863           1);
864       auto two = ops::Const<int32>(
865           scope.WithOpName("while/mul/two").WithControlDependencies(identity_x),
866           2);
867 
868       auto add = ops::Add(scope.WithOpName("while/add"), identity_x, one);
869       auto mul = ops::Add(scope.WithOpName("while/mul"), identity_y, two);
870       auto retval0 = ops::_Retval(scope.WithOpName("retval0_RetVal"), add, 0);
871       auto retval1 = ops::_Retval(scope.WithOpName("retval1_RetVal"), mul, 1);
872 
873       GraphDef expected;
874       TF_EXPECT_OK(scope.ToGraphDef(&expected));
875 
876       InstantiationResultForTest result;
877       TF_EXPECT_OK(
878           InstantiateFunctionForTest(body_fn.name(), library, &result));
879 
880       EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32}), result.arg_types);
881       EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32}), result.ret_types);
882       TF_EXPECT_GRAPH_EQ(expected, result.gdef);
883     }
884   }
885 }
886 
887 // More complex example with nesting, loop-invariant arguments, and resource
888 // variables. Used for multiple tests with different node filters.
889 class ComplexTestFixture
890     : public ::testing::TestWithParam<std::tuple<bool, bool, bool>> {
891  protected:
SetUp()892   void SetUp() override {
893     restrict_to_tpu_nodes_ = std::get<0>(GetParam());
894     mark_inner_loop_tpu_ = std::get<1>(GetParam());
895     mark_outer_loop_tpu_ = std::get<2>(GetParam());
896   }
897   void RunTest();
898 
899  private:
900   void CheckOuterNodesFunctionalized(const GraphDef& graph_def,
901                                      const FunctionLibraryDefinition& library,
902                                      NameAttrList& inner_cond_fn,
903                                      NameAttrList& inner_body_fn);
904   void CheckInnerNodesFunctionalized(const GraphDef& graph_def,
905                                      const FunctionLibraryDefinition& library,
906                                      const NameAttrList& inner_cond_fn,
907                                      const NameAttrList& inner_body_fn);
908 
909   bool restrict_to_tpu_nodes_ = false;
910   bool mark_inner_loop_tpu_ = false;
911   bool mark_outer_loop_tpu_ = false;
912 };
913 
TEST_P(ComplexTestFixture,ComplexTests)914 TEST_P(ComplexTestFixture, ComplexTests) { RunTest(); }
915 
916 INSTANTIATE_TEST_SUITE_P(
917     FunctionalizeControlFlow, ComplexTestFixture,
918     ::testing::Combine(::testing::Bool(), ::testing::Bool(), ::testing::Bool()),
__anon8d20b2690402(const ::testing::TestParamInfo<ComplexTestFixture::ParamType>& info) 919     [](const ::testing::TestParamInfo<ComplexTestFixture::ParamType>& info) {
920       bool restrict_to_tpu_nodes = std::get<0>(info.param);
921       bool mark_inner_loop_tpu = std::get<1>(info.param);
922       bool mark_outer_loop_tpu = std::get<2>(info.param);
923 
924       string node_string;
925       if (mark_inner_loop_tpu && mark_outer_loop_tpu)
926         node_string = "both_loops_tpu";
927       else if (!mark_inner_loop_tpu && !mark_outer_loop_tpu)
928         node_string = "no_loop_tpu";
929       else
930         node_string = mark_inner_loop_tpu ? "inner_loop_tpu" : "outer_loop_tpu";
931 
932       string name = absl::StrCat(
933           restrict_to_tpu_nodes ? "restricted_" : "unrestricted_", node_string);
934       return name;
935     });
936 
RunTest()937 void ComplexTestFixture::RunTest() {
938   // Graph:
939   //
940   // accum = resource_variable_ops.ResourceVariable(1)
941   // x = array_ops.placeholder(2, dtype=dtypes.int32)
942   // y = 3 + x
943   //
944   // def inner_body(j, k):
945   //   add = state_ops.assign_add(accum, k * j + x)
946   //   with ops.control_dependencies([add]):
947   //     return [j + 1, k]
948   //
949   // def body(i):
950   //   m = control_flow_ops.while_loop(lambda j, k: j < 5, inner_body,
951   //                                   [1, y], name="inner")
952   //   with ops.control_dependencies(m):
953   //     return [i + 1]
954   //
955   // z = control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="outer")
956   Graph graph(OpRegistry::Global());
957   {
958     Scope scope = Scope::NewRootScope().ExitOnError();
959 
960     auto dummy = ops::Placeholder(scope.WithOpName("Dummy"), DT_INT32);
961 
962     auto x = ops::Placeholder(scope.WithOpName("x"), DT_INT32);
963     auto three = ops::Const<int32>(scope.WithOpName("three"), 3);
964     auto y = ops::Add(scope.WithOpName("y"), x, three);
965 
966     auto var = ops::VarHandleOp(scope.WithOpName("Variable"), DT_INT32,
967                                 TensorShape({}));
968 
969     // Outer loop
970     auto zero = ops::Const<int32>(scope.WithOpName("outer/Const"), 0);
971     auto enter_i =
972         ops::internal::Enter(scope.WithOpName("outer/Enter_i"), zero, "outer");
973     auto merge_i = ops::Merge(scope.WithOpName("outer/Merge_i"),
974                               std::initializer_list<Input>{enter_i, dummy});
975     auto ten = ops::Const<int32>(scope.WithOpName("outer/Less/y")
976                                      .WithControlDependencies(merge_i.output),
977                                  10);
978     auto less_i =
979         ops::Less(scope.WithOpName("outer/Less_i"), merge_i.output, ten);
980     auto outer_loop_cond =
981         ops::LoopCond(scope.WithOpName("outer/LoopCond"), less_i);
982     auto switch_i = ops::Switch(scope.WithOpName("outer/Switch"),
983                                 merge_i.output, outer_loop_cond);
984     auto exit_i = ops::internal::Exit(scope.WithOpName("outer/Exit"),
985                                       switch_i.output_false);
986     auto identity_i =
987         ops::Identity(scope.WithOpName("outer/Identity"), switch_i.output_true);
988 
989     auto enter_x_outer =
990         ops::internal::Enter(scope.WithOpName("outer/Enter_x"), x, "outer",
991                              ops::internal::Enter::Attrs().IsConstant(true));
992     auto enter_k_outer =
993         ops::internal::Enter(scope.WithOpName("outer/Enter_k"), y, "outer",
994                              ops::internal::Enter::Attrs().IsConstant(true));
995     auto enter_var_outer =
996         ops::internal::Enter(scope.WithOpName("outer/Enter_var"), var, "outer",
997                              ops::internal::Enter::Attrs().IsConstant(true));
998 
999     // Inner loop
1000     auto one_j = ops::Const<int32>(
1001         scope.WithOpName("outer/j").WithControlDependencies(identity_i), 1);
1002     auto enter_j = ops::internal::Enter(scope.WithOpName("outer/inner/Enter_j"),
1003                                         one_j, "inner");
1004     auto enter_k =
1005         ops::internal::Enter(scope.WithOpName("outer/inner/Enter_k")
1006                                  .WithControlDependencies(identity_i),
1007                              enter_k_outer, "inner");
1008     auto enter_x = ops::internal::Enter(
1009         scope.WithOpName("outer/inner/Enter_x"), enter_x_outer, "inner",
1010         ops::internal::Enter::Attrs().IsConstant(true));
1011     auto enter_var = ops::internal::Enter(
1012         scope.WithOpName("outer/inner/Enter_var"), enter_var_outer, "inner",
1013         ops::internal::Enter::Attrs().IsConstant(true));
1014 
1015     auto merge_j = ops::Merge(scope.WithOpName("outer/inner/Merge_j"),
1016                               std::initializer_list<Input>{enter_j, dummy});
1017     auto merge_k = ops::Merge(scope.WithOpName("outer/inner/Merge_k"),
1018                               std::initializer_list<Input>{enter_k, dummy});
1019 
1020     auto five = ops::Const<int32>(scope.WithOpName("outer/inner/Five")
1021                                       .WithControlDependencies(merge_j.output),
1022                                   5);
1023     auto less_j =
1024         ops::Less(scope.WithOpName("outer/inner/Less_j"), merge_j.output, five);
1025     auto loop_cond =
1026         ops::LoopCond(scope.WithOpName("outer/inner/LoopCond"), less_j);
1027 
1028     auto switch_j = ops::Switch(scope.WithOpName("outer/inner/Switch_j"),
1029                                 merge_j.output, loop_cond);
1030     auto switch_k = ops::Switch(scope.WithOpName("outer/inner/Switch_k"),
1031                                 merge_k.output, loop_cond);
1032     auto exit_j = ops::internal::Exit(scope.WithOpName("outer/inner/Exit_j"),
1033                                       switch_j.output_false);
1034     auto exit_k = ops::internal::Exit(scope.WithOpName("outer/inner/Exit_k"),
1035                                       switch_k.output_false);
1036     auto identity_j = ops::Identity(scope.WithOpName("outer/inner/Identity_j"),
1037                                     switch_j.output_true);
1038     auto identity_k = ops::Identity(scope.WithOpName("outer/inner/Identity_k"),
1039                                     switch_k.output_true);
1040 
1041     // Variable update
1042     auto mul_jk =
1043         ops::Mul(scope.WithOpName("outer/inner/mul"), identity_j, identity_k);
1044     auto add_jkx =
1045         ops::Add(scope.WithOpName("outer/inner/add"), mul_jk, enter_x);
1046     auto assign = ops::AssignAddVariableOp(
1047         scope.WithOpName("outer/inner/assign_add"), enter_var, add_jkx);
1048 
1049     auto one = ops::Const<int32>(
1050         scope.WithOpName("outer/inner/One")
1051             .WithControlDependencies(
1052                 absl::Span<const Operation>{assign.operation}),
1053         1);
1054     auto add_j =
1055         ops::Add(scope.WithOpName("outer/inner/add_j"), identity_j, one);
1056 
1057     auto next_iteration_j = ops::NextIteration(
1058         scope.WithOpName("outer/inner/NextIteration_j"), add_j);
1059     auto next_iteration_k = ops::NextIteration(
1060         scope.WithOpName("outer/inner/NextIteration_k"), identity_k);
1061 
1062     // Body and backedge for outer loop.
1063     auto one_outer = ops::Const<int32>(
1064         scope.WithOpName("outer/add/y").WithControlDependencies(identity_i), 1);
1065     auto add_i =
1066         ops::Add(scope.WithOpName("outer/add")
1067                      .WithControlDependencies(absl::Span<const Operation>{
1068                          exit_j.output.op(), exit_k.output.op()}),
1069                  identity_i, one_outer);
1070     auto next_iteration_i =
1071         ops::NextIteration(scope.WithOpName("outer/NextIteration"), add_i);
1072 
1073     auto sink = ops::Identity(scope.WithOpName("sink"), exit_i);
1074 
1075     // Remove the dummy node and add the loop backedge.
1076     scope.graph()->RemoveNode(dummy.node());
1077     scope.graph()->AddEdge(next_iteration_i.node(), 0, merge_i.output.node(),
1078                            1);
1079     scope.graph()->AddEdge(next_iteration_j.node(), 0, merge_j.output.node(),
1080                            1);
1081     scope.graph()->AddEdge(next_iteration_k.node(), 0, merge_k.output.node(),
1082                            1);
1083 
1084     TF_EXPECT_OK(scope.ToGraph(&graph));
1085   }
1086   // Add '_tpu_replicate' attributes as specified.
1087   for (Node* n : graph.nodes()) {
1088     string name = n->name();
1089     bool is_inner_node = name.find("outer/inner/") != string::npos;
1090     bool is_outer_node = !is_inner_node && name.find("outer/") != string::npos;
1091     if ((is_inner_node && mark_inner_loop_tpu_) ||
1092         (is_outer_node && mark_outer_loop_tpu_)) {
1093       n->AddAttr("_tpu_replicate", "cluster");
1094     }
1095   }
1096 
1097   FunctionLibraryDefinition library(OpRegistry::Global(), {});
1098   GraphDef orig_graph_def, optimized_graph_def;
1099   graph.ToGraphDef(&orig_graph_def);
1100   optimized_graph_def = orig_graph_def;
1101   // If `restrict_to_tpu_nodes_` is true let filter function return true for
1102   // `_tpu_replicate` nodes, otherwise don't set filter.
1103   NodeFilter node_filter =
1104       restrict_to_tpu_nodes_
1105           ? [](const Node* n) { return n->attrs().Find("_tpu_replicate"); }
1106           : NodeFilter{};
1107 
1108   Status status1 = FunctionalizeControlFlowForGraphDef(&optimized_graph_def,
1109                                                        &library, node_filter);
1110   Status status2 = FunctionalizeControlFlow(&graph, &library, node_filter);
1111   ASSERT_EQ(status1, status2);
1112   if (restrict_to_tpu_nodes_ && mark_outer_loop_tpu_ && !mark_inner_loop_tpu_) {
1113     // This case violates the precondition of `FunctionalizeControlFlow`, we
1114     // expect an internal error.
1115     ASSERT_EQ(errors::IsInternal(status1), true);
1116     return;
1117   } else {
1118     // Supported cases, no error expected.
1119     TF_ASSERT_OK(status1);
1120   }
1121 
1122   GraphDef optimized_converted_graph_def;
1123   graph.ToGraphDef(&optimized_converted_graph_def);
1124   for (const GraphDef& graph_def :
1125        {optimized_graph_def, optimized_converted_graph_def}) {
1126     NameAttrList inner_cond_fn, inner_body_fn;
1127     if (!restrict_to_tpu_nodes_ ||
1128         (restrict_to_tpu_nodes_ && mark_outer_loop_tpu_ &&
1129          mark_inner_loop_tpu_)) {
1130       // We expect that both inner and outer nodes have been functionalized.
1131       CheckOuterNodesFunctionalized(graph_def, library, inner_cond_fn,
1132                                     inner_body_fn);
1133       CheckInnerNodesFunctionalized(graph_def, library, inner_cond_fn,
1134                                     inner_body_fn);
1135     } else /*restrict_to_tpu_nodes_ == true*/ {
1136       if (!mark_outer_loop_tpu_ && !mark_inner_loop_tpu_) {
1137         // Graph has no TPU nodes so we expect no functionalization.
1138         TF_EXPECT_GRAPH_EQ(orig_graph_def, graph_def);
1139       } else if (!mark_outer_loop_tpu_ && mark_inner_loop_tpu_) {
1140         // We expect that only inner nodes have been functionalized.
1141         TF_EXPECT_OK(
1142             FindWhileCondAndBody(graph_def, &inner_cond_fn, &inner_body_fn));
1143         CheckInnerNodesFunctionalized(graph_def, library, inner_cond_fn,
1144                                       inner_body_fn);
1145       }
1146     }
1147   }
1148 }
1149 
CheckOuterNodesFunctionalized(const GraphDef & graph_def,const FunctionLibraryDefinition & library,NameAttrList & inner_cond_fn,NameAttrList & inner_body_fn)1150 void ComplexTestFixture::CheckOuterNodesFunctionalized(
1151     const GraphDef& graph_def, const FunctionLibraryDefinition& library,
1152     NameAttrList& inner_cond_fn, NameAttrList& inner_body_fn) {
1153   NameAttrList outer_cond_fn, outer_body_fn;
1154   TF_EXPECT_OK(FindWhileCondAndBody(graph_def, &outer_cond_fn, &outer_body_fn));
1155 
1156   // Outer graph.
1157   {
1158     Scope scope = Scope::NewRootScope().ExitOnError();
1159     auto x = ops::Placeholder(scope.WithOpName("x"), DT_INT32);
1160     auto three = ops::Const<int32>(scope.WithOpName("three"), 3);
1161     auto y = ops::Add(scope.WithOpName("y"), x, three);
1162 
1163     auto var = ops::VarHandleOp(scope.WithOpName("Variable"), DT_INT32,
1164                                 TensorShape({}));
1165 
1166     auto zero = ops::Const<int32>(scope.WithOpName("outer/Const"), 0);
1167 
1168     auto while_op = ops::While(scope.WithOpName("outer/LoopCond"),
1169                                std::initializer_list<Input>{zero, y, x, var},
1170                                outer_cond_fn, outer_body_fn);
1171     auto sink = ops::Identity(scope.WithOpName("sink"), while_op[0]);
1172     GraphDef expected;
1173     TF_EXPECT_OK(scope.ToGraphDef(&expected));
1174     TF_EXPECT_GRAPH_EQ(expected, graph_def);
1175   }
1176 
1177   // Outer condition graph.
1178   {
1179     Scope scope = Scope::NewRootScope().ExitOnError();
1180     auto arg0 = ops::_Arg(scope.WithOpName("arg0"), DT_INT32, 0);
1181     auto arg1 = ops::_Arg(scope.WithOpName("arg1"), DT_INT32, 1);
1182     auto arg2 = ops::_Arg(scope.WithOpName("arg2"), DT_INT32, 2);
1183     auto arg3 = ops::_Arg(scope.WithOpName("arg3"), DT_RESOURCE, 3);
1184 
1185     auto ten = ops::Const<int32>(
1186         scope.WithOpName("outer/Less/y").WithControlDependencies(arg0.output),
1187         10);
1188     auto less = ops::Less(scope.WithOpName("outer/Less_i"), arg0, ten);
1189     auto retval = ops::_Retval(scope.WithOpName("retval0_RetVal"), less, 0);
1190 
1191     GraphDef expected;
1192     TF_EXPECT_OK(scope.ToGraphDef(&expected));
1193 
1194     InstantiationResultForTest result;
1195     TF_EXPECT_OK(
1196         InstantiateFunctionForTest(outer_cond_fn.name(), library, &result));
1197 
1198     EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32, DT_RESOURCE}),
1199               result.arg_types);
1200     EXPECT_EQ(DataTypeVector{DT_BOOL}, result.ret_types);
1201     TF_EXPECT_GRAPH_EQ(expected, result.gdef);
1202   }
1203 
1204   // Outer body graph.
1205   {
1206     InstantiationResultForTest result;
1207     TF_EXPECT_OK(
1208         InstantiateFunctionForTest(outer_body_fn.name(), library, &result));
1209 
1210     // Find the inner condition and body names.
1211     TF_EXPECT_OK(
1212         FindWhileCondAndBody(result.gdef, &inner_cond_fn, &inner_body_fn));
1213 
1214     Scope scope = Scope::NewRootScope().ExitOnError();
1215     auto arg0 = ops::_Arg(scope.WithOpName("arg0"), DT_INT32, 0);
1216     auto arg1 = ops::_Arg(scope.WithOpName("arg1"), DT_INT32, 1);
1217     auto arg2 = ops::_Arg(scope.WithOpName("arg2"), DT_INT32, 2);
1218     auto arg3 = ops::_Arg(scope.WithOpName("arg3"), DT_RESOURCE, 3);
1219 
1220     auto identity_i = ops::Identity(scope.WithOpName("outer/Identity"), arg0);
1221     auto one_j = ops::Const<int32>(
1222         scope.WithOpName("outer/j").WithControlDependencies(identity_i), 1);
1223     auto while_op =
1224         ops::While(scope.WithOpName("outer/inner/LoopCond"),
1225                    std::initializer_list<Input>{one_j, arg1, arg2, arg3},
1226                    inner_cond_fn, inner_body_fn);
1227 
1228     auto one_outer = ops::Const<int32>(
1229         scope.WithOpName("outer/add/y").WithControlDependencies(identity_i), 1);
1230     auto add_i =
1231         ops::Add(scope.WithOpName("outer/add")
1232                      .WithControlDependencies(absl::Span<const Operation>{
1233                          while_op[0].op(), while_op[1].op()}),
1234                  identity_i, one_outer);
1235 
1236     auto retval0 = ops::_Retval(scope.WithOpName("retval0_RetVal"), add_i, 0);
1237     auto retval1 = ops::_Retval(scope.WithOpName("retval1_RetVal"), arg1, 1);
1238     auto retval2 = ops::_Retval(scope.WithOpName("retval2_RetVal"), arg2, 2);
1239     auto retval3 = ops::_Retval(scope.WithOpName("retval3_RetVal"), arg3, 3);
1240 
1241     GraphDef expected;
1242     TF_EXPECT_OK(scope.ToGraphDef(&expected));
1243 
1244     EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32, DT_RESOURCE}),
1245               result.arg_types);
1246     EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32, DT_RESOURCE}),
1247               result.ret_types);
1248     TF_EXPECT_GRAPH_EQ(expected, result.gdef);
1249   }
1250 }
1251 
CheckInnerNodesFunctionalized(const GraphDef & graph_def,const FunctionLibraryDefinition & library,const NameAttrList & inner_cond_fn,const NameAttrList & inner_body_fn)1252 void ComplexTestFixture::CheckInnerNodesFunctionalized(
1253     const GraphDef& graph_def, const FunctionLibraryDefinition& library,
1254     const NameAttrList& inner_cond_fn, const NameAttrList& inner_body_fn) {
1255   // Inner condition graph.
1256   {
1257     Scope scope = Scope::NewRootScope().ExitOnError();
1258     auto arg0 = ops::_Arg(scope.WithOpName("arg0"), DT_INT32, 0);
1259     auto arg1 = ops::_Arg(scope.WithOpName("arg1"), DT_INT32, 1);
1260     auto arg2 = ops::_Arg(scope.WithOpName("arg2"), DT_INT32, 2);
1261     auto arg3 = ops::_Arg(scope.WithOpName("arg3"), DT_RESOURCE, 3);
1262 
1263     auto five = ops::Const<int32>(
1264         scope.WithOpName("outer/inner/Five").WithControlDependencies(arg0), 5);
1265     auto less_j = ops::Less(scope.WithOpName("outer/inner/Less_j"), arg0, five);
1266     auto retval = ops::_Retval(scope.WithOpName("retval0_RetVal"), less_j, 0);
1267 
1268     GraphDef expected;
1269     TF_EXPECT_OK(scope.ToGraphDef(&expected));
1270 
1271     InstantiationResultForTest result;
1272     TF_EXPECT_OK(
1273         InstantiateFunctionForTest(inner_cond_fn.name(), library, &result));
1274 
1275     EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32, DT_RESOURCE}),
1276               result.arg_types);
1277     EXPECT_EQ(DataTypeVector{DT_BOOL}, result.ret_types);
1278     TF_EXPECT_GRAPH_EQ(expected, result.gdef);
1279   }
1280 
1281   // Inner body graph.
1282   {
1283     Scope scope = Scope::NewRootScope().ExitOnError();
1284     auto arg0 = ops::_Arg(scope.WithOpName("arg0"), DT_INT32, 0);
1285     auto arg1 = ops::_Arg(scope.WithOpName("arg1"), DT_INT32, 1);
1286     auto arg2 = ops::_Arg(scope.WithOpName("arg2"), DT_INT32, 2);
1287     auto arg3 = ops::_Arg(scope.WithOpName("arg3"), DT_RESOURCE, 3);
1288 
1289     auto identity_j =
1290         ops::Identity(scope.WithOpName("outer/inner/Identity_j"), arg0);
1291     auto identity_k =
1292         ops::Identity(scope.WithOpName("outer/inner/Identity_k"), arg1);
1293 
1294     auto mul_jk =
1295         ops::Mul(scope.WithOpName("outer/inner/mul"), identity_j, identity_k);
1296     auto add_jkx = ops::Add(scope.WithOpName("outer/inner/add"), mul_jk, arg2);
1297     auto assign = ops::AssignAddVariableOp(
1298         scope.WithOpName("outer/inner/assign_add"), arg3, add_jkx);
1299 
1300     auto one = ops::Const<int32>(
1301         scope.WithOpName("outer/inner/One")
1302             .WithControlDependencies(
1303                 absl::Span<const Operation>{assign.operation}),
1304         1);
1305     auto add_j =
1306         ops::Add(scope.WithOpName("outer/inner/add_j"), identity_j, one);
1307 
1308     auto retval0 = ops::_Retval(scope.WithOpName("retval0_RetVal"), add_j, 0);
1309     auto retval1 =
1310         ops::_Retval(scope.WithOpName("retval1_RetVal"), identity_k, 1);
1311     auto retval2 = ops::_Retval(scope.WithOpName("retval2_RetVal"), arg2, 2);
1312     auto retval3 = ops::_Retval(scope.WithOpName("retval3_RetVal"), arg3, 3);
1313 
1314     GraphDef expected;
1315     TF_EXPECT_OK(scope.ToGraphDef(&expected));
1316 
1317     InstantiationResultForTest result;
1318     TF_EXPECT_OK(
1319         InstantiateFunctionForTest(inner_body_fn.name(), library, &result));
1320 
1321     EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32, DT_RESOURCE}),
1322               result.arg_types);
1323     EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32, DT_RESOURCE}),
1324               result.ret_types);
1325     TF_EXPECT_GRAPH_EQ(expected, result.gdef);
1326   }
1327 }
1328 
1329 }  // namespace
1330 }  // namespace tensorflow
1331