1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/tf2xla/functionalize_control_flow.h"
17
18 #include <string>
19
20 #include "tensorflow/cc/framework/ops.h"
21 #include "tensorflow/cc/ops/control_flow_ops_internal.h"
22 #include "tensorflow/cc/ops/function_ops.h"
23 #include "tensorflow/cc/ops/functional_ops.h"
24 #include "tensorflow/cc/ops/resource_variable_ops.h"
25 #include "tensorflow/cc/ops/standard_ops.h"
26 #include "tensorflow/compiler/tf2xla/cc/ops/xla_ops.h"
27 #include "tensorflow/compiler/tf2xla/test_util.h"
28 #include "tensorflow/compiler/tf2xla/tf2xla_util.h"
29 #include "tensorflow/compiler/xla/status_macros.h"
30 #include "tensorflow/core/common_runtime/function.h"
31 #include "tensorflow/core/common_runtime/graph_constructor.h"
32 #include "tensorflow/core/framework/function.h"
33 #include "tensorflow/core/framework/graph_to_functiondef.h"
34 #include "tensorflow/core/framework/node_def_util.h"
35 #include "tensorflow/core/framework/op.h"
36 #include "tensorflow/core/graph/graph_def_builder.h"
37 #include "tensorflow/core/graph/validate.h"
38 #include "tensorflow/core/lib/core/status_test_util.h"
39 #include "tensorflow/core/platform/test.h"
40 #include "tensorflow/core/public/version.h"
41 #include "tensorflow/core/util/dump_graph.h"
42 #include "tensorflow/core/util/equal_graph_def.h"
43
44 namespace tensorflow {
45 namespace {
46
47 // Returns the names of the "then" and "else" functions for the If node in a
48 // graph.
FindIfThenAndElse(const GraphDef & graph,string * op_name,NameAttrList * then_fn,NameAttrList * else_fn)49 Status FindIfThenAndElse(const GraphDef& graph, string* op_name,
50 NameAttrList* then_fn, NameAttrList* else_fn) {
51 for (const NodeDef& node : graph.node()) {
52 if (node.op() == "If") {
53 *op_name = node.name();
54 const NameAttrList* result;
55 TF_RETURN_IF_ERROR(GetNodeAttr(node, "then_branch", &result));
56 *then_fn = *result;
57 TF_RETURN_IF_ERROR(GetNodeAttr(node, "else_branch", &result));
58 *else_fn = *result;
59 return OkStatus();
60 }
61 }
62 return errors::NotFound("No If node found in graph");
63 }
64
65 // Graph:
66 // x = array_ops.placeholder(dtypes.int32)
67 // y = array_ops.placeholder(dtypes.int32)
68 // z = control_flow_ops.cond(
69 // math_ops.less(y, x), lambda: math_ops.multiply(y, 17),
70 // lambda: math_ops.add(x, 23))
71 //
72 // Tests different node filters and functionalization inside of a function.
73 class ConditionalTestFixture
74 : public ::testing::TestWithParam<std::tuple<bool, bool>> {
75 protected:
SetUp()76 void SetUp() override {
77 restrict_to_tpu_nodes_ = std::get<0>(GetParam());
78 wrap_condition_in_function_ = std::get<1>(GetParam());
79 }
80 void RunTest();
81
82 private:
83 void BuildCondGraph(Graph* cond_graph);
84 void CheckGraphDef(const GraphDef& graph_def,
85 const FunctionLibraryDefinition& library);
86
87 bool restrict_to_tpu_nodes_ = false;
88 bool wrap_condition_in_function_ = false;
89 };
90
TEST_P(ConditionalTestFixture,ConditionalTests)91 TEST_P(ConditionalTestFixture, ConditionalTests) { RunTest(); }
92
93 INSTANTIATE_TEST_SUITE_P(
94 FunctionalizeControlFlow, ConditionalTestFixture,
95 ::testing::Combine(::testing::Bool(), ::testing::Bool()),
96 [](const ::testing::TestParamInfo<ConditionalTestFixture::ParamType>&
__anon8d20b2690202(const ::testing::TestParamInfo<ConditionalTestFixture::ParamType>& info) 97 info) {
98 bool restrict_to_tpu_nodes = std::get<0>(info.param);
99 bool wrap_cond_in_function = std::get<1>(info.param);
100 string name =
101 absl::StrCat(restrict_to_tpu_nodes ? "with_filter" : "without_filter",
102 wrap_cond_in_function ? "_in_function" : "_in_graph");
103 return name;
104 });
105
BuildCondGraph(Graph * cond_graph)106 void ConditionalTestFixture::BuildCondGraph(Graph* cond_graph) {
107 {
108 Scope scope = Scope::NewRootScope().ExitOnError();
109
110 auto x = ops::Placeholder(scope.WithOpName("x"), DT_INT32);
111 auto y = ops::Placeholder(scope.WithOpName("y"), DT_INT32);
112 auto less = ops::Less(scope.WithOpName("cond/Less"), y, x);
113 auto switch_1 = ops::Switch(scope.WithOpName("cond/Switch"), less, less);
114
115 auto identity_t =
116 ops::Identity(scope.WithOpName("cond/Identity"), switch_1.output_true);
117 auto seventeen = ops::Const<int32>(
118 scope.WithOpName("cond").WithControlDependencies(identity_t), 17);
119 auto switch_2 = ops::Switch(scope.WithOpName("cond/Switch"), y, less);
120 auto mul = ops::Multiply(scope.WithOpName("cond/Mul"), switch_2.output_true,
121 seventeen);
122
123 auto identity_f =
124 ops::Identity(scope.WithOpName("cond/Identity"), switch_1.output_false);
125 auto twenty_three = ops::Const<int32>(
126 scope.WithOpName("cond").WithControlDependencies(identity_f), 23);
127 auto switch_3 = ops::Switch(scope.WithOpName("cond/Switch"), x, less);
128 auto add = ops::Add(scope.WithOpName("cond/false/add"),
129 switch_3.output_false, twenty_three);
130
131 auto merge = ops::Merge(scope.WithOpName("cond/Merge"),
132 std::initializer_list<Input>{add, mul});
133
134 TF_EXPECT_OK(scope.ToGraph(cond_graph));
135
136 // Set all attributes that need propagation for all nodes. This is to test
137 // if propagation works. Note that this includes `_tpu_replicate`.
138 for (Node* n : cond_graph->nodes()) {
139 std::string dummy_value = "value";
140 for (absl::string_view attr_name : kAttrsToPropagate) {
141 n->AddAttr(std::string(attr_name), dummy_value);
142 }
143 }
144 }
145 }
146
CheckGraphDef(const GraphDef & graph_def,const FunctionLibraryDefinition & library)147 void ConditionalTestFixture::CheckGraphDef(
148 const GraphDef& graph_def, const FunctionLibraryDefinition& library) {
149 string op_name;
150 NameAttrList then_fn;
151 NameAttrList else_fn;
152 TF_EXPECT_OK(FindIfThenAndElse(graph_def, &op_name, &then_fn, &else_fn));
153 InstantiationResultForTest else_result;
154 TF_EXPECT_OK(
155 InstantiateFunctionForTest(else_fn.name(), library, &else_result));
156
157 // Outer graph
158 {
159 Scope scope = Scope::NewRootScope().ExitOnError();
160 auto y = ops::Placeholder(scope.WithOpName("y"), DT_INT32);
161 auto x = ops::Placeholder(scope.WithOpName("x"), DT_INT32);
162 auto less = ops::Less(scope.WithOpName("cond/Less"), y, x);
163 auto if_op =
164 ops::If(scope.WithOpName(op_name), less,
165 std::initializer_list<Input>{less, y, x}, {DT_INT32}, then_fn,
166 else_fn, ops::If::OutputShapes({PartialTensorShape()}));
167 auto id = ops::Identity(scope.WithOpName("cond/Merge"), if_op.output[0]);
168 GraphDef expected;
169 TF_EXPECT_OK(scope.ToGraphDef(&expected));
170 TF_EXPECT_GRAPH_EQ(expected, graph_def);
171 }
172
173 // then body.
174 {
175 Scope scope = Scope::NewRootScope().ExitOnError();
176 auto arg_0 = ops::_Arg(scope.WithOpName("arg0"), DT_BOOL, 0);
177 auto arg_1 = ops::_Arg(scope.WithOpName("arg1"), DT_INT32, 1);
178 auto arg_2 = ops::_Arg(scope.WithOpName("arg2"), DT_INT32, 2);
179 auto identity = ops::Identity(scope.WithOpName("cond/Identity"), arg_0);
180 auto cond = ops::Const(
181 scope.WithOpName("cond").WithControlDependencies(identity), 17);
182 auto mul = ops::Mul(scope.WithOpName("cond/Mul"), arg_1, cond);
183 auto retval0 = ops::_Retval(scope.WithOpName("retval0_RetVal"), mul, 0);
184
185 GraphDef expected;
186 TF_EXPECT_OK(scope.ToGraphDef(&expected));
187
188 InstantiationResultForTest result;
189 TF_EXPECT_OK(InstantiateFunctionForTest(then_fn.name(), library, &result));
190
191 EXPECT_EQ(DataTypeVector{DT_INT32}, result.ret_types);
192 EXPECT_EQ((DataTypeVector{DT_BOOL, DT_INT32, DT_INT32}), result.arg_types);
193 TF_EXPECT_GRAPH_EQ(expected, result.gdef);
194 }
195
196 // else body.
197 {
198 Scope scope = Scope::NewRootScope().ExitOnError();
199 auto arg_0 = ops::_Arg(scope.WithOpName("arg0"), DT_BOOL, 0);
200 auto arg_1 = ops::_Arg(scope.WithOpName("arg1"), DT_INT32, 1);
201 auto arg_2 = ops::_Arg(scope.WithOpName("arg2"), DT_INT32, 2);
202 auto identity = ops::Identity(scope.WithOpName("cond/Identity_1"), arg_0);
203 auto cond_1 = ops::Const(
204 scope.WithOpName("cond_1").WithControlDependencies(identity), 23);
205 auto add = ops::Add(scope.WithOpName("cond/false/add"), arg_2, cond_1);
206 auto retval0 = ops::_Retval(scope.WithOpName("retval0_RetVal"), add, 0);
207
208 GraphDef expected;
209 TF_EXPECT_OK(scope.ToGraphDef(&expected));
210
211 InstantiationResultForTest result;
212 TF_EXPECT_OK(InstantiateFunctionForTest(else_fn.name(), library, &result));
213
214 EXPECT_EQ(DataTypeVector{DT_INT32}, result.ret_types);
215 EXPECT_EQ((DataTypeVector{DT_BOOL, DT_INT32, DT_INT32}), result.arg_types);
216 TF_EXPECT_GRAPH_EQ(expected, result.gdef);
217
218 // Check that internal attributes were correctly propagated to `If` node
219 // (such attributes are ignored in the above graph equality check).
220 for (const NodeDef& node : graph_def.node()) {
221 if (node.op() == "If") {
222 for (absl::string_view attr_name : kAttrsToPropagate) {
223 std::string attr_val;
224 TF_EXPECT_OK(GetNodeAttr(node, attr_name, &attr_val));
225 EXPECT_EQ(attr_val, "value");
226 }
227 }
228 }
229 }
230 }
231
RunTest()232 void ConditionalTestFixture::RunTest() {
233 Graph graph(OpRegistry::Global());
234 if (wrap_condition_in_function_) {
235 // Wrap condition in a function which is called from `graph`.
236 Scope scope = Scope::NewRootScope().ExitOnError();
237 auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32);
238
239 Graph cond_graph(OpRegistry::Global());
240 BuildCondGraph(&cond_graph);
241
242 FunctionDef cond_fdef;
243 TF_ASSERT_OK(GraphToFunctionDef(cond_graph, "cond_fn", &cond_fdef));
244
245 FunctionDefLibrary fdef_lib;
246 *(fdef_lib.add_function()) = cond_fdef;
247 TF_ASSERT_OK(scope.graph()->AddFunctionLibrary(fdef_lib));
248 NodeDef cond_fn;
249 cond_fn.set_name("cond_node");
250 cond_fn.set_op("cond_fn");
251 *(cond_fn.add_input()) = "source";
252 Status status;
253 scope.graph()->AddNode(cond_fn, &status);
254 TF_ASSERT_OK(status);
255 TF_ASSERT_OK(scope.ToGraph(&graph));
256 } else {
257 // Build condition in `graph`.
258 BuildCondGraph(&graph);
259 }
260 FunctionLibraryDefinition library(graph.flib_def());
261 // If `restrict_to_tpu_nodes_` is true let filter function return true for
262 // `_tpu_replicate` nodes.
263 NodeFilter node_filter =
264 restrict_to_tpu_nodes_
265 ? [](const Node* n) { return n->attrs().Find("_tpu_replicate"); }
266 : NodeFilter{};
267
268 GraphDef optimized_graph_def;
269 graph.ToGraphDef(&optimized_graph_def);
270 TF_ASSERT_OK(FunctionalizeControlFlowForGraphDef(
271 &optimized_graph_def, &library, node_filter,
272 /*include_functions=*/wrap_condition_in_function_));
273 TF_ASSERT_OK(FunctionalizeControlFlow(
274 &graph, &library, node_filter,
275 /*include_functions=*/wrap_condition_in_function_));
276
277 if (wrap_condition_in_function_) {
278 // Check if function body was functionalized.
279 auto pflr = std::make_unique<ProcessFunctionLibraryRuntime>(
280 /*device_mgr=*/nullptr, tensorflow::Env::Default(),
281 /*config=*/nullptr, TF_GRAPH_DEF_VERSION, &library,
282 tensorflow::OptimizerOptions());
283 FunctionLibraryRuntime* flr =
284 pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice);
285 FunctionLibraryRuntime::Handle handle;
286
287 // Functionalized function name is the type string of `cond_node`.
288 string func_name;
289 for (Node* n : graph.nodes()) {
290 if (n->name() == "cond_node") {
291 func_name = n->type_string();
292 break;
293 }
294 }
295 TF_ASSERT_OK(flr->Instantiate(func_name, AttrSlice(), &handle));
296 const FunctionBody* body = flr->GetFunctionBody(handle);
297 GraphDef graph_def;
298 body->graph->ToGraphDef(&graph_def);
299 CheckGraphDef(graph_def, library);
300 } else {
301 // Check if graphs were functionalized.
302 CheckGraphDef(optimized_graph_def, library);
303 GraphDef converted_graph_def;
304 graph.ToGraphDef(&converted_graph_def);
305 CheckGraphDef(converted_graph_def, library);
306 }
307 }
308
309 // Returns the names of the "cond" and "body" functions for the While node
310 // in a graph.
FindWhileCondAndBody(const GraphDef & graph,NameAttrList * cond,NameAttrList * body)311 Status FindWhileCondAndBody(const GraphDef& graph, NameAttrList* cond,
312 NameAttrList* body) {
313 for (const NodeDef& node : graph.node()) {
314 if (node.op() == "While") {
315 const NameAttrList* result;
316 TF_RETURN_IF_ERROR(GetNodeAttr(node, "cond", &result));
317 *cond = *result;
318 TF_RETURN_IF_ERROR(GetNodeAttr(node, "body", &result));
319 *body = *result;
320 return OkStatus();
321 }
322 }
323 return errors::NotFound("No While node found in graph");
324 }
325
326 // Graph:
327 // x = array_ops.placeholder(dtypes.int32)
328 // y = control_flow_ops.while_loop(lambda i: i < 10, lambda i: i + 1, [x])
TEST(FunctionalizeControlFlow,OneLoopVar)329 TEST(FunctionalizeControlFlow, OneLoopVar) {
330 Graph graph(OpRegistry::Global());
331 {
332 Scope scope = Scope::NewRootScope().ExitOnError();
333
334 auto dummy = ops::Placeholder(scope.WithOpName("Dummy"), DT_INT32);
335
336 auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32);
337 auto enter =
338 ops::internal::Enter(scope.WithOpName("while/Enter"), source, "aloop");
339 // Add an unused Enter node. These should be ignored.
340 auto enter2 =
341 ops::internal::Enter(scope.WithOpName("while/Enter2"), source, "aloop");
342 auto merge = ops::Merge(scope.WithOpName("while/Merge"),
343 std::initializer_list<Input>{enter, dummy});
344 auto ten = ops::Const<int32>(
345 scope.WithOpName("while/Less/y").WithControlDependencies(merge.output),
346 10);
347 auto less = ops::Less(scope.WithOpName("while/Less"), merge.output, ten);
348 auto loop_cond = ops::LoopCond(scope.WithOpName("while/LoopCond"), less);
349 auto switch_ =
350 ops::Switch(scope.WithOpName("while/Switch"), merge.output, loop_cond);
351 auto exit = ops::internal::Exit(scope.WithOpName("while/Exit"),
352 switch_.output_false);
353 auto identity =
354 ops::Identity(scope.WithOpName("while/Identity"), switch_.output_true);
355 auto one = ops::Const<int32>(
356 scope.WithOpName("while/add/y").WithControlDependencies(identity), 1);
357 auto add = ops::Add(scope.WithOpName("while/add"), identity, one);
358 auto next_iteration =
359 ops::NextIteration(scope.WithOpName("while/NextIteration"), add);
360
361 auto sink = ops::Identity(scope.WithOpName("sink"), exit);
362
363 // Remove the dummy node and add the loop backedge.
364 scope.graph()->RemoveNode(dummy.node());
365 scope.graph()->AddEdge(next_iteration.node(), 0, merge.output.node(), 1);
366
367 TF_EXPECT_OK(scope.ToGraph(&graph));
368 }
369
370 // Regression test: control edges from an Enter node to the graph sink should
371 // be ignored.
372 for (Node* n : graph.nodes()) {
373 if (n->name() == "while/Enter") {
374 graph.AddControlEdge(n, graph.sink_node());
375 }
376 }
377
378 FunctionLibraryDefinition library(OpRegistry::Global(), {});
379 GraphDef optimized_graph_def;
380 graph.ToGraphDef(&optimized_graph_def);
381 TF_ASSERT_OK(
382 FunctionalizeControlFlowForGraphDef(&optimized_graph_def, &library));
383 TF_ASSERT_OK(FunctionalizeControlFlow(&graph, &library));
384 GraphDef converted_graph_def;
385 graph.ToGraphDef(&converted_graph_def);
386
387 for (const GraphDef& graph_def : {optimized_graph_def, converted_graph_def}) {
388 NameAttrList cond_fn, body_fn;
389 TF_EXPECT_OK(FindWhileCondAndBody(graph_def, &cond_fn, &body_fn));
390
391 // Outer graph
392 {
393 Scope scope = Scope::NewRootScope().ExitOnError();
394 auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32);
395 auto while_op =
396 ops::While(scope.WithOpName("while/LoopCond"),
397 std::initializer_list<Input>{source}, cond_fn, body_fn);
398 auto sink = ops::Identity(scope.WithOpName("sink"), while_op[0]);
399 GraphDef expected;
400 TF_EXPECT_OK(scope.ToGraphDef(&expected));
401 TF_EXPECT_GRAPH_EQ(expected, graph_def);
402 }
403
404 // Condition graph
405 {
406 Scope scope = Scope::NewRootScope().ExitOnError();
407 auto arg = ops::_Arg(scope.WithOpName("arg0"), DT_INT32, 0);
408 auto ten = ops::Const<int32>(
409 scope.WithOpName("while/Less/y").WithControlDependencies(arg), 10);
410 auto less = ops::Less(scope.WithOpName("while/Less"), arg, ten);
411 auto retval = ops::_Retval(scope.WithOpName("retval0_RetVal"), less, 0);
412
413 GraphDef expected;
414 TF_EXPECT_OK(scope.ToGraphDef(&expected));
415
416 InstantiationResultForTest result;
417 TF_EXPECT_OK(
418 InstantiateFunctionForTest(cond_fn.name(), library, &result));
419
420 EXPECT_EQ(DataTypeVector{DT_INT32}, result.arg_types);
421 EXPECT_EQ(DataTypeVector{DT_BOOL}, result.ret_types);
422 TF_EXPECT_GRAPH_EQ(expected, result.gdef);
423 }
424
425 // Body graph.
426 {
427 Scope scope = Scope::NewRootScope().ExitOnError();
428 auto arg = ops::_Arg(scope.WithOpName("arg0"), DT_INT32, 0);
429 auto identity = ops::Identity(scope.WithOpName("while/Identity"), arg);
430 auto one = ops::Const<int32>(
431 scope.WithOpName("while/add/y").WithControlDependencies(identity), 1);
432 auto add = ops::Add(scope.WithOpName("while/add"), identity, one);
433 auto retval = ops::_Retval(scope.WithOpName("retval0_RetVal"), add, 0);
434
435 GraphDef expected;
436 TF_EXPECT_OK(scope.ToGraphDef(&expected));
437
438 InstantiationResultForTest result;
439 TF_EXPECT_OK(
440 InstantiateFunctionForTest(body_fn.name(), library, &result));
441
442 EXPECT_EQ(DataTypeVector{DT_INT32}, result.arg_types);
443 EXPECT_EQ(DataTypeVector{DT_INT32}, result.ret_types);
444 TF_EXPECT_GRAPH_EQ(expected, result.gdef);
445 }
446 }
447 }
448
GetNoinlineFunctionDef()449 FunctionDef GetNoinlineFunctionDef() {
450 FunctionDef fdef = FunctionDefHelper::Create(
451 "increment_fn", {"x:int32"}, {"add:int32"}, {},
452 {
453 {{"add/y"}, "Const", {}, {{"dtype", DT_INT32}}},
454 {{"add_0"}, "Add", {"x", "add/y:output:0"}, {{"T", DT_INT32}}},
455 },
456 {{"add", "add_0:z:0"}});
457 (*fdef.mutable_attr())["_noinline"].set_b(true);
458 return fdef;
459 }
460
461 // @function.Defun(noinline=True)
462 // def increment_fn(x):
463 // return [x + 1]
464 // Define the above function, and add it to the given graph. It's used as the
465 // while loop body in NoinlineLoopBody test.
AddNoinlineFunctionToGraph(const string & node_name,Graph * graph)466 Status AddNoinlineFunctionToGraph(const string& node_name, Graph* graph) {
467 FunctionDefLibrary fdef_lib;
468 *(fdef_lib.add_function()) = GetNoinlineFunctionDef();
469 TF_RETURN_IF_ERROR(graph->AddFunctionLibrary(fdef_lib));
470 NodeDef increment_fn;
471 increment_fn.set_name(node_name);
472 increment_fn.set_op("increment_fn");
473 *increment_fn.add_input() = "while/Identity";
474 *increment_fn.add_input() = "^while/Identity";
475 Status status;
476 graph->AddNode(increment_fn, &status);
477 return status;
478 }
479
480 // Graph:
481 // x = array_ops.placeholder(dtypes.int32)
482 // y = control_flow_ops.while_loop(lambda i: i < 10, increment_fn, [x])
TEST(FunctionalizeControlFlow,NoinlineLoopBody)483 TEST(FunctionalizeControlFlow, NoinlineLoopBody) {
484 const string& noinline_node_name = "while/increment_fn";
485 Graph graph(OpRegistry::Global());
486 {
487 Scope scope = Scope::NewRootScope().ExitOnError();
488 auto dummy = ops::Placeholder(scope.WithOpName("Dummy"), DT_INT32);
489 auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32);
490 auto enter = ops::internal::Enter(scope.WithOpName("while/Enter"), source,
491 "while/while_context");
492 auto merge = ops::Merge(scope.WithOpName("while/Merge"),
493 std::initializer_list<Input>{enter, dummy});
494 auto ten = ops::Const<int32>(
495 scope.WithOpName("while/Less/y").WithControlDependencies(merge.output),
496 10);
497 auto less = ops::Less(scope.WithOpName("while/Less"), merge.output, ten);
498 auto loop_cond = ops::LoopCond(scope.WithOpName("while/LoopCond"), less);
499 auto switch_ =
500 ops::Switch(scope.WithOpName("while/Switch"), merge.output, loop_cond);
501 auto exit = ops::internal::Exit(scope.WithOpName("while/Exit"),
502 switch_.output_false);
503 auto identity =
504 ops::Identity(scope.WithOpName("while/Identity"), switch_.output_true);
505
506 TF_ASSERT_OK(AddNoinlineFunctionToGraph(noinline_node_name, scope.graph()));
507
508 NodeDef next_iter;
509 next_iter.set_name("while/NextIteration");
510 next_iter.set_op("NextIteration");
511 *next_iter.add_input() = noinline_node_name;
512 (*next_iter.mutable_attr())["T"].set_type(DT_INT32);
513
514 Status status;
515 Node* n = scope.graph()->AddNode(next_iter, &status);
516 TF_ASSERT_OK(status);
517
518 // Remove the dummy node and add the loop backedge.
519 scope.graph()->RemoveNode(dummy.node());
520 scope.graph()->AddEdge(n, 0, merge.output.node(), 1);
521 TF_ASSERT_OK(scope.ToGraph(&graph));
522 }
523
524 FunctionLibraryDefinition library(graph.flib_def());
525 GraphDef optimized_graph_def;
526 graph.ToGraphDef(&optimized_graph_def);
527
528 *(optimized_graph_def.mutable_library()->add_function()) =
529 GetNoinlineFunctionDef();
530
531 TF_ASSERT_OK(
532 FunctionalizeControlFlowForGraphDef(&optimized_graph_def, &library));
533 TF_ASSERT_OK(FunctionalizeControlFlow(&graph, &library));
534 GraphDef converted_graph_def;
535 graph.ToGraphDef(&converted_graph_def);
536
537 for (const GraphDef& graph_def : {optimized_graph_def, converted_graph_def}) {
538 NameAttrList cond_fn, body_fn;
539 TF_ASSERT_OK(FindWhileCondAndBody(graph_def, &cond_fn, &body_fn));
540
541 // Outer graph
542 {
543 Scope scope = Scope::NewRootScope().ExitOnError();
544 auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32);
545 auto while_op =
546 ops::While(scope.WithOpName("while/LoopCond"),
547 std::initializer_list<Input>{source}, cond_fn, body_fn);
548 GraphDef expected;
549 TF_ASSERT_OK(scope.ToGraphDef(&expected));
550 TF_EXPECT_GRAPH_EQ(expected, graph_def);
551 }
552
553 // Body graph.
554 {
555 Scope scope = Scope::NewRootScope().ExitOnError();
556 auto arg = ops::_Arg(scope.WithOpName("arg0"), DT_INT32, 0);
557 TF_ASSERT_OK(
558 AddNoinlineFunctionToGraph(noinline_node_name, scope.graph()));
559 auto identity = ops::Identity(scope.WithOpName("while/Identity"), arg);
560 NodeDef retval;
561 retval.set_name("retval0_RetVal");
562 retval.set_op(FunctionLibraryDefinition::kRetOp);
563 *retval.add_input() = noinline_node_name;
564 (*retval.mutable_attr())["T"].set_type(DT_INT32);
565 (*retval.mutable_attr())["index"].set_i(0);
566 Status status;
567 scope.graph()->AddNode(retval, &status);
568 TF_ASSERT_OK(status);
569
570 GraphDef expected;
571 TF_ASSERT_OK(scope.ToGraphDef(&expected));
572
573 InstantiationResultForTest result;
574 // Verify that increment_fn has been copied to library.
575 TF_EXPECT_OK(
576 InstantiateFunctionForTest(body_fn.name(), library, &result));
577
578 EXPECT_EQ(DataTypeVector{DT_INT32}, result.arg_types);
579 EXPECT_EQ(DataTypeVector{DT_INT32}, result.ret_types);
580 // Ignore the function library when comparing the graphs.
581 expected.clear_library();
582 TF_EXPECT_GRAPH_EQ(expected, result.gdef);
583 }
584 }
585 }
586
TEST(FunctionalizeControlFlow,MissingFunctionDefInLibrary)587 TEST(FunctionalizeControlFlow, MissingFunctionDefInLibrary) {
588 const string& noinline_node_name = "while/increment_fn";
589 Graph graph(OpRegistry::Global());
590 {
591 Scope scope = Scope::NewRootScope().ExitOnError();
592 auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32);
593 auto identity = ops::Identity(scope.WithOpName("while/Identity"), source);
594 TF_ASSERT_OK(AddNoinlineFunctionToGraph(noinline_node_name, scope.graph()));
595 TF_ASSERT_OK(scope.ToGraph(&graph));
596 }
597
598 FunctionLibraryDefinition library(graph.flib_def());
599 GraphDef graph_def;
600 graph.ToGraphDef(&graph_def);
601 graph_def.clear_library();
602
603 Status status = FunctionalizeControlFlowForGraphDef(&graph_def, &library);
604 EXPECT_EQ(tensorflow::error::NOT_FOUND, status.code());
605 }
606
607 // Tests functionalizing OneLoopVar where the loop value is not used post the
608 // loop.
609 // Graph:
610 // x = array_ops.placeholder(dtypes.int32)
611 // control_flow_ops.while_loop(lambda i: i < 10, lambda i: i + 1, [x])
TEST(FunctionalizeControlFlow,OneLoopVarWithoutExit)612 TEST(FunctionalizeControlFlow, OneLoopVarWithoutExit) {
613 Graph graph(OpRegistry::Global());
614 {
615 Scope scope = Scope::NewRootScope().ExitOnError();
616
617 auto dummy = ops::Placeholder(scope.WithOpName("Dummy"), DT_INT32);
618
619 auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32);
620 auto enter =
621 ops::internal::Enter(scope.WithOpName("while/Enter"), source, "aloop");
622 auto merge = ops::Merge(scope.WithOpName("while/Merge"),
623 std::initializer_list<Input>{enter, dummy});
624 auto ten = ops::Const<int32>(
625 scope.WithOpName("while/Less/y").WithControlDependencies(merge.output),
626 10);
627 auto less = ops::Less(scope.WithOpName("while/Less"), merge.output, ten);
628 auto loop_cond = ops::LoopCond(scope.WithOpName("while/LoopCond"), less);
629 auto switch_ =
630 ops::Switch(scope.WithOpName("while/Switch"), merge.output, loop_cond);
631 auto identity =
632 ops::Identity(scope.WithOpName("while/Identity"), switch_.output_true);
633 auto one = ops::Const<int32>(
634 scope.WithOpName("while/add/y").WithControlDependencies(identity), 1);
635 auto add = ops::Add(scope.WithOpName("while/add"), identity, one);
636 auto next_iteration =
637 ops::NextIteration(scope.WithOpName("while/NextIteration"), add);
638
639 // Remove the dummy node and add the loop backedge.
640 scope.graph()->RemoveNode(dummy.node());
641 scope.graph()->AddEdge(next_iteration.node(), 0, merge.output.node(), 1);
642
643 TF_EXPECT_OK(scope.ToGraph(&graph));
644 }
645
646 FunctionLibraryDefinition library(OpRegistry::Global(), {});
647 GraphDef optimized_graph_def;
648 graph.ToGraphDef(&optimized_graph_def);
649 TF_ASSERT_OK(
650 FunctionalizeControlFlowForGraphDef(&optimized_graph_def, &library));
651 TF_ASSERT_OK(FunctionalizeControlFlow(&graph, &library));
652 GraphDef converted_graph_def;
653 graph.ToGraphDef(&converted_graph_def);
654
655 for (const GraphDef& graph_def : {optimized_graph_def, converted_graph_def}) {
656 NameAttrList cond_fn, body_fn;
657 TF_EXPECT_OK(FindWhileCondAndBody(graph_def, &cond_fn, &body_fn));
658
659 // Outer graph
660 {
661 Scope scope = Scope::NewRootScope().ExitOnError();
662 auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32);
663 auto while_op =
664 ops::While(scope.WithOpName("while/LoopCond"),
665 std::initializer_list<Input>{source}, cond_fn, body_fn);
666 GraphDef expected;
667 TF_EXPECT_OK(scope.ToGraphDef(&expected));
668 TF_EXPECT_GRAPH_EQ(expected, graph_def);
669 }
670
671 // Condition graph
672 {
673 Scope scope = Scope::NewRootScope().ExitOnError();
674 auto arg = ops::_Arg(scope.WithOpName("arg0"), DT_INT32, 0);
675 auto ten = ops::Const<int32>(
676 scope.WithOpName("while/Less/y").WithControlDependencies(arg), 10);
677 auto less = ops::Less(scope.WithOpName("while/Less"), arg, ten);
678 auto retval = ops::_Retval(scope.WithOpName("retval0_RetVal"), less, 0);
679
680 GraphDef expected;
681 TF_EXPECT_OK(scope.ToGraphDef(&expected));
682
683 InstantiationResultForTest result;
684 TF_EXPECT_OK(
685 InstantiateFunctionForTest(cond_fn.name(), library, &result));
686
687 EXPECT_EQ(DataTypeVector{DT_INT32}, result.arg_types);
688 EXPECT_EQ(DataTypeVector{DT_BOOL}, result.ret_types);
689 TF_EXPECT_GRAPH_EQ(expected, result.gdef);
690 }
691
692 // Body graph.
693 {
694 Scope scope = Scope::NewRootScope().ExitOnError();
695 auto arg = ops::_Arg(scope.WithOpName("arg0"), DT_INT32, 0);
696 auto identity = ops::Identity(scope.WithOpName("while/Identity"), arg);
697 auto one = ops::Const<int32>(
698 scope.WithOpName("while/add/y").WithControlDependencies(identity), 1);
699 auto add = ops::Add(scope.WithOpName("while/add"), identity, one);
700 auto retval = ops::_Retval(scope.WithOpName("retval0_RetVal"), add, 0);
701
702 GraphDef expected;
703 TF_EXPECT_OK(scope.ToGraphDef(&expected));
704
705 InstantiationResultForTest result;
706 TF_EXPECT_OK(
707 InstantiateFunctionForTest(body_fn.name(), library, &result));
708
709 EXPECT_EQ(DataTypeVector{DT_INT32}, result.arg_types);
710 EXPECT_EQ(DataTypeVector{DT_INT32}, result.ret_types);
711 TF_EXPECT_GRAPH_EQ(expected, result.gdef);
712 }
713 }
714 }
715
716 // Graph:
717 // x = array_ops.placeholder(dtypes.int32)
718 // y = array_ops.placeholder(dtypes.int32)
719 // cond = lambda (i, j): i + 3 < 10
720 // body = lambda (i, j): (i < 10, j * 2)
721 // z = control_flow_ops.while_loop(cond, body, [x, y])
TEST(FunctionalizeControlFlow,TwoLoopVars)722 TEST(FunctionalizeControlFlow, TwoLoopVars) {
723 Graph graph(OpRegistry::Global());
724 {
725 Scope scope = Scope::NewRootScope().ExitOnError();
726
727 auto dummy = ops::Placeholder(scope.WithOpName("Dummy"), DT_INT32);
728
729 auto x = ops::Placeholder(scope.WithOpName("Placeholder/x"), DT_INT32);
730 auto y = ops::Placeholder(scope.WithOpName("Placeholder/y"), DT_INT32);
731 auto enter_x =
732 ops::internal::Enter(scope.WithOpName("while/Enter/x"), x, "aloop");
733 auto enter_y =
734 ops::internal::Enter(scope.WithOpName("while/Enter/y"), y, "aloop");
735 auto merge_x = ops::Merge(scope.WithOpName("while/Merge/x"),
736 std::initializer_list<Input>{enter_x, dummy});
737 auto merge_y = ops::Merge(scope.WithOpName("while/Merge/y"),
738 std::initializer_list<Input>{enter_y, dummy});
739
740 // Loop condition
741 auto three = ops::Const<int32>(scope.WithOpName("while/cond/three")
742 .WithControlDependencies(merge_x.output),
743 3);
744 auto cond_add =
745 ops::Add(scope.WithOpName("while/cond/Add"), merge_x.output, three);
746 auto ten = ops::Const<int32>(scope.WithOpName("while/cond/ten")
747 .WithControlDependencies(merge_x.output),
748 10);
749 auto less = ops::Less(scope.WithOpName("while/cond/Less"), cond_add, ten);
750 auto loop_cond = ops::LoopCond(scope.WithOpName("while/LoopCond"), less);
751
752 auto switch_x = ops::Switch(scope.WithOpName("while/Switch/x"),
753 merge_x.output, loop_cond);
754 auto switch_y = ops::Switch(scope.WithOpName("while/Switch/y"),
755 merge_y.output, loop_cond);
756
757 auto exit_x = ops::internal::Exit(scope.WithOpName("while/Exit/x"),
758 switch_x.output_false);
759 auto exit_y = ops::internal::Exit(scope.WithOpName("while/Exit/y"),
760 switch_y.output_false);
761
762 auto identity_x = ops::Identity(scope.WithOpName("while/Identity/x"),
763 switch_x.output_true);
764 auto identity_y = ops::Identity(scope.WithOpName("while/Identity/y"),
765 switch_y.output_true);
766
767 auto one = ops::Const<int32>(
768 scope.WithOpName("while/add/one").WithControlDependencies(identity_x),
769 1);
770 auto two = ops::Const<int32>(
771 scope.WithOpName("while/mul/two").WithControlDependencies(identity_x),
772 2);
773
774 auto add = ops::Add(scope.WithOpName("while/add"), identity_x, one);
775 auto mul = ops::Add(scope.WithOpName("while/mul"), identity_y, two);
776 auto next_iteration_x =
777 ops::NextIteration(scope.WithOpName("while/NextIteration/x"), add);
778 auto next_iteration_y =
779 ops::NextIteration(scope.WithOpName("while/NextIteration/y"), mul);
780
781 auto sink_x = ops::Identity(scope.WithOpName("sink_x"), exit_x);
782 auto sink_y = ops::Identity(scope.WithOpName("sink_y"), exit_y);
783
784 // Remove the dummy node and add the loop backedges.
785 scope.graph()->RemoveNode(dummy.node());
786 scope.graph()->AddEdge(next_iteration_x.node(), 0, merge_x.output.node(),
787 1);
788 scope.graph()->AddEdge(next_iteration_y.node(), 0, merge_y.output.node(),
789 1);
790
791 TF_EXPECT_OK(scope.ToGraph(&graph));
792 }
793
794 FunctionLibraryDefinition library(OpRegistry::Global(), {});
795 GraphDef optimized_graph_def;
796 graph.ToGraphDef(&optimized_graph_def);
797 TF_ASSERT_OK(
798 FunctionalizeControlFlowForGraphDef(&optimized_graph_def, &library));
799 TF_ASSERT_OK(FunctionalizeControlFlow(&graph, &library));
800 GraphDef converted_graph_def;
801 graph.ToGraphDef(&converted_graph_def);
802
803 for (const GraphDef& graph_def : {optimized_graph_def, converted_graph_def}) {
804 NameAttrList cond_fn, body_fn;
805 TF_EXPECT_OK(FindWhileCondAndBody(graph_def, &cond_fn, &body_fn));
806
807 // Outer graph.
808 {
809 Scope scope = Scope::NewRootScope().ExitOnError();
810 auto x = ops::Placeholder(scope.WithOpName("Placeholder/x"), DT_INT32);
811 auto y = ops::Placeholder(scope.WithOpName("Placeholder/y"), DT_INT32);
812 auto while_op =
813 ops::While(scope.WithOpName("while/LoopCond"),
814 std::initializer_list<Input>{x, y}, cond_fn, body_fn);
815 auto sink_x = ops::Identity(scope.WithOpName("sink_x"), while_op[0]);
816 auto sink_y = ops::Identity(scope.WithOpName("sink_y"), while_op[1]);
817 GraphDef expected;
818 TF_EXPECT_OK(scope.ToGraphDef(&expected));
819 TF_EXPECT_GRAPH_EQ(expected, graph_def);
820 }
821
822 // Condition graph.
823 {
824 Scope scope = Scope::NewRootScope().ExitOnError();
825 auto arg0 = ops::_Arg(scope.WithOpName("arg0"), DT_INT32, 0);
826 auto arg1 = ops::_Arg(scope.WithOpName("arg1"), DT_INT32, 1);
827 auto three = ops::Const<int32>(scope.WithOpName("while/cond/three")
828 .WithControlDependencies(arg0.output),
829 3);
830 auto cond_add =
831 ops::Add(scope.WithOpName("while/cond/Add"), arg0.output, three);
832 auto ten = ops::Const<int32>(scope.WithOpName("while/cond/ten")
833 .WithControlDependencies(arg0.output),
834 10);
835 auto less = ops::Less(scope.WithOpName("while/cond/Less"), cond_add, ten);
836 auto retval = ops::_Retval(scope.WithOpName("retval0_RetVal"), less, 0);
837
838 GraphDef expected;
839 TF_EXPECT_OK(scope.ToGraphDef(&expected));
840
841 InstantiationResultForTest result;
842 TF_EXPECT_OK(
843 InstantiateFunctionForTest(cond_fn.name(), library, &result));
844
845 EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32}), result.arg_types);
846 EXPECT_EQ(DataTypeVector{DT_BOOL}, result.ret_types);
847 TF_EXPECT_GRAPH_EQ(expected, result.gdef);
848 }
849
850 // Body graph.
851 {
852 Scope scope = Scope::NewRootScope().ExitOnError();
853 auto arg0 = ops::_Arg(scope.WithOpName("arg0"), DT_INT32, 0);
854 auto arg1 = ops::_Arg(scope.WithOpName("arg1"), DT_INT32, 1);
855
856 auto identity_x =
857 ops::Identity(scope.WithOpName("while/Identity/x"), arg0);
858 auto identity_y =
859 ops::Identity(scope.WithOpName("while/Identity/y"), arg1);
860
861 auto one = ops::Const<int32>(
862 scope.WithOpName("while/add/one").WithControlDependencies(identity_x),
863 1);
864 auto two = ops::Const<int32>(
865 scope.WithOpName("while/mul/two").WithControlDependencies(identity_x),
866 2);
867
868 auto add = ops::Add(scope.WithOpName("while/add"), identity_x, one);
869 auto mul = ops::Add(scope.WithOpName("while/mul"), identity_y, two);
870 auto retval0 = ops::_Retval(scope.WithOpName("retval0_RetVal"), add, 0);
871 auto retval1 = ops::_Retval(scope.WithOpName("retval1_RetVal"), mul, 1);
872
873 GraphDef expected;
874 TF_EXPECT_OK(scope.ToGraphDef(&expected));
875
876 InstantiationResultForTest result;
877 TF_EXPECT_OK(
878 InstantiateFunctionForTest(body_fn.name(), library, &result));
879
880 EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32}), result.arg_types);
881 EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32}), result.ret_types);
882 TF_EXPECT_GRAPH_EQ(expected, result.gdef);
883 }
884 }
885 }
886
887 // More complex example with nesting, loop-invariant arguments, and resource
888 // variables. Used for multiple tests with different node filters.
889 class ComplexTestFixture
890 : public ::testing::TestWithParam<std::tuple<bool, bool, bool>> {
891 protected:
SetUp()892 void SetUp() override {
893 restrict_to_tpu_nodes_ = std::get<0>(GetParam());
894 mark_inner_loop_tpu_ = std::get<1>(GetParam());
895 mark_outer_loop_tpu_ = std::get<2>(GetParam());
896 }
897 void RunTest();
898
899 private:
900 void CheckOuterNodesFunctionalized(const GraphDef& graph_def,
901 const FunctionLibraryDefinition& library,
902 NameAttrList& inner_cond_fn,
903 NameAttrList& inner_body_fn);
904 void CheckInnerNodesFunctionalized(const GraphDef& graph_def,
905 const FunctionLibraryDefinition& library,
906 const NameAttrList& inner_cond_fn,
907 const NameAttrList& inner_body_fn);
908
909 bool restrict_to_tpu_nodes_ = false;
910 bool mark_inner_loop_tpu_ = false;
911 bool mark_outer_loop_tpu_ = false;
912 };
913
TEST_P(ComplexTestFixture,ComplexTests)914 TEST_P(ComplexTestFixture, ComplexTests) { RunTest(); }
915
916 INSTANTIATE_TEST_SUITE_P(
917 FunctionalizeControlFlow, ComplexTestFixture,
918 ::testing::Combine(::testing::Bool(), ::testing::Bool(), ::testing::Bool()),
__anon8d20b2690402(const ::testing::TestParamInfo<ComplexTestFixture::ParamType>& info) 919 [](const ::testing::TestParamInfo<ComplexTestFixture::ParamType>& info) {
920 bool restrict_to_tpu_nodes = std::get<0>(info.param);
921 bool mark_inner_loop_tpu = std::get<1>(info.param);
922 bool mark_outer_loop_tpu = std::get<2>(info.param);
923
924 string node_string;
925 if (mark_inner_loop_tpu && mark_outer_loop_tpu)
926 node_string = "both_loops_tpu";
927 else if (!mark_inner_loop_tpu && !mark_outer_loop_tpu)
928 node_string = "no_loop_tpu";
929 else
930 node_string = mark_inner_loop_tpu ? "inner_loop_tpu" : "outer_loop_tpu";
931
932 string name = absl::StrCat(
933 restrict_to_tpu_nodes ? "restricted_" : "unrestricted_", node_string);
934 return name;
935 });
936
RunTest()937 void ComplexTestFixture::RunTest() {
938 // Graph:
939 //
940 // accum = resource_variable_ops.ResourceVariable(1)
941 // x = array_ops.placeholder(2, dtype=dtypes.int32)
942 // y = 3 + x
943 //
944 // def inner_body(j, k):
945 // add = state_ops.assign_add(accum, k * j + x)
946 // with ops.control_dependencies([add]):
947 // return [j + 1, k]
948 //
949 // def body(i):
950 // m = control_flow_ops.while_loop(lambda j, k: j < 5, inner_body,
951 // [1, y], name="inner")
952 // with ops.control_dependencies(m):
953 // return [i + 1]
954 //
955 // z = control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="outer")
956 Graph graph(OpRegistry::Global());
957 {
958 Scope scope = Scope::NewRootScope().ExitOnError();
959
960 auto dummy = ops::Placeholder(scope.WithOpName("Dummy"), DT_INT32);
961
962 auto x = ops::Placeholder(scope.WithOpName("x"), DT_INT32);
963 auto three = ops::Const<int32>(scope.WithOpName("three"), 3);
964 auto y = ops::Add(scope.WithOpName("y"), x, three);
965
966 auto var = ops::VarHandleOp(scope.WithOpName("Variable"), DT_INT32,
967 TensorShape({}));
968
969 // Outer loop
970 auto zero = ops::Const<int32>(scope.WithOpName("outer/Const"), 0);
971 auto enter_i =
972 ops::internal::Enter(scope.WithOpName("outer/Enter_i"), zero, "outer");
973 auto merge_i = ops::Merge(scope.WithOpName("outer/Merge_i"),
974 std::initializer_list<Input>{enter_i, dummy});
975 auto ten = ops::Const<int32>(scope.WithOpName("outer/Less/y")
976 .WithControlDependencies(merge_i.output),
977 10);
978 auto less_i =
979 ops::Less(scope.WithOpName("outer/Less_i"), merge_i.output, ten);
980 auto outer_loop_cond =
981 ops::LoopCond(scope.WithOpName("outer/LoopCond"), less_i);
982 auto switch_i = ops::Switch(scope.WithOpName("outer/Switch"),
983 merge_i.output, outer_loop_cond);
984 auto exit_i = ops::internal::Exit(scope.WithOpName("outer/Exit"),
985 switch_i.output_false);
986 auto identity_i =
987 ops::Identity(scope.WithOpName("outer/Identity"), switch_i.output_true);
988
989 auto enter_x_outer =
990 ops::internal::Enter(scope.WithOpName("outer/Enter_x"), x, "outer",
991 ops::internal::Enter::Attrs().IsConstant(true));
992 auto enter_k_outer =
993 ops::internal::Enter(scope.WithOpName("outer/Enter_k"), y, "outer",
994 ops::internal::Enter::Attrs().IsConstant(true));
995 auto enter_var_outer =
996 ops::internal::Enter(scope.WithOpName("outer/Enter_var"), var, "outer",
997 ops::internal::Enter::Attrs().IsConstant(true));
998
999 // Inner loop
1000 auto one_j = ops::Const<int32>(
1001 scope.WithOpName("outer/j").WithControlDependencies(identity_i), 1);
1002 auto enter_j = ops::internal::Enter(scope.WithOpName("outer/inner/Enter_j"),
1003 one_j, "inner");
1004 auto enter_k =
1005 ops::internal::Enter(scope.WithOpName("outer/inner/Enter_k")
1006 .WithControlDependencies(identity_i),
1007 enter_k_outer, "inner");
1008 auto enter_x = ops::internal::Enter(
1009 scope.WithOpName("outer/inner/Enter_x"), enter_x_outer, "inner",
1010 ops::internal::Enter::Attrs().IsConstant(true));
1011 auto enter_var = ops::internal::Enter(
1012 scope.WithOpName("outer/inner/Enter_var"), enter_var_outer, "inner",
1013 ops::internal::Enter::Attrs().IsConstant(true));
1014
1015 auto merge_j = ops::Merge(scope.WithOpName("outer/inner/Merge_j"),
1016 std::initializer_list<Input>{enter_j, dummy});
1017 auto merge_k = ops::Merge(scope.WithOpName("outer/inner/Merge_k"),
1018 std::initializer_list<Input>{enter_k, dummy});
1019
1020 auto five = ops::Const<int32>(scope.WithOpName("outer/inner/Five")
1021 .WithControlDependencies(merge_j.output),
1022 5);
1023 auto less_j =
1024 ops::Less(scope.WithOpName("outer/inner/Less_j"), merge_j.output, five);
1025 auto loop_cond =
1026 ops::LoopCond(scope.WithOpName("outer/inner/LoopCond"), less_j);
1027
1028 auto switch_j = ops::Switch(scope.WithOpName("outer/inner/Switch_j"),
1029 merge_j.output, loop_cond);
1030 auto switch_k = ops::Switch(scope.WithOpName("outer/inner/Switch_k"),
1031 merge_k.output, loop_cond);
1032 auto exit_j = ops::internal::Exit(scope.WithOpName("outer/inner/Exit_j"),
1033 switch_j.output_false);
1034 auto exit_k = ops::internal::Exit(scope.WithOpName("outer/inner/Exit_k"),
1035 switch_k.output_false);
1036 auto identity_j = ops::Identity(scope.WithOpName("outer/inner/Identity_j"),
1037 switch_j.output_true);
1038 auto identity_k = ops::Identity(scope.WithOpName("outer/inner/Identity_k"),
1039 switch_k.output_true);
1040
1041 // Variable update
1042 auto mul_jk =
1043 ops::Mul(scope.WithOpName("outer/inner/mul"), identity_j, identity_k);
1044 auto add_jkx =
1045 ops::Add(scope.WithOpName("outer/inner/add"), mul_jk, enter_x);
1046 auto assign = ops::AssignAddVariableOp(
1047 scope.WithOpName("outer/inner/assign_add"), enter_var, add_jkx);
1048
1049 auto one = ops::Const<int32>(
1050 scope.WithOpName("outer/inner/One")
1051 .WithControlDependencies(
1052 absl::Span<const Operation>{assign.operation}),
1053 1);
1054 auto add_j =
1055 ops::Add(scope.WithOpName("outer/inner/add_j"), identity_j, one);
1056
1057 auto next_iteration_j = ops::NextIteration(
1058 scope.WithOpName("outer/inner/NextIteration_j"), add_j);
1059 auto next_iteration_k = ops::NextIteration(
1060 scope.WithOpName("outer/inner/NextIteration_k"), identity_k);
1061
1062 // Body and backedge for outer loop.
1063 auto one_outer = ops::Const<int32>(
1064 scope.WithOpName("outer/add/y").WithControlDependencies(identity_i), 1);
1065 auto add_i =
1066 ops::Add(scope.WithOpName("outer/add")
1067 .WithControlDependencies(absl::Span<const Operation>{
1068 exit_j.output.op(), exit_k.output.op()}),
1069 identity_i, one_outer);
1070 auto next_iteration_i =
1071 ops::NextIteration(scope.WithOpName("outer/NextIteration"), add_i);
1072
1073 auto sink = ops::Identity(scope.WithOpName("sink"), exit_i);
1074
1075 // Remove the dummy node and add the loop backedge.
1076 scope.graph()->RemoveNode(dummy.node());
1077 scope.graph()->AddEdge(next_iteration_i.node(), 0, merge_i.output.node(),
1078 1);
1079 scope.graph()->AddEdge(next_iteration_j.node(), 0, merge_j.output.node(),
1080 1);
1081 scope.graph()->AddEdge(next_iteration_k.node(), 0, merge_k.output.node(),
1082 1);
1083
1084 TF_EXPECT_OK(scope.ToGraph(&graph));
1085 }
1086 // Add '_tpu_replicate' attributes as specified.
1087 for (Node* n : graph.nodes()) {
1088 string name = n->name();
1089 bool is_inner_node = name.find("outer/inner/") != string::npos;
1090 bool is_outer_node = !is_inner_node && name.find("outer/") != string::npos;
1091 if ((is_inner_node && mark_inner_loop_tpu_) ||
1092 (is_outer_node && mark_outer_loop_tpu_)) {
1093 n->AddAttr("_tpu_replicate", "cluster");
1094 }
1095 }
1096
1097 FunctionLibraryDefinition library(OpRegistry::Global(), {});
1098 GraphDef orig_graph_def, optimized_graph_def;
1099 graph.ToGraphDef(&orig_graph_def);
1100 optimized_graph_def = orig_graph_def;
1101 // If `restrict_to_tpu_nodes_` is true let filter function return true for
1102 // `_tpu_replicate` nodes, otherwise don't set filter.
1103 NodeFilter node_filter =
1104 restrict_to_tpu_nodes_
1105 ? [](const Node* n) { return n->attrs().Find("_tpu_replicate"); }
1106 : NodeFilter{};
1107
1108 Status status1 = FunctionalizeControlFlowForGraphDef(&optimized_graph_def,
1109 &library, node_filter);
1110 Status status2 = FunctionalizeControlFlow(&graph, &library, node_filter);
1111 ASSERT_EQ(status1, status2);
1112 if (restrict_to_tpu_nodes_ && mark_outer_loop_tpu_ && !mark_inner_loop_tpu_) {
1113 // This case violates the precondition of `FunctionalizeControlFlow`, we
1114 // expect an internal error.
1115 ASSERT_EQ(errors::IsInternal(status1), true);
1116 return;
1117 } else {
1118 // Supported cases, no error expected.
1119 TF_ASSERT_OK(status1);
1120 }
1121
1122 GraphDef optimized_converted_graph_def;
1123 graph.ToGraphDef(&optimized_converted_graph_def);
1124 for (const GraphDef& graph_def :
1125 {optimized_graph_def, optimized_converted_graph_def}) {
1126 NameAttrList inner_cond_fn, inner_body_fn;
1127 if (!restrict_to_tpu_nodes_ ||
1128 (restrict_to_tpu_nodes_ && mark_outer_loop_tpu_ &&
1129 mark_inner_loop_tpu_)) {
1130 // We expect that both inner and outer nodes have been functionalized.
1131 CheckOuterNodesFunctionalized(graph_def, library, inner_cond_fn,
1132 inner_body_fn);
1133 CheckInnerNodesFunctionalized(graph_def, library, inner_cond_fn,
1134 inner_body_fn);
1135 } else /*restrict_to_tpu_nodes_ == true*/ {
1136 if (!mark_outer_loop_tpu_ && !mark_inner_loop_tpu_) {
1137 // Graph has no TPU nodes so we expect no functionalization.
1138 TF_EXPECT_GRAPH_EQ(orig_graph_def, graph_def);
1139 } else if (!mark_outer_loop_tpu_ && mark_inner_loop_tpu_) {
1140 // We expect that only inner nodes have been functionalized.
1141 TF_EXPECT_OK(
1142 FindWhileCondAndBody(graph_def, &inner_cond_fn, &inner_body_fn));
1143 CheckInnerNodesFunctionalized(graph_def, library, inner_cond_fn,
1144 inner_body_fn);
1145 }
1146 }
1147 }
1148 }
1149
CheckOuterNodesFunctionalized(const GraphDef & graph_def,const FunctionLibraryDefinition & library,NameAttrList & inner_cond_fn,NameAttrList & inner_body_fn)1150 void ComplexTestFixture::CheckOuterNodesFunctionalized(
1151 const GraphDef& graph_def, const FunctionLibraryDefinition& library,
1152 NameAttrList& inner_cond_fn, NameAttrList& inner_body_fn) {
1153 NameAttrList outer_cond_fn, outer_body_fn;
1154 TF_EXPECT_OK(FindWhileCondAndBody(graph_def, &outer_cond_fn, &outer_body_fn));
1155
1156 // Outer graph.
1157 {
1158 Scope scope = Scope::NewRootScope().ExitOnError();
1159 auto x = ops::Placeholder(scope.WithOpName("x"), DT_INT32);
1160 auto three = ops::Const<int32>(scope.WithOpName("three"), 3);
1161 auto y = ops::Add(scope.WithOpName("y"), x, three);
1162
1163 auto var = ops::VarHandleOp(scope.WithOpName("Variable"), DT_INT32,
1164 TensorShape({}));
1165
1166 auto zero = ops::Const<int32>(scope.WithOpName("outer/Const"), 0);
1167
1168 auto while_op = ops::While(scope.WithOpName("outer/LoopCond"),
1169 std::initializer_list<Input>{zero, y, x, var},
1170 outer_cond_fn, outer_body_fn);
1171 auto sink = ops::Identity(scope.WithOpName("sink"), while_op[0]);
1172 GraphDef expected;
1173 TF_EXPECT_OK(scope.ToGraphDef(&expected));
1174 TF_EXPECT_GRAPH_EQ(expected, graph_def);
1175 }
1176
1177 // Outer condition graph.
1178 {
1179 Scope scope = Scope::NewRootScope().ExitOnError();
1180 auto arg0 = ops::_Arg(scope.WithOpName("arg0"), DT_INT32, 0);
1181 auto arg1 = ops::_Arg(scope.WithOpName("arg1"), DT_INT32, 1);
1182 auto arg2 = ops::_Arg(scope.WithOpName("arg2"), DT_INT32, 2);
1183 auto arg3 = ops::_Arg(scope.WithOpName("arg3"), DT_RESOURCE, 3);
1184
1185 auto ten = ops::Const<int32>(
1186 scope.WithOpName("outer/Less/y").WithControlDependencies(arg0.output),
1187 10);
1188 auto less = ops::Less(scope.WithOpName("outer/Less_i"), arg0, ten);
1189 auto retval = ops::_Retval(scope.WithOpName("retval0_RetVal"), less, 0);
1190
1191 GraphDef expected;
1192 TF_EXPECT_OK(scope.ToGraphDef(&expected));
1193
1194 InstantiationResultForTest result;
1195 TF_EXPECT_OK(
1196 InstantiateFunctionForTest(outer_cond_fn.name(), library, &result));
1197
1198 EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32, DT_RESOURCE}),
1199 result.arg_types);
1200 EXPECT_EQ(DataTypeVector{DT_BOOL}, result.ret_types);
1201 TF_EXPECT_GRAPH_EQ(expected, result.gdef);
1202 }
1203
1204 // Outer body graph.
1205 {
1206 InstantiationResultForTest result;
1207 TF_EXPECT_OK(
1208 InstantiateFunctionForTest(outer_body_fn.name(), library, &result));
1209
1210 // Find the inner condition and body names.
1211 TF_EXPECT_OK(
1212 FindWhileCondAndBody(result.gdef, &inner_cond_fn, &inner_body_fn));
1213
1214 Scope scope = Scope::NewRootScope().ExitOnError();
1215 auto arg0 = ops::_Arg(scope.WithOpName("arg0"), DT_INT32, 0);
1216 auto arg1 = ops::_Arg(scope.WithOpName("arg1"), DT_INT32, 1);
1217 auto arg2 = ops::_Arg(scope.WithOpName("arg2"), DT_INT32, 2);
1218 auto arg3 = ops::_Arg(scope.WithOpName("arg3"), DT_RESOURCE, 3);
1219
1220 auto identity_i = ops::Identity(scope.WithOpName("outer/Identity"), arg0);
1221 auto one_j = ops::Const<int32>(
1222 scope.WithOpName("outer/j").WithControlDependencies(identity_i), 1);
1223 auto while_op =
1224 ops::While(scope.WithOpName("outer/inner/LoopCond"),
1225 std::initializer_list<Input>{one_j, arg1, arg2, arg3},
1226 inner_cond_fn, inner_body_fn);
1227
1228 auto one_outer = ops::Const<int32>(
1229 scope.WithOpName("outer/add/y").WithControlDependencies(identity_i), 1);
1230 auto add_i =
1231 ops::Add(scope.WithOpName("outer/add")
1232 .WithControlDependencies(absl::Span<const Operation>{
1233 while_op[0].op(), while_op[1].op()}),
1234 identity_i, one_outer);
1235
1236 auto retval0 = ops::_Retval(scope.WithOpName("retval0_RetVal"), add_i, 0);
1237 auto retval1 = ops::_Retval(scope.WithOpName("retval1_RetVal"), arg1, 1);
1238 auto retval2 = ops::_Retval(scope.WithOpName("retval2_RetVal"), arg2, 2);
1239 auto retval3 = ops::_Retval(scope.WithOpName("retval3_RetVal"), arg3, 3);
1240
1241 GraphDef expected;
1242 TF_EXPECT_OK(scope.ToGraphDef(&expected));
1243
1244 EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32, DT_RESOURCE}),
1245 result.arg_types);
1246 EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32, DT_RESOURCE}),
1247 result.ret_types);
1248 TF_EXPECT_GRAPH_EQ(expected, result.gdef);
1249 }
1250 }
1251
CheckInnerNodesFunctionalized(const GraphDef & graph_def,const FunctionLibraryDefinition & library,const NameAttrList & inner_cond_fn,const NameAttrList & inner_body_fn)1252 void ComplexTestFixture::CheckInnerNodesFunctionalized(
1253 const GraphDef& graph_def, const FunctionLibraryDefinition& library,
1254 const NameAttrList& inner_cond_fn, const NameAttrList& inner_body_fn) {
1255 // Inner condition graph.
1256 {
1257 Scope scope = Scope::NewRootScope().ExitOnError();
1258 auto arg0 = ops::_Arg(scope.WithOpName("arg0"), DT_INT32, 0);
1259 auto arg1 = ops::_Arg(scope.WithOpName("arg1"), DT_INT32, 1);
1260 auto arg2 = ops::_Arg(scope.WithOpName("arg2"), DT_INT32, 2);
1261 auto arg3 = ops::_Arg(scope.WithOpName("arg3"), DT_RESOURCE, 3);
1262
1263 auto five = ops::Const<int32>(
1264 scope.WithOpName("outer/inner/Five").WithControlDependencies(arg0), 5);
1265 auto less_j = ops::Less(scope.WithOpName("outer/inner/Less_j"), arg0, five);
1266 auto retval = ops::_Retval(scope.WithOpName("retval0_RetVal"), less_j, 0);
1267
1268 GraphDef expected;
1269 TF_EXPECT_OK(scope.ToGraphDef(&expected));
1270
1271 InstantiationResultForTest result;
1272 TF_EXPECT_OK(
1273 InstantiateFunctionForTest(inner_cond_fn.name(), library, &result));
1274
1275 EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32, DT_RESOURCE}),
1276 result.arg_types);
1277 EXPECT_EQ(DataTypeVector{DT_BOOL}, result.ret_types);
1278 TF_EXPECT_GRAPH_EQ(expected, result.gdef);
1279 }
1280
1281 // Inner body graph.
1282 {
1283 Scope scope = Scope::NewRootScope().ExitOnError();
1284 auto arg0 = ops::_Arg(scope.WithOpName("arg0"), DT_INT32, 0);
1285 auto arg1 = ops::_Arg(scope.WithOpName("arg1"), DT_INT32, 1);
1286 auto arg2 = ops::_Arg(scope.WithOpName("arg2"), DT_INT32, 2);
1287 auto arg3 = ops::_Arg(scope.WithOpName("arg3"), DT_RESOURCE, 3);
1288
1289 auto identity_j =
1290 ops::Identity(scope.WithOpName("outer/inner/Identity_j"), arg0);
1291 auto identity_k =
1292 ops::Identity(scope.WithOpName("outer/inner/Identity_k"), arg1);
1293
1294 auto mul_jk =
1295 ops::Mul(scope.WithOpName("outer/inner/mul"), identity_j, identity_k);
1296 auto add_jkx = ops::Add(scope.WithOpName("outer/inner/add"), mul_jk, arg2);
1297 auto assign = ops::AssignAddVariableOp(
1298 scope.WithOpName("outer/inner/assign_add"), arg3, add_jkx);
1299
1300 auto one = ops::Const<int32>(
1301 scope.WithOpName("outer/inner/One")
1302 .WithControlDependencies(
1303 absl::Span<const Operation>{assign.operation}),
1304 1);
1305 auto add_j =
1306 ops::Add(scope.WithOpName("outer/inner/add_j"), identity_j, one);
1307
1308 auto retval0 = ops::_Retval(scope.WithOpName("retval0_RetVal"), add_j, 0);
1309 auto retval1 =
1310 ops::_Retval(scope.WithOpName("retval1_RetVal"), identity_k, 1);
1311 auto retval2 = ops::_Retval(scope.WithOpName("retval2_RetVal"), arg2, 2);
1312 auto retval3 = ops::_Retval(scope.WithOpName("retval3_RetVal"), arg3, 3);
1313
1314 GraphDef expected;
1315 TF_EXPECT_OK(scope.ToGraphDef(&expected));
1316
1317 InstantiationResultForTest result;
1318 TF_EXPECT_OK(
1319 InstantiateFunctionForTest(inner_body_fn.name(), library, &result));
1320
1321 EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32, DT_RESOURCE}),
1322 result.arg_types);
1323 EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32, DT_RESOURCE}),
1324 result.ret_types);
1325 TF_EXPECT_GRAPH_EQ(expected, result.gdef);
1326 }
1327 }
1328
1329 } // namespace
1330 } // namespace tensorflow
1331