1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/tf2xla/functionalize_control_flow.h"
17
18 #include "tensorflow/cc/framework/ops.h"
19 #include "tensorflow/cc/ops/control_flow_ops_internal.h"
20 #include "tensorflow/cc/ops/function_ops.h"
21 #include "tensorflow/cc/ops/functional_ops.h"
22 #include "tensorflow/cc/ops/resource_variable_ops.h"
23 #include "tensorflow/cc/ops/standard_ops.h"
24 #include "tensorflow/compiler/tf2xla/cc/ops/xla_ops.h"
25 #include "tensorflow/compiler/tf2xla/test_util.h"
26 #include "tensorflow/compiler/xla/status_macros.h"
27 #include "tensorflow/core/common_runtime/function.h"
28 #include "tensorflow/core/common_runtime/graph_constructor.h"
29 #include "tensorflow/core/framework/function.h"
30 #include "tensorflow/core/framework/graph_to_functiondef.h"
31 #include "tensorflow/core/framework/node_def_util.h"
32 #include "tensorflow/core/framework/op.h"
33 #include "tensorflow/core/graph/graph_def_builder.h"
34 #include "tensorflow/core/graph/validate.h"
35 #include "tensorflow/core/lib/core/status_test_util.h"
36 #include "tensorflow/core/platform/test.h"
37 #include "tensorflow/core/public/version.h"
38 #include "tensorflow/core/util/dump_graph.h"
39 #include "tensorflow/core/util/equal_graph_def.h"
40
41 namespace tensorflow {
42 namespace {
43
44 // Returns the names of the "then" and "else" functions for the If node in a
45 // graph.
FindIfThenAndElse(const GraphDef & graph,string * op_name,NameAttrList * then_fn,NameAttrList * else_fn)46 Status FindIfThenAndElse(const GraphDef& graph, string* op_name,
47 NameAttrList* then_fn, NameAttrList* else_fn) {
48 for (const NodeDef& node : graph.node()) {
49 if (node.op() == "If") {
50 *op_name = node.name();
51 const NameAttrList* result;
52 TF_RETURN_IF_ERROR(GetNodeAttr(node, "then_branch", &result));
53 *then_fn = *result;
54 TF_RETURN_IF_ERROR(GetNodeAttr(node, "else_branch", &result));
55 *else_fn = *result;
56 return Status::OK();
57 }
58 }
59 return errors::NotFound("No If node found in graph");
60 }
61
62 // Graph:
63 // x = array_ops.placeholder(dtypes.int32)
64 // y = array_ops.placeholder(dtypes.int32)
65 // z = control_flow_ops.cond(
66 // math_ops.less(y, x), lambda: math_ops.multiply(y, 17),
67 // lambda: math_ops.add(x, 23))
68 //
69 // Tests different node filters and functionalization inside of a function.
70 class ConditionalTestFixture
71 : public ::testing::TestWithParam<std::tuple<bool, bool>> {
72 protected:
SetUp()73 void SetUp() override {
74 restrict_to_tpu_nodes_ = std::get<0>(GetParam());
75 wrap_condition_in_function_ = std::get<1>(GetParam());
76 }
77 void RunTest();
78
79 private:
80 void BuildCondGraph(Graph* cond_graph);
81 void CheckGraphDef(const GraphDef& graph_def,
82 const FunctionLibraryDefinition& library);
83
84 bool restrict_to_tpu_nodes_ = false;
85 bool wrap_condition_in_function_ = false;
86 };
87
TEST_P(ConditionalTestFixture,ConditionalTests)88 TEST_P(ConditionalTestFixture, ConditionalTests) { RunTest(); }
89
90 INSTANTIATE_TEST_SUITE_P(
91 FunctionalizeControlFlow, ConditionalTestFixture,
92 ::testing::Combine(::testing::Bool(), ::testing::Bool()),
93 [](const ::testing::TestParamInfo<ConditionalTestFixture::ParamType>&
__anon7e28f50b0202(const ::testing::TestParamInfo<ConditionalTestFixture::ParamType>& info) 94 info) {
95 bool restrict_to_tpu_nodes = std::get<0>(info.param);
96 bool wrap_cond_in_function = std::get<1>(info.param);
97 string name =
98 absl::StrCat(restrict_to_tpu_nodes ? "with_filter" : "without_filter",
99 wrap_cond_in_function ? "_in_function" : "_in_graph");
100 return name;
101 });
102
BuildCondGraph(Graph * cond_graph)103 void ConditionalTestFixture::BuildCondGraph(Graph* cond_graph) {
104 {
105 Scope scope = Scope::NewRootScope().ExitOnError();
106
107 auto x = ops::Placeholder(scope.WithOpName("x"), DT_INT32);
108 auto y = ops::Placeholder(scope.WithOpName("y"), DT_INT32);
109 auto less = ops::Less(scope.WithOpName("cond/Less"), y, x);
110 auto switch_1 = ops::Switch(scope.WithOpName("cond/Switch"), less, less);
111
112 auto identity_t =
113 ops::Identity(scope.WithOpName("cond/Identity"), switch_1.output_true);
114 auto seventeen = ops::Const<int32>(
115 scope.WithOpName("cond").WithControlDependencies(identity_t), 17);
116 auto switch_2 = ops::Switch(scope.WithOpName("cond/Switch"), y, less);
117 auto mul = ops::Multiply(scope.WithOpName("cond/Mul"), switch_2.output_true,
118 seventeen);
119
120 auto identity_f =
121 ops::Identity(scope.WithOpName("cond/Identity"), switch_1.output_false);
122 auto twenty_three = ops::Const<int32>(
123 scope.WithOpName("cond").WithControlDependencies(identity_f), 23);
124 auto switch_3 = ops::Switch(scope.WithOpName("cond/Switch"), x, less);
125 auto add = ops::Add(scope.WithOpName("cond/false/add"),
126 switch_3.output_false, twenty_three);
127
128 auto merge = ops::Merge(scope.WithOpName("cond/Merge"),
129 std::initializer_list<Input>{add, mul});
130
131 TF_EXPECT_OK(scope.ToGraph(cond_graph));
132
133 // Set `_tpu_replicate` attribute for all nodes.
134 for (Node* n : cond_graph->nodes()) {
135 n->AddAttr("_tpu_replicate", "cluster");
136 }
137 }
138 }
139
CheckGraphDef(const GraphDef & graph_def,const FunctionLibraryDefinition & library)140 void ConditionalTestFixture::CheckGraphDef(
141 const GraphDef& graph_def, const FunctionLibraryDefinition& library) {
142 string op_name;
143 NameAttrList then_fn;
144 NameAttrList else_fn;
145 TF_EXPECT_OK(FindIfThenAndElse(graph_def, &op_name, &then_fn, &else_fn));
146 InstantiationResultForTest else_result;
147 TF_EXPECT_OK(
148 InstantiateFunctionForTest(else_fn.name(), library, &else_result));
149
150 // Outer graph
151 {
152 Scope scope = Scope::NewRootScope().ExitOnError();
153 auto y = ops::Placeholder(scope.WithOpName("y"), DT_INT32);
154 auto x = ops::Placeholder(scope.WithOpName("x"), DT_INT32);
155 auto less = ops::Less(scope.WithOpName("cond/Less"), y, x);
156 auto if_op =
157 ops::If(scope.WithOpName(op_name), less,
158 std::initializer_list<Input>{less, y, x}, {DT_INT32}, then_fn,
159 else_fn, ops::If::OutputShapes({PartialTensorShape()}));
160 auto id = ops::Identity(scope.WithOpName("cond/Merge"), if_op.output[0]);
161 GraphDef expected;
162 TF_EXPECT_OK(scope.ToGraphDef(&expected));
163 TF_EXPECT_GRAPH_EQ(expected, graph_def);
164 }
165
166 // then body.
167 {
168 Scope scope = Scope::NewRootScope().ExitOnError();
169 auto arg_0 = ops::_Arg(scope.WithOpName("arg0"), DT_BOOL, 0);
170 auto arg_1 = ops::_Arg(scope.WithOpName("arg1"), DT_INT32, 1);
171 auto arg_2 = ops::_Arg(scope.WithOpName("arg2"), DT_INT32, 2);
172 auto identity = ops::Identity(scope.WithOpName("cond/Identity"), arg_0);
173 auto cond = ops::Const(
174 scope.WithOpName("cond").WithControlDependencies(identity), 17);
175 auto mul = ops::Mul(scope.WithOpName("cond/Mul"), arg_1, cond);
176 auto retval0 = ops::_Retval(scope.WithOpName("retval0_RetVal"), mul, 0);
177
178 GraphDef expected;
179 TF_EXPECT_OK(scope.ToGraphDef(&expected));
180
181 InstantiationResultForTest result;
182 TF_EXPECT_OK(InstantiateFunctionForTest(then_fn.name(), library, &result));
183
184 EXPECT_EQ(DataTypeVector{DT_INT32}, result.ret_types);
185 EXPECT_EQ((DataTypeVector{DT_BOOL, DT_INT32, DT_INT32}), result.arg_types);
186 TF_EXPECT_GRAPH_EQ(expected, result.gdef);
187 }
188
189 // else body.
190 {
191 Scope scope = Scope::NewRootScope().ExitOnError();
192 auto arg_0 = ops::_Arg(scope.WithOpName("arg0"), DT_BOOL, 0);
193 auto arg_1 = ops::_Arg(scope.WithOpName("arg1"), DT_INT32, 1);
194 auto arg_2 = ops::_Arg(scope.WithOpName("arg2"), DT_INT32, 2);
195 auto identity = ops::Identity(scope.WithOpName("cond/Identity_1"), arg_0);
196 auto cond_1 = ops::Const(
197 scope.WithOpName("cond_1").WithControlDependencies(identity), 23);
198 auto add = ops::Add(scope.WithOpName("cond/false/add"), arg_2, cond_1);
199 auto retval0 = ops::_Retval(scope.WithOpName("retval0_RetVal"), add, 0);
200
201 GraphDef expected;
202 TF_EXPECT_OK(scope.ToGraphDef(&expected));
203
204 InstantiationResultForTest result;
205 TF_EXPECT_OK(InstantiateFunctionForTest(else_fn.name(), library, &result));
206
207 EXPECT_EQ(DataTypeVector{DT_INT32}, result.ret_types);
208 EXPECT_EQ((DataTypeVector{DT_BOOL, DT_INT32, DT_INT32}), result.arg_types);
209 TF_EXPECT_GRAPH_EQ(expected, result.gdef);
210 }
211 }
212
RunTest()213 void ConditionalTestFixture::RunTest() {
214 Graph graph(OpRegistry::Global());
215 if (wrap_condition_in_function_) {
216 // Wrap condition in a function which is called from `graph`.
217 Scope scope = Scope::NewRootScope().ExitOnError();
218 auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32);
219
220 Graph cond_graph(OpRegistry::Global());
221 BuildCondGraph(&cond_graph);
222
223 FunctionDef cond_fdef;
224 TF_ASSERT_OK(GraphToFunctionDef(cond_graph, "cond_fn", &cond_fdef));
225
226 FunctionDefLibrary fdef_lib;
227 *(fdef_lib.add_function()) = cond_fdef;
228 TF_ASSERT_OK(scope.graph()->AddFunctionLibrary(fdef_lib));
229 NodeDef cond_fn;
230 cond_fn.set_name("cond_node");
231 cond_fn.set_op("cond_fn");
232 *(cond_fn.add_input()) = "source";
233 Status status;
234 scope.graph()->AddNode(cond_fn, &status);
235 TF_ASSERT_OK(status);
236 TF_ASSERT_OK(scope.ToGraph(&graph));
237 } else {
238 // Build condition in `graph`.
239 BuildCondGraph(&graph);
240 }
241 FunctionLibraryDefinition library(graph.flib_def());
242 // If `restrict_to_tpu_nodes_` is true let filter function return true for
243 // `_tpu_replicate` nodes.
244 NodeFilter node_filter =
245 restrict_to_tpu_nodes_
246 ? [](const Node* n) { return n->attrs().Find("_tpu_replicate"); }
247 : NodeFilter{};
248
249 GraphDef optimized_graph_def;
250 graph.ToGraphDef(&optimized_graph_def);
251 TF_ASSERT_OK(FunctionalizeControlFlowForGraphDef(
252 &optimized_graph_def, &library, node_filter,
253 /*include_functions=*/wrap_condition_in_function_));
254 TF_ASSERT_OK(FunctionalizeControlFlow(
255 &graph, &library, node_filter,
256 /*include_functions=*/wrap_condition_in_function_));
257
258 if (wrap_condition_in_function_) {
259 // Check if function body was functionalized.
260 auto pflr = absl::make_unique<ProcessFunctionLibraryRuntime>(
261 /*device_mgr=*/nullptr, tensorflow::Env::Default(),
262 /*config=*/nullptr, TF_GRAPH_DEF_VERSION, &library,
263 tensorflow::OptimizerOptions());
264 FunctionLibraryRuntime* flr =
265 pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice);
266 FunctionLibraryRuntime::Handle handle;
267
268 // Functionalized function name is the type string of `cond_node`.
269 string func_name;
270 for (Node* n : graph.nodes()) {
271 if (n->name() == "cond_node") {
272 func_name = n->type_string();
273 break;
274 }
275 }
276 TF_ASSERT_OK(flr->Instantiate(func_name, AttrSlice(), &handle));
277 const FunctionBody* body = flr->GetFunctionBody(handle);
278 GraphDef graph_def;
279 body->graph->ToGraphDef(&graph_def);
280 CheckGraphDef(graph_def, library);
281 } else {
282 // Check if graphs were functionalized.
283 CheckGraphDef(optimized_graph_def, library);
284 GraphDef converted_graph_def;
285 graph.ToGraphDef(&converted_graph_def);
286 CheckGraphDef(converted_graph_def, library);
287 }
288 }
289
290 // Returns the names of the "cond" and "body" functions for the While node
291 // in a graph.
FindWhileCondAndBody(const GraphDef & graph,NameAttrList * cond,NameAttrList * body)292 Status FindWhileCondAndBody(const GraphDef& graph, NameAttrList* cond,
293 NameAttrList* body) {
294 for (const NodeDef& node : graph.node()) {
295 if (node.op() == "While") {
296 const NameAttrList* result;
297 TF_RETURN_IF_ERROR(GetNodeAttr(node, "cond", &result));
298 *cond = *result;
299 TF_RETURN_IF_ERROR(GetNodeAttr(node, "body", &result));
300 *body = *result;
301 return Status::OK();
302 }
303 }
304 return errors::NotFound("No While node found in graph");
305 }
306
307 // Graph:
308 // x = array_ops.placeholder(dtypes.int32)
309 // y = control_flow_ops.while_loop(lambda i: i < 10, lambda i: i + 1, [x])
TEST(FunctionalizeControlFlow,OneLoopVar)310 TEST(FunctionalizeControlFlow, OneLoopVar) {
311 Graph graph(OpRegistry::Global());
312 {
313 Scope scope = Scope::NewRootScope().ExitOnError();
314
315 auto dummy = ops::Placeholder(scope.WithOpName("Dummy"), DT_INT32);
316
317 auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32);
318 auto enter =
319 ops::internal::Enter(scope.WithOpName("while/Enter"), source, "aloop");
320 // Add an unused Enter node. These should be ignored.
321 auto enter2 =
322 ops::internal::Enter(scope.WithOpName("while/Enter2"), source, "aloop");
323 auto merge = ops::Merge(scope.WithOpName("while/Merge"),
324 std::initializer_list<Input>{enter, dummy});
325 auto ten = ops::Const<int32>(
326 scope.WithOpName("while/Less/y").WithControlDependencies(merge.output),
327 10);
328 auto less = ops::Less(scope.WithOpName("while/Less"), merge.output, ten);
329 auto loop_cond = ops::LoopCond(scope.WithOpName("while/LoopCond"), less);
330 auto switch_ =
331 ops::Switch(scope.WithOpName("while/Switch"), merge.output, loop_cond);
332 auto exit = ops::internal::Exit(scope.WithOpName("while/Exit"),
333 switch_.output_false);
334 auto identity =
335 ops::Identity(scope.WithOpName("while/Identity"), switch_.output_true);
336 auto one = ops::Const<int32>(
337 scope.WithOpName("while/add/y").WithControlDependencies(identity), 1);
338 auto add = ops::Add(scope.WithOpName("while/add"), identity, one);
339 auto next_iteration =
340 ops::NextIteration(scope.WithOpName("while/NextIteration"), add);
341
342 auto sink = ops::Identity(scope.WithOpName("sink"), exit);
343
344 // Remove the dummy node and add the loop backedge.
345 scope.graph()->RemoveNode(dummy.node());
346 scope.graph()->AddEdge(next_iteration.node(), 0, merge.output.node(), 1);
347
348 TF_EXPECT_OK(scope.ToGraph(&graph));
349 }
350
351 // Regression test: control edges from an Enter node to the graph sink should
352 // be ignored.
353 for (Node* n : graph.nodes()) {
354 if (n->name() == "while/Enter") {
355 graph.AddControlEdge(n, graph.sink_node());
356 }
357 }
358
359 FunctionLibraryDefinition library(OpRegistry::Global(), {});
360 GraphDef optimized_graph_def;
361 graph.ToGraphDef(&optimized_graph_def);
362 TF_ASSERT_OK(
363 FunctionalizeControlFlowForGraphDef(&optimized_graph_def, &library));
364 TF_ASSERT_OK(FunctionalizeControlFlow(&graph, &library));
365 GraphDef converted_graph_def;
366 graph.ToGraphDef(&converted_graph_def);
367
368 for (const GraphDef& graph_def : {optimized_graph_def, converted_graph_def}) {
369 NameAttrList cond_fn, body_fn;
370 TF_EXPECT_OK(FindWhileCondAndBody(graph_def, &cond_fn, &body_fn));
371
372 // Outer graph
373 {
374 Scope scope = Scope::NewRootScope().ExitOnError();
375 auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32);
376 auto while_op =
377 ops::While(scope.WithOpName("while/LoopCond"),
378 std::initializer_list<Input>{source}, cond_fn, body_fn);
379 auto sink = ops::Identity(scope.WithOpName("sink"), while_op[0]);
380 GraphDef expected;
381 TF_EXPECT_OK(scope.ToGraphDef(&expected));
382 TF_EXPECT_GRAPH_EQ(expected, graph_def);
383 }
384
385 // Condition graph
386 {
387 Scope scope = Scope::NewRootScope().ExitOnError();
388 auto arg = ops::_Arg(scope.WithOpName("arg0"), DT_INT32, 0);
389 auto ten = ops::Const<int32>(
390 scope.WithOpName("while/Less/y").WithControlDependencies(arg), 10);
391 auto less = ops::Less(scope.WithOpName("while/Less"), arg, ten);
392 auto retval = ops::_Retval(scope.WithOpName("retval0_RetVal"), less, 0);
393
394 GraphDef expected;
395 TF_EXPECT_OK(scope.ToGraphDef(&expected));
396
397 InstantiationResultForTest result;
398 TF_EXPECT_OK(
399 InstantiateFunctionForTest(cond_fn.name(), library, &result));
400
401 EXPECT_EQ(DataTypeVector{DT_INT32}, result.arg_types);
402 EXPECT_EQ(DataTypeVector{DT_BOOL}, result.ret_types);
403 TF_EXPECT_GRAPH_EQ(expected, result.gdef);
404 }
405
406 // Body graph.
407 {
408 Scope scope = Scope::NewRootScope().ExitOnError();
409 auto arg = ops::_Arg(scope.WithOpName("arg0"), DT_INT32, 0);
410 auto identity = ops::Identity(scope.WithOpName("while/Identity"), arg);
411 auto one = ops::Const<int32>(
412 scope.WithOpName("while/add/y").WithControlDependencies(identity), 1);
413 auto add = ops::Add(scope.WithOpName("while/add"), identity, one);
414 auto retval = ops::_Retval(scope.WithOpName("retval0_RetVal"), add, 0);
415
416 GraphDef expected;
417 TF_EXPECT_OK(scope.ToGraphDef(&expected));
418
419 InstantiationResultForTest result;
420 TF_EXPECT_OK(
421 InstantiateFunctionForTest(body_fn.name(), library, &result));
422
423 EXPECT_EQ(DataTypeVector{DT_INT32}, result.arg_types);
424 EXPECT_EQ(DataTypeVector{DT_INT32}, result.ret_types);
425 TF_EXPECT_GRAPH_EQ(expected, result.gdef);
426 }
427 }
428 }
429
GetNoinlineFunctionDef()430 FunctionDef GetNoinlineFunctionDef() {
431 FunctionDef fdef = FunctionDefHelper::Create(
432 "increment_fn", {"x:int32"}, {"add:int32"}, {},
433 {
434 {{"add/y"}, "Const", {}, {{"dtype", DT_INT32}}},
435 {{"add_0"}, "Add", {"x", "add/y:output:0"}, {{"T", DT_INT32}}},
436 },
437 {{"add", "add_0:z:0"}});
438 (*fdef.mutable_attr())["_noinline"].set_b(true);
439 return fdef;
440 }
441
442 // @function.Defun(noinline=True)
443 // def increment_fn(x):
444 // return [x + 1]
445 // Define the above function, and add it to the given graph. It's used as the
446 // while loop body in NoinlineLoopBody test.
AddNoinlineFunctionToGraph(const string & node_name,Graph * graph)447 Status AddNoinlineFunctionToGraph(const string& node_name, Graph* graph) {
448 FunctionDefLibrary fdef_lib;
449 *(fdef_lib.add_function()) = GetNoinlineFunctionDef();
450 TF_RETURN_IF_ERROR(graph->AddFunctionLibrary(fdef_lib));
451 NodeDef increment_fn;
452 increment_fn.set_name(node_name);
453 increment_fn.set_op("increment_fn");
454 *increment_fn.add_input() = "while/Identity";
455 *increment_fn.add_input() = "^while/Identity";
456 Status status;
457 graph->AddNode(increment_fn, &status);
458 return status;
459 }
460
461 // Graph:
462 // x = array_ops.placeholder(dtypes.int32)
463 // y = control_flow_ops.while_loop(lambda i: i < 10, increment_fn, [x])
TEST(FunctionalizeControlFlow,NoinlineLoopBody)464 TEST(FunctionalizeControlFlow, NoinlineLoopBody) {
465 const string& noinline_node_name = "while/increment_fn";
466 Graph graph(OpRegistry::Global());
467 {
468 Scope scope = Scope::NewRootScope().ExitOnError();
469 auto dummy = ops::Placeholder(scope.WithOpName("Dummy"), DT_INT32);
470 auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32);
471 auto enter = ops::internal::Enter(scope.WithOpName("while/Enter"), source,
472 "while/while_context");
473 auto merge = ops::Merge(scope.WithOpName("while/Merge"),
474 std::initializer_list<Input>{enter, dummy});
475 auto ten = ops::Const<int32>(
476 scope.WithOpName("while/Less/y").WithControlDependencies(merge.output),
477 10);
478 auto less = ops::Less(scope.WithOpName("while/Less"), merge.output, ten);
479 auto loop_cond = ops::LoopCond(scope.WithOpName("while/LoopCond"), less);
480 auto switch_ =
481 ops::Switch(scope.WithOpName("while/Switch"), merge.output, loop_cond);
482 auto exit = ops::internal::Exit(scope.WithOpName("while/Exit"),
483 switch_.output_false);
484 auto identity =
485 ops::Identity(scope.WithOpName("while/Identity"), switch_.output_true);
486
487 TF_ASSERT_OK(AddNoinlineFunctionToGraph(noinline_node_name, scope.graph()));
488
489 NodeDef next_iter;
490 next_iter.set_name("while/NextIteration");
491 next_iter.set_op("NextIteration");
492 *next_iter.add_input() = noinline_node_name;
493 (*next_iter.mutable_attr())["T"].set_type(DT_INT32);
494
495 Status status;
496 Node* n = scope.graph()->AddNode(next_iter, &status);
497 TF_ASSERT_OK(status);
498
499 // Remove the dummy node and add the loop backedge.
500 scope.graph()->RemoveNode(dummy.node());
501 scope.graph()->AddEdge(n, 0, merge.output.node(), 1);
502 TF_ASSERT_OK(scope.ToGraph(&graph));
503 }
504
505 FunctionLibraryDefinition library(graph.flib_def());
506 GraphDef optimized_graph_def;
507 graph.ToGraphDef(&optimized_graph_def);
508
509 *(optimized_graph_def.mutable_library()->add_function()) =
510 GetNoinlineFunctionDef();
511
512 TF_ASSERT_OK(
513 FunctionalizeControlFlowForGraphDef(&optimized_graph_def, &library));
514 TF_ASSERT_OK(FunctionalizeControlFlow(&graph, &library));
515 GraphDef converted_graph_def;
516 graph.ToGraphDef(&converted_graph_def);
517
518 for (const GraphDef& graph_def : {optimized_graph_def, converted_graph_def}) {
519 NameAttrList cond_fn, body_fn;
520 TF_ASSERT_OK(FindWhileCondAndBody(graph_def, &cond_fn, &body_fn));
521
522 // Outer graph
523 {
524 Scope scope = Scope::NewRootScope().ExitOnError();
525 auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32);
526 auto while_op =
527 ops::While(scope.WithOpName("while/LoopCond"),
528 std::initializer_list<Input>{source}, cond_fn, body_fn);
529 GraphDef expected;
530 TF_ASSERT_OK(scope.ToGraphDef(&expected));
531 TF_EXPECT_GRAPH_EQ(expected, graph_def);
532 }
533
534 // Body graph.
535 {
536 Scope scope = Scope::NewRootScope().ExitOnError();
537 auto arg = ops::_Arg(scope.WithOpName("arg0"), DT_INT32, 0);
538 TF_ASSERT_OK(
539 AddNoinlineFunctionToGraph(noinline_node_name, scope.graph()));
540 auto identity = ops::Identity(scope.WithOpName("while/Identity"), arg);
541 NodeDef retval;
542 retval.set_name("retval0_RetVal");
543 retval.set_op(FunctionLibraryDefinition::kRetOp);
544 *retval.add_input() = noinline_node_name;
545 (*retval.mutable_attr())["T"].set_type(DT_INT32);
546 (*retval.mutable_attr())["index"].set_i(0);
547 Status status;
548 scope.graph()->AddNode(retval, &status);
549 TF_ASSERT_OK(status);
550
551 GraphDef expected;
552 TF_ASSERT_OK(scope.ToGraphDef(&expected));
553
554 InstantiationResultForTest result;
555 // Verify that increment_fn has been copied to library.
556 TF_EXPECT_OK(
557 InstantiateFunctionForTest(body_fn.name(), library, &result));
558
559 EXPECT_EQ(DataTypeVector{DT_INT32}, result.arg_types);
560 EXPECT_EQ(DataTypeVector{DT_INT32}, result.ret_types);
561 // Ignore the function library when comparing the graphs.
562 expected.clear_library();
563 TF_EXPECT_GRAPH_EQ(expected, result.gdef);
564 }
565 }
566 }
567
TEST(FunctionalizeControlFlow,MissingFunctionDefInLibrary)568 TEST(FunctionalizeControlFlow, MissingFunctionDefInLibrary) {
569 const string& noinline_node_name = "while/increment_fn";
570 Graph graph(OpRegistry::Global());
571 {
572 Scope scope = Scope::NewRootScope().ExitOnError();
573 auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32);
574 auto identity = ops::Identity(scope.WithOpName("while/Identity"), source);
575 TF_ASSERT_OK(AddNoinlineFunctionToGraph(noinline_node_name, scope.graph()));
576 TF_ASSERT_OK(scope.ToGraph(&graph));
577 }
578
579 FunctionLibraryDefinition library(graph.flib_def());
580 GraphDef graph_def;
581 graph.ToGraphDef(&graph_def);
582 graph_def.clear_library();
583
584 Status status = FunctionalizeControlFlowForGraphDef(&graph_def, &library);
585 EXPECT_EQ(tensorflow::error::NOT_FOUND, status.code());
586 }
587
588 // Tests functionalizing OneLoopVar where the loop value is not used post the
589 // loop.
590 // Graph:
591 // x = array_ops.placeholder(dtypes.int32)
592 // control_flow_ops.while_loop(lambda i: i < 10, lambda i: i + 1, [x])
TEST(FunctionalizeControlFlow,OneLoopVarWithoutExit)593 TEST(FunctionalizeControlFlow, OneLoopVarWithoutExit) {
594 Graph graph(OpRegistry::Global());
595 {
596 Scope scope = Scope::NewRootScope().ExitOnError();
597
598 auto dummy = ops::Placeholder(scope.WithOpName("Dummy"), DT_INT32);
599
600 auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32);
601 auto enter =
602 ops::internal::Enter(scope.WithOpName("while/Enter"), source, "aloop");
603 auto merge = ops::Merge(scope.WithOpName("while/Merge"),
604 std::initializer_list<Input>{enter, dummy});
605 auto ten = ops::Const<int32>(
606 scope.WithOpName("while/Less/y").WithControlDependencies(merge.output),
607 10);
608 auto less = ops::Less(scope.WithOpName("while/Less"), merge.output, ten);
609 auto loop_cond = ops::LoopCond(scope.WithOpName("while/LoopCond"), less);
610 auto switch_ =
611 ops::Switch(scope.WithOpName("while/Switch"), merge.output, loop_cond);
612 auto identity =
613 ops::Identity(scope.WithOpName("while/Identity"), switch_.output_true);
614 auto one = ops::Const<int32>(
615 scope.WithOpName("while/add/y").WithControlDependencies(identity), 1);
616 auto add = ops::Add(scope.WithOpName("while/add"), identity, one);
617 auto next_iteration =
618 ops::NextIteration(scope.WithOpName("while/NextIteration"), add);
619
620 // Remove the dummy node and add the loop backedge.
621 scope.graph()->RemoveNode(dummy.node());
622 scope.graph()->AddEdge(next_iteration.node(), 0, merge.output.node(), 1);
623
624 TF_EXPECT_OK(scope.ToGraph(&graph));
625 }
626
627 FunctionLibraryDefinition library(OpRegistry::Global(), {});
628 GraphDef optimized_graph_def;
629 graph.ToGraphDef(&optimized_graph_def);
630 TF_ASSERT_OK(
631 FunctionalizeControlFlowForGraphDef(&optimized_graph_def, &library));
632 TF_ASSERT_OK(FunctionalizeControlFlow(&graph, &library));
633 GraphDef converted_graph_def;
634 graph.ToGraphDef(&converted_graph_def);
635
636 for (const GraphDef& graph_def : {optimized_graph_def, converted_graph_def}) {
637 NameAttrList cond_fn, body_fn;
638 TF_EXPECT_OK(FindWhileCondAndBody(graph_def, &cond_fn, &body_fn));
639
640 // Outer graph
641 {
642 Scope scope = Scope::NewRootScope().ExitOnError();
643 auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32);
644 auto while_op =
645 ops::While(scope.WithOpName("while/LoopCond"),
646 std::initializer_list<Input>{source}, cond_fn, body_fn);
647 GraphDef expected;
648 TF_EXPECT_OK(scope.ToGraphDef(&expected));
649 TF_EXPECT_GRAPH_EQ(expected, graph_def);
650 }
651
652 // Condition graph
653 {
654 Scope scope = Scope::NewRootScope().ExitOnError();
655 auto arg = ops::_Arg(scope.WithOpName("arg0"), DT_INT32, 0);
656 auto ten = ops::Const<int32>(
657 scope.WithOpName("while/Less/y").WithControlDependencies(arg), 10);
658 auto less = ops::Less(scope.WithOpName("while/Less"), arg, ten);
659 auto retval = ops::_Retval(scope.WithOpName("retval0_RetVal"), less, 0);
660
661 GraphDef expected;
662 TF_EXPECT_OK(scope.ToGraphDef(&expected));
663
664 InstantiationResultForTest result;
665 TF_EXPECT_OK(
666 InstantiateFunctionForTest(cond_fn.name(), library, &result));
667
668 EXPECT_EQ(DataTypeVector{DT_INT32}, result.arg_types);
669 EXPECT_EQ(DataTypeVector{DT_BOOL}, result.ret_types);
670 TF_EXPECT_GRAPH_EQ(expected, result.gdef);
671 }
672
673 // Body graph.
674 {
675 Scope scope = Scope::NewRootScope().ExitOnError();
676 auto arg = ops::_Arg(scope.WithOpName("arg0"), DT_INT32, 0);
677 auto identity = ops::Identity(scope.WithOpName("while/Identity"), arg);
678 auto one = ops::Const<int32>(
679 scope.WithOpName("while/add/y").WithControlDependencies(identity), 1);
680 auto add = ops::Add(scope.WithOpName("while/add"), identity, one);
681 auto retval = ops::_Retval(scope.WithOpName("retval0_RetVal"), add, 0);
682
683 GraphDef expected;
684 TF_EXPECT_OK(scope.ToGraphDef(&expected));
685
686 InstantiationResultForTest result;
687 TF_EXPECT_OK(
688 InstantiateFunctionForTest(body_fn.name(), library, &result));
689
690 EXPECT_EQ(DataTypeVector{DT_INT32}, result.arg_types);
691 EXPECT_EQ(DataTypeVector{DT_INT32}, result.ret_types);
692 TF_EXPECT_GRAPH_EQ(expected, result.gdef);
693 }
694 }
695 }
696
697 // Graph:
698 // x = array_ops.placeholder(dtypes.int32)
699 // y = array_ops.placeholder(dtypes.int32)
700 // cond = lambda (i, j): i + 3 < 10
701 // body = lambda (i, j): (i < 10, j * 2)
702 // z = control_flow_ops.while_loop(cond, body, [x, y])
TEST(FunctionalizeControlFlow,TwoLoopVars)703 TEST(FunctionalizeControlFlow, TwoLoopVars) {
704 Graph graph(OpRegistry::Global());
705 {
706 Scope scope = Scope::NewRootScope().ExitOnError();
707
708 auto dummy = ops::Placeholder(scope.WithOpName("Dummy"), DT_INT32);
709
710 auto x = ops::Placeholder(scope.WithOpName("Placeholder/x"), DT_INT32);
711 auto y = ops::Placeholder(scope.WithOpName("Placeholder/y"), DT_INT32);
712 auto enter_x =
713 ops::internal::Enter(scope.WithOpName("while/Enter/x"), x, "aloop");
714 auto enter_y =
715 ops::internal::Enter(scope.WithOpName("while/Enter/y"), y, "aloop");
716 auto merge_x = ops::Merge(scope.WithOpName("while/Merge/x"),
717 std::initializer_list<Input>{enter_x, dummy});
718 auto merge_y = ops::Merge(scope.WithOpName("while/Merge/y"),
719 std::initializer_list<Input>{enter_y, dummy});
720
721 // Loop condition
722 auto three = ops::Const<int32>(scope.WithOpName("while/cond/three")
723 .WithControlDependencies(merge_x.output),
724 3);
725 auto cond_add =
726 ops::Add(scope.WithOpName("while/cond/Add"), merge_x.output, three);
727 auto ten = ops::Const<int32>(scope.WithOpName("while/cond/ten")
728 .WithControlDependencies(merge_x.output),
729 10);
730 auto less = ops::Less(scope.WithOpName("while/cond/Less"), cond_add, ten);
731 auto loop_cond = ops::LoopCond(scope.WithOpName("while/LoopCond"), less);
732
733 auto switch_x = ops::Switch(scope.WithOpName("while/Switch/x"),
734 merge_x.output, loop_cond);
735 auto switch_y = ops::Switch(scope.WithOpName("while/Switch/y"),
736 merge_y.output, loop_cond);
737
738 auto exit_x = ops::internal::Exit(scope.WithOpName("while/Exit/x"),
739 switch_x.output_false);
740 auto exit_y = ops::internal::Exit(scope.WithOpName("while/Exit/y"),
741 switch_y.output_false);
742
743 auto identity_x = ops::Identity(scope.WithOpName("while/Identity/x"),
744 switch_x.output_true);
745 auto identity_y = ops::Identity(scope.WithOpName("while/Identity/y"),
746 switch_y.output_true);
747
748 auto one = ops::Const<int32>(
749 scope.WithOpName("while/add/one").WithControlDependencies(identity_x),
750 1);
751 auto two = ops::Const<int32>(
752 scope.WithOpName("while/mul/two").WithControlDependencies(identity_x),
753 2);
754
755 auto add = ops::Add(scope.WithOpName("while/add"), identity_x, one);
756 auto mul = ops::Add(scope.WithOpName("while/mul"), identity_y, two);
757 auto next_iteration_x =
758 ops::NextIteration(scope.WithOpName("while/NextIteration/x"), add);
759 auto next_iteration_y =
760 ops::NextIteration(scope.WithOpName("while/NextIteration/y"), mul);
761
762 auto sink_x = ops::Identity(scope.WithOpName("sink_x"), exit_x);
763 auto sink_y = ops::Identity(scope.WithOpName("sink_y"), exit_y);
764
765 // Remove the dummy node and add the loop backedges.
766 scope.graph()->RemoveNode(dummy.node());
767 scope.graph()->AddEdge(next_iteration_x.node(), 0, merge_x.output.node(),
768 1);
769 scope.graph()->AddEdge(next_iteration_y.node(), 0, merge_y.output.node(),
770 1);
771
772 TF_EXPECT_OK(scope.ToGraph(&graph));
773 }
774
775 FunctionLibraryDefinition library(OpRegistry::Global(), {});
776 GraphDef optimized_graph_def;
777 graph.ToGraphDef(&optimized_graph_def);
778 TF_ASSERT_OK(
779 FunctionalizeControlFlowForGraphDef(&optimized_graph_def, &library));
780 TF_ASSERT_OK(FunctionalizeControlFlow(&graph, &library));
781 GraphDef converted_graph_def;
782 graph.ToGraphDef(&converted_graph_def);
783
784 for (const GraphDef& graph_def : {optimized_graph_def, converted_graph_def}) {
785 NameAttrList cond_fn, body_fn;
786 TF_EXPECT_OK(FindWhileCondAndBody(graph_def, &cond_fn, &body_fn));
787
788 // Outer graph.
789 {
790 Scope scope = Scope::NewRootScope().ExitOnError();
791 auto x = ops::Placeholder(scope.WithOpName("Placeholder/x"), DT_INT32);
792 auto y = ops::Placeholder(scope.WithOpName("Placeholder/y"), DT_INT32);
793 auto while_op =
794 ops::While(scope.WithOpName("while/LoopCond"),
795 std::initializer_list<Input>{x, y}, cond_fn, body_fn);
796 auto sink_x = ops::Identity(scope.WithOpName("sink_x"), while_op[0]);
797 auto sink_y = ops::Identity(scope.WithOpName("sink_y"), while_op[1]);
798 GraphDef expected;
799 TF_EXPECT_OK(scope.ToGraphDef(&expected));
800 TF_EXPECT_GRAPH_EQ(expected, graph_def);
801 }
802
803 // Condition graph.
804 {
805 Scope scope = Scope::NewRootScope().ExitOnError();
806 auto arg0 = ops::_Arg(scope.WithOpName("arg0"), DT_INT32, 0);
807 auto arg1 = ops::_Arg(scope.WithOpName("arg1"), DT_INT32, 1);
808 auto three = ops::Const<int32>(scope.WithOpName("while/cond/three")
809 .WithControlDependencies(arg0.output),
810 3);
811 auto cond_add =
812 ops::Add(scope.WithOpName("while/cond/Add"), arg0.output, three);
813 auto ten = ops::Const<int32>(scope.WithOpName("while/cond/ten")
814 .WithControlDependencies(arg0.output),
815 10);
816 auto less = ops::Less(scope.WithOpName("while/cond/Less"), cond_add, ten);
817 auto retval = ops::_Retval(scope.WithOpName("retval0_RetVal"), less, 0);
818
819 GraphDef expected;
820 TF_EXPECT_OK(scope.ToGraphDef(&expected));
821
822 InstantiationResultForTest result;
823 TF_EXPECT_OK(
824 InstantiateFunctionForTest(cond_fn.name(), library, &result));
825
826 EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32}), result.arg_types);
827 EXPECT_EQ(DataTypeVector{DT_BOOL}, result.ret_types);
828 TF_EXPECT_GRAPH_EQ(expected, result.gdef);
829 }
830
831 // Body graph.
832 {
833 Scope scope = Scope::NewRootScope().ExitOnError();
834 auto arg0 = ops::_Arg(scope.WithOpName("arg0"), DT_INT32, 0);
835 auto arg1 = ops::_Arg(scope.WithOpName("arg1"), DT_INT32, 1);
836
837 auto identity_x =
838 ops::Identity(scope.WithOpName("while/Identity/x"), arg0);
839 auto identity_y =
840 ops::Identity(scope.WithOpName("while/Identity/y"), arg1);
841
842 auto one = ops::Const<int32>(
843 scope.WithOpName("while/add/one").WithControlDependencies(identity_x),
844 1);
845 auto two = ops::Const<int32>(
846 scope.WithOpName("while/mul/two").WithControlDependencies(identity_x),
847 2);
848
849 auto add = ops::Add(scope.WithOpName("while/add"), identity_x, one);
850 auto mul = ops::Add(scope.WithOpName("while/mul"), identity_y, two);
851 auto retval0 = ops::_Retval(scope.WithOpName("retval0_RetVal"), add, 0);
852 auto retval1 = ops::_Retval(scope.WithOpName("retval1_RetVal"), mul, 1);
853
854 GraphDef expected;
855 TF_EXPECT_OK(scope.ToGraphDef(&expected));
856
857 InstantiationResultForTest result;
858 TF_EXPECT_OK(
859 InstantiateFunctionForTest(body_fn.name(), library, &result));
860
861 EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32}), result.arg_types);
862 EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32}), result.ret_types);
863 TF_EXPECT_GRAPH_EQ(expected, result.gdef);
864 }
865 }
866 }
867
868 // More complex example with nesting, loop-invariant arguments, and resource
869 // variables. Used for multiple tests with different node filters.
870 class ComplexTestFixture
871 : public ::testing::TestWithParam<std::tuple<bool, bool, bool>> {
872 protected:
SetUp()873 void SetUp() override {
874 restrict_to_tpu_nodes_ = std::get<0>(GetParam());
875 mark_inner_loop_tpu_ = std::get<1>(GetParam());
876 mark_outer_loop_tpu_ = std::get<2>(GetParam());
877 }
878 void RunTest();
879
880 private:
881 void CheckOuterNodesFunctionalized(const GraphDef& graph_def,
882 const FunctionLibraryDefinition& library,
883 NameAttrList& inner_cond_fn,
884 NameAttrList& inner_body_fn);
885 void CheckInnerNodesFunctionalized(const GraphDef& graph_def,
886 const FunctionLibraryDefinition& library,
887 const NameAttrList& inner_cond_fn,
888 const NameAttrList& inner_body_fn);
889
890 bool restrict_to_tpu_nodes_ = false;
891 bool mark_inner_loop_tpu_ = false;
892 bool mark_outer_loop_tpu_ = false;
893 };
894
TEST_P(ComplexTestFixture,ComplexTests)895 TEST_P(ComplexTestFixture, ComplexTests) { RunTest(); }
896
897 INSTANTIATE_TEST_SUITE_P(
898 FunctionalizeControlFlow, ComplexTestFixture,
899 ::testing::Combine(::testing::Bool(), ::testing::Bool(), ::testing::Bool()),
__anon7e28f50b0402(const ::testing::TestParamInfo<ComplexTestFixture::ParamType>& info) 900 [](const ::testing::TestParamInfo<ComplexTestFixture::ParamType>& info) {
901 bool restrict_to_tpu_nodes = std::get<0>(info.param);
902 bool mark_inner_loop_tpu = std::get<1>(info.param);
903 bool mark_outer_loop_tpu = std::get<2>(info.param);
904
905 string node_string;
906 if (mark_inner_loop_tpu && mark_outer_loop_tpu)
907 node_string = "both_loops_tpu";
908 else if (!mark_inner_loop_tpu && !mark_outer_loop_tpu)
909 node_string = "no_loop_tpu";
910 else
911 node_string = mark_inner_loop_tpu ? "inner_loop_tpu" : "outer_loop_tpu";
912
913 string name = absl::StrCat(
914 restrict_to_tpu_nodes ? "restricted_" : "unrestricted_", node_string);
915 return name;
916 });
917
RunTest()918 void ComplexTestFixture::RunTest() {
919 // Graph:
920 //
921 // accum = resource_variable_ops.ResourceVariable(1)
922 // x = array_ops.placeholder(2, dtype=dtypes.int32)
923 // y = 3 + x
924 //
925 // def inner_body(j, k):
926 // add = state_ops.assign_add(accum, k * j + x)
927 // with ops.control_dependencies([add]):
928 // return [j + 1, k]
929 //
930 // def body(i):
931 // m = control_flow_ops.while_loop(lambda j, k: j < 5, inner_body,
932 // [1, y], name="inner")
933 // with ops.control_dependencies(m):
934 // return [i + 1]
935 //
936 // z = control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="outer")
937 Graph graph(OpRegistry::Global());
938 {
939 Scope scope = Scope::NewRootScope().ExitOnError();
940
941 auto dummy = ops::Placeholder(scope.WithOpName("Dummy"), DT_INT32);
942
943 auto x = ops::Placeholder(scope.WithOpName("x"), DT_INT32);
944 auto three = ops::Const<int32>(scope.WithOpName("three"), 3);
945 auto y = ops::Add(scope.WithOpName("y"), x, three);
946
947 auto var = ops::VarHandleOp(scope.WithOpName("Variable"), DT_INT32,
948 TensorShape({}));
949
950 // Outer loop
951 auto zero = ops::Const<int32>(scope.WithOpName("outer/Const"), 0);
952 auto enter_i =
953 ops::internal::Enter(scope.WithOpName("outer/Enter_i"), zero, "outer");
954 auto merge_i = ops::Merge(scope.WithOpName("outer/Merge_i"),
955 std::initializer_list<Input>{enter_i, dummy});
956 auto ten = ops::Const<int32>(scope.WithOpName("outer/Less/y")
957 .WithControlDependencies(merge_i.output),
958 10);
959 auto less_i =
960 ops::Less(scope.WithOpName("outer/Less_i"), merge_i.output, ten);
961 auto outer_loop_cond =
962 ops::LoopCond(scope.WithOpName("outer/LoopCond"), less_i);
963 auto switch_i = ops::Switch(scope.WithOpName("outer/Switch"),
964 merge_i.output, outer_loop_cond);
965 auto exit_i = ops::internal::Exit(scope.WithOpName("outer/Exit"),
966 switch_i.output_false);
967 auto identity_i =
968 ops::Identity(scope.WithOpName("outer/Identity"), switch_i.output_true);
969
970 auto enter_x_outer =
971 ops::internal::Enter(scope.WithOpName("outer/Enter_x"), x, "outer",
972 ops::internal::Enter::Attrs().IsConstant(true));
973 auto enter_k_outer =
974 ops::internal::Enter(scope.WithOpName("outer/Enter_k"), y, "outer",
975 ops::internal::Enter::Attrs().IsConstant(true));
976 auto enter_var_outer =
977 ops::internal::Enter(scope.WithOpName("outer/Enter_var"), var, "outer",
978 ops::internal::Enter::Attrs().IsConstant(true));
979
980 // Inner loop
981 auto one_j = ops::Const<int32>(
982 scope.WithOpName("outer/j").WithControlDependencies(identity_i), 1);
983 auto enter_j = ops::internal::Enter(scope.WithOpName("outer/inner/Enter_j"),
984 one_j, "inner");
985 auto enter_k =
986 ops::internal::Enter(scope.WithOpName("outer/inner/Enter_k")
987 .WithControlDependencies(identity_i),
988 enter_k_outer, "inner");
989 auto enter_x = ops::internal::Enter(
990 scope.WithOpName("outer/inner/Enter_x"), enter_x_outer, "inner",
991 ops::internal::Enter::Attrs().IsConstant(true));
992 auto enter_var = ops::internal::Enter(
993 scope.WithOpName("outer/inner/Enter_var"), enter_var_outer, "inner",
994 ops::internal::Enter::Attrs().IsConstant(true));
995
996 auto merge_j = ops::Merge(scope.WithOpName("outer/inner/Merge_j"),
997 std::initializer_list<Input>{enter_j, dummy});
998 auto merge_k = ops::Merge(scope.WithOpName("outer/inner/Merge_k"),
999 std::initializer_list<Input>{enter_k, dummy});
1000
1001 auto five = ops::Const<int32>(scope.WithOpName("outer/inner/Five")
1002 .WithControlDependencies(merge_j.output),
1003 5);
1004 auto less_j =
1005 ops::Less(scope.WithOpName("outer/inner/Less_j"), merge_j.output, five);
1006 auto loop_cond =
1007 ops::LoopCond(scope.WithOpName("outer/inner/LoopCond"), less_j);
1008
1009 auto switch_j = ops::Switch(scope.WithOpName("outer/inner/Switch_j"),
1010 merge_j.output, loop_cond);
1011 auto switch_k = ops::Switch(scope.WithOpName("outer/inner/Switch_k"),
1012 merge_k.output, loop_cond);
1013 auto exit_j = ops::internal::Exit(scope.WithOpName("outer/inner/Exit_j"),
1014 switch_j.output_false);
1015 auto exit_k = ops::internal::Exit(scope.WithOpName("outer/inner/Exit_k"),
1016 switch_k.output_false);
1017 auto identity_j = ops::Identity(scope.WithOpName("outer/inner/Identity_j"),
1018 switch_j.output_true);
1019 auto identity_k = ops::Identity(scope.WithOpName("outer/inner/Identity_k"),
1020 switch_k.output_true);
1021
1022 // Variable update
1023 auto mul_jk =
1024 ops::Mul(scope.WithOpName("outer/inner/mul"), identity_j, identity_k);
1025 auto add_jkx =
1026 ops::Add(scope.WithOpName("outer/inner/add"), mul_jk, enter_x);
1027 auto assign = ops::AssignAddVariableOp(
1028 scope.WithOpName("outer/inner/assign_add"), enter_var, add_jkx);
1029
1030 auto one = ops::Const<int32>(
1031 scope.WithOpName("outer/inner/One")
1032 .WithControlDependencies(
1033 absl::Span<const Operation>{assign.operation}),
1034 1);
1035 auto add_j =
1036 ops::Add(scope.WithOpName("outer/inner/add_j"), identity_j, one);
1037
1038 auto next_iteration_j = ops::NextIteration(
1039 scope.WithOpName("outer/inner/NextIteration_j"), add_j);
1040 auto next_iteration_k = ops::NextIteration(
1041 scope.WithOpName("outer/inner/NextIteration_k"), identity_k);
1042
1043 // Body and backedge for outer loop.
1044 auto one_outer = ops::Const<int32>(
1045 scope.WithOpName("outer/add/y").WithControlDependencies(identity_i), 1);
1046 auto add_i =
1047 ops::Add(scope.WithOpName("outer/add")
1048 .WithControlDependencies(absl::Span<const Operation>{
1049 exit_j.output.op(), exit_k.output.op()}),
1050 identity_i, one_outer);
1051 auto next_iteration_i =
1052 ops::NextIteration(scope.WithOpName("outer/NextIteration"), add_i);
1053
1054 auto sink = ops::Identity(scope.WithOpName("sink"), exit_i);
1055
1056 // Remove the dummy node and add the loop backedge.
1057 scope.graph()->RemoveNode(dummy.node());
1058 scope.graph()->AddEdge(next_iteration_i.node(), 0, merge_i.output.node(),
1059 1);
1060 scope.graph()->AddEdge(next_iteration_j.node(), 0, merge_j.output.node(),
1061 1);
1062 scope.graph()->AddEdge(next_iteration_k.node(), 0, merge_k.output.node(),
1063 1);
1064
1065 TF_EXPECT_OK(scope.ToGraph(&graph));
1066 }
1067 // Add '_tpu_replicate' attributes as specified.
1068 for (Node* n : graph.nodes()) {
1069 string name = n->name();
1070 bool is_inner_node = name.find("outer/inner/") != string::npos;
1071 bool is_outer_node = !is_inner_node && name.find("outer/") != string::npos;
1072 if ((is_inner_node && mark_inner_loop_tpu_) ||
1073 (is_outer_node && mark_outer_loop_tpu_)) {
1074 n->AddAttr("_tpu_replicate", "cluster");
1075 }
1076 }
1077
1078 FunctionLibraryDefinition library(OpRegistry::Global(), {});
1079 GraphDef orig_graph_def, optimized_graph_def;
1080 graph.ToGraphDef(&orig_graph_def);
1081 optimized_graph_def = orig_graph_def;
1082 // If `restrict_to_tpu_nodes_` is true let filter function return true for
1083 // `_tpu_replicate` nodes, otherwise don't set filter.
1084 NodeFilter node_filter =
1085 restrict_to_tpu_nodes_
1086 ? [](const Node* n) { return n->attrs().Find("_tpu_replicate"); }
1087 : NodeFilter{};
1088
1089 Status status1 = FunctionalizeControlFlowForGraphDef(&optimized_graph_def,
1090 &library, node_filter);
1091 Status status2 = FunctionalizeControlFlow(&graph, &library, node_filter);
1092 ASSERT_EQ(status1, status2);
1093 if (restrict_to_tpu_nodes_ && mark_outer_loop_tpu_ && !mark_inner_loop_tpu_) {
1094 // This case violates the precondition of `FunctionalizeControlFlow`, we
1095 // expect an internal error.
1096 ASSERT_EQ(errors::IsInternal(status1), true);
1097 return;
1098 } else {
1099 // Supported cases, no error expected.
1100 TF_ASSERT_OK(status1);
1101 }
1102
1103 GraphDef optimized_converted_graph_def;
1104 graph.ToGraphDef(&optimized_converted_graph_def);
1105 for (const GraphDef& graph_def :
1106 {optimized_graph_def, optimized_converted_graph_def}) {
1107 NameAttrList inner_cond_fn, inner_body_fn;
1108 if (!restrict_to_tpu_nodes_ ||
1109 (restrict_to_tpu_nodes_ && mark_outer_loop_tpu_ &&
1110 mark_inner_loop_tpu_)) {
1111 // We expect that both inner and outer nodes have been functionalized.
1112 CheckOuterNodesFunctionalized(graph_def, library, inner_cond_fn,
1113 inner_body_fn);
1114 CheckInnerNodesFunctionalized(graph_def, library, inner_cond_fn,
1115 inner_body_fn);
1116 } else /*restrict_to_tpu_nodes_ == true*/ {
1117 if (!mark_outer_loop_tpu_ && !mark_inner_loop_tpu_) {
1118 // Graph has no TPU nodes so we expect no functionalization.
1119 TF_EXPECT_GRAPH_EQ(orig_graph_def, graph_def);
1120 } else if (!mark_outer_loop_tpu_ && mark_inner_loop_tpu_) {
1121 // We expect that only inner nodes have been functionalized.
1122 TF_EXPECT_OK(
1123 FindWhileCondAndBody(graph_def, &inner_cond_fn, &inner_body_fn));
1124 CheckInnerNodesFunctionalized(graph_def, library, inner_cond_fn,
1125 inner_body_fn);
1126 }
1127 }
1128 }
1129 }
1130
CheckOuterNodesFunctionalized(const GraphDef & graph_def,const FunctionLibraryDefinition & library,NameAttrList & inner_cond_fn,NameAttrList & inner_body_fn)1131 void ComplexTestFixture::CheckOuterNodesFunctionalized(
1132 const GraphDef& graph_def, const FunctionLibraryDefinition& library,
1133 NameAttrList& inner_cond_fn, NameAttrList& inner_body_fn) {
1134 NameAttrList outer_cond_fn, outer_body_fn;
1135 TF_EXPECT_OK(FindWhileCondAndBody(graph_def, &outer_cond_fn, &outer_body_fn));
1136
1137 // Outer graph.
1138 {
1139 Scope scope = Scope::NewRootScope().ExitOnError();
1140 auto x = ops::Placeholder(scope.WithOpName("x"), DT_INT32);
1141 auto three = ops::Const<int32>(scope.WithOpName("three"), 3);
1142 auto y = ops::Add(scope.WithOpName("y"), x, three);
1143
1144 auto var = ops::VarHandleOp(scope.WithOpName("Variable"), DT_INT32,
1145 TensorShape({}));
1146
1147 auto zero = ops::Const<int32>(scope.WithOpName("outer/Const"), 0);
1148
1149 auto while_op = ops::While(scope.WithOpName("outer/LoopCond"),
1150 std::initializer_list<Input>{zero, y, x, var},
1151 outer_cond_fn, outer_body_fn);
1152 auto sink = ops::Identity(scope.WithOpName("sink"), while_op[0]);
1153 GraphDef expected;
1154 TF_EXPECT_OK(scope.ToGraphDef(&expected));
1155 TF_EXPECT_GRAPH_EQ(expected, graph_def);
1156 }
1157
1158 // Outer condition graph.
1159 {
1160 Scope scope = Scope::NewRootScope().ExitOnError();
1161 auto arg0 = ops::_Arg(scope.WithOpName("arg0"), DT_INT32, 0);
1162 auto arg1 = ops::_Arg(scope.WithOpName("arg1"), DT_INT32, 1);
1163 auto arg2 = ops::_Arg(scope.WithOpName("arg2"), DT_INT32, 2);
1164 auto arg3 = ops::_Arg(scope.WithOpName("arg3"), DT_RESOURCE, 3);
1165
1166 auto ten = ops::Const<int32>(
1167 scope.WithOpName("outer/Less/y").WithControlDependencies(arg0.output),
1168 10);
1169 auto less = ops::Less(scope.WithOpName("outer/Less_i"), arg0, ten);
1170 auto retval = ops::_Retval(scope.WithOpName("retval0_RetVal"), less, 0);
1171
1172 GraphDef expected;
1173 TF_EXPECT_OK(scope.ToGraphDef(&expected));
1174
1175 InstantiationResultForTest result;
1176 TF_EXPECT_OK(
1177 InstantiateFunctionForTest(outer_cond_fn.name(), library, &result));
1178
1179 EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32, DT_RESOURCE}),
1180 result.arg_types);
1181 EXPECT_EQ(DataTypeVector{DT_BOOL}, result.ret_types);
1182 TF_EXPECT_GRAPH_EQ(expected, result.gdef);
1183 }
1184
1185 // Outer body graph.
1186 {
1187 InstantiationResultForTest result;
1188 TF_EXPECT_OK(
1189 InstantiateFunctionForTest(outer_body_fn.name(), library, &result));
1190
1191 // Find the inner condition and body names.
1192 TF_EXPECT_OK(
1193 FindWhileCondAndBody(result.gdef, &inner_cond_fn, &inner_body_fn));
1194
1195 Scope scope = Scope::NewRootScope().ExitOnError();
1196 auto arg0 = ops::_Arg(scope.WithOpName("arg0"), DT_INT32, 0);
1197 auto arg1 = ops::_Arg(scope.WithOpName("arg1"), DT_INT32, 1);
1198 auto arg2 = ops::_Arg(scope.WithOpName("arg2"), DT_INT32, 2);
1199 auto arg3 = ops::_Arg(scope.WithOpName("arg3"), DT_RESOURCE, 3);
1200
1201 auto identity_i = ops::Identity(scope.WithOpName("outer/Identity"), arg0);
1202 auto one_j = ops::Const<int32>(
1203 scope.WithOpName("outer/j").WithControlDependencies(identity_i), 1);
1204 auto while_op =
1205 ops::While(scope.WithOpName("outer/inner/LoopCond"),
1206 std::initializer_list<Input>{one_j, arg1, arg2, arg3},
1207 inner_cond_fn, inner_body_fn);
1208
1209 auto one_outer = ops::Const<int32>(
1210 scope.WithOpName("outer/add/y").WithControlDependencies(identity_i), 1);
1211 auto add_i =
1212 ops::Add(scope.WithOpName("outer/add")
1213 .WithControlDependencies(absl::Span<const Operation>{
1214 while_op[0].op(), while_op[1].op()}),
1215 identity_i, one_outer);
1216
1217 auto retval0 = ops::_Retval(scope.WithOpName("retval0_RetVal"), add_i, 0);
1218 auto retval1 = ops::_Retval(scope.WithOpName("retval1_RetVal"), arg1, 1);
1219 auto retval2 = ops::_Retval(scope.WithOpName("retval2_RetVal"), arg2, 2);
1220 auto retval3 = ops::_Retval(scope.WithOpName("retval3_RetVal"), arg3, 3);
1221
1222 GraphDef expected;
1223 TF_EXPECT_OK(scope.ToGraphDef(&expected));
1224
1225 EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32, DT_RESOURCE}),
1226 result.arg_types);
1227 EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32, DT_RESOURCE}),
1228 result.ret_types);
1229 TF_EXPECT_GRAPH_EQ(expected, result.gdef);
1230 }
1231 }
1232
CheckInnerNodesFunctionalized(const GraphDef & graph_def,const FunctionLibraryDefinition & library,const NameAttrList & inner_cond_fn,const NameAttrList & inner_body_fn)1233 void ComplexTestFixture::CheckInnerNodesFunctionalized(
1234 const GraphDef& graph_def, const FunctionLibraryDefinition& library,
1235 const NameAttrList& inner_cond_fn, const NameAttrList& inner_body_fn) {
1236 // Inner condition graph.
1237 {
1238 Scope scope = Scope::NewRootScope().ExitOnError();
1239 auto arg0 = ops::_Arg(scope.WithOpName("arg0"), DT_INT32, 0);
1240 auto arg1 = ops::_Arg(scope.WithOpName("arg1"), DT_INT32, 1);
1241 auto arg2 = ops::_Arg(scope.WithOpName("arg2"), DT_INT32, 2);
1242 auto arg3 = ops::_Arg(scope.WithOpName("arg3"), DT_RESOURCE, 3);
1243
1244 auto five = ops::Const<int32>(
1245 scope.WithOpName("outer/inner/Five").WithControlDependencies(arg0), 5);
1246 auto less_j = ops::Less(scope.WithOpName("outer/inner/Less_j"), arg0, five);
1247 auto retval = ops::_Retval(scope.WithOpName("retval0_RetVal"), less_j, 0);
1248
1249 GraphDef expected;
1250 TF_EXPECT_OK(scope.ToGraphDef(&expected));
1251
1252 InstantiationResultForTest result;
1253 TF_EXPECT_OK(
1254 InstantiateFunctionForTest(inner_cond_fn.name(), library, &result));
1255
1256 EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32, DT_RESOURCE}),
1257 result.arg_types);
1258 EXPECT_EQ(DataTypeVector{DT_BOOL}, result.ret_types);
1259 TF_EXPECT_GRAPH_EQ(expected, result.gdef);
1260 }
1261
1262 // Inner body graph.
1263 {
1264 Scope scope = Scope::NewRootScope().ExitOnError();
1265 auto arg0 = ops::_Arg(scope.WithOpName("arg0"), DT_INT32, 0);
1266 auto arg1 = ops::_Arg(scope.WithOpName("arg1"), DT_INT32, 1);
1267 auto arg2 = ops::_Arg(scope.WithOpName("arg2"), DT_INT32, 2);
1268 auto arg3 = ops::_Arg(scope.WithOpName("arg3"), DT_RESOURCE, 3);
1269
1270 auto identity_j =
1271 ops::Identity(scope.WithOpName("outer/inner/Identity_j"), arg0);
1272 auto identity_k =
1273 ops::Identity(scope.WithOpName("outer/inner/Identity_k"), arg1);
1274
1275 auto mul_jk =
1276 ops::Mul(scope.WithOpName("outer/inner/mul"), identity_j, identity_k);
1277 auto add_jkx = ops::Add(scope.WithOpName("outer/inner/add"), mul_jk, arg2);
1278 auto assign = ops::AssignAddVariableOp(
1279 scope.WithOpName("outer/inner/assign_add"), arg3, add_jkx);
1280
1281 auto one = ops::Const<int32>(
1282 scope.WithOpName("outer/inner/One")
1283 .WithControlDependencies(
1284 absl::Span<const Operation>{assign.operation}),
1285 1);
1286 auto add_j =
1287 ops::Add(scope.WithOpName("outer/inner/add_j"), identity_j, one);
1288
1289 auto retval0 = ops::_Retval(scope.WithOpName("retval0_RetVal"), add_j, 0);
1290 auto retval1 =
1291 ops::_Retval(scope.WithOpName("retval1_RetVal"), identity_k, 1);
1292 auto retval2 = ops::_Retval(scope.WithOpName("retval2_RetVal"), arg2, 2);
1293 auto retval3 = ops::_Retval(scope.WithOpName("retval3_RetVal"), arg3, 3);
1294
1295 GraphDef expected;
1296 TF_EXPECT_OK(scope.ToGraphDef(&expected));
1297
1298 InstantiationResultForTest result;
1299 TF_EXPECT_OK(
1300 InstantiateFunctionForTest(inner_body_fn.name(), library, &result));
1301
1302 EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32, DT_RESOURCE}),
1303 result.arg_types);
1304 EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32, DT_RESOURCE}),
1305 result.ret_types);
1306 TF_EXPECT_GRAPH_EQ(expected, result.gdef);
1307 }
1308 }
1309
1310 } // namespace
1311 } // namespace tensorflow
1312