• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "absl/strings/match.h"
17 #include "tensorflow/cc/client/client_session.h"
18 #include "tensorflow/cc/framework/ops.h"
19 #include "tensorflow/cc/ops/array_ops.h"
20 #include "tensorflow/cc/ops/control_flow_ops_internal.h"
21 #include "tensorflow/cc/ops/function_ops.h"
22 #include "tensorflow/cc/ops/standard_ops.h"
23 #include "tensorflow/core/common_runtime/graph_constructor.h"
24 #include "tensorflow/core/common_runtime/graph_runner.h"
25 #include "tensorflow/core/common_runtime/lower_functional_ops.h"
26 #include "tensorflow/core/framework/function_testlib.h"
27 #include "tensorflow/core/framework/node_def_util.h"
28 #include "tensorflow/core/framework/op.h"
29 #include "tensorflow/core/framework/types.pb.h"
30 #include "tensorflow/core/graph/graph_def_builder.h"
31 #include "tensorflow/core/lib/core/status_test_util.h"
32 #include "tensorflow/core/lib/strings/str_util.h"
33 #include "tensorflow/core/platform/test.h"
34 
35 namespace tensorflow {
36 namespace {
37 
SessionOptionsWithInlining()38 SessionOptions SessionOptionsWithInlining() {
39   SessionOptions session_options;
40   session_options.config.mutable_graph_options()
41       ->mutable_optimizer_options()
42       ->set_do_function_inlining(true);
43   return session_options;
44 }
45 
Rewrite(std::unique_ptr<Graph> * graph)46 Status Rewrite(std::unique_ptr<Graph>* graph) {
47   FunctionLibraryDefinition flib_def((*graph)->flib_def());
48   GraphOptimizationPassOptions opt_options;
49   SessionOptions session_options = SessionOptionsWithInlining();
50   opt_options.session_options = &session_options;
51   opt_options.graph = graph;
52   opt_options.flib_def = &flib_def;
53   LowerFunctionalOpsPass pass;
54   return pass.Run(opt_options);
55 }
56 
TEST(LowerWhileOpTest,Simple)57 TEST(LowerWhileOpTest, Simple) {
58   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
59 
60   // Add test functions for cond and body.
61   FunctionDefLibrary f_lib_proto;
62   *f_lib_proto.add_function() = test::function::XTimesTwo();
63   *f_lib_proto.add_function() = test::function::LessThanOrEqualToN(8);
64 
65   Scope root = Scope::NewRootScope().ExitOnError();
66   TF_ASSERT_OK(root.graph()->AddFunctionLibrary(f_lib_proto));
67   auto a = ops::Placeholder(root.WithOpName("A"), DT_INT32);
68   Node* while_node;
69   std::vector<NodeBuilder::NodeOut> inputs({NodeBuilder::NodeOut(a.node())});
70   AttrValue cond_func;
71   cond_func.mutable_func()->set_name("LessThanOrEqualToN");
72   AttrValue body_func;
73   body_func.mutable_func()->set_name("XTimesTwo");
74   TF_ASSERT_OK(
75       NodeBuilder("while", "While", &root.graph()->flib_def())
76           .Input(inputs)
77           .Attr("T", {DT_INT32})
78           .Attr("cond", cond_func)
79           .Attr("body", body_func)
80           .Attr("parallel_iterations", 100)
81           .Attr(LowerFunctionalOpsPass::kLowerUsingSwitchMergeAttr, true)
82           .Finalize(root.graph(), &while_node));
83   auto c = ops::Identity(
84       root.WithOpName("C").WithControlDependencies(Output(while_node)),
85       Output(while_node));
86   TF_ASSERT_OK(root.DoShapeInference(while_node));
87   TF_ASSERT_OK(root.ToGraph(graph.get()));
88 
89   // The input graph has no lower level control flow primitives.
90   int node_called_while_count = 0;
91   for (const auto* op : graph->op_nodes()) {
92     ASSERT_FALSE(op->IsEnter());
93     ASSERT_FALSE(op->IsExit());
94     ASSERT_FALSE(op->IsSwitch());
95     ASSERT_FALSE(op->IsMerge());
96     ASSERT_FALSE(op->IsNextIteration());
97     ASSERT_FALSE(op->IsLoopCond());
98     if (op->name() == "while") {
99       node_called_while_count++;
100     }
101   }
102   ASSERT_EQ(node_called_while_count, 1);
103 
104   TF_ASSERT_OK(Rewrite(&graph));
105 
106   int enter_count = 0;
107   int exit_count = 0;
108   int switch_count = 0;
109   int merge_count = 0;
110   int next_iteration_count = 0;
111   node_called_while_count = 0;
112   int less_than_or_equan_to_n_count = 0;
113   int x_times_two_count = 0;
114 
115   for (const auto* op : graph->op_nodes()) {
116     if (op->IsEnter()) {
117       ++enter_count;
118       ASSERT_EQ(op->attrs().Find("parallel_iterations")->i(), 100);
119     }
120     if (op->IsExit()) {
121       ++exit_count;
122     }
123     if (op->IsSwitch()) {
124       ++switch_count;
125     }
126     if (op->IsMerge()) {
127       ++merge_count;
128     }
129     if (op->IsNextIteration()) {
130       ++next_iteration_count;
131     }
132     if (op->name() == "while") {
133       node_called_while_count++;
134     }
135     if (op->type_string() == "LessThanOrEqualToN") {
136       less_than_or_equan_to_n_count++;
137     }
138     if (op->type_string() == "XTimesTwo") {
139       x_times_two_count++;
140     }
141     if (op->name() == "C") {
142       ASSERT_EQ(op->in_edges().size(), 2);
143     }
144     ASSERT_NE(op->type_string(), "While");
145   }
146   // One node per loop input.
147   ASSERT_EQ(enter_count, 1);
148   ASSERT_EQ(exit_count, 1);
149   ASSERT_EQ(switch_count, 1);
150   ASSERT_EQ(merge_count, 1);
151   ASSERT_EQ(next_iteration_count, 1);
152   ASSERT_EQ(node_called_while_count, 1);
153 
154   // Verify execution.
155   ClientSession session(root, SessionOptionsWithInlining());
156   {
157     ClientSession::FeedType feeds;
158     feeds.emplace(Output(a.node()), Input::Initializer(1));
159     std::vector<Tensor> out_tensors;
160     TF_ASSERT_OK(session.Run(feeds, {Output(while_node)}, &out_tensors));
161     ASSERT_EQ(out_tensors.size(), 1);
162     EXPECT_EQ(out_tensors[0].scalar<int>()(), 16);
163   }
164   {
165     ClientSession::FeedType feeds;
166     feeds.emplace(Output(a.node()), Input::Initializer(3));
167     std::vector<Tensor> out_tensors;
168     TF_ASSERT_OK(session.Run(feeds, {Output(while_node)}, &out_tensors));
169     ASSERT_EQ(out_tensors.size(), 1);
170     EXPECT_EQ(out_tensors[0].scalar<int>()(), 12);
171   }
172 }
173 
TEST(LowerWhileOpTest,ForwardAssignedInputDevice)174 TEST(LowerWhileOpTest, ForwardAssignedInputDevice) {
175   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
176 
177   // Add test functions for cond and body.
178   FunctionDefLibrary f_lib_proto;
179   *f_lib_proto.add_function() = test::function::XTimesTwo();
180   *f_lib_proto.add_function() = test::function::LessThanOrEqualToN(8);
181 
182   TF_ASSERT_OK(graph->AddFunctionLibrary(f_lib_proto));
183   auto type = DT_FLOAT;
184   Node* placeholder;
185   TF_CHECK_OK(NodeBuilder("placed_node", "Placeholder")
186                   .Attr("dtype", type)
187                   .Finalize(graph.get(), &placeholder));
188   const string assigned_device_name = "/job:localhost/replica:0/task:0/gpu:0";
189   placeholder->set_assigned_device_name(assigned_device_name);
190   Node* while_node;
191   std::vector<NodeBuilder::NodeOut> inputs({NodeBuilder::NodeOut(placeholder)});
192   AttrValue cond_func;
193   cond_func.mutable_func()->set_name("LessThanOrEqualToN");
194   AttrValue body_func;
195   body_func.mutable_func()->set_name("XTimesTwo");
196   TF_ASSERT_OK(
197       NodeBuilder("while", "While", &graph->flib_def())
198           .Input(inputs)
199           .Attr("T", {type})
200           .Attr("cond", cond_func)
201           .Attr("body", body_func)
202           .Attr("parallel_iterations", 100)
203           .Attr(LowerFunctionalOpsPass::kLowerUsingSwitchMergeAttr, true)
204           .Finalize(graph.get(), &while_node));
205   TF_ASSERT_OK(Rewrite(&graph));
206 
207   const Node* placeholder_node = nullptr;
208   for (const auto* op : graph->op_nodes()) {
209     if (op->name() == "placed_node") {
210       placeholder_node = op;
211     }
212   }
213   ASSERT_NE(placeholder_node, nullptr);
214   // Verify the assigned device of the Enter node.
215   int enter_consumers = 0;
216   const Node* enter_node = nullptr;
217   for (const Node* consumer : placeholder_node->out_nodes()) {
218     if (consumer->type_string() == "Enter") {
219       enter_consumers += 1;
220       enter_node = consumer;
221       ASSERT_EQ(consumer->assigned_device_name(), assigned_device_name);
222     }
223   }
224   ASSERT_EQ(enter_consumers, 1);
225   // Verify the assigned device of the Merge node.
226   int merge_consumers = 0;
227   const Node* merge_node = nullptr;
228   for (const Node* consumer : enter_node->out_nodes()) {
229     if (consumer->type_string() == "Merge") {
230       merge_consumers += 1;
231       merge_node = consumer;
232       ASSERT_EQ(consumer->assigned_device_name(), assigned_device_name);
233     }
234   }
235   ASSERT_EQ(merge_consumers, 1);
236   // Verify the assigned device of the NextIteration node.
237   int next_iteration_consumers = 0;
238   for (const Node* consumer : merge_node->in_nodes()) {
239     if (consumer->type_string() == "NextIteration") {
240       next_iteration_consumers += 1;
241       ASSERT_EQ(consumer->assigned_device_name(), assigned_device_name);
242     }
243   }
244   ASSERT_EQ(next_iteration_consumers, 1);
245   // Verify the assigned device of the Switch node.
246   int switch_consumers = 0;
247   const Node* switch_node = nullptr;
248   for (const Node* consumer : merge_node->out_nodes()) {
249     if (consumer->type_string() == "Switch") {
250       switch_consumers += 1;
251       switch_node = consumer;
252       ASSERT_EQ(consumer->assigned_device_name(), assigned_device_name);
253     }
254   }
255   ASSERT_EQ(switch_consumers, 1);
256   // Verify the assigned device of the Exit node.
257   int exit_consumers = 0;
258   for (const Node* consumer : switch_node->out_nodes()) {
259     if (consumer->type_string() == "Exit") {
260       exit_consumers += 1;
261       ASSERT_EQ(consumer->assigned_device_name(), assigned_device_name);
262     }
263   }
264   ASSERT_EQ(exit_consumers, 1);
265 }
266 
TEST(LowerWhileOpTest,ForwardRequestedInputDevice)267 TEST(LowerWhileOpTest, ForwardRequestedInputDevice) {
268   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
269 
270   // Add test functions for cond and body.
271   FunctionDefLibrary f_lib_proto;
272   *f_lib_proto.add_function() = test::function::XTimesTwo();
273   *f_lib_proto.add_function() = test::function::LessThanOrEqualToN(8);
274 
275   TF_ASSERT_OK(graph->AddFunctionLibrary(f_lib_proto));
276   auto type = DT_FLOAT;
277   // We will place the loop var on the gpu:0.
278   const string gpu_0_device = "/job:localhost/replica:0/task:0/gpu:0";
279   // We will place loop's control input on the gpu:1.
280   const string gpu_1_device = "/job:localhost/replica:0/task:0/gpu:1";
281   // We will place While op on gpu:2.
282   const string gpu_2_device = "/job:localhost/replica:0/task:0/gpu:2";
283   Node* gpu_0_ph;
284   TF_CHECK_OK(NodeBuilder("placed_node", "Placeholder")
285                   .Attr("dtype", type)
286                   .Device(gpu_0_device)
287                   .Finalize(graph.get(), &gpu_0_ph));
288   Node* control_in;
289   // Add a control input to the While op to trigger the creation of a
290   // LoopExecuted node.
291   TF_CHECK_OK(NodeBuilder("control_in", "Placeholder")
292                   .Attr("dtype", type)
293                   .Device(gpu_1_device)
294                   .Finalize(graph.get(), &control_in));
295   Node* while_node;
296   std::vector<NodeBuilder::NodeOut> inputs({NodeBuilder::NodeOut(gpu_0_ph)});
297   AttrValue cond_func;
298   cond_func.mutable_func()->set_name("LessThanOrEqualToN");
299   AttrValue body_func;
300   body_func.mutable_func()->set_name("XTimesTwo");
301   TF_ASSERT_OK(
302       NodeBuilder("while", "While", &graph->flib_def())
303           .Input(inputs)
304           .ControlInput(control_in)
305           .Device(gpu_2_device)
306           .Attr("T", {type})
307           .Attr("cond", cond_func)
308           .Attr("body", body_func)
309           .Attr("parallel_iterations", 100)
310           .Attr(LowerFunctionalOpsPass::kLowerUsingSwitchMergeAttr, true)
311           .Finalize(graph.get(), &while_node));
312 
313   // Create an empty Const node with control dep from the While op.
314   // This triggers the creation of a LoopExecuted node.
315   Node* control_out;
316   TensorProto proto;
317   proto.set_dtype(DT_FLOAT);
318   TensorShape empty_shape({0});
319   empty_shape.AsProto(proto.mutable_tensor_shape());
320   TF_ASSERT_OK(NodeBuilder("control_out", "Const")
321                    .ControlInput(while_node)
322                    .Attr("dtype", DT_FLOAT)
323                    .Attr("value", proto)
324                    .Finalize(graph.get(), &control_out));
325 
326   TF_ASSERT_OK(Rewrite(&graph));
327 
328   const Node* placeholder_node = nullptr;
329   for (const auto* op : graph->op_nodes()) {
330     if (op->name() == "placed_node") {
331       placeholder_node = op;
332     }
333   }
334   ASSERT_NE(placeholder_node, nullptr);
335   // Verify the requested device of the Enter node.
336   int enter_consumers = 0;
337   const Node* enter_node = nullptr;
338   for (const Node* consumer : placeholder_node->out_nodes()) {
339     if (consumer->type_string() == "Enter") {
340       enter_consumers += 1;
341       enter_node = consumer;
342       ASSERT_EQ(consumer->requested_device(), gpu_0_device);
343     }
344   }
345   ASSERT_EQ(enter_consumers, 1);
346   // Verify the requested device of the Merge node.
347   int merge_consumers = 0;
348   const Node* merge_node = nullptr;
349   for (const Node* consumer : enter_node->out_nodes()) {
350     if (consumer->type_string() == "Merge") {
351       merge_consumers += 1;
352       merge_node = consumer;
353       ASSERT_EQ(consumer->requested_device(), gpu_0_device);
354     }
355   }
356   ASSERT_EQ(merge_consumers, 1);
357   // Verify the requested device of the NextIteration node.
358   int next_iteration_consumers = 0;
359   for (const Node* consumer : merge_node->in_nodes()) {
360     if (consumer->type_string() == "NextIteration") {
361       next_iteration_consumers += 1;
362       ASSERT_EQ(consumer->requested_device(), gpu_0_device);
363     }
364   }
365   ASSERT_EQ(next_iteration_consumers, 1);
366   // Verify the requested device of the Switch node.
367   int switch_consumers = 0;
368   const Node* switch_node = nullptr;
369   for (const Node* consumer : merge_node->out_nodes()) {
370     if (consumer->type_string() == "Switch") {
371       switch_consumers += 1;
372       switch_node = consumer;
373       ASSERT_EQ(consumer->requested_device(), gpu_0_device);
374     }
375   }
376   ASSERT_EQ(switch_consumers, 1);
377   // Verify the requested device of the Exit node.
378   int exit_consumers = 0;
379   for (const Node* consumer : switch_node->out_nodes()) {
380     if (consumer->type_string() == "Exit") {
381       exit_consumers += 1;
382       ASSERT_EQ(consumer->requested_device(), gpu_0_device);
383     }
384   }
385   ASSERT_EQ(exit_consumers, 1);
386   // Verify the requested device of LoopControlInputs.
387   const Node* loop_control_inputs_node = nullptr;
388   for (const auto* op : graph->op_nodes()) {
389     if (absl::StrContains(op->name(), "LoopControlInputs")) {
390       loop_control_inputs_node = op;
391     }
392   }
393   ASSERT_NE(loop_control_inputs_node, nullptr);
394   ASSERT_EQ(loop_control_inputs_node->requested_device(), gpu_2_device);
395   // Verify the requested device of LoopExecuted.
396   const Node* loop_executed_node = nullptr;
397   for (const auto* op : graph->op_nodes()) {
398     if (absl::StrContains(op->name(), "LoopExecuted")) {
399       loop_executed_node = op;
400     }
401   }
402   ASSERT_NE(loop_executed_node, nullptr);
403   ASSERT_EQ(loop_executed_node->requested_device(), gpu_2_device);
404 }
405 
TEST(LowerWhileOpTest,MultipleInputs)406 TEST(LowerWhileOpTest, MultipleInputs) {
407   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
408 
409   // Add test functions for cond and body.
410   FunctionDefLibrary f_lib_proto;
411   *(f_lib_proto.add_function()) = test::function::XPlusOneXTimesY();
412   *(f_lib_proto.add_function()) = test::function::XYXLessThanOrEqualToN(4);
413 
414   Scope root = Scope::NewRootScope().ExitOnError();
415   TF_ASSERT_OK(root.graph()->AddFunctionLibrary(f_lib_proto));
416   auto a = ops::Placeholder(root.WithOpName("A"), DT_INT32);
417   auto b = ops::Placeholder(root.WithOpName("B"), DT_INT32);
418   Node* while_node;
419   std::vector<NodeBuilder::NodeOut> inputs(
420       {NodeBuilder::NodeOut(a.node()), NodeBuilder::NodeOut(b.node())});
421   AttrValue cond_func;
422   cond_func.mutable_func()->set_name("XYXLessThanOrEqualToN");
423   AttrValue body_func;
424   body_func.mutable_func()->set_name("XPlusOneXTimesY");
425   TF_ASSERT_OK(
426       NodeBuilder("while", "While", &root.graph()->flib_def())
427           .Input(inputs)
428           .Attr("T", {DT_INT32, DT_INT32})
429           .Attr("cond", cond_func)
430           .Attr("body", body_func)
431           .Attr(LowerFunctionalOpsPass::kLowerUsingSwitchMergeAttr, true)
432           .Finalize(root.graph(), &while_node));
433   TF_ASSERT_OK(root.DoShapeInference(while_node));
434   TF_ASSERT_OK(root.ToGraph(graph.get()));
435 
436   // The input graph has no lower level control flow primitives.
437   for (const auto* op : graph->op_nodes()) {
438     ASSERT_FALSE(op->IsEnter());
439     ASSERT_FALSE(op->IsExit());
440     ASSERT_FALSE(op->IsSwitch());
441     ASSERT_FALSE(op->IsMerge());
442     ASSERT_FALSE(op->IsNextIteration());
443     ASSERT_FALSE(op->IsLoopCond());
444   }
445 
446   TF_ASSERT_OK(Rewrite(&graph));
447 
448   int enter_count = 0;
449   int exit_count = 0;
450   int switch_count = 0;
451   int merge_count = 0;
452   int next_iteration_count = 0;
453   int x_plus_one_x_times_y_count = 0;
454   int x_y_x_less_than_equal_to_n_count = 0;
455 
456   for (const auto* op : graph->op_nodes()) {
457     if (op->IsEnter()) {
458       ++enter_count;
459     }
460     if (op->IsExit()) {
461       ++exit_count;
462     }
463     if (op->IsSwitch()) {
464       ++switch_count;
465     }
466     if (op->IsMerge()) {
467       ++merge_count;
468     }
469     if (op->IsNextIteration()) {
470       ++next_iteration_count;
471     }
472     if (op->type_string() == "XPlusOneXTimesY") {
473       x_plus_one_x_times_y_count++;
474     }
475     if (op->type_string() == "XYXLessThanOrEqualToN") {
476       x_y_x_less_than_equal_to_n_count++;
477     }
478     ASSERT_NE(op->type_string(), "While");
479   }
480   // Two nodes per loop input.
481   ASSERT_EQ(enter_count, 2);
482   ASSERT_EQ(exit_count, 2);
483   ASSERT_EQ(switch_count, 2);
484   ASSERT_EQ(merge_count, 2);
485   ASSERT_EQ(next_iteration_count, 2);
486   ASSERT_EQ(x_plus_one_x_times_y_count, 0);
487   ASSERT_EQ(x_y_x_less_than_equal_to_n_count, 0);
488 
489   // Verify execution.
490   ClientSession session(root, SessionOptionsWithInlining());
491   {
492     ClientSession::FeedType feeds;
493     feeds.emplace(Output(a.node()), Input::Initializer(1));
494     feeds.emplace(Output(b.node()), Input::Initializer(1));
495     std::vector<Tensor> out_tensors;
496     TF_ASSERT_OK(session.Run(
497         feeds, {Output(while_node, 0), Output(while_node, 1)}, &out_tensors));
498     ASSERT_EQ(out_tensors.size(), 2);
499     EXPECT_EQ(out_tensors[0].scalar<int>()(), 5);
500     EXPECT_EQ(out_tensors[1].scalar<int>()(), 24);
501   }
502   {
503     ClientSession::FeedType feeds;
504     feeds.emplace(Output(a.node()), Input::Initializer(3));
505     feeds.emplace(Output(b.node()), Input::Initializer(5));
506     std::vector<Tensor> out_tensors;
507     TF_ASSERT_OK(session.Run(
508         feeds, {Output(while_node, 0), Output(while_node, 1)}, &out_tensors));
509     ASSERT_EQ(out_tensors.size(), 2);
510     EXPECT_EQ(out_tensors[0].scalar<int>()(), 5);
511     EXPECT_EQ(out_tensors[1].scalar<int>()(), 60);
512   }
513 }
514 
TEST(LowerWhileOpTest,DoNotInlineLoweredFunctions)515 TEST(LowerWhileOpTest, DoNotInlineLoweredFunctions) {
516   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
517 
518   FunctionDef x_times_two = test::function::XTimesTwo();
519   FunctionDef less_than_or_eq = test::function::LessThanOrEqualToN(8);
520 
521   // While loop `cond` and `body` nodes can't be inlined.
522   (*x_times_two.mutable_attr())["_noinline"].set_b(true);
523   (*less_than_or_eq.mutable_attr())["_noinline"].set_b(true);
524 
525   // Add test functions for cond and body.
526   FunctionDefLibrary f_lib_proto;
527   *f_lib_proto.add_function() = x_times_two;
528   *f_lib_proto.add_function() = less_than_or_eq;
529 
530   Scope root = Scope::NewRootScope().ExitOnError();
531   TF_ASSERT_OK(root.graph()->AddFunctionLibrary(f_lib_proto));
532   auto a = ops::Placeholder(root.WithOpName("A"), DT_INT32);
533   Node* while_node;
534   std::vector<NodeBuilder::NodeOut> inputs({NodeBuilder::NodeOut(a.node())});
535   AttrValue cond_func;
536   cond_func.mutable_func()->set_name("LessThanOrEqualToN");
537   AttrValue body_func;
538   body_func.mutable_func()->set_name("XTimesTwo");
539   TF_ASSERT_OK(
540       NodeBuilder("while", "While", &root.graph()->flib_def())
541           .Input(inputs)
542           .Attr("T", {DT_INT32})
543           .Attr("cond", cond_func)
544           .Attr("body", body_func)
545           .Attr("parallel_iterations", 100)
546           .Attr(LowerFunctionalOpsPass::kLowerUsingSwitchMergeAttr, true)
547           .Finalize(root.graph(), &while_node));
548   TF_ASSERT_OK(root.DoShapeInference(while_node));
549   TF_ASSERT_OK(root.ToGraph(graph.get()));
550 
551   TF_ASSERT_OK(Rewrite(&graph));
552 
553   // Verify that while node was lowered but functions were not inlined.
554   int x_times_two_count = 0;
555   int less_than_or_eq_count = 0;
556 
557   for (const auto* op : graph->op_nodes()) {
558     if (op->type_string() == x_times_two.signature().name()) {
559       x_times_two_count++;
560     }
561     if (op->type_string() == less_than_or_eq.signature().name()) {
562       less_than_or_eq_count++;
563     }
564     ASSERT_NE(op->type_string(), "While");
565   }
566 
567   ASSERT_EQ(x_times_two_count, 1);
568   ASSERT_EQ(less_than_or_eq_count, 1);
569 
570   // Verify execution.
571   ClientSession session(root, SessionOptionsWithInlining());
572   {
573     ClientSession::FeedType feeds;
574     feeds.emplace(Output(a.node()), Input::Initializer(1));
575     std::vector<Tensor> out_tensors;
576     TF_ASSERT_OK(session.Run(feeds, {Output(while_node)}, &out_tensors));
577     ASSERT_EQ(out_tensors.size(), 1);
578     EXPECT_EQ(out_tensors[0].scalar<int>()(), 16);
579   }
580   {
581     ClientSession::FeedType feeds;
582     feeds.emplace(Output(a.node()), Input::Initializer(3));
583     std::vector<Tensor> out_tensors;
584     TF_ASSERT_OK(session.Run(feeds, {Output(while_node)}, &out_tensors));
585     ASSERT_EQ(out_tensors.size(), 1);
586     EXPECT_EQ(out_tensors[0].scalar<int>()(), 12);
587   }
588 }
589 
590 }  // namespace
591 }  // namespace tensorflow
592