• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "absl/container/flat_hash_map.h"
17 #include "absl/memory/memory.h"
18 #include "absl/strings/match.h"
19 #include "tensorflow/cc/framework/ops.h"
20 #include "tensorflow/cc/ops/array_ops.h"
21 #include "tensorflow/cc/ops/control_flow_ops_internal.h"
22 #include "tensorflow/cc/ops/function_ops.h"
23 #include "tensorflow/cc/ops/functional_ops.h"
24 #include "tensorflow/cc/ops/list_ops.h"
25 #include "tensorflow/cc/ops/resource_variable_ops.h"
26 #include "tensorflow/cc/ops/sendrecv_ops.h"
27 #include "tensorflow/cc/ops/standard_ops.h"
28 #include "tensorflow/compiler/jit/defs.h"
29 #include "tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.h"
30 #include "tensorflow/compiler/jit/node_matchers.h"
31 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
32 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
33 #include "tensorflow/core/common_runtime/graph_constructor.h"
34 #include "tensorflow/core/common_runtime/graph_def_builder_util.h"
35 #include "tensorflow/core/framework/node_def_util.h"
36 #include "tensorflow/core/framework/op.h"
37 #include "tensorflow/core/graph/algorithm.h"
38 #include "tensorflow/core/graph/graph_def_builder.h"
39 #include "tensorflow/core/lib/core/status_test_util.h"
40 #include "tensorflow/core/platform/test.h"
41 
42 using ::tensorflow::testing::FindNodeByName;
43 
44 namespace tensorflow {
45 namespace {
46 
__anon1ff13d790202null47 static bool Initialized = [] {
48   tensorflow::GetXlaDeviceFlags()->tf_xla_enable_xla_devices = true;
49   return true;
50 }();
51 
52 REGISTER_OP("UncompilableNullary").Output("o: float");
53 REGISTER_OP("UncompilableUnary").Input("a: float").Output("o: float");
54 
GetClusters(const Graph & graph)55 std::unordered_map<string, string> GetClusters(const Graph& graph) {
56   std::unordered_map<string, string> ids;
57   for (Node* node : graph.nodes()) {
58     string cluster;
59     if (TryGetNodeAttr(node->attrs(), kXlaClusterAttr, &cluster)) {
60       CHECK(!cluster.empty());
61       ids[node->name()] = cluster;
62     }
63   }
64 
65   if (VLOG_IS_ON(2)) {
66     VLOG(2) << "Clusters:";
67     for (const auto& p : ids) {
68       VLOG(2) << " " << p.first << " -> " << p.second;
69     }
70   }
71   return ids;
72 }
73 
GetClusterSets(const Graph & g,std::vector<string> * cluster_names=nullptr)74 absl::flat_hash_map<string, std::vector<string>> GetClusterSets(
75     const Graph& g, std::vector<string>* cluster_names = nullptr) {
76   CHECK(cluster_names == nullptr || cluster_names->empty());
77   absl::flat_hash_map<string, std::vector<string>> cluster_sets;
78   for (const auto& p : GetClusters(g)) {
79     cluster_sets[p.second].push_back(p.first);
80   }
81   for (auto& p : cluster_sets) {
82     if (cluster_names != nullptr) {
83       cluster_names->push_back(p.first);
84     }
85     std::sort(p.second.begin(), p.second.end());
86   }
87   if (cluster_names != nullptr) {
88     std::sort(cluster_names->begin(), cluster_names->end());
89   }
90   return cluster_sets;
91 }
92 
TEST(XlaCompilationTest,Chains)93 TEST(XlaCompilationTest, Chains) {
94   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
95   {
96     GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
97     Node* a =
98         ops::SourceOp("UncompilableNullary", builder.opts().WithName("A"));
99     Node* b = ops::UnaryOp("Relu", a, builder.opts().WithName("B"));
100     Node* c = ops::UnaryOp("Relu", b, builder.opts().WithName("C"));
101     Node* d =
102         ops::UnaryOp("UncompilableUnary", c, builder.opts().WithName("D"));
103     Node* e = ops::UnaryOp("Relu", d, builder.opts().WithName("E"));
104     ops::UnaryOp("Relu", e, builder.opts().WithName("F"));
105     TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
106   }
107 
108   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
109   auto clusters = GetClusters(*graph);
110   EXPECT_EQ(4, clusters.size());
111   EXPECT_EQ(clusters["B"], clusters["C"]);
112   EXPECT_EQ(clusters["E"], clusters["F"]);
113   EXPECT_NE(clusters["B"], clusters["E"]);
114   EXPECT_TRUE(clusters.find("A") == clusters.cend());
115   EXPECT_TRUE(clusters.find("D") == clusters.cend());
116 }
117 
TEST(XlaCompilationTest,UncompilableCycles)118 TEST(XlaCompilationTest, UncompilableCycles) {
119   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
120   {
121     GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
122     Node* a = ops::SourceOp("Const", builder.opts()
123                                          .WithName("A")
124                                          .WithAttr("dtype", DT_FLOAT)
125                                          .WithAttr("value", Tensor()));
126     Node* b =
127         ops::UnaryOp("UncompilableUnary", a, builder.opts().WithName("B"));
128     ops::BinaryOp("MatMul", a, b, builder.opts().WithName("C"));
129     TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
130   }
131 
132   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
133   auto clusters = GetClusters(*graph);
134 
135   EXPECT_TRUE(clusters.empty());
136 }
137 
TEST(XlaCompilationTest,CompilableCycles)138 TEST(XlaCompilationTest, CompilableCycles) {
139   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
140   {
141     GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
142     Node* a = ops::SourceOp("Const", builder.opts()
143                                          .WithName("A")
144                                          .WithAttr("dtype", DT_FLOAT)
145                                          .WithAttr("value", Tensor()));
146     Node* b = ops::UnaryOp("Relu", a, builder.opts().WithName("B"));
147     ops::BinaryOp("MatMul", a, b, builder.opts().WithName("C"));
148     TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
149   }
150 
151   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
152   auto clusters = GetClusters(*graph);
153 
154   EXPECT_EQ(3, clusters.size());
155   EXPECT_EQ(clusters["A"], clusters["B"]);
156   EXPECT_EQ(clusters["A"], clusters["C"]);
157 }
158 
TEST(XlaCompilationTest,StringUnsupported)159 TEST(XlaCompilationTest, StringUnsupported) {
160   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
161   {
162     GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
163     Node* a = ops::SourceOp(
164         "Const", builder.opts()
165                      .WithName("A")
166                      .WithAttr("dtype", DT_STRING)
167                      .WithAttr("value", Tensor(DT_STRING, TensorShape())));
168     Node* b = ops::UnaryOp("EncodeBase64", a, builder.opts().WithName("B"));
169     ops::BinaryOp("StringSplit", a, b, builder.opts().WithName("C"));
170     TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
171   }
172 
173   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
174   auto clusters = GetClusters(*graph);
175   EXPECT_TRUE(clusters.empty());
176 }
177 
TEST(XlaCompilationTest,HalfSupported)178 TEST(XlaCompilationTest, HalfSupported) {
179   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
180   {
181     GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
182     Tensor t(DT_HALF, TensorShape());
183     t.scalar<Eigen::half>()() = static_cast<Eigen::half>(0.0f);
184     Node* a = ops::SourceOp("Const", builder.opts()
185                                          .WithName("A")
186                                          .WithAttr("dtype", DT_HALF)
187                                          .WithAttr("value", t));
188     Node* b = ops::UnaryOp("Neg", a, builder.opts().WithName("B"));
189     ops::BinaryOp("MatMul", a, b, builder.opts().WithName("C"));
190     TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
191   }
192 
193   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
194   auto clusters = GetClusters(*graph);
195   EXPECT_FALSE(clusters.empty());
196 }
197 
198 // Tests that PartitionedCalls are only marked for compilation if every node
199 // inside the function can be compiled.
TEST(XlaCompilationTest,PartitionedCallUnsupported)200 TEST(XlaCompilationTest, PartitionedCallUnsupported) {
201   FunctionDef compilable = FunctionDefHelper::Define(
202       "CompilableFn", {"n_a:float", "n_b:float"}, {"n_c:float"}, {},
203       {{{"n_c"}, "Add", {"n_a", "n_b"}, {{"T", DT_FLOAT}}}});
204   FunctionDef uncompilable =
205       FunctionDefHelper::Define("UncompilableFn", {"n_a:float"}, {"n_c:float"},
206                                 {}, {{{"n_c"}, "UncompilableUnary", {"n_a"}}});
207 
208   FunctionDefLibrary flib;
209   *flib.add_function() = compilable;
210   *flib.add_function() = uncompilable;
211   FunctionLibraryDefinition flib_def(OpRegistry::Global(), flib);
212 
213   std::unique_ptr<Graph> graph(new Graph(&flib_def));
214   Scope root = Scope::NewRootScope().ExitOnError();
215   Output a = ops::Placeholder(root.WithOpName("A"), DT_FLOAT);
216 
217   NameAttrList b_name_attr;
218   b_name_attr.set_name("CompilableFn");
219   ops::PartitionedCall b(root.WithOpName("B"), {a, a}, {DT_FLOAT}, b_name_attr);
220   NameAttrList c_name_attr;
221   c_name_attr.set_name("UncompilableFn");
222 
223   ops::PartitionedCall c(root.WithOpName("C"), {a}, {DT_FLOAT}, c_name_attr);
224   Output d = ops::Add(root.WithOpName("D"), b.output.front(), c.output.front());
225 
226   TF_ASSERT_OK(root.ToGraph(graph.get()));
227   TF_ASSERT_OK(
228       MarkForCompilationPassTestHelper::MarkForCompilation(&graph, &flib_def));
229   auto clusters = GetClusters(*graph);
230 
231   EXPECT_EQ(2, clusters.size());
232   EXPECT_FALSE(clusters["B"].empty());
233   EXPECT_TRUE(clusters["C"].empty());
234   EXPECT_EQ(clusters["B"], clusters["D"]);
235 }
236 
TEST(XlaCompilationTest,FunctionCalls)237 TEST(XlaCompilationTest, FunctionCalls) {
238   FunctionDef compilable = FunctionDefHelper::Define(
239       "CompilableFn", {"n_a:float", "n_b:float"}, {"n_c:float"}, {},
240       {{{"n_c"}, "Add", {"n_a", "n_b"}, {{"T", DT_FLOAT}}}});
241   FunctionDef uncompilable =
242       FunctionDefHelper::Define("UncompilableFn", {"n_a:float"}, {"n_c:float"},
243                                 {}, {{{"n_c"}, "UncompilableUnary", {"n_a"}}});
244   FunctionDef noinline = compilable;
245   noinline.mutable_signature()->set_name("NoInlineFn");
246   AddAttr("_noinline", static_cast<bool>(true), noinline.mutable_attr());
247 
248   FunctionDefLibrary flib;
249   *flib.add_function() = compilable;
250   *flib.add_function() = uncompilable;
251   *flib.add_function() = noinline;
252   FunctionLibraryDefinition flib_def(OpRegistry::Global(), flib);
253 
254   std::unique_ptr<Graph> graph(new Graph(&flib_def));
255   {
256     GraphDefBuilder builder(GraphDefBuilder::kFailImmediately, &flib_def);
257     Node* a =
258         ops::SourceOp("UncompilableNullary", builder.opts().WithName("A"));
259     Node* b = ops::BinaryOp("CompilableFn", a, a, builder.opts().WithName("B"));
260     Node* c = ops::UnaryOp("Relu", b, builder.opts().WithName("C"));
261     ops::UnaryOp("UncompilableFn", c, builder.opts().WithName("D"));
262     ops::BinaryOp("NoInlineFn", c, c, builder.opts().WithName("E"));
263     TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
264   }
265 
266   TF_ASSERT_OK(
267       MarkForCompilationPassTestHelper::MarkForCompilation(&graph, &flib_def));
268   auto clusters = GetClusters(*graph);
269 
270   EXPECT_EQ(2, clusters.size());
271   EXPECT_FALSE(clusters["C"].empty());
272   EXPECT_EQ(clusters["C"], clusters["E"]);
273   EXPECT_TRUE(clusters.find("A") == clusters.cend());
274   EXPECT_TRUE(clusters.find("B") == clusters.cend());
275   EXPECT_TRUE(clusters.find("D") == clusters.cend());
276 }
277 
TEST(XlaCompilationTest,CallXlaDeviceFuncWithResourceOp)278 TEST(XlaCompilationTest, CallXlaDeviceFuncWithResourceOp) {
279   FunctionDef compilable = FunctionDefHelper::Define(
280       "FnWithResourceOp", {"var:resource", "val:float"}, {"retval:float"}, {},
281       {{{"assign_op"},
282         "AssignVariableOp",
283         {"var", "val"},
284         {{"dtype", DT_FLOAT}}},
285        {{"retval"}, "Identity", {"val"}, {{"T", DT_FLOAT}}, {"assign_op"}}});
286 
287   FunctionDefLibrary flib;
288   *flib.add_function() = compilable;
289   FunctionLibraryDefinition flib_def(OpRegistry::Global(), flib);
290 
291   std::unique_ptr<Graph> graph(new Graph(&flib_def));
292   {
293     GraphDefBuilder builder(GraphDefBuilder::kFailImmediately, &flib_def);
294     Node* resource =
295         ops::SourceOp("VarHandleOp", builder.opts()
296                                          .WithName("varhandle")
297                                          .WithAttr("dtype", DT_FLOAT)
298                                          .WithAttr("shape", TensorShape({})));
299 
300     Tensor const_tensor(DT_FLOAT, TensorShape({}));
301     const_tensor.scalar<float>()() = 42.0f;
302     Node* value = ops::SourceOp("Const", builder.opts()
303                                              .WithName("const")
304                                              .WithAttr("value", const_tensor)
305                                              .WithAttr("dtype", DT_FLOAT));
306 
307     Node* call = ops::BinaryOp("FnWithResourceOp", resource, value,
308                                builder.opts().WithName("A"));
309     Node* tanh0 = ops::UnaryOp("Tanh", call, builder.opts().WithName("tanh0"));
310     Node* tanh1 = ops::UnaryOp("Tanh", tanh0, builder.opts().WithName("tanh1"));
311     ops::UnaryOp("Tanh", tanh1, builder.opts().WithName("tanh2"));
312     TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
313   }
314 
315   string xla_cpu_device = "/job:worker/replica:0/task:0/device:XLA_CPU:0";
316   testing::FindNodeByName(graph.get(), "A")
317       ->set_assigned_device_name(xla_cpu_device);
318   testing::FindNodeByName(graph.get(), "tanh0")
319       ->set_assigned_device_name(xla_cpu_device);
320   testing::FindNodeByName(graph.get(), "tanh1")
321       ->set_assigned_device_name(xla_cpu_device);
322   testing::FindNodeByName(graph.get(), "tanh2")
323       ->set_assigned_device_name(xla_cpu_device);
324 
325   TF_ASSERT_OK(
326       MarkForCompilationPassTestHelper::MarkForCompilation(&graph, &flib_def));
327   auto clusters = GetClusters(*graph);
328 
329   EXPECT_NE(clusters["A"], "");
330 }
331 
GradForUnaryCwise(FunctionDef * g,std::vector<FunctionDefHelper::Node> nodes)332 static Status GradForUnaryCwise(FunctionDef* g,
333                                 std::vector<FunctionDefHelper::Node> nodes) {
334   for (auto& n : nodes) {
335     if (n.attr.empty()) {
336       n.attr = {{"T", DT_FLOAT}};
337     }
338   }
339   *g = FunctionDefHelper::Define(
340       // Arg defs
341       {"x: float", "dy: float"},
342       // Ret val defs
343       {"dx: float"},
344       // Attr defs
345       {},
346       // Nodes
347       nodes);
348   return Status::OK();
349 }
350 
351 // A gradient containing only supported operators
SupportedGrad(const AttrSlice & attrs,FunctionDef * g)352 Status SupportedGrad(const AttrSlice& attrs, FunctionDef* g) {
353   // clang-format off
354   return GradForUnaryCwise(g, {
355       {{"y"}, "Tanh", {"x"}},
356       {{"y2"}, "Square", {"y"}, {}, {"dy"}},
357       FunctionDefHelper::Const("one", 1.0f),
358       {{"a"}, "Sub", {"one", "y2"}},
359       {{"dx"}, "Mul", {"dy", "a"}},
360   });
361   // clang-format on
362 }
363 REGISTER_OP_GRADIENT("Supported", SupportedGrad);
364 
365 // A gradient containing an unsupported operator.
UnsupportedGrad(const AttrSlice & attrs,FunctionDef * g)366 Status UnsupportedGrad(const AttrSlice& attrs, FunctionDef* g) {
367   // clang-format off
368   return GradForUnaryCwise(g, {
369       {{"y"}, "Tanh", {"x"}},
370       {{"y2"}, "UncompilableUnary", {"y"}, {}, {"dy"}},
371       FunctionDefHelper::Const("one", 1.0f),
372       {{"a"}, "Sub", {"one", "y2"}},
373       {{"dx"}, "Mul", {"dy", "a"}},
374   });
375   // clang-format on
376 }
377 REGISTER_OP_GRADIENT("Unsupported", UnsupportedGrad);
378 
TEST(XlaCompilationTest,SymbolicGradients)379 TEST(XlaCompilationTest, SymbolicGradients) {
380   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
381   {
382     GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
383     Node* a =
384         ops::SourceOp("UncompilableNullary", builder.opts().WithName("A"));
385 
386     // Builds a Symbolic gradient for Supported
387     NodeBuilder b_builder("B", "SymbolicGradient",
388                           builder.opts().op_registry());
389     NameAttrList b_name_attr;
390     b_name_attr.set_name("Supported");
391     b_builder.Attr("f", b_name_attr);
392     b_builder.Attr("Tin", {DT_FLOAT, DT_FLOAT});
393     b_builder.Attr("Tout", {DT_FLOAT});
394     b_builder.Input({a, a});
395     Node* b = builder.opts().FinalizeBuilder(&b_builder);
396 
397     Node* c = ops::UnaryOp("Relu", b, builder.opts().WithName("C"));
398 
399     // Builds a Symbolic gradient for Unsupported
400     NodeBuilder d_builder("D", "SymbolicGradient",
401                           builder.opts().op_registry());
402     NameAttrList d_name_attr;
403     d_name_attr.set_name("Unsupported");
404     d_builder.Attr("f", d_name_attr);
405     d_builder.Attr("Tin", {DT_FLOAT, DT_FLOAT});
406     d_builder.Attr("Tout", {DT_FLOAT});
407     d_builder.Input({c, c});
408     builder.opts().FinalizeBuilder(&d_builder);
409 
410     TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
411   }
412 
413   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
414   auto clusters = GetClusters(*graph);
415 
416   EXPECT_EQ(2, clusters.size());
417   EXPECT_FALSE(clusters["B"].empty());
418   EXPECT_EQ(clusters["B"], clusters["C"]);
419   EXPECT_TRUE(clusters.find("A") == clusters.cend());
420   EXPECT_TRUE(clusters.find("D") == clusters.cend());
421 }
422 
TEST(XlaCompilationTest,Loops)423 TEST(XlaCompilationTest, Loops) {
424   // Regression test for b/32350199, where the autoclustering code introduced a
425   // deadlock in a graph containing a while loop.
426   Scope root = Scope::NewRootScope().ExitOnError();
427   auto a = ops::Placeholder(root.WithOpName("A"), DT_FLOAT);
428   auto b = ops::Placeholder(root.WithOpName("B"), DT_FLOAT);
429   auto c = ops::Add(root.WithOpName("C"), a, b);
430   auto enter = ops::internal::Enter(root, c, "aframe");
431   auto next_iter = ops::NextIteration(root, enter);
432   auto exit = ops::internal::Exit(root, next_iter);
433   auto d = ops::Add(root.WithOpName("D"), c, exit);
434 
435   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
436   TF_EXPECT_OK(root.ToGraph(graph.get()));
437 
438   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
439   auto clusters = GetClusters(*graph);
440 
441   // Nothing should be compiled. In particular, 'd' and 'c' must not be
442   // compiled.
443   EXPECT_EQ(0, clusters.size());
444 }
445 
TEST(XlaCompilationTest,CyclesWithAllDifferentScopesGlobalJitOverridden)446 TEST(XlaCompilationTest, CyclesWithAllDifferentScopesGlobalJitOverridden) {
447   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
448   {
449     GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
450     Node* a = ops::SourceOp("Const", builder.opts()
451                                          .WithName("A")
452                                          .WithAttr("dtype", DT_FLOAT)
453                                          .WithAttr("value", Tensor())
454                                          .WithAttr(kXlaScopeAttr, "ScopeA"));
455     Node* b = ops::UnaryOp(
456         "Relu", a,
457         builder.opts().WithName("B").WithAttr(kXlaScopeAttr, "ScopeB"));
458     ops::BinaryOp(
459         "MatMul", a, b,
460         builder.opts().WithName("C").WithAttr(kXlaScopeAttr, "ScopeC"));
461     TF_CHECK_OK(GraphDefBuilderToGraph(builder, graph.get()));
462   }
463 
464   FunctionDefLibrary flib;
465   FunctionLibraryDefinition flib_def(graph->op_registry(), flib);
466   TF_ASSERT_OK(
467       MarkForCompilationPassTestHelper::MarkForCompilation(&graph, &flib_def));
468   auto clusters = GetClusters(*graph);
469 
470   // The computation is: C = A + relu(A)
471   // where A sits in ScopeA, relu(A) sits in ScopeB, and C sits in ScopeC.
472   // In this case, the GlobalJitLevel overrides the scopes to cluster while
473   // ignoring scopes.
474   EXPECT_EQ(3, clusters.size());
475   EXPECT_EQ(clusters["A"], clusters["B"]);
476   EXPECT_EQ(clusters["A"], clusters["C"]);
477 }
478 
TEST(XlaCompilationTest,CyclesWithAllDifferentScopes)479 TEST(XlaCompilationTest, CyclesWithAllDifferentScopes) {
480   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
481   {
482     GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
483     Node* a = ops::SourceOp("Const", builder.opts()
484                                          .WithName("A")
485                                          .WithAttr("dtype", DT_FLOAT)
486                                          .WithAttr("value", Tensor())
487                                          .WithAttr(kXlaScopeAttr, "ScopeA"));
488     Node* b = ops::UnaryOp(
489         "Relu", a,
490         builder.opts().WithName("B").WithAttr(kXlaScopeAttr, "ScopeB"));
491     ops::BinaryOp(
492         "MatMul", a, b,
493         builder.opts().WithName("C").WithAttr(kXlaScopeAttr, "ScopeC"));
494     TF_CHECK_OK(GraphDefBuilderToGraph(builder, graph.get()));
495   }
496 
497   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(
498       &graph, MarkForCompilationPassTestHelper::Options().WithNoGlobalJit()));
499   auto clusters = GetClusters(*graph);
500 
501   // The computation is: C = A + relu(A)
502   // where A sits in ScopeA, relu(A) sits in ScopeB, and C sits in ScopeC.
503   // In this case, we cannot fuse anything, and there are no clusters.
504   EXPECT_EQ(0, clusters.size());
505 }
506 
TEST(XlaCompilationTest,CyclesWithSplittingScopes)507 TEST(XlaCompilationTest, CyclesWithSplittingScopes) {
508   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
509   {
510     GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
511     Node* a = ops::SourceOp("Const", builder.opts()
512                                          .WithName("A")
513                                          .WithAttr("dtype", DT_FLOAT)
514                                          .WithAttr("value", Tensor())
515                                          .WithAttr(kXlaCompileAttr, true)
516                                          .WithAttr(kXlaScopeAttr, "Scope1"));
517     Node* b = ops::UnaryOp("Relu", a,
518                            builder.opts()
519                                .WithName("B")
520                                .WithAttr(kXlaCompileAttr, true)
521                                .WithAttr(kXlaScopeAttr, "Scope1"));
522     Node* c = ops::BinaryOp("MatMul", a, b,
523                             builder.opts()
524                                 .WithName("C")
525                                 .WithAttr(kXlaCompileAttr, true)
526                                 .WithAttr(kXlaScopeAttr, "Scope2"));
527     ops::BinaryOp("Add", b, c,
528                   builder.opts()
529                       .WithName("D")
530                       .WithAttr(kXlaCompileAttr, true)
531                       .WithAttr(kXlaScopeAttr, "Scope2"));
532     TF_CHECK_OK(GraphDefBuilderToGraph(builder, graph.get()));
533   }
534 
535   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(
536       &graph, MarkForCompilationPassTestHelper::Options().WithNoGlobalJit()));
537   auto clusters = GetClusters(*graph);
538 
539   // The computation is: D = relu(A) + (A @ relu(A))
540   // where A and relu(A) are in Scope1, and the @, + ops are in Scope2.
541   // In this case, we can fuse the A and relu(A), and we can fuse the
542   // second half of the operations; there are two clusters.
543   EXPECT_EQ(4, clusters.size());
544   EXPECT_EQ(clusters["A"], clusters["B"]);
545   EXPECT_NE(clusters["A"], clusters["C"]);
546   EXPECT_EQ(clusters["C"], clusters["D"]);
547 }
548 
TEST(XlaCompilationTest,CyclesWithDifferentScopesAndBridge)549 TEST(XlaCompilationTest, CyclesWithDifferentScopesAndBridge) {
550   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
551   {
552     GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
553     Node* a = ops::SourceOp("Const", builder.opts()
554                                          .WithName("A")
555                                          .WithAttr("dtype", DT_FLOAT)
556                                          .WithAttr("value", Tensor())
557                                          .WithAttr(kXlaCompileAttr, true)
558                                          .WithAttr(kXlaScopeAttr, "ScopeA"));
559     Node* b = ops::UnaryOp("Relu", a,
560                            builder.opts()
561                                .WithName("B")
562                                .WithAttr(kXlaCompileAttr, true)
563                                .WithAttr(kXlaScopeAttr, "ScopeB"));
564     ops::BinaryOp("MatMul", a, b, builder.opts().WithName("C"));
565     TF_CHECK_OK(GraphDefBuilderToGraph(builder, graph.get()));
566   }
567 
568   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(
569       &graph, MarkForCompilationPassTestHelper::Options().WithNoGlobalJit()));
570   auto clusters = GetClusters(*graph);
571 
572   // The computation is: C = A @ relu(A)
573   // where A sits in ScopeA, relu(A) sits in ScopeB, and C sits in ScopeC.
574   // In this case, we cannot fuse anything.
575   EXPECT_EQ(3, clusters.size());
576   EXPECT_NE(clusters["A"], clusters["B"]);
577   EXPECT_EQ(clusters["B"], clusters["C"]);
578 }
579 
TEST(XlaCompilationTest,DontClusterNodesWithMismatchingDeadness)580 TEST(XlaCompilationTest, DontClusterNodesWithMismatchingDeadness) {
581   Scope root = Scope::NewRootScope().ExitOnError();
582 
583   Output cond_a = ops::Placeholder(root.WithOpName("cond_a"), DT_BOOL);
584   Output cond_b = ops::Placeholder(root.WithOpName("cond_b"), DT_BOOL);
585 
586   Output value = ops::Placeholder(root.WithOpName("value"), DT_FLOAT);
587 
588   ops::Switch switch_a(root.WithOpName("switch_a"), value, cond_a);
589   ops::Switch switch_b(root.WithOpName("switch_b"), value, cond_b);
590 
591   Output tanh_a0 = ops::Tanh(root.WithOpName("tan_a0"), switch_a.output_true);
592   Output tanh_a1 = ops::Tanh(root.WithOpName("tan_a1"), tanh_a0);
593 
594   Output tanh_b0 = ops::Tanh(root.WithOpName("tan_b0"), switch_b.output_true);
595   Output tanh_b1 = ops::Tanh(root.WithOpName("tan_b1"), tanh_b0);
596 
597   Output add = ops::Add(root.WithOpName("add"), tanh_a1, tanh_b1);
598 
599   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
600   TF_EXPECT_OK(root.ToGraph(graph.get()));
601 
602   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(
603       &graph,
604       MarkForCompilationPassTestHelper::Options().WithDeadnessAnalysis()));
605   auto clusters = GetClusters(*graph);
606 
607   EXPECT_NE(clusters["tan_a0"], "");
608   EXPECT_NE(clusters["tan_a1"], "");
609   EXPECT_NE(clusters["tan_b0"], "");
610   EXPECT_NE(clusters["tan_b1"], "");
611 
612   EXPECT_EQ(clusters["tan_a0"], clusters["tan_a1"]);
613   EXPECT_EQ(clusters["tan_b0"], clusters["tan_b1"]);
614 
615   EXPECT_NE(clusters["tan_a0"], clusters["tan_b0"]);
616 }
617 
TEST(XlaCompilationTest,ClusterNodesWithMismatchingInputDeadness)618 TEST(XlaCompilationTest, ClusterNodesWithMismatchingInputDeadness) {
619   Scope root = Scope::NewRootScope().ExitOnError();
620 
621   Output cond_a = ops::Placeholder(root.WithOpName("cond_a"), DT_BOOL);
622   Output cond_b = ops::Placeholder(root.WithOpName("cond_b"), DT_BOOL);
623 
624   Output value = ops::Placeholder(root.WithOpName("value"), DT_FLOAT);
625 
626   ops::Switch switch_a(root.WithOpName("switch_a"), value, cond_a);
627   ops::Switch switch_b(root.WithOpName("switch_b"), value, cond_b);
628 
629   Output add_a = ops::Add(root.WithOpName("add_a"), switch_a.output_true,
630                           switch_b.output_true);
631   Output add_b = ops::Add(root.WithOpName("add_b"), switch_a.output_true,
632                           switch_b.output_true);
633   Output add = ops::Add(root.WithOpName("add_c"), add_a, add_b);
634 
635   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
636   TF_EXPECT_OK(root.ToGraph(graph.get()));
637 
638   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(
639       &graph,
640       MarkForCompilationPassTestHelper::Options().WithDeadnessAnalysis()));
641   auto clusters = GetClusters(*graph);
642 
643   EXPECT_NE(clusters["add_a"], "");
644   EXPECT_NE(clusters["add_b"], "");
645   EXPECT_NE(clusters["add_c"], "");
646 
647   EXPECT_EQ(clusters["add_a"], clusters["add_b"]);
648   EXPECT_EQ(clusters["add_b"], clusters["add_c"]);
649 }
650 
651 namespace {
MakeRead(const Scope & scope,const string & id,Node ** var_handle_op=nullptr)652 Node* MakeRead(const Scope& scope, const string& id,
653                Node** var_handle_op = nullptr) {
654   Output var_handle =
655       ops::VarHandleOp(scope.WithOpName("Var" + id), DT_FLOAT, TensorShape({}));
656   Output read =
657       ops::ReadVariableOp(scope.WithOpName("Read" + id), var_handle, DT_FLOAT);
658   if (var_handle_op) {
659     *var_handle_op = var_handle.node();
660   }
661   return read.node();
662 }
663 
MakeWrite(const Scope & scope,const string & id)664 Node* MakeWrite(const Scope& scope, const string& id) {
665   Output var_handle =
666       ops::VarHandleOp(scope.WithOpName("Var" + id), DT_FLOAT, TensorShape({}));
667   Output value_to_write =
668       ops::Const(scope.WithOpName("ValueToAssign" + id), 1.0f);
669   ops::AssignVariableOp assign_op(scope.WithOpName("Assignment" + id),
670                                   var_handle, value_to_write);
671   return assign_op.operation.node();
672 }
673 
MakeNeutral(const Scope & scope,const string & id)674 Node* MakeNeutral(const Scope& scope, const string& id) {
675   return ops::Const(scope.WithOpName("Const" + id), 42.0f).node();
676 }
677 }  // namespace
678 
TEST(XlaCompilationTest,ResourcesClusteringAllowed)679 TEST(XlaCompilationTest, ResourcesClusteringAllowed) {
680   Scope root = Scope::NewRootScope().ExitOnError();
681 
682   Node* read = MakeRead(root, "R");
683   Node* write = MakeWrite(root, "W");
684 
685   root.graph()->AddControlEdge(read, write);
686 
687   FixupSourceAndSinkEdges(root.graph());
688   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
689   TF_EXPECT_OK(root.ToGraph(graph.get()));
690   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
691   absl::flat_hash_map<string, std::vector<string>> cluster_sets =
692       GetClusterSets(*graph);
693   ASSERT_EQ(cluster_sets.size(), 1);
694   std::vector<string> expected_clustered_nodes = {"AssignmentW", "ReadR",
695                                                   "ValueToAssignW"};
696   ASSERT_EQ(cluster_sets.begin()->second, expected_clustered_nodes);
697 }
698 
TEST(XlaCompilationTest,ResourcesClusteringDisallowed)699 TEST(XlaCompilationTest, ResourcesClusteringDisallowed) {
700   Scope root = Scope::NewRootScope().ExitOnError();
701 
702   Node* read = MakeRead(root, "R");
703   Node* write = MakeWrite(root, "W");
704 
705   root.graph()->AddControlEdge(write, read);
706 
707   FixupSourceAndSinkEdges(root.graph());
708   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
709   TF_EXPECT_OK(root.ToGraph(graph.get()));
710   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
711   absl::flat_hash_map<string, std::vector<string>> cluster_sets =
712       GetClusterSets(*graph);
713   ASSERT_EQ(cluster_sets.size(), 0);
714 }
715 
TEST(XlaCompilationTest,ChainOfOps)716 TEST(XlaCompilationTest, ChainOfOps) {
717   Scope root = Scope::NewRootScope().ExitOnError();
718 
719   Node* write_0 = MakeWrite(root, "W0");
720   Node* neutral_0 = MakeNeutral(root, "N0");
721   Node* read_0 = MakeRead(root, "R0");
722   Node* write_1 = MakeWrite(root, "W1");
723   Node* neutral_1 = MakeNeutral(root, "N1");
724   Node* read_1 = MakeRead(root, "R1");
725 
726   root.graph()->AddControlEdge(write_0, neutral_0);
727   root.graph()->AddControlEdge(neutral_0, read_0);
728   root.graph()->AddControlEdge(read_0, write_1);
729   root.graph()->AddControlEdge(write_1, neutral_1);
730   root.graph()->AddControlEdge(neutral_1, read_1);
731 
732   FixupSourceAndSinkEdges(root.graph());
733   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
734   TF_EXPECT_OK(root.ToGraph(graph.get()));
735   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
736 
737   std::vector<string> cluster_names;
738   absl::flat_hash_map<string, std::vector<string>> cluster_sets =
739       GetClusterSets(*graph, &cluster_names);
740 
741   ASSERT_EQ(cluster_sets.size(), 1);
742 
743   std::vector<string> expected_clustered_nodes_a = {
744       "AssignmentW1", "ConstN0", "ReadR0", "ValueToAssignW1"};
745   ASSERT_EQ(cluster_sets[cluster_names[0]], expected_clustered_nodes_a);
746 }
747 
TEST(XlaCompilationTest,IllegalCycle_UsefulErrorMessage)748 TEST(XlaCompilationTest, IllegalCycle_UsefulErrorMessage) {
749   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
750   Scope root = Scope::NewRootScope().ExitOnError();
751   {
752     auto BuildNoopNode = [](absl::string_view name, Graph* graph) {
753       NodeDefBuilder builder(name, "NoOp");
754       NodeDef def;
755       TF_CHECK_OK(builder.Finalize(&def));
756 
757       Status status;
758       Node* node = graph->AddNode(def, &status);
759       TF_CHECK_OK(status);
760       return node;
761     };
762 
763     Node* a = BuildNoopNode("a", graph.get());
764     Node* b = BuildNoopNode("b", graph.get());
765     Node* c = BuildNoopNode("c", graph.get());
766     graph->AddControlEdge(a, b);
767     graph->AddControlEdge(b, c);
768     graph->AddControlEdge(c, a);
769   }
770 
771   TF_EXPECT_OK(root.ToGraph(graph.get()));
772 
773   Status status = MarkForCompilationPassTestHelper::MarkForCompilation(&graph);
774   EXPECT_FALSE(status.ok());
775   EXPECT_TRUE(absl::StrContains(status.ToString(),
776                                 "Edge from c to a would create a cycle.\n"
777                                 "+-> a\n"
778                                 "|   b\n"
779                                 "+-- c\n"));
780 }
781 
TEST(XlaCompilationTest,Retval)782 TEST(XlaCompilationTest, Retval) {
783   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
784   {
785     GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
786     Node* a = ops::SourceOp("Const", builder.opts()
787                                          .WithName("A")
788                                          .WithAttr("dtype", DT_FLOAT)
789                                          .WithAttr("value", Tensor()));
790     Node* b = ops::UnaryOp("Relu", a, builder.opts().WithName("B"));
791     ops::UnaryOp("_Retval", b,
792                  builder.opts()
793                      .WithName("R")
794                      .WithAttr("T", DT_FLOAT)
795                      .WithAttr("index", 0));
796 
797     TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
798   }
799 
800   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
801   auto clusters = GetClusters(*graph);
802 
803   EXPECT_TRUE(clusters.empty());
804 }
805 
TEST(XlaCompilationTest,DontCountIdentityOps)806 TEST(XlaCompilationTest, DontCountIdentityOps) {
807   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
808   Scope root = Scope::NewRootScope().ExitOnError();
809   {
810     auto a = ops::_Arg(root.WithOpName("A"), DT_INT32, 0);
811     auto b = ops::Identity(root.WithOpName("B"), a);
812     auto c = ops::Identity(root.WithOpName("C"), b);
813     auto r = ops::_Retval(root.WithOpName("R"), c, 0);
814   }
815   TF_ASSERT_OK(root.ToGraph(graph.get()));
816   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
817   auto clusters = GetClusters(*graph);
818 
819   EXPECT_TRUE(clusters.empty());
820 }
821 
TEST(XlaCompilationTest,ConstOp)822 TEST(XlaCompilationTest, ConstOp) {
823   // valid data type
824   {
825     std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
826     Scope root = Scope::NewRootScope().ExitOnError();
827     auto c = ops::Const(root.WithOpName("const"), 0.5f);
828     c.node()->AddAttr(kXlaCompileAttr, true);
829     TF_ASSERT_OK(root.ToGraph(graph.get()));
830     TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
831     EXPECT_EQ(1, GetClusters(*graph).size());
832   }
833 
834   // invalid data type
835   {
836     std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
837     Scope root = Scope::NewRootScope().ExitOnError();
838     auto c = ops::Const(root.WithOpName("const"), string("string"));
839     c.node()->AddAttr(kXlaCompileAttr, true);
840     TF_ASSERT_OK(root.ToGraph(graph.get()));
841     TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
842     EXPECT_TRUE(GetClusters(*graph).empty());
843   }
844 }
845 
TEST(XlaCompilationTest,DontClusterIdentityWithRefInput)846 TEST(XlaCompilationTest, DontClusterIdentityWithRefInput) {
847   Scope root = Scope::NewRootScope().ExitOnError();
848   Output variable = ops::Variable(root.WithOpName("variable"),
849                                   PartialTensorShape{}, DT_FLOAT);
850   Output read = ops::Identity(root.WithOpName("read"), variable);
851   Output neg = ops::Negate(root.WithOpName("negate"), read);
852   Output add = ops::Add(root.WithOpName("add"), neg, neg);
853   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
854 
855   TF_ASSERT_OK(root.ToGraph(graph.get()));
856   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
857 
858   std::unordered_map<string, string> clusters = GetClusters(*graph);
859 
860   ASSERT_FALSE(clusters.empty());
861   string cluster_name = clusters.begin()->second;
862 
863   std::unordered_map<string, string> expected_clusters(
864       {{"negate", cluster_name}, {"add", cluster_name}});
865   EXPECT_EQ(clusters, expected_clusters);
866 }
867 
TEST(XlaCompilationTest,ClusterIdentityWithNonRefInput)868 TEST(XlaCompilationTest, ClusterIdentityWithNonRefInput) {
869   Scope root = Scope::NewRootScope().ExitOnError();
870   Output variable = ops::Variable(root.WithOpName("variable"),
871                                   PartialTensorShape{}, DT_FLOAT);
872   Output read = ops::Identity(root.WithOpName("read"), variable);
873   Output neg = ops::Negate(root.WithOpName("negate"), read);
874   Output identity = ops::Negate(root.WithOpName("identity"), neg);
875   Output add = ops::Add(root.WithOpName("add"), identity, neg);
876   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
877 
878   TF_ASSERT_OK(root.ToGraph(graph.get()));
879   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
880 
881   std::unordered_map<string, string> clusters = GetClusters(*graph);
882 
883   ASSERT_FALSE(clusters.empty());
884   string cluster_name = clusters.begin()->second;
885 
886   std::unordered_map<string, string> expected_clusters(
887       {{"negate", cluster_name},
888        {"identity", cluster_name},
889        {"add", cluster_name}});
890   EXPECT_EQ(clusters, expected_clusters);
891 }
892 
TEST(XlaCompilationTest,ClusterControlTrigger)893 TEST(XlaCompilationTest, ClusterControlTrigger) {
894   Scope root = Scope::NewRootScope().ExitOnError();
895 
896   Output recv_a = ops::_Recv(root.WithOpName("recv_a"), DT_BOOL, "tensor_a",
897                              "sender", 0, "receiver");
898   Output recv_b = ops::_Recv(root.WithOpName("recv_b"), DT_BOOL, "tensor_b",
899                              "sender", 0, "receiver");
900   Output const_a = ops::Const(root.WithOpName("const_a"), 42);
901 
902   ops::ControlTrigger ctrl_trigger_a(root.WithOpName("ctrl_trigger_a"));
903   ops::ControlTrigger ctrl_trigger_b(root.WithOpName("ctrl_trigger_b"));
904   root.graph()->AddControlEdge(recv_a.node(), ctrl_trigger_a.operation.node());
905   root.graph()->AddControlEdge(recv_b.node(), ctrl_trigger_a.operation.node());
906   root.graph()->AddControlEdge(ctrl_trigger_b.operation.node(), const_a.node());
907 
908   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
909 
910   TF_ASSERT_OK(root.ToGraph(graph.get()));
911   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
912 
913   std::unordered_map<string, string> clusters = GetClusters(*graph);
914 
915   // TODO(b/118970344): ctrl_trigger_a has inputs with mismatching deadness so
916   // it won't be clustered.  ctrl_trigger_b is okay to cluster but we don't
917   // cluster it because of b/118970344.
918   EXPECT_TRUE(clusters.empty());
919 }
920 
TEST(XlaCompilationTest,RandomShape)921 TEST(XlaCompilationTest, RandomShape) {
922   Scope root = Scope::NewRootScope().ExitOnError();
923   Output shape_shape = ops::Const(root.WithOpName("shape_shape"), {2}, {1});
924   Output shape =
925       ops::RandomUniformInt(root.WithOpName("shape"), shape_shape,
926                             ops::Const(root.WithOpName("minval"), 1),
927                             ops::Const(root.WithOpName("maxval"), 20));
928   Output reshape_input =
929       ops::Placeholder(root.WithOpName("reshape_input"), DT_FLOAT,
930                        ops::Placeholder::Shape(TensorShape({500, 500})));
931   Output reshape =
932       ops::Reshape(root.WithOpName("reshape"), reshape_input, shape);
933 
934   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
935 
936   TF_ASSERT_OK(root.ToGraph(graph.get()));
937   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
938 
939   std::unordered_map<string, string> clusters = GetClusters(*graph);
940   EXPECT_EQ(clusters["shape"], "");
941 }
942 
TEST(XlaCompilationTest,RandomShapeWithFunc)943 TEST(XlaCompilationTest, RandomShapeWithFunc) {
944   Scope root = Scope::DisabledShapeInferenceScope().ExitOnError();
945 
946   FunctionDefLibrary flib_def;
947   FunctionDef func = FunctionDefHelper::Create(
948       /*function_name=*/"Stateful_func", /*in_def=*/{},
949       /*out_def=*/{"out: int32"},
950       /*attr_def*/
951       {}, /*node_def=*/
952       {FunctionDefHelper::Const("shape_shape", 2),
953        FunctionDefHelper::Const("minval", 1),
954        FunctionDefHelper::Const("maxval", 20),
955        {{"shape"},
956         "RandomUniformInt",
957         {"shape_shape:output:0", "minval:output:0", "maxval:output:0"},
958         {{"Tout", DataType::DT_INT32}, {"T", DataType::DT_INT32}}}},
959       /*ret_def=*/{{"out", "shape:output:0"}});
960 
961   func.mutable_signature()->set_is_stateful(true);
962   *flib_def.add_function() = std::move(func);
963   TF_ASSERT_OK(root.graph()->AddFunctionLibrary(flib_def));
964   NodeDef call_node;
965   call_node.set_name("fn_call");
966   call_node.set_op("Stateful_func");
967   Status status;
968   Node* call = root.graph()->AddNode(call_node, &status);
969   TF_ASSERT_OK(status);
970 
971   Output shape = Output(call, 0);
972   Output reshape_input =
973       ops::Placeholder(root.WithOpName("reshape_input"), DT_FLOAT,
974                        ops::Placeholder::Shape(TensorShape({500, 500})));
975   Output reshape =
976       ops::Reshape(root.WithOpName("reshape"), reshape_input, shape);
977 
978   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
979   TF_ASSERT_OK(root.ToGraph(graph.get()));
980   auto fld = absl::make_unique<FunctionLibraryDefinition>(OpRegistry::Global(),
981                                                           flib_def);
982   TF_ASSERT_OK(
983       MarkForCompilationPassTestHelper::MarkForCompilation(&graph, fld.get()));
984 
985   std::unordered_map<string, string> clusters = GetClusters(*graph);
986   EXPECT_EQ(clusters["fn_call"], "");
987 }
988 
TEST(XlaCompilationTest,RandomShapeOnXlaDevice)989 TEST(XlaCompilationTest, RandomShapeOnXlaDevice) {
990   absl::string_view xla_gpu_device =
991       "/job:worker/replica:0/task:0/device:XLA_GPU:0";
992 
993   Scope root = Scope::NewRootScope().ExitOnError();
994   Output shape_shape =
995       ops::Const(root.WithOpName("test/shape_shape"), {2}, {1});
996   Output shape =
997       ops::RandomUniformInt(root.WithOpName("test/shape_rng"), shape_shape,
998                             ops::Const(root.WithOpName("test/minval"), 1),
999                             ops::Const(root.WithOpName("test/maxval"), 20));
1000   Output reshape_input =
1001       ops::Placeholder(root.WithOpName("test/reshape_input"), DT_FLOAT,
1002                        ops::Placeholder::Shape(TensorShape({500, 500})));
1003   Output reshape =
1004       ops::Reshape(root.WithOpName("test/reshape"), reshape_input, shape);
1005 
1006   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1007   TF_ASSERT_OK(root.ToGraph(graph.get()));
1008 
1009   for (Node* n : graph->nodes()) {
1010     if (absl::StartsWith(n->name(), /*prefix=*/"test/")) {
1011       n->set_assigned_device_name(string(xla_gpu_device));
1012     }
1013   }
1014   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
1015 
1016   std::unordered_map<string, string> clusters = GetClusters(*graph);
1017   EXPECT_EQ(clusters["test/shape_rng"], "");
1018   EXPECT_EQ(clusters["test/reshape"], "");
1019 }
1020 
TEST(XlaCompilationTest,TensorArrayShapeOnXlaDevice)1021 TEST(XlaCompilationTest, TensorArrayShapeOnXlaDevice) {
1022   absl::string_view xla_gpu_device =
1023       "/job:worker/replica:0/task:0/device:XLA_GPU:0";
1024   Scope root = Scope::NewRootScope().ExitOnError();
1025   ops::TensorArray tensor_array(root.WithOpName("test/tensor_array"), 1,
1026                                 DT_INT32);
1027   Output zero = ops::Const(root.WithOpName("test/zero"), 0);
1028   ops::TensorArrayWrite tensor_array_write(
1029       root.WithOpName("test/write"), tensor_array.handle, zero,
1030       ops::Const(root.WithOpName("test/forty_two"), 42.0f), tensor_array.flow);
1031   Output tensor_array_read =
1032       ops::TensorArrayRead(root.WithOpName("test/read"), tensor_array.handle,
1033                            zero, tensor_array_write.flow_out, DT_INT32);
1034   Output reshape =
1035       ops::Reshape(root.WithOpName("test/reshape"),
1036                    ops::Placeholder(root.WithOpName("placeholder"), DT_FLOAT),
1037                    tensor_array_read);
1038 
1039   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1040   TF_ASSERT_OK(root.ToGraph(graph.get()));
1041 
1042   for (Node* n : graph->nodes()) {
1043     if (absl::StartsWith(n->name(), /*prefix=*/"test/")) {
1044       n->set_assigned_device_name(string(xla_gpu_device));
1045     }
1046   }
1047   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
1048 
1049   std::unordered_map<string, string> clusters = GetClusters(*graph);
1050   EXPECT_NE(clusters["test/read"], "");
1051   EXPECT_EQ(clusters["test/read"], clusters["test/reshape"]);
1052 }
1053 
TEST(XlaCompilationTest,DontClusterMergingNodes)1054 TEST(XlaCompilationTest, DontClusterMergingNodes) {
1055   // MatMulCombined below takes data from nodes on GPU0 and GPU1 and is placed
1056   // on GPU1. However, it should not be clustered with the previous node on
1057   // GPU1, because that will serialize production of its inputs that should be
1058   // done in parallel.
1059   //
1060   // This graph is:
1061   // (Const0, Const0) -> MatMul0
1062   // (Const1, Const1) -> MatMul1
1063   // (MatMul0, MatMul1) -> MatMulCombined
1064   //
1065   // Device0: [Const0, Const0, MatMul0]
1066   // Device1: [Const1, Const1, MatMul1, MatMulCombined]
1067   //
1068   // Cluster0: [Const0, Const0, MatMul0]
1069   // Cluster1: [Const1, Const1, MatMul1]
1070   // Cluster2: [MatMulCombined]
1071   Scope root = Scope::NewRootScope().ExitOnError();
1072   absl::string_view xla_gpu_dev0 =
1073       "/job:worker/replica:0/task:0/device:XLA_GPU:0";
1074   absl::string_view xla_gpu_dev1 =
1075       "/job:worker/replica:0/task:0/device:XLA_GPU:1";
1076   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1077   Output a = ops::Tanh(root.WithOpName("tanh_A_dev0"),
1078                        ops::Const(root.WithOpName("A_dev0"), 1.0f, {2, 2}));
1079   Output b = ops::Tanh(root.WithOpName("tanh_B_dev1"),
1080                        ops::Const(root.WithOpName("B_dev1"), 1.0f, {2, 2}));
1081   Output matmul0 = ops::MatMul(root.WithOpName("MatMul0_dev0"), a, a);
1082   Output matmul1 = ops::MatMul(root.WithOpName("MatMul1_dev1"), b, b);
1083 
1084   Output combined =
1085       ops::MatMul(root.WithOpName("MatMulCombined_dev1"), matmul0, matmul1);
1086   TF_ASSERT_OK(root.ToGraph(graph.get()));
1087 
1088   for (Node* n : graph->nodes()) {
1089     if (absl::EndsWith(n->name(), /*suffix=*/"dev0")) {
1090       n->set_assigned_device_name(string(xla_gpu_dev0));
1091     } else if (absl::EndsWith(n->name(), /*suffix=*/"dev1")) {
1092       n->set_assigned_device_name(string(xla_gpu_dev1));
1093     }
1094   }
1095   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
1096 
1097   // Each of the MatMuls should be in a separate cluster.
1098   std::unordered_map<string, string> clusters = GetClusters(*graph);
1099   EXPECT_NE(clusters["MatMul0_dev0"], clusters["MatMul1_dev1"]);
1100   EXPECT_NE(clusters["MatMulCombined_dev1"], clusters["MatMul0_dev0"]);
1101   EXPECT_NE(clusters["MatMulCombined_dev1"], clusters["MatMul1_dev1"]);
1102   EXPECT_EQ(clusters["A_dev0"], clusters["MatMul0_dev0"]);
1103   EXPECT_EQ(clusters["B_dev1"], clusters["MatMul1_dev1"]);
1104 }
1105 
TEST(XlaCompilationTest,DontClusterMergingNodesOnCPU)1106 TEST(XlaCompilationTest, DontClusterMergingNodesOnCPU) {
1107   // This is similar to the 'DontClusterMergingNodes' above, except
1108   // MatMulCombined is placed on the CPU.
1109   Scope root = Scope::NewRootScope().ExitOnError();
1110   absl::string_view xla_gpu_dev0 = "/job:worker/replica:0/task:0/device:GPU:0";
1111   absl::string_view xla_gpu_dev1 = "/job:worker/replica:0/task:0/device:GPU:1";
1112   absl::string_view xla_cpu_dev0 = "/job:worker/replica:0/task:0/device:CPU:0";
1113   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1114   Output a = ops::Tanh(root.WithOpName("tanh_A_dev0"),
1115                        ops::Const(root.WithOpName("A_dev0"), 1.0f, {2, 2}));
1116   Output b = ops::Tanh(root.WithOpName("tanh_B_dev1"),
1117                        ops::Const(root.WithOpName("B_dev1"), 1.0f, {2, 2}));
1118   Output matmul0 = ops::MatMul(root.WithOpName("MatMul0_dev0"), a, a);
1119   Output matmul1 = ops::MatMul(root.WithOpName("MatMul1_dev1"), b, b);
1120 
1121   Output combined =
1122       ops::MatMul(root.WithOpName("MatMulCombined_cpu"), matmul0, matmul1);
1123   TF_ASSERT_OK(root.ToGraph(graph.get()));
1124 
1125   for (Node* n : graph->nodes()) {
1126     if (absl::EndsWith(n->name(), /*suffix=*/"cpu")) {
1127       n->set_assigned_device_name(string(xla_cpu_dev0));
1128     } else if (absl::EndsWith(n->name(), /*suffix=*/"dev0")) {
1129       n->set_assigned_device_name(string(xla_gpu_dev0));
1130     } else if (absl::EndsWith(n->name(), /*suffix=*/"dev1")) {
1131       n->set_assigned_device_name(string(xla_gpu_dev1));
1132     }
1133   }
1134   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
1135 
1136   // Each of the MatMuls should be in a separate cluster.
1137   std::unordered_map<string, string> clusters = GetClusters(*graph);
1138   EXPECT_NE(clusters["MatMul0_dev0"], clusters["MatMul1_dev1"]);
1139   EXPECT_NE(clusters["MatMulCombined_cpu"], clusters["MatMul0_dev0"]);
1140   EXPECT_NE(clusters["MatMulCombined_cpu"], clusters["MatMul1_dev1"]);
1141   EXPECT_EQ(clusters["A_dev0"], clusters["MatMul0_dev0"]);
1142   EXPECT_EQ(clusters["B_dev1"], clusters["MatMul1_dev1"]);
1143 }
1144 
1145 // TODO(b/117085735): This form of clustering should be prevented.
TEST(XlaCompilationTest,NOT_DontClusterSpreadingNodes)1146 TEST(XlaCompilationTest, NOT_DontClusterSpreadingNodes) {
1147   // MatMulSource below creates data for nodes on GPU0 and GPU1 and is placed
1148   // on GPU0. However, it should not be clustered with the next node on
1149   // GPU0, because that will prevent the node on GPU1 from beginning its work as
1150   // soon as the data has been produced.
1151   //
1152   // This graph is:
1153   // (Const0, Const0) -> MatMulSource
1154   // MatMulSource -> (MatMul0, MatMul1)
1155   //
1156   // Device0: [Const0, Const1, MatMulSource, MatMul0]
1157   // Device1: [MatMul1]
1158   //
1159   // Cluster0: [Const0, Const1, MatMulSource]
1160   // Cluster1: [MatMul0]
1161   // Cluster2: [MatMul1]
1162   Scope root = Scope::NewRootScope().ExitOnError();
1163   absl::string_view xla_gpu_dev0 =
1164       "/job:worker/replica:0/task:0/device:XLA_GPU:0";
1165   absl::string_view xla_gpu_dev1 =
1166       "/job:worker/replica:0/task:0/device:XLA_GPU:1";
1167   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1168   Output a = ops::Const(root.WithOpName("A_dev0"), 1.0f, {2, 2});
1169   Output matmul_source =
1170       ops::MatMul(root.WithOpName("MatMulSource_dev0"), a, a);
1171 
1172   Output matmul0 = ops::MatMul(root.WithOpName("MatMul0_dev0"), matmul_source,
1173                                matmul_source);
1174   Output matmul1 = ops::MatMul(root.WithOpName("MatMul1_dev1"), matmul_source,
1175                                matmul_source);
1176 
1177   TF_ASSERT_OK(root.ToGraph(graph.get()));
1178   for (Node* n : graph->nodes()) {
1179     if (absl::EndsWith(n->name(), /*suffix=*/"dev0")) {
1180       n->set_assigned_device_name(string(xla_gpu_dev0));
1181     } else if (absl::EndsWith(n->name(), /*suffix=*/"dev1")) {
1182       n->set_assigned_device_name(string(xla_gpu_dev1));
1183     }
1184   }
1185   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
1186 
1187   std::unordered_map<string, string> clusters = GetClusters(*graph);
1188   EXPECT_EQ(clusters["A_dev0"], clusters["MatMulSource_dev0"]);
1189   EXPECT_NE(clusters["MatMul0_dev0"], clusters["MatMul1_dev1"]);
1190   EXPECT_NE(clusters["MatMulSource_dev0"], clusters["MatMul1_dev1"]);
1191 
1192   // Improved Heuristics should prevent this probably.
1193   EXPECT_EQ(clusters["MatMulSource_dev0"], clusters["MatMul0_dev0"]);
1194 }
1195 
TEST(XlaCompilationTest,ClusterStatefulRandomOpOnXlaDevice)1196 TEST(XlaCompilationTest, ClusterStatefulRandomOpOnXlaDevice) {
1197   absl::string_view xla_cpu_device =
1198       "/job:worker/replica:0/task:0/device:XLA_CPU:0";
1199 
1200   Scope root = Scope::NewRootScope().ExitOnError();
1201   Output shape = ops::Const(root.WithOpName("test/shape_shape"), {200, 200});
1202   Output a = ops::RandomUniform(root.WithOpName("test/a"), shape, DT_FLOAT);
1203   Output b = ops::RandomUniform(root.WithOpName("test/b"), shape, DT_FLOAT);
1204   Output c = ops::Add(root.WithOpName("test/c"), a, b);
1205 
1206   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1207   TF_ASSERT_OK(root.ToGraph(graph.get()));
1208 
1209   for (Node* n : graph->nodes()) {
1210     if (absl::StartsWith(n->name(), /*prefix=*/"test/")) {
1211       n->set_assigned_device_name(string(xla_cpu_device));
1212     }
1213   }
1214   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
1215 
1216   std::unordered_map<string, string> clusters = GetClusters(*graph);
1217   EXPECT_NE(clusters["test/a"], "");
1218   EXPECT_NE(clusters["test/b"], "");
1219   EXPECT_NE(clusters["test/c"], "");
1220 }
1221 
TEST(XlaCompilationTest,DontAutoClusterStatefulRandomOp)1222 TEST(XlaCompilationTest, DontAutoClusterStatefulRandomOp) {
1223   Scope root = Scope::NewRootScope().ExitOnError();
1224   Output shape = ops::Const(root.WithOpName("test/shape_shape"), {200, 200});
1225   Output a = ops::RandomUniform(root.WithOpName("test/a"), shape, DT_FLOAT);
1226   Output b = ops::RandomUniform(root.WithOpName("test/b"), shape, DT_FLOAT);
1227   Output c = ops::Add(root.WithOpName("test/c"), a, b);
1228 
1229   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1230   TF_ASSERT_OK(root.ToGraph(graph.get()));
1231 
1232   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
1233 
1234   std::unordered_map<string, string> clusters = GetClusters(*graph);
1235   EXPECT_EQ(clusters["test/a"], "");
1236   EXPECT_EQ(clusters["test/b"], "");
1237 }
1238 
TEST(XlaCompilationTest,ClusterDummyOpsOnXlaDevice)1239 TEST(XlaCompilationTest, ClusterDummyOpsOnXlaDevice) {
1240   absl::string_view xla_cpu_device =
1241       "/job:worker/replica:0/task:0/device:XLA_CPU:0";
1242 
1243   Scope root = Scope::NewRootScope().ExitOnError();
1244   Output a = ops::Placeholder(root.WithOpName("test/a"), DT_FLOAT);
1245   Output b = ops::Placeholder(root.WithOpName("test/b"), DT_FLOAT);
1246   Output check =
1247       ops::CheckNumerics(root.WithOpName("test/check"), a, "test/check");
1248   Output ge = ops::GreaterEqual(root.WithOpName("test/greaterequal"), check, b);
1249   Operation assert = ops::Assert(root.WithOpName("test/assert"), ge, {a, b});
1250 
1251   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1252   TF_ASSERT_OK(root.ToGraph(graph.get()));
1253 
1254   for (Node* n : graph->nodes()) {
1255     if (absl::StartsWith(n->name(), /*prefix=*/"test/")) {
1256       n->set_assigned_device_name(string(xla_cpu_device));
1257     }
1258   }
1259   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
1260 
1261   std::unordered_map<string, string> clusters = GetClusters(*graph);
1262   EXPECT_NE(clusters["test/check"], "");
1263   EXPECT_NE(clusters["test/greaterequal"], "");
1264   EXPECT_NE(clusters["test/assert"], "");
1265 }
1266 
TEST(XlaCompilationTest,DontAutoClusterDummyOps)1267 TEST(XlaCompilationTest, DontAutoClusterDummyOps) {
1268   Scope root = Scope::NewRootScope().ExitOnError();
1269   Output a = ops::Placeholder(root.WithOpName("test/a"), DT_FLOAT);
1270   Output b = ops::Placeholder(root.WithOpName("test/b"), DT_FLOAT);
1271   Output check =
1272       ops::CheckNumerics(root.WithOpName("test/check"), a, "test/check");
1273   Output ge = ops::GreaterEqual(root.WithOpName("test/greaterequal"), check, b);
1274   Operation assert = ops::Assert(root.WithOpName("test/assert"), ge, {a, b});
1275 
1276   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1277   TF_ASSERT_OK(root.ToGraph(graph.get()));
1278 
1279   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
1280 
1281   std::unordered_map<string, string> clusters = GetClusters(*graph);
1282   EXPECT_EQ(clusters["test/assert"], "");
1283   EXPECT_EQ(clusters["test/check"], "");
1284 }
1285 
TEST(XlaCompilationTest,DontAutoClusterOpsProducingVariant)1286 TEST(XlaCompilationTest, DontAutoClusterOpsProducingVariant) {
1287   Scope root = Scope::NewRootScope().ExitOnError();
1288   Output a = ops::Placeholder(root.WithOpName("test/a"), DT_INT64);
1289   Output b = ops::Placeholder(root.WithOpName("test/b"), DT_INT64);
1290 
1291   Output cast_a = ops::Cast(root.WithOpName("test/cast_a"), a, DT_INT32);
1292   Output cast_b = ops::Cast(root.WithOpName("test/cast_b"), b, DT_INT32);
1293 
1294   Output tensor_list_reserve = ops::TensorListReserve(
1295       root.WithOpName("test/tensor_list_reserve"), cast_a, cast_b, DT_FLOAT);
1296 
1297   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1298   TF_ASSERT_OK(root.ToGraph(graph.get()));
1299 
1300   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
1301 
1302   std::unordered_map<string, string> clusters = GetClusters(*graph);
1303   EXPECT_EQ(clusters["test/tensor_list_reserve"], "");
1304 }
1305 
TEST(XlaCompilationTest,DontAutoClusterOpsConsumingVariant)1306 TEST(XlaCompilationTest, DontAutoClusterOpsConsumingVariant) {
1307   Scope root = Scope::NewRootScope().ExitOnError();
1308   Output dummy_input =
1309       ops::Placeholder(root.WithOpName("test/dummy_input"), DT_INT64);
1310   Output variant_input =
1311       ops::Placeholder(root.WithOpName("test/variant_input"), DT_VARIANT);
1312 
1313   // Create one more node so that we don't avoid creating a cluster solely
1314   // because it would be trivial.
1315   Output dummy_cast =
1316       ops::Cast(root.WithOpName("test/dummy_cast"), dummy_input, DT_INT32);
1317 
1318   Output tensor_list_element_shape = ops::TensorListElementShape(
1319       root.WithOpName("test/tensor_list_element_shape"), variant_input,
1320       DT_INT32);
1321 
1322   root.graph()->AddControlEdge(dummy_cast.node(),
1323                                tensor_list_element_shape.node());
1324 
1325   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1326   TF_ASSERT_OK(root.ToGraph(graph.get()));
1327 
1328   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
1329 
1330   std::unordered_map<string, string> clusters = GetClusters(*graph);
1331   EXPECT_EQ(clusters["test/tensor_list_element_shape"], "");
1332 }
1333 
TEST(XlaCompilationTest,ClusterOpsProducingVariantIfOnXlaDevice)1334 TEST(XlaCompilationTest, ClusterOpsProducingVariantIfOnXlaDevice) {
1335   Scope root = Scope::NewRootScope().ExitOnError();
1336   Output a = ops::Placeholder(root.WithOpName("test/a"), DT_INT64);
1337   Output b = ops::Placeholder(root.WithOpName("test/b"), DT_INT64);
1338 
1339   Output cast_a = ops::Cast(root.WithOpName("test/cast_a"), a, DT_INT32);
1340   Output cast_b = ops::Cast(root.WithOpName("test/cast_b"), b, DT_INT32);
1341 
1342   Output tensor_list_reserve = ops::TensorListReserve(
1343       root.WithOpName("test/tensor_list_reserve"), cast_a, cast_b, DT_FLOAT);
1344 
1345   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1346   TF_ASSERT_OK(root.ToGraph(graph.get()));
1347 
1348   string xla_cpu_device = "/job:worker/replica:0/task:0/device:XLA_CPU:0";
1349   for (Node* n : graph->nodes()) {
1350     if (absl::StartsWith(n->name(), /*prefix=*/"test/")) {
1351       n->set_assigned_device_name(xla_cpu_device);
1352     }
1353   }
1354 
1355   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
1356 
1357   std::unordered_map<string, string> clusters = GetClusters(*graph);
1358   EXPECT_NE(clusters["test/tensor_list_reserve"], "");
1359 }
1360 
1361 const char* kCPU0 = "/job:worker/replica:0/task:0/device:CPU:0";
1362 const char* kGPU0 = "/job:worker/replica:0/task:0/device:GPU:0";
1363 const char* kXLA_GPU0 = "/job:worker/replica:0/task:0/device:XLA_GPU:0";
1364 const char* kGPU1 = "/job:worker/replica:0/task:0/device:GPU:1";
1365 
TEST(XlaCompilationTest,CreateCombinedCpuGpuClusters)1366 TEST(XlaCompilationTest, CreateCombinedCpuGpuClusters) {
1367   Scope root = Scope::NewRootScope().ExitOnError();
1368   Output a = ops::Placeholder(root.WithOpName("test/a"), DT_FLOAT);
1369   Output b = ops::Placeholder(root.WithOpName("test/b"), DT_FLOAT);
1370 
1371   Output x = ops::Add(root.WithOpName("test/x"), a, b);
1372   Output y = ops::MatMul(root.WithOpName("test/y"), a, b);
1373   Output z = ops::Add(root.WithOpName("test/z"), x, y);
1374 
1375   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1376   TF_ASSERT_OK(root.ToGraph(graph.get()));
1377 
1378   FindNodeByName(graph.get(), "test/x")->set_assigned_device_name(kGPU0);
1379   FindNodeByName(graph.get(), "test/y")->set_assigned_device_name(kCPU0);
1380   FindNodeByName(graph.get(), "test/z")->set_assigned_device_name(kGPU0);
1381 
1382   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
1383 
1384   std::unordered_map<string, string> clusters = GetClusters(*graph);
1385 
1386   EXPECT_NE(clusters["test/x"], "");
1387 
1388   EXPECT_EQ(clusters["test/x"], clusters["test/y"]);
1389   EXPECT_EQ(clusters["test/y"], clusters["test/z"]);
1390 }
1391 
TEST(XlaCompilationTest,DontCreateGpu0AndGpu1Clusters)1392 TEST(XlaCompilationTest, DontCreateGpu0AndGpu1Clusters) {
1393   Scope root = Scope::NewRootScope().ExitOnError();
1394   Output a = ops::Placeholder(root.WithOpName("test/a"), DT_FLOAT);
1395   Output b = ops::Placeholder(root.WithOpName("test/b"), DT_FLOAT);
1396 
1397   Output x = ops::Add(root.WithOpName("test/x"), a, b);
1398   Output y = ops::Add(root.WithOpName("test/y"), x, x);
1399 
1400   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1401   TF_ASSERT_OK(root.ToGraph(graph.get()));
1402 
1403   FindNodeByName(graph.get(), "test/x")->set_assigned_device_name(kGPU0);
1404   FindNodeByName(graph.get(), "test/y")->set_assigned_device_name(kGPU1);
1405 
1406   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
1407 
1408   std::unordered_map<string, string> clusters = GetClusters(*graph);
1409 
1410   EXPECT_EQ(clusters["test/x"], "");
1411   EXPECT_EQ(clusters["test/y"], "");
1412 }
1413 
TEST(XlaCompilationTest,DontCreateCombinedCpuUnknownClusters)1414 TEST(XlaCompilationTest, DontCreateCombinedCpuUnknownClusters) {
1415   Scope root = Scope::NewRootScope().ExitOnError();
1416   Output a = ops::Placeholder(root.WithOpName("test/a"), DT_FLOAT);
1417   Output b = ops::Placeholder(root.WithOpName("test/b"), DT_FLOAT);
1418 
1419   Output x = ops::Add(root.WithOpName("test/x"), a, b);
1420   Output y = ops::Add(root.WithOpName("test/y"), x, x);
1421 
1422   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1423   TF_ASSERT_OK(root.ToGraph(graph.get()));
1424 
1425   FindNodeByName(graph.get(), "test/x")->set_assigned_device_name(kCPU0);
1426   FindNodeByName(graph.get(), "test/y")->set_assigned_device_name(kXLA_GPU0);
1427 
1428   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
1429 
1430   std::unordered_map<string, string> clusters = GetClusters(*graph);
1431 
1432   EXPECT_EQ(clusters["test/x"], "");
1433   EXPECT_EQ(clusters["test/y"], "");
1434 }
1435 
TEST(XlaCompilationTest,ClusterResourceOpsWhenSafe)1436 TEST(XlaCompilationTest, ClusterResourceOpsWhenSafe) {
1437   Scope root = Scope::NewRootScope().ExitOnError();
1438   Output a = ops::Placeholder(root.WithOpName("test/a"), DT_FLOAT);
1439   Node* var_handle;
1440   Node* resource_read = MakeRead(root, "read", &var_handle);
1441   Output b = ops::Add(root.WithOpName("test/b"), Output(resource_read, 0), a);
1442 
1443   string resource_read_name = resource_read->name();
1444   string var_handle_name = var_handle->name();
1445 
1446   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1447   TF_ASSERT_OK(root.ToGraph(graph.get()));
1448 
1449   FindNodeByName(graph.get(), "test/b")->set_assigned_device_name(kCPU0);
1450   FindNodeByName(graph.get(), resource_read_name)
1451       ->set_assigned_device_name(kGPU0);
1452   FindNodeByName(graph.get(), var_handle_name)->set_assigned_device_name(kGPU0);
1453 
1454   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
1455 
1456   std::unordered_map<string, string> clusters = GetClusters(*graph);
1457 
1458   EXPECT_NE(clusters["test/b"], "");
1459   EXPECT_EQ(clusters["test/b"], clusters[resource_read_name]);
1460 }
1461 
TEST(XlaCompilationTest,DontClusterResourceOpsWhenUnsafe)1462 TEST(XlaCompilationTest, DontClusterResourceOpsWhenUnsafe) {
1463   Scope root = Scope::NewRootScope().ExitOnError();
1464   Output a = ops::Placeholder(root.WithOpName("test/a"), DT_FLOAT);
1465   Node* var_handle;
1466   Node* resource_read = MakeRead(root, "read", &var_handle);
1467   Output b = ops::Add(root.WithOpName("test/b"), Output(resource_read, 0), a);
1468 
1469   string resource_read_name = resource_read->name();
1470   string var_handle_name = var_handle->name();
1471 
1472   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1473   TF_ASSERT_OK(root.ToGraph(graph.get()));
1474 
1475   FindNodeByName(graph.get(), "test/b")->set_assigned_device_name(kGPU0);
1476   FindNodeByName(graph.get(), resource_read_name)
1477       ->set_assigned_device_name(kCPU0);
1478   FindNodeByName(graph.get(), var_handle_name)->set_assigned_device_name(kCPU0);
1479 
1480   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
1481 
1482   std::unordered_map<string, string> clusters = GetClusters(*graph);
1483 
1484   EXPECT_EQ(clusters["test/b"], "");
1485   EXPECT_EQ(clusters[resource_read_name], "");
1486 }
1487 
TEST(XlaCompilationTest,DontClusterNodesWithScopedAllocatorAttr)1488 TEST(XlaCompilationTest, DontClusterNodesWithScopedAllocatorAttr) {
1489   Scope root = Scope::NewRootScope().ExitOnError();
1490   Output a = ops::Placeholder(root.WithOpName("test/a"), DT_FLOAT);
1491   Output b = ops::Placeholder(root.WithOpName("test/b"), DT_FLOAT);
1492 
1493   Output x = ops::Add(root.WithOpName("test/x"), a, b);
1494   Output y = ops::MatMul(root.WithOpName("test/y"), a, b);
1495   Output z = ops::Add(root.WithOpName("test/z"), x, y);
1496 
1497   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1498   TF_ASSERT_OK(root.ToGraph(graph.get()));
1499 
1500   FindNodeByName(graph.get(), "test/x")->set_assigned_device_name(kGPU0);
1501   FindNodeByName(graph.get(), "test/y")->set_assigned_device_name(kGPU0);
1502   FindNodeByName(graph.get(), "test/z")->set_assigned_device_name(kGPU0);
1503 
1504   std::vector<int> scoped_allocator_value;
1505   scoped_allocator_value.push_back(0);
1506   scoped_allocator_value.push_back(155);
1507   FindNodeByName(graph.get(), "test/z")
1508       ->AddAttr("_scoped_allocator", scoped_allocator_value);
1509 
1510   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
1511 
1512   std::unordered_map<string, string> clusters = GetClusters(*graph);
1513 
1514   EXPECT_EQ(clusters["test/z"], "");
1515 }
1516 
TEST(XlaCompilationTest,DontClusterNodesWithForwardFromAttr)1517 TEST(XlaCompilationTest, DontClusterNodesWithForwardFromAttr) {
1518   Scope root = Scope::NewRootScope().ExitOnError();
1519   Output a = ops::Placeholder(root.WithOpName("test/a"), DT_FLOAT);
1520   Output b = ops::Placeholder(root.WithOpName("test/b"), DT_FLOAT);
1521 
1522   Output x = ops::Add(root.WithOpName("test/x"), a, b);
1523   Output y = ops::MatMul(root.WithOpName("test/y"), a, b);
1524   Output z = ops::Add(root.WithOpName("test/z"), x, y);
1525 
1526   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1527   TF_ASSERT_OK(root.ToGraph(graph.get()));
1528 
1529   FindNodeByName(graph.get(), "test/x")->set_assigned_device_name(kGPU0);
1530   FindNodeByName(graph.get(), "test/y")->set_assigned_device_name(kGPU0);
1531   FindNodeByName(graph.get(), "test/z")->set_assigned_device_name(kGPU0);
1532 
1533   FindNodeByName(graph.get(), "test/z")->AddAttr("_forward_from", 0);
1534 
1535   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
1536 
1537   std::unordered_map<string, string> clusters = GetClusters(*graph);
1538 
1539   EXPECT_EQ(clusters["test/z"], "");
1540 }
1541 
1542 // Note, this relies on other implementation details to test the
1543 // specific heuristic we care about here, so other changes might be at fault if
1544 // this CL breaks. What we care about is that if a ShapeConsumingOp can be
1545 // connected with a producer or consumer and cannot be clustered with both, it
1546 // should be clustered with the producer.
TEST(XlaCompilationTest,ClusterShapeConsumerWithProducer)1547 TEST(XlaCompilationTest, ClusterShapeConsumerWithProducer) {
1548   Scope root = Scope::NewRootScope().ExitOnError();
1549   Output a = ops::Placeholder(root.WithOpName("test/a"), DT_FLOAT);
1550   Output b = ops::Placeholder(root.WithOpName("test/b"), DT_FLOAT);
1551 
1552   Output x = ops::MatMul(root.WithOpName("test/x"), a, b);
1553   Output y = ops::Size(root.WithOpName("test/y"), x);
1554   Output z = ops::Add(root.WithOpName("test/z"), y, y);
1555 
1556   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1557   TF_ASSERT_OK(root.ToGraph(graph.get()));
1558 
1559   // Ensure that the "Size" op can only be clustered with either the producer or
1560   // consumer by putting them on different devices.
1561   FindNodeByName(graph.get(), "test/x")->set_assigned_device_name(kGPU0);
1562   FindNodeByName(graph.get(), "test/y")->set_assigned_device_name(kCPU0);
1563   FindNodeByName(graph.get(), "test/z")->set_assigned_device_name(kGPU1);
1564 
1565   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
1566 
1567   std::unordered_map<string, string> clusters = GetClusters(*graph);
1568 
1569   EXPECT_NE(clusters["test/y"], "");
1570   EXPECT_EQ(clusters["test/x"], clusters["test/y"]);
1571   EXPECT_NE(clusters["test/z"], clusters["test/y"]);
1572 }
1573 
1574 // Test that ShapeConsuming ops are still fully clustered whenever possible.
TEST(XlaCompilationTest,ClusterShapeConsumerWithProducerAndConsumer)1575 TEST(XlaCompilationTest, ClusterShapeConsumerWithProducerAndConsumer) {
1576   Scope root = Scope::NewRootScope().ExitOnError();
1577   Output a = ops::Placeholder(root.WithOpName("test/a"), DT_FLOAT);
1578   Output b = ops::Placeholder(root.WithOpName("test/b"), DT_FLOAT);
1579 
1580   Output x = ops::MatMul(root.WithOpName("test/x"), a, b);
1581   Output y = ops::Size(root.WithOpName("test/y"), x);
1582   Output z = ops::Add(root.WithOpName("test/z"), y, y);
1583 
1584   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1585   TF_ASSERT_OK(root.ToGraph(graph.get()));
1586 
1587   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
1588 
1589   std::unordered_map<string, string> clusters = GetClusters(*graph);
1590 
1591   EXPECT_NE(clusters["test/y"], "");
1592   EXPECT_EQ(clusters["test/y"], clusters["test/x"]);
1593   EXPECT_EQ(clusters["test/y"], clusters["test/z"]);
1594 }
1595 
AddCtrlEdge(const Scope & scope,Operation a,Operation b)1596 void AddCtrlEdge(const Scope& scope, Operation a, Operation b) {
1597   scope.graph()->AddControlEdge(a.node(), b.node());
1598 }
1599 
AddCtrlEdge(const Scope & scope,Output a,Operation b)1600 void AddCtrlEdge(const Scope& scope, Output a, Operation b) {
1601   AddCtrlEdge(scope, a.op(), b);
1602 }
1603 
AddCtrlEdge(const Scope & scope,Operation a,Output b)1604 void AddCtrlEdge(const Scope& scope, Operation a, Output b) {
1605   AddCtrlEdge(scope, a, b.op());
1606 }
1607 
1608 // Tests that we pick a good clustering for graphs that have an integer
1609 // increment operation control dependent on gradient update operations.
TEST(XlaCompilationTest,IterationIncrementAndGroupDeps)1610 TEST(XlaCompilationTest, IterationIncrementAndGroupDeps) {
1611   Scope scope = Scope::NewRootScope().ExitOnError();
1612 
1613   Output iter =
1614       ops::VarHandleOp(scope.WithOpName("iter"), DT_INT64, TensorShape({}));
1615   Output weights_0 = ops::VarHandleOp(scope.WithOpName("weights_0"), DT_FLOAT,
1616                                       TensorShape({1000}));
1617   Output weights_1 = ops::VarHandleOp(scope.WithOpName("weights_1"), DT_FLOAT,
1618                                       TensorShape({1000}));
1619 
1620   // We update the weights by adding delta to them (to "simulate" a
1621   // ResourceApplyGradientDescent and similar things).
1622   Output delta = ops::Placeholder(scope.WithOpName("delta"), DT_FLOAT);
1623 
1624   ops::AssignAddVariableOp increment_op(
1625       scope.WithOpName("IncrementIteration"), iter,
1626       ops::Const(scope.WithOpName("one"), static_cast<int64>(1)));
1627 
1628   ops::AssignAddVariableOp weights_0_update_op(
1629       scope.WithOpName("weights_0_update"), weights_0, delta);
1630   ops::AssignAddVariableOp weights_1_update_op(
1631       scope.WithOpName("weights_1_update"), weights_1, delta);
1632 
1633   ops::NoOp group_deps(scope.WithOpName("group_deps"));
1634 
1635   ops::NoOp some_ctrl_input(scope.WithOpName("some_ctrl_input"));
1636 
1637   Output matmul_input =
1638       ops::Placeholder(scope.WithOpName("matmul_input"), DT_FLOAT);
1639   Output matmul_0 =
1640       ops::MatMul(scope.WithOpName("matmul_0"), matmul_input, matmul_input);
1641   Output matmul_1 =
1642       ops::MatMul(scope.WithOpName("matmul_1"), matmul_input, matmul_input);
1643 
1644   AddCtrlEdge(scope, increment_op, group_deps);
1645   AddCtrlEdge(scope, weights_0_update_op, increment_op);
1646   AddCtrlEdge(scope, weights_1_update_op, increment_op);
1647 
1648   AddCtrlEdge(scope, some_ctrl_input, weights_0_update_op);
1649   AddCtrlEdge(scope, some_ctrl_input, weights_1_update_op);
1650 
1651   AddCtrlEdge(scope, matmul_0, group_deps);
1652   AddCtrlEdge(scope, matmul_1, group_deps);
1653 
1654   AddCtrlEdge(scope, weights_0_update_op, matmul_0);
1655   AddCtrlEdge(scope, weights_1_update_op, matmul_1);
1656 
1657   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1658   TF_ASSERT_OK(scope.ToGraph(graph.get()));
1659 
1660   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
1661 
1662   std::unordered_map<string, string> clusters = GetClusters(*graph);
1663 
1664   EXPECT_NE(clusters["some_ctrl_input"], "");
1665   EXPECT_EQ(clusters["some_ctrl_input"], clusters["weights_0_update"]);
1666   EXPECT_EQ(clusters["some_ctrl_input"], clusters["weights_1_update"]);
1667   EXPECT_EQ(clusters["some_ctrl_input"], clusters["matmul_0"]);
1668   EXPECT_EQ(clusters["some_ctrl_input"], clusters["matmul_0"]);
1669 }
1670 
1671 // Test a pattern where a special Identity node is driving consts in a loop.
1672 // Expect that the Identity node will not go into any clusters.  Note that we
1673 // create an incomplete graph here (e.g., lacking Enter/Exit/NextIteration,
1674 // etc.) just enough to test the pattern, as a complete graph may be too
1675 // cumbersome and unnecessary.
TEST(XlaCompilationTest,DontClusterTheSpecialIdentityDrivingConstsInLoop)1676 TEST(XlaCompilationTest, DontClusterTheSpecialIdentityDrivingConstsInLoop) {
1677   Scope root = Scope::NewRootScope().ExitOnError();
1678 
1679   Output cond = ops::Placeholder(root.WithOpName("cond"), DT_BOOL);
1680   Output value = ops::Placeholder(root.WithOpName("value"), DT_FLOAT);
1681   Output loop_cond = ops::LoopCond(root.WithOpName("loop_cond"), cond);
1682   ops::Switch switch_node(root.WithOpName("switch"), value, loop_cond);
1683 
1684   Output identity =
1685       ops::Identity(root.WithOpName("identity"), switch_node.output_true);
1686   Output const_node = ops::Const(root.WithOpName("const"), 1.0f);
1687   root.graph()->AddControlEdge(identity.node(), const_node.node());
1688   Output tanh0 = ops::Tanh(root.WithOpName("tanh0"), const_node);
1689   Output tanh1 = ops::Tanh(root.WithOpName("tanh1"), tanh0);
1690   Output add = ops::Add(root.WithOpName("add"), const_node, tanh1);
1691 
1692   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1693   TF_EXPECT_OK(root.ToGraph(graph.get()));
1694 
1695   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(
1696       &graph,
1697       MarkForCompilationPassTestHelper::Options().WithDeadnessAnalysis()));
1698   auto clusters = GetClusters(*graph);
1699 
1700   EXPECT_EQ(clusters["identity"], "");
1701 }
1702 
TEST(XlaCompilationTest,UnsupportedEnterExitPattern)1703 TEST(XlaCompilationTest, UnsupportedEnterExitPattern) {
1704   // Regression test for b/32350199, where the autoclustering code introduced a
1705   // deadlock in a graph containing a while loop.
1706   Scope root = Scope::NewRootScope().ExitOnError();
1707   auto a = ops::Placeholder(root.WithOpName("A"), DT_FLOAT);
1708   auto enter_0 = ops::internal::Enter(root.WithOpName("enter_a"), a, "frame");
1709   auto exit_0 = ops::internal::Exit(root.WithOpName("exit_a"), enter_0);
1710   auto tanh = ops::Tanh(root.WithOpName("tanh"), exit_0);
1711   auto enter_1 =
1712       ops::internal::Enter(root.WithOpName("enter_1"), tanh, "frame");
1713   auto exit_1 = ops::internal::Exit(root.WithOpName("exit_1"), enter_1);
1714 
1715   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1716   TF_EXPECT_OK(root.ToGraph(graph.get()));
1717 
1718   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
1719   auto clusters = GetClusters(*graph);
1720 
1721   // Nothing should be compiled.
1722   EXPECT_EQ(0, clusters.size());
1723 }
1724 
1725 namespace {
MakeStageNode(GraphDefBuilder & builder,string name,std::initializer_list<DataType> dtypes,absl::Span<const ops::NodeOut> values)1726 Node* MakeStageNode(GraphDefBuilder& builder, string name,
1727                     std::initializer_list<DataType> dtypes,
1728                     absl::Span<const ops::NodeOut> values) {
1729   auto opts = builder.opts()
1730                   .WithName(std::move(name))
1731                   .WithAttr("dtypes", std::move(dtypes));
1732   if (opts.HaveError()) {
1733     return nullptr;
1734   }
1735 
1736   NodeBuilder node_builder(name, "Stage", opts.op_registry());
1737   node_builder.Input(values);
1738   return opts.FinalizeBuilder(&node_builder);
1739 }
1740 }  // namespace
1741 
TEST(XlaCompilationTest,StagePipelinePreservedByClusterScopingPass)1742 TEST(XlaCompilationTest, StagePipelinePreservedByClusterScopingPass) {
1743   auto build_staged_graph = [](std::unique_ptr<Graph>* graph) -> Status {
1744     // Construct a graph as below with two pipeline stages and test that nodes
1745     // in different stages will not be merged if ClusterScopingPass is on.
1746     //
1747     //       b
1748     //       |
1749     //       v
1750     // a -> add0 -> relu0 -> stage
1751     //
1752     //             b
1753     //             |
1754     //             v
1755     // unstage -> add1 -> relu1
1756     GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
1757     Node* a = ops::SourceOp("Const", builder.opts()
1758                                          .WithName("a")
1759                                          .WithAttr("dtype", DT_FLOAT)
1760                                          .WithAttr("value", Tensor()));
1761     Node* b = ops::SourceOp("Const", builder.opts()
1762                                          .WithName("b")
1763                                          .WithAttr("dtype", DT_FLOAT)
1764                                          .WithAttr("value", Tensor()));
1765     Node* unstage = ops::SourceOp(
1766         "Unstage",
1767         builder.opts().WithName("unstage").WithAttr("dtypes", {DT_FLOAT}));
1768 
1769     Node* add0 = ops::BinaryOp("Add", a, b, builder.opts().WithName("add0"));
1770     Node* add1 =
1771         ops::BinaryOp("Add", unstage, b, builder.opts().WithName("add1"));
1772     Node* relu0 = ops::UnaryOp("Relu", add0, builder.opts().WithName("relu0"));
1773     ops::UnaryOp("Relu", add1, builder.opts().WithName("relu1"));
1774     MakeStageNode(builder, "stage", {DT_FLOAT}, {relu0});
1775 
1776     return GraphDefBuilderToGraph(builder, graph->get());
1777   };
1778 
1779   // All nodes go into the same cluster if ClusterScopingPass is off.
1780   {
1781     std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1782     TF_ASSERT_OK(build_staged_graph(&graph));
1783 
1784     TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(
1785         &graph,
1786         MarkForCompilationPassTestHelper::Options().WithNoClusterScoping()));
1787 
1788     std::unordered_map<string, string> clusters = GetClusters(*graph);
1789     EXPECT_EQ(clusters["add0"], clusters["add1"]);
1790     EXPECT_EQ(clusters["add0"], clusters["relu1"]);
1791     EXPECT_EQ(clusters["relu0"], clusters["add1"]);
1792     EXPECT_EQ(clusters["relu0"], clusters["relu1"]);
1793   }
1794 
1795   // By default, ClusterScopingPass is on and different pipeline stages should
1796   // not be merged.
1797   {
1798     std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1799     TF_ASSERT_OK(build_staged_graph(&graph));
1800 
1801     TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
1802 
1803     std::unordered_map<string, string> clusters = GetClusters(*graph);
1804     EXPECT_NE(clusters["add0"], clusters["add1"]);
1805     EXPECT_NE(clusters["add0"], clusters["relu1"]);
1806     EXPECT_NE(clusters["relu0"], clusters["add1"]);
1807     EXPECT_NE(clusters["relu0"], clusters["relu1"]);
1808   }
1809 }
TEST(XlaCompilationTest,XLALiteAllowlist)1810 TEST(XlaCompilationTest, XLALiteAllowlist) {
1811   auto* allowlist_table = tensorflow::GetAllowlistTable();
1812   absl::flat_hash_set<string> hallowlist;
1813   std::vector<string> vall_ops = XlaOpRegistry::GetAllRegisteredOps();
1814   absl::flat_hash_set<string> all_ops(vall_ops.begin(), vall_ops.end());
1815 
1816   // Check that all the operations in the table are existing TF operations
1817   for (auto pair : *allowlist_table) {
1818     hallowlist.insert(pair.second.begin(), pair.second.end());
1819     for (auto op : pair.second) {
1820       ASSERT_TRUE(all_ops.contains(op));
1821     }
1822   }
1823 
1824   // Check that all registered XLA operation are in the allowlist
1825   // table or are known to not be in it.
1826 
1827   absl::flat_hash_set<string> known_not_in_list =
1828       tensorflow::testing::GetKnownXLAAllowlistOp();
1829   std::vector<string> unknow_op;
1830   for (string op : vall_ops) {
1831     if (!hallowlist.contains(op) && !known_not_in_list.contains(op)) {
1832       unknow_op.push_back(op);
1833     }
1834   }
1835   EXPECT_TRUE(unknow_op.empty())
1836       << "Someone added support for a new TF opeations inside XLA. They must "
1837          "be included in the XLALite allowlist or denylist:\n"
1838       << absl::StrJoin(unknow_op, "\n");
1839 }
1840 }  // namespace
1841 }  // namespace tensorflow
1842