1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "absl/strings/match.h"
17 #include "tensorflow/cc/client/client_session.h"
18 #include "tensorflow/cc/framework/ops.h"
19 #include "tensorflow/cc/ops/array_ops.h"
20 #include "tensorflow/cc/ops/control_flow_ops_internal.h"
21 #include "tensorflow/cc/ops/function_ops.h"
22 #include "tensorflow/cc/ops/standard_ops.h"
23 #include "tensorflow/core/common_runtime/graph_constructor.h"
24 #include "tensorflow/core/common_runtime/graph_runner.h"
25 #include "tensorflow/core/common_runtime/lower_functional_ops.h"
26 #include "tensorflow/core/framework/function_testlib.h"
27 #include "tensorflow/core/framework/node_def_util.h"
28 #include "tensorflow/core/framework/op.h"
29 #include "tensorflow/core/framework/types.pb.h"
30 #include "tensorflow/core/graph/graph_def_builder.h"
31 #include "tensorflow/core/lib/core/status_test_util.h"
32 #include "tensorflow/core/lib/strings/str_util.h"
33 #include "tensorflow/core/platform/test.h"
34
35 namespace tensorflow {
36 namespace {
37
SessionOptionsWithInlining()38 SessionOptions SessionOptionsWithInlining() {
39 SessionOptions session_options;
40 session_options.config.mutable_graph_options()
41 ->mutable_optimizer_options()
42 ->set_do_function_inlining(true);
43 return session_options;
44 }
45
Rewrite(std::unique_ptr<Graph> * graph)46 Status Rewrite(std::unique_ptr<Graph>* graph) {
47 FunctionLibraryDefinition flib_def((*graph)->flib_def());
48 GraphOptimizationPassOptions opt_options;
49 SessionOptions session_options = SessionOptionsWithInlining();
50 opt_options.session_options = &session_options;
51 opt_options.graph = graph;
52 opt_options.flib_def = &flib_def;
53 LowerFunctionalOpsPass pass;
54 return pass.Run(opt_options);
55 }
56
TEST(LowerWhileOpTest,Simple)57 TEST(LowerWhileOpTest, Simple) {
58 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
59
60 // Add test functions for cond and body.
61 FunctionDefLibrary f_lib_proto;
62 *f_lib_proto.add_function() = test::function::XTimesTwo();
63 *f_lib_proto.add_function() = test::function::LessThanOrEqualToN(8);
64
65 Scope root = Scope::NewRootScope().ExitOnError();
66 TF_ASSERT_OK(root.graph()->AddFunctionLibrary(f_lib_proto));
67 auto a = ops::Placeholder(root.WithOpName("A"), DT_INT32);
68 Node* while_node;
69 std::vector<NodeBuilder::NodeOut> inputs({NodeBuilder::NodeOut(a.node())});
70 AttrValue cond_func;
71 cond_func.mutable_func()->set_name("LessThanOrEqualToN");
72 AttrValue body_func;
73 body_func.mutable_func()->set_name("XTimesTwo");
74 TF_ASSERT_OK(
75 NodeBuilder("while", "While", &root.graph()->flib_def())
76 .Input(inputs)
77 .Attr("T", {DT_INT32})
78 .Attr("cond", cond_func)
79 .Attr("body", body_func)
80 .Attr("parallel_iterations", 100)
81 .Attr(LowerFunctionalOpsPass::kLowerUsingSwitchMergeAttr, true)
82 .Finalize(root.graph(), &while_node));
83 auto c = ops::Identity(
84 root.WithOpName("C").WithControlDependencies(Output(while_node)),
85 Output(while_node));
86 TF_ASSERT_OK(root.DoShapeInference(while_node));
87 TF_ASSERT_OK(root.ToGraph(graph.get()));
88
89 // The input graph has no lower level control flow primitives.
90 int node_called_while_count = 0;
91 for (const auto* op : graph->op_nodes()) {
92 ASSERT_FALSE(op->IsEnter());
93 ASSERT_FALSE(op->IsExit());
94 ASSERT_FALSE(op->IsSwitch());
95 ASSERT_FALSE(op->IsMerge());
96 ASSERT_FALSE(op->IsNextIteration());
97 ASSERT_FALSE(op->IsLoopCond());
98 if (op->name() == "while") {
99 node_called_while_count++;
100 }
101 }
102 ASSERT_EQ(node_called_while_count, 1);
103
104 TF_ASSERT_OK(Rewrite(&graph));
105
106 int enter_count = 0;
107 int exit_count = 0;
108 int switch_count = 0;
109 int merge_count = 0;
110 int next_iteration_count = 0;
111 node_called_while_count = 0;
112 int less_than_or_equan_to_n_count = 0;
113 int x_times_two_count = 0;
114
115 for (const auto* op : graph->op_nodes()) {
116 if (op->IsEnter()) {
117 ++enter_count;
118 ASSERT_EQ(op->attrs().Find("parallel_iterations")->i(), 100);
119 }
120 if (op->IsExit()) {
121 ++exit_count;
122 }
123 if (op->IsSwitch()) {
124 ++switch_count;
125 }
126 if (op->IsMerge()) {
127 ++merge_count;
128 }
129 if (op->IsNextIteration()) {
130 ++next_iteration_count;
131 }
132 if (op->name() == "while") {
133 node_called_while_count++;
134 }
135 if (op->type_string() == "LessThanOrEqualToN") {
136 less_than_or_equan_to_n_count++;
137 }
138 if (op->type_string() == "XTimesTwo") {
139 x_times_two_count++;
140 }
141 if (op->name() == "C") {
142 ASSERT_EQ(op->in_edges().size(), 2);
143 }
144 ASSERT_NE(op->type_string(), "While");
145 }
146 // One node per loop input.
147 ASSERT_EQ(enter_count, 1);
148 ASSERT_EQ(exit_count, 1);
149 ASSERT_EQ(switch_count, 1);
150 ASSERT_EQ(merge_count, 1);
151 ASSERT_EQ(next_iteration_count, 1);
152 ASSERT_EQ(node_called_while_count, 1);
153
154 // Verify execution.
155 ClientSession session(root, SessionOptionsWithInlining());
156 {
157 ClientSession::FeedType feeds;
158 feeds.emplace(Output(a.node()), Input::Initializer(1));
159 std::vector<Tensor> out_tensors;
160 TF_ASSERT_OK(session.Run(feeds, {Output(while_node)}, &out_tensors));
161 ASSERT_EQ(out_tensors.size(), 1);
162 EXPECT_EQ(out_tensors[0].scalar<int>()(), 16);
163 }
164 {
165 ClientSession::FeedType feeds;
166 feeds.emplace(Output(a.node()), Input::Initializer(3));
167 std::vector<Tensor> out_tensors;
168 TF_ASSERT_OK(session.Run(feeds, {Output(while_node)}, &out_tensors));
169 ASSERT_EQ(out_tensors.size(), 1);
170 EXPECT_EQ(out_tensors[0].scalar<int>()(), 12);
171 }
172 }
173
TEST(LowerWhileOpTest,ForwardAssignedInputDevice)174 TEST(LowerWhileOpTest, ForwardAssignedInputDevice) {
175 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
176
177 // Add test functions for cond and body.
178 FunctionDefLibrary f_lib_proto;
179 *f_lib_proto.add_function() = test::function::XTimesTwo();
180 *f_lib_proto.add_function() = test::function::LessThanOrEqualToN(8);
181
182 TF_ASSERT_OK(graph->AddFunctionLibrary(f_lib_proto));
183 auto type = DT_FLOAT;
184 Node* placeholder;
185 TF_CHECK_OK(NodeBuilder("placed_node", "Placeholder")
186 .Attr("dtype", type)
187 .Finalize(graph.get(), &placeholder));
188 const string assigned_device_name = "/job:localhost/replica:0/task:0/gpu:0";
189 placeholder->set_assigned_device_name(assigned_device_name);
190 Node* while_node;
191 std::vector<NodeBuilder::NodeOut> inputs({NodeBuilder::NodeOut(placeholder)});
192 AttrValue cond_func;
193 cond_func.mutable_func()->set_name("LessThanOrEqualToN");
194 AttrValue body_func;
195 body_func.mutable_func()->set_name("XTimesTwo");
196 TF_ASSERT_OK(
197 NodeBuilder("while", "While", &graph->flib_def())
198 .Input(inputs)
199 .Attr("T", {type})
200 .Attr("cond", cond_func)
201 .Attr("body", body_func)
202 .Attr("parallel_iterations", 100)
203 .Attr(LowerFunctionalOpsPass::kLowerUsingSwitchMergeAttr, true)
204 .Finalize(graph.get(), &while_node));
205 TF_ASSERT_OK(Rewrite(&graph));
206
207 const Node* placeholder_node = nullptr;
208 for (const auto* op : graph->op_nodes()) {
209 if (op->name() == "placed_node") {
210 placeholder_node = op;
211 }
212 }
213 ASSERT_NE(placeholder_node, nullptr);
214 // Verify the assigned device of the Enter node.
215 int enter_consumers = 0;
216 const Node* enter_node = nullptr;
217 for (const Node* consumer : placeholder_node->out_nodes()) {
218 if (consumer->type_string() == "Enter") {
219 enter_consumers += 1;
220 enter_node = consumer;
221 ASSERT_EQ(consumer->assigned_device_name(), assigned_device_name);
222 }
223 }
224 ASSERT_EQ(enter_consumers, 1);
225 // Verify the assigned device of the Merge node.
226 int merge_consumers = 0;
227 const Node* merge_node = nullptr;
228 for (const Node* consumer : enter_node->out_nodes()) {
229 if (consumer->type_string() == "Merge") {
230 merge_consumers += 1;
231 merge_node = consumer;
232 ASSERT_EQ(consumer->assigned_device_name(), assigned_device_name);
233 }
234 }
235 ASSERT_EQ(merge_consumers, 1);
236 // Verify the assigned device of the NextIteration node.
237 int next_iteration_consumers = 0;
238 for (const Node* consumer : merge_node->in_nodes()) {
239 if (consumer->type_string() == "NextIteration") {
240 next_iteration_consumers += 1;
241 ASSERT_EQ(consumer->assigned_device_name(), assigned_device_name);
242 }
243 }
244 ASSERT_EQ(next_iteration_consumers, 1);
245 // Verify the assigned device of the Switch node.
246 int switch_consumers = 0;
247 const Node* switch_node = nullptr;
248 for (const Node* consumer : merge_node->out_nodes()) {
249 if (consumer->type_string() == "Switch") {
250 switch_consumers += 1;
251 switch_node = consumer;
252 ASSERT_EQ(consumer->assigned_device_name(), assigned_device_name);
253 }
254 }
255 ASSERT_EQ(switch_consumers, 1);
256 // Verify the assigned device of the Exit node.
257 int exit_consumers = 0;
258 for (const Node* consumer : switch_node->out_nodes()) {
259 if (consumer->type_string() == "Exit") {
260 exit_consumers += 1;
261 ASSERT_EQ(consumer->assigned_device_name(), assigned_device_name);
262 }
263 }
264 ASSERT_EQ(exit_consumers, 1);
265 }
266
TEST(LowerWhileOpTest,ForwardRequestedInputDevice)267 TEST(LowerWhileOpTest, ForwardRequestedInputDevice) {
268 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
269
270 // Add test functions for cond and body.
271 FunctionDefLibrary f_lib_proto;
272 *f_lib_proto.add_function() = test::function::XTimesTwo();
273 *f_lib_proto.add_function() = test::function::LessThanOrEqualToN(8);
274
275 TF_ASSERT_OK(graph->AddFunctionLibrary(f_lib_proto));
276 auto type = DT_FLOAT;
277 // We will place the loop var on the gpu:0.
278 const string gpu_0_device = "/job:localhost/replica:0/task:0/gpu:0";
279 // We will place loop's control input on the gpu:1.
280 const string gpu_1_device = "/job:localhost/replica:0/task:0/gpu:1";
281 // We will place While op on gpu:2.
282 const string gpu_2_device = "/job:localhost/replica:0/task:0/gpu:2";
283 Node* gpu_0_ph;
284 TF_CHECK_OK(NodeBuilder("placed_node", "Placeholder")
285 .Attr("dtype", type)
286 .Device(gpu_0_device)
287 .Finalize(graph.get(), &gpu_0_ph));
288 Node* control_in;
289 // Add a control input to the While op to trigger the creation of a
290 // LoopExecuted node.
291 TF_CHECK_OK(NodeBuilder("control_in", "Placeholder")
292 .Attr("dtype", type)
293 .Device(gpu_1_device)
294 .Finalize(graph.get(), &control_in));
295 Node* while_node;
296 std::vector<NodeBuilder::NodeOut> inputs({NodeBuilder::NodeOut(gpu_0_ph)});
297 AttrValue cond_func;
298 cond_func.mutable_func()->set_name("LessThanOrEqualToN");
299 AttrValue body_func;
300 body_func.mutable_func()->set_name("XTimesTwo");
301 TF_ASSERT_OK(
302 NodeBuilder("while", "While", &graph->flib_def())
303 .Input(inputs)
304 .ControlInput(control_in)
305 .Device(gpu_2_device)
306 .Attr("T", {type})
307 .Attr("cond", cond_func)
308 .Attr("body", body_func)
309 .Attr("parallel_iterations", 100)
310 .Attr(LowerFunctionalOpsPass::kLowerUsingSwitchMergeAttr, true)
311 .Finalize(graph.get(), &while_node));
312
313 // Create an empty Const node with control dep from the While op.
314 // This triggers the creation of a LoopExecuted node.
315 Node* control_out;
316 TensorProto proto;
317 proto.set_dtype(DT_FLOAT);
318 TensorShape empty_shape({0});
319 empty_shape.AsProto(proto.mutable_tensor_shape());
320 TF_ASSERT_OK(NodeBuilder("control_out", "Const")
321 .ControlInput(while_node)
322 .Attr("dtype", DT_FLOAT)
323 .Attr("value", proto)
324 .Finalize(graph.get(), &control_out));
325
326 TF_ASSERT_OK(Rewrite(&graph));
327
328 const Node* placeholder_node = nullptr;
329 for (const auto* op : graph->op_nodes()) {
330 if (op->name() == "placed_node") {
331 placeholder_node = op;
332 }
333 }
334 ASSERT_NE(placeholder_node, nullptr);
335 // Verify the requested device of the Enter node.
336 int enter_consumers = 0;
337 const Node* enter_node = nullptr;
338 for (const Node* consumer : placeholder_node->out_nodes()) {
339 if (consumer->type_string() == "Enter") {
340 enter_consumers += 1;
341 enter_node = consumer;
342 ASSERT_EQ(consumer->requested_device(), gpu_0_device);
343 }
344 }
345 ASSERT_EQ(enter_consumers, 1);
346 // Verify the requested device of the Merge node.
347 int merge_consumers = 0;
348 const Node* merge_node = nullptr;
349 for (const Node* consumer : enter_node->out_nodes()) {
350 if (consumer->type_string() == "Merge") {
351 merge_consumers += 1;
352 merge_node = consumer;
353 ASSERT_EQ(consumer->requested_device(), gpu_0_device);
354 }
355 }
356 ASSERT_EQ(merge_consumers, 1);
357 // Verify the requested device of the NextIteration node.
358 int next_iteration_consumers = 0;
359 for (const Node* consumer : merge_node->in_nodes()) {
360 if (consumer->type_string() == "NextIteration") {
361 next_iteration_consumers += 1;
362 ASSERT_EQ(consumer->requested_device(), gpu_0_device);
363 }
364 }
365 ASSERT_EQ(next_iteration_consumers, 1);
366 // Verify the requested device of the Switch node.
367 int switch_consumers = 0;
368 const Node* switch_node = nullptr;
369 for (const Node* consumer : merge_node->out_nodes()) {
370 if (consumer->type_string() == "Switch") {
371 switch_consumers += 1;
372 switch_node = consumer;
373 ASSERT_EQ(consumer->requested_device(), gpu_0_device);
374 }
375 }
376 ASSERT_EQ(switch_consumers, 1);
377 // Verify the requested device of the Exit node.
378 int exit_consumers = 0;
379 for (const Node* consumer : switch_node->out_nodes()) {
380 if (consumer->type_string() == "Exit") {
381 exit_consumers += 1;
382 ASSERT_EQ(consumer->requested_device(), gpu_0_device);
383 }
384 }
385 ASSERT_EQ(exit_consumers, 1);
386 // Verify the requested device of LoopControlInputs.
387 const Node* loop_control_inputs_node = nullptr;
388 for (const auto* op : graph->op_nodes()) {
389 if (absl::StrContains(op->name(), "LoopControlInputs")) {
390 loop_control_inputs_node = op;
391 }
392 }
393 ASSERT_NE(loop_control_inputs_node, nullptr);
394 ASSERT_EQ(loop_control_inputs_node->requested_device(), gpu_2_device);
395 // Verify the requested device of LoopExecuted.
396 const Node* loop_executed_node = nullptr;
397 for (const auto* op : graph->op_nodes()) {
398 if (absl::StrContains(op->name(), "LoopExecuted")) {
399 loop_executed_node = op;
400 }
401 }
402 ASSERT_NE(loop_executed_node, nullptr);
403 ASSERT_EQ(loop_executed_node->requested_device(), gpu_2_device);
404 }
405
TEST(LowerWhileOpTest,MultipleInputs)406 TEST(LowerWhileOpTest, MultipleInputs) {
407 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
408
409 // Add test functions for cond and body.
410 FunctionDefLibrary f_lib_proto;
411 *(f_lib_proto.add_function()) = test::function::XPlusOneXTimesY();
412 *(f_lib_proto.add_function()) = test::function::XYXLessThanOrEqualToN(4);
413
414 Scope root = Scope::NewRootScope().ExitOnError();
415 TF_ASSERT_OK(root.graph()->AddFunctionLibrary(f_lib_proto));
416 auto a = ops::Placeholder(root.WithOpName("A"), DT_INT32);
417 auto b = ops::Placeholder(root.WithOpName("B"), DT_INT32);
418 Node* while_node;
419 std::vector<NodeBuilder::NodeOut> inputs(
420 {NodeBuilder::NodeOut(a.node()), NodeBuilder::NodeOut(b.node())});
421 AttrValue cond_func;
422 cond_func.mutable_func()->set_name("XYXLessThanOrEqualToN");
423 AttrValue body_func;
424 body_func.mutable_func()->set_name("XPlusOneXTimesY");
425 TF_ASSERT_OK(
426 NodeBuilder("while", "While", &root.graph()->flib_def())
427 .Input(inputs)
428 .Attr("T", {DT_INT32, DT_INT32})
429 .Attr("cond", cond_func)
430 .Attr("body", body_func)
431 .Attr(LowerFunctionalOpsPass::kLowerUsingSwitchMergeAttr, true)
432 .Finalize(root.graph(), &while_node));
433 TF_ASSERT_OK(root.DoShapeInference(while_node));
434 TF_ASSERT_OK(root.ToGraph(graph.get()));
435
436 // The input graph has no lower level control flow primitives.
437 for (const auto* op : graph->op_nodes()) {
438 ASSERT_FALSE(op->IsEnter());
439 ASSERT_FALSE(op->IsExit());
440 ASSERT_FALSE(op->IsSwitch());
441 ASSERT_FALSE(op->IsMerge());
442 ASSERT_FALSE(op->IsNextIteration());
443 ASSERT_FALSE(op->IsLoopCond());
444 }
445
446 TF_ASSERT_OK(Rewrite(&graph));
447
448 int enter_count = 0;
449 int exit_count = 0;
450 int switch_count = 0;
451 int merge_count = 0;
452 int next_iteration_count = 0;
453 int x_plus_one_x_times_y_count = 0;
454 int x_y_x_less_than_equal_to_n_count = 0;
455
456 for (const auto* op : graph->op_nodes()) {
457 if (op->IsEnter()) {
458 ++enter_count;
459 }
460 if (op->IsExit()) {
461 ++exit_count;
462 }
463 if (op->IsSwitch()) {
464 ++switch_count;
465 }
466 if (op->IsMerge()) {
467 ++merge_count;
468 }
469 if (op->IsNextIteration()) {
470 ++next_iteration_count;
471 }
472 if (op->type_string() == "XPlusOneXTimesY") {
473 x_plus_one_x_times_y_count++;
474 }
475 if (op->type_string() == "XYXLessThanOrEqualToN") {
476 x_y_x_less_than_equal_to_n_count++;
477 }
478 ASSERT_NE(op->type_string(), "While");
479 }
480 // Two nodes per loop input.
481 ASSERT_EQ(enter_count, 2);
482 ASSERT_EQ(exit_count, 2);
483 ASSERT_EQ(switch_count, 2);
484 ASSERT_EQ(merge_count, 2);
485 ASSERT_EQ(next_iteration_count, 2);
486 ASSERT_EQ(x_plus_one_x_times_y_count, 0);
487 ASSERT_EQ(x_y_x_less_than_equal_to_n_count, 0);
488
489 // Verify execution.
490 ClientSession session(root, SessionOptionsWithInlining());
491 {
492 ClientSession::FeedType feeds;
493 feeds.emplace(Output(a.node()), Input::Initializer(1));
494 feeds.emplace(Output(b.node()), Input::Initializer(1));
495 std::vector<Tensor> out_tensors;
496 TF_ASSERT_OK(session.Run(
497 feeds, {Output(while_node, 0), Output(while_node, 1)}, &out_tensors));
498 ASSERT_EQ(out_tensors.size(), 2);
499 EXPECT_EQ(out_tensors[0].scalar<int>()(), 5);
500 EXPECT_EQ(out_tensors[1].scalar<int>()(), 24);
501 }
502 {
503 ClientSession::FeedType feeds;
504 feeds.emplace(Output(a.node()), Input::Initializer(3));
505 feeds.emplace(Output(b.node()), Input::Initializer(5));
506 std::vector<Tensor> out_tensors;
507 TF_ASSERT_OK(session.Run(
508 feeds, {Output(while_node, 0), Output(while_node, 1)}, &out_tensors));
509 ASSERT_EQ(out_tensors.size(), 2);
510 EXPECT_EQ(out_tensors[0].scalar<int>()(), 5);
511 EXPECT_EQ(out_tensors[1].scalar<int>()(), 60);
512 }
513 }
514
TEST(LowerWhileOpTest,DoNotInlineLoweredFunctions)515 TEST(LowerWhileOpTest, DoNotInlineLoweredFunctions) {
516 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
517
518 FunctionDef x_times_two = test::function::XTimesTwo();
519 FunctionDef less_than_or_eq = test::function::LessThanOrEqualToN(8);
520
521 // While loop `cond` and `body` nodes can't be inlined.
522 (*x_times_two.mutable_attr())["_noinline"].set_b(true);
523 (*less_than_or_eq.mutable_attr())["_noinline"].set_b(true);
524
525 // Add test functions for cond and body.
526 FunctionDefLibrary f_lib_proto;
527 *f_lib_proto.add_function() = x_times_two;
528 *f_lib_proto.add_function() = less_than_or_eq;
529
530 Scope root = Scope::NewRootScope().ExitOnError();
531 TF_ASSERT_OK(root.graph()->AddFunctionLibrary(f_lib_proto));
532 auto a = ops::Placeholder(root.WithOpName("A"), DT_INT32);
533 Node* while_node;
534 std::vector<NodeBuilder::NodeOut> inputs({NodeBuilder::NodeOut(a.node())});
535 AttrValue cond_func;
536 cond_func.mutable_func()->set_name("LessThanOrEqualToN");
537 AttrValue body_func;
538 body_func.mutable_func()->set_name("XTimesTwo");
539 TF_ASSERT_OK(
540 NodeBuilder("while", "While", &root.graph()->flib_def())
541 .Input(inputs)
542 .Attr("T", {DT_INT32})
543 .Attr("cond", cond_func)
544 .Attr("body", body_func)
545 .Attr("parallel_iterations", 100)
546 .Attr(LowerFunctionalOpsPass::kLowerUsingSwitchMergeAttr, true)
547 .Finalize(root.graph(), &while_node));
548 TF_ASSERT_OK(root.DoShapeInference(while_node));
549 TF_ASSERT_OK(root.ToGraph(graph.get()));
550
551 TF_ASSERT_OK(Rewrite(&graph));
552
553 // Verify that while node was lowered but functions were not inlined.
554 int x_times_two_count = 0;
555 int less_than_or_eq_count = 0;
556
557 for (const auto* op : graph->op_nodes()) {
558 if (op->type_string() == x_times_two.signature().name()) {
559 x_times_two_count++;
560 }
561 if (op->type_string() == less_than_or_eq.signature().name()) {
562 less_than_or_eq_count++;
563 }
564 ASSERT_NE(op->type_string(), "While");
565 }
566
567 ASSERT_EQ(x_times_two_count, 1);
568 ASSERT_EQ(less_than_or_eq_count, 1);
569
570 // Verify execution.
571 ClientSession session(root, SessionOptionsWithInlining());
572 {
573 ClientSession::FeedType feeds;
574 feeds.emplace(Output(a.node()), Input::Initializer(1));
575 std::vector<Tensor> out_tensors;
576 TF_ASSERT_OK(session.Run(feeds, {Output(while_node)}, &out_tensors));
577 ASSERT_EQ(out_tensors.size(), 1);
578 EXPECT_EQ(out_tensors[0].scalar<int>()(), 16);
579 }
580 {
581 ClientSession::FeedType feeds;
582 feeds.emplace(Output(a.node()), Input::Initializer(3));
583 std::vector<Tensor> out_tensors;
584 TF_ASSERT_OK(session.Run(feeds, {Output(while_node)}, &out_tensors));
585 ASSERT_EQ(out_tensors.size(), 1);
586 EXPECT_EQ(out_tensors[0].scalar<int>()(), 12);
587 }
588 }
589
590 } // namespace
591 } // namespace tensorflow
592