• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.h"
17 
18 #include "absl/container/flat_hash_map.h"
19 #include "absl/memory/memory.h"
20 #include "absl/strings/match.h"
21 #include "tensorflow/cc/framework/ops.h"
22 #include "tensorflow/cc/ops/array_ops.h"
23 #include "tensorflow/cc/ops/control_flow_ops_internal.h"
24 #include "tensorflow/cc/ops/function_ops.h"
25 #include "tensorflow/cc/ops/list_ops.h"
26 #include "tensorflow/cc/ops/resource_variable_ops.h"
27 #include "tensorflow/cc/ops/sendrecv_ops.h"
28 #include "tensorflow/cc/ops/standard_ops.h"
29 #include "tensorflow/compiler/jit/defs.h"
30 #include "tensorflow/compiler/jit/node_matchers.h"
31 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
32 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
33 #include "tensorflow/core/framework/node_def_util.h"
34 #include "tensorflow/core/framework/op.h"
35 #include "tensorflow/core/graph/algorithm.h"
36 #include "tensorflow/core/graph/graph_constructor.h"
37 #include "tensorflow/core/graph/graph_def_builder.h"
38 #include "tensorflow/core/graph/graph_def_builder_util.h"
39 #include "tensorflow/core/lib/core/status_test_util.h"
40 #include "tensorflow/core/platform/test.h"
41 
42 using ::tensorflow::testing::FindNodeByName;
43 
44 namespace tensorflow {
45 namespace {
46 
47 REGISTER_OP("UncompilableNullary").Output("o: float");
48 REGISTER_OP("UncompilableUnary").Input("a: float").Output("o: float");
49 
GetClusters(const Graph & graph)50 std::unordered_map<string, string> GetClusters(const Graph& graph) {
51   std::unordered_map<string, string> ids;
52   for (Node* node : graph.nodes()) {
53     string cluster;
54     if (GetNodeAttr(node->attrs(), kXlaClusterAttr, &cluster).ok()) {
55       CHECK(!cluster.empty());
56       ids[node->name()] = cluster;
57     }
58   }
59 
60   if (VLOG_IS_ON(2)) {
61     VLOG(2) << "Clusters:";
62     for (const auto& p : ids) {
63       VLOG(2) << " " << p.first << " -> " << p.second;
64     }
65   }
66   return ids;
67 }
68 
GetClusterSets(const Graph & g,std::vector<string> * cluster_names=nullptr)69 absl::flat_hash_map<string, std::vector<string>> GetClusterSets(
70     const Graph& g, std::vector<string>* cluster_names = nullptr) {
71   CHECK(cluster_names == nullptr || cluster_names->empty());
72   absl::flat_hash_map<string, std::vector<string>> cluster_sets;
73   for (const auto& p : GetClusters(g)) {
74     cluster_sets[p.second].push_back(p.first);
75   }
76   for (auto& p : cluster_sets) {
77     if (cluster_names != nullptr) {
78       cluster_names->push_back(p.first);
79     }
80     std::sort(p.second.begin(), p.second.end());
81   }
82   if (cluster_names != nullptr) {
83     std::sort(cluster_names->begin(), cluster_names->end());
84   }
85   return cluster_sets;
86 }
87 
TEST(XlaCompilationTest,Chains)88 TEST(XlaCompilationTest, Chains) {
89   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
90   GraphDef graphdef;
91   {
92     GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
93     Node* a =
94         ops::SourceOp("UncompilableNullary", builder.opts().WithName("A"));
95     Node* b = ops::UnaryOp("Relu", a, builder.opts().WithName("B"));
96     Node* c = ops::UnaryOp("Relu", b, builder.opts().WithName("C"));
97     Node* d =
98         ops::UnaryOp("UncompilableUnary", c, builder.opts().WithName("D"));
99     Node* e = ops::UnaryOp("Relu", d, builder.opts().WithName("E"));
100     ops::UnaryOp("Relu", e, builder.opts().WithName("F"));
101     TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
102   }
103 
104   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
105   auto clusters = GetClusters(*graph);
106   EXPECT_EQ(4, clusters.size());
107   EXPECT_EQ(clusters["B"], clusters["C"]);
108   EXPECT_EQ(clusters["E"], clusters["F"]);
109   EXPECT_NE(clusters["B"], clusters["E"]);
110   EXPECT_TRUE(clusters.find("A") == clusters.cend());
111   EXPECT_TRUE(clusters.find("D") == clusters.cend());
112 }
113 
TEST(XlaCompilationTest,UncompilableCycles)114 TEST(XlaCompilationTest, UncompilableCycles) {
115   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
116   GraphDef graphdef;
117   {
118     GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
119     Node* a = ops::SourceOp("Const", builder.opts()
120                                          .WithName("A")
121                                          .WithAttr("dtype", DT_FLOAT)
122                                          .WithAttr("value", Tensor()));
123     Node* b =
124         ops::UnaryOp("UncompilableUnary", a, builder.opts().WithName("B"));
125     ops::BinaryOp("MatMul", a, b, builder.opts().WithName("C"));
126     TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
127   }
128 
129   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
130   auto clusters = GetClusters(*graph);
131 
132   EXPECT_TRUE(clusters.empty());
133 }
134 
TEST(XlaCompilationTest,CompilableCycles)135 TEST(XlaCompilationTest, CompilableCycles) {
136   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
137   GraphDef graphdef;
138   {
139     GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
140     Node* a = ops::SourceOp("Const", builder.opts()
141                                          .WithName("A")
142                                          .WithAttr("dtype", DT_FLOAT)
143                                          .WithAttr("value", Tensor()));
144     Node* b = ops::UnaryOp("Relu", a, builder.opts().WithName("B"));
145     ops::BinaryOp("MatMul", a, b, builder.opts().WithName("C"));
146     TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
147   }
148 
149   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
150   auto clusters = GetClusters(*graph);
151 
152   EXPECT_EQ(3, clusters.size());
153   EXPECT_EQ(clusters["A"], clusters["B"]);
154   EXPECT_EQ(clusters["A"], clusters["C"]);
155 }
156 
TEST(XlaCompilationTest,StringUnsupported)157 TEST(XlaCompilationTest, StringUnsupported) {
158   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
159   GraphDef graphdef;
160   {
161     GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
162     Node* a = ops::SourceOp(
163         "Const", builder.opts()
164                      .WithName("A")
165                      .WithAttr("dtype", DT_STRING)
166                      .WithAttr("value", Tensor(DT_STRING, TensorShape())));
167     Node* b = ops::UnaryOp("EncodeBase64", a, builder.opts().WithName("B"));
168     ops::BinaryOp("StringSplit", a, b, builder.opts().WithName("C"));
169     TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
170   }
171 
172   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
173   auto clusters = GetClusters(*graph);
174   EXPECT_TRUE(clusters.empty());
175 }
176 
TEST(XlaCompilationTest,HalfSupported)177 TEST(XlaCompilationTest, HalfSupported) {
178   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
179   GraphDef graphdef;
180   {
181     GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
182     Tensor t(DT_HALF, TensorShape());
183     t.scalar<Eigen::half>()() = static_cast<Eigen::half>(0.0f);
184     Node* a = ops::SourceOp("Const", builder.opts()
185                                          .WithName("A")
186                                          .WithAttr("dtype", DT_HALF)
187                                          .WithAttr("value", t));
188     Node* b = ops::UnaryOp("Neg", a, builder.opts().WithName("B"));
189     ops::BinaryOp("MatMul", a, b, builder.opts().WithName("C"));
190     TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
191   }
192 
193   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
194   auto clusters = GetClusters(*graph);
195   EXPECT_FALSE(clusters.empty());
196 }
197 
TEST(XlaCompilationTest,FunctionCalls)198 TEST(XlaCompilationTest, FunctionCalls) {
199   FunctionDef compilable = FunctionDefHelper::Define(
200       "CompilableFn", {"n_a:float", "n_b:float"}, {"n_c:float"}, {},
201       {{{"n_c"}, "Add", {"n_a", "n_b"}, {{"T", DT_FLOAT}}}});
202   FunctionDef uncompilable =
203       FunctionDefHelper::Define("UncompilableFn", {"n_a:float"}, {"n_c:float"},
204                                 {}, {{{"n_c"}, "UncompilableUnary", {"n_a"}}});
205   FunctionDef noinline = compilable;
206   noinline.mutable_signature()->set_name("NoInlineFn");
207   AddAttr("_noinline", static_cast<bool>(true), noinline.mutable_attr());
208 
209   FunctionDefLibrary flib;
210   *flib.add_function() = compilable;
211   *flib.add_function() = uncompilable;
212   *flib.add_function() = noinline;
213   FunctionLibraryDefinition flib_def(OpRegistry::Global(), flib);
214 
215   std::unique_ptr<Graph> graph(new Graph(&flib_def));
216   GraphDef graphdef;
217   {
218     GraphDefBuilder builder(GraphDefBuilder::kFailImmediately, &flib_def);
219     Node* a =
220         ops::SourceOp("UncompilableNullary", builder.opts().WithName("A"));
221     Node* b = ops::BinaryOp("CompilableFn", a, a, builder.opts().WithName("B"));
222     Node* c = ops::UnaryOp("Relu", b, builder.opts().WithName("C"));
223     ops::UnaryOp("UncompilableFn", c, builder.opts().WithName("D"));
224     ops::BinaryOp("NoInlineFn", c, c, builder.opts().WithName("E"));
225     TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
226   }
227 
228   TF_ASSERT_OK(
229       MarkForCompilationPassTestHelper::MarkForCompilation(&graph, &flib_def));
230   auto clusters = GetClusters(*graph);
231 
232   EXPECT_EQ(2, clusters.size());
233   EXPECT_FALSE(clusters["B"].empty());
234   EXPECT_EQ(clusters["B"], clusters["C"]);
235   EXPECT_TRUE(clusters.find("A") == clusters.cend());
236   EXPECT_TRUE(clusters.find("D") == clusters.cend());
237   EXPECT_TRUE(clusters.find("E") == clusters.cend());
238 }
239 
240 // Metadata-only operators such as Shape/Rank/Size may not be the root of a
241 // cluster. This is partially to work around b/26800664, and partially because
242 // we should probably prefer to compile metadata operators with their producers
243 // wherever possible, rather than their consumers.
TEST(XlaCompilationTest,MetadataOpsDontStartClusters)244 TEST(XlaCompilationTest, MetadataOpsDontStartClusters) {
245   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
246   GraphDef graphdef;
247   {
248     GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
249     Node* a =
250         ops::SourceOp("UncompilableNullary", builder.opts().WithName("A"));
251     // While all of the following ops are notionally compilable, none is
252     // permitted
253     // to start a cluster. So nothing should be compiled.
254     Node* b = ops::UnaryOp("Shape", a, builder.opts().WithName("B"));
255     Node* c = ops::UnaryOp("Rank", b, builder.opts().WithName("C"));
256     Node* d = ops::UnaryOp("Size", c, builder.opts().WithName("D"));
257     ops::UnaryOp("Shape", d, builder.opts().WithName("E"));
258     TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
259   }
260   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
261   auto clusters = GetClusters(*graph);
262   EXPECT_EQ(0, clusters.size());  // Nothing should be compiled.
263 }
264 
GradForUnaryCwise(FunctionDef * g,std::vector<FunctionDefHelper::Node> nodes)265 static Status GradForUnaryCwise(FunctionDef* g,
266                                 std::vector<FunctionDefHelper::Node> nodes) {
267   for (auto& n : nodes) {
268     if (n.attr.empty()) {
269       n.attr = {{"T", DT_FLOAT}};
270     }
271   }
272   *g = FunctionDefHelper::Define(
273       // Arg defs
274       {"x: float", "dy: float"},
275       // Ret val defs
276       {"dx: float"},
277       // Attr defs
278       {},
279       // Nodes
280       nodes);
281   return Status::OK();
282 }
283 
284 // A gradient containing only supported operators
SupportedGrad(const AttrSlice & attrs,FunctionDef * g)285 Status SupportedGrad(const AttrSlice& attrs, FunctionDef* g) {
286   // clang-format off
287   return GradForUnaryCwise(g, {
288       {{"y"}, "Tanh", {"x"}},
289       {{"y2"}, "Square", {"y"}, {}, {"dy"}},
290       FunctionDefHelper::Const("one", 1.0f),
291       {{"a"}, "Sub", {"one", "y2"}},
292       {{"dx"}, "Mul", {"dy", "a"}},
293   });
294   // clang-format on
295 }
296 REGISTER_OP_GRADIENT("Supported", SupportedGrad);
297 
298 // A gradient containing an unsupported operator.
UnsupportedGrad(const AttrSlice & attrs,FunctionDef * g)299 Status UnsupportedGrad(const AttrSlice& attrs, FunctionDef* g) {
300   // clang-format off
301   return GradForUnaryCwise(g, {
302       {{"y"}, "Tanh", {"x"}},
303       {{"y2"}, "UncompilableUnary", {"y"}, {}, {"dy"}},
304       FunctionDefHelper::Const("one", 1.0f),
305       {{"a"}, "Sub", {"one", "y2"}},
306       {{"dx"}, "Mul", {"dy", "a"}},
307   });
308   // clang-format on
309 }
310 REGISTER_OP_GRADIENT("Unsupported", UnsupportedGrad);
311 
TEST(XlaCompilationTest,SymbolicGradients)312 TEST(XlaCompilationTest, SymbolicGradients) {
313   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
314   GraphDef graphdef;
315   {
316     GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
317     Node* a =
318         ops::SourceOp("UncompilableNullary", builder.opts().WithName("A"));
319 
320     // Builds a Symbolic gradient for Supported
321     NodeBuilder b_builder("B", "SymbolicGradient",
322                           builder.opts().op_registry());
323     NameAttrList b_name_attr;
324     b_name_attr.set_name("Supported");
325     b_builder.Attr("f", b_name_attr);
326     b_builder.Attr("Tin", {DT_FLOAT, DT_FLOAT});
327     b_builder.Attr("Tout", {DT_FLOAT});
328     b_builder.Input({a, a});
329     Node* b = builder.opts().FinalizeBuilder(&b_builder);
330 
331     Node* c = ops::UnaryOp("Relu", b, builder.opts().WithName("C"));
332 
333     // Builds a Symbolic gradient for Unsupported
334     NodeBuilder d_builder("D", "SymbolicGradient",
335                           builder.opts().op_registry());
336     NameAttrList d_name_attr;
337     d_name_attr.set_name("Unsupported");
338     d_builder.Attr("f", d_name_attr);
339     d_builder.Attr("Tin", {DT_FLOAT, DT_FLOAT});
340     d_builder.Attr("Tout", {DT_FLOAT});
341     d_builder.Input({c, c});
342     builder.opts().FinalizeBuilder(&d_builder);
343 
344     TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
345   }
346 
347   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
348   auto clusters = GetClusters(*graph);
349 
350   EXPECT_EQ(2, clusters.size());
351   EXPECT_FALSE(clusters["B"].empty());
352   EXPECT_EQ(clusters["B"], clusters["C"]);
353   EXPECT_TRUE(clusters.find("A") == clusters.cend());
354   EXPECT_TRUE(clusters.find("D") == clusters.cend());
355 }
356 
TEST(XlaCompilationTest,Loops)357 TEST(XlaCompilationTest, Loops) {
358   // Regression test for b/32350199, where the autoclustering code introduced a
359   // deadlock in a graph containing a while loop.
360   Scope root = Scope::NewRootScope().ExitOnError();
361   auto a = ops::Placeholder(root.WithOpName("A"), DT_FLOAT);
362   auto b = ops::Placeholder(root.WithOpName("B"), DT_FLOAT);
363   auto c = ops::Add(root.WithOpName("C"), a, b);
364   auto enter = ops::internal::Enter(root, c, "aframe");
365   auto next_iter = ops::NextIteration(root, enter);
366   auto exit = ops::internal::Exit(root, next_iter);
367   auto d = ops::Add(root.WithOpName("D"), c, exit);
368 
369   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
370   TF_EXPECT_OK(root.ToGraph(graph.get()));
371 
372   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
373   auto clusters = GetClusters(*graph);
374 
375   // Nothing should be compiled. In particular, 'd' and 'c' must not be
376   // compiled.
377   EXPECT_EQ(0, clusters.size());
378 }
379 
TEST(XlaCompilationTest,CyclesWithAllDifferentScopesGlobalJitOverridden)380 TEST(XlaCompilationTest, CyclesWithAllDifferentScopesGlobalJitOverridden) {
381   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
382   GraphDef graphdef;
383   {
384     GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
385     Node* a = ops::SourceOp("Const", builder.opts()
386                                          .WithName("A")
387                                          .WithAttr("dtype", DT_FLOAT)
388                                          .WithAttr("value", Tensor())
389                                          .WithAttr(kXlaScopeAttr, "ScopeA"));
390     Node* b = ops::UnaryOp(
391         "Relu", a,
392         builder.opts().WithName("B").WithAttr(kXlaScopeAttr, "ScopeB"));
393     ops::BinaryOp(
394         "MatMul", a, b,
395         builder.opts().WithName("C").WithAttr(kXlaScopeAttr, "ScopeC"));
396     TF_CHECK_OK(GraphDefBuilderToGraph(builder, graph.get()));
397   }
398 
399   FunctionDefLibrary flib;
400   FunctionLibraryDefinition flib_def(graph->op_registry(), flib);
401   TF_ASSERT_OK(
402       MarkForCompilationPassTestHelper::MarkForCompilation(&graph, &flib_def));
403   auto clusters = GetClusters(*graph);
404 
405   // The computation is: C = A + relu(A)
406   // where A sits in ScopeA, relu(A) sits in ScopeB, and C sits in ScopeC.
407   // In this case, the GlobalJitLevel overrides the scopes to cluster while
408   // ignoring scopes.
409   EXPECT_EQ(3, clusters.size());
410   EXPECT_EQ(clusters["A"], clusters["B"]);
411   EXPECT_EQ(clusters["A"], clusters["C"]);
412 }
413 
TEST(XlaCompilationTest,CyclesWithAllDifferentScopes)414 TEST(XlaCompilationTest, CyclesWithAllDifferentScopes) {
415   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
416   GraphDef graphdef;
417   {
418     GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
419     Node* a = ops::SourceOp("Const", builder.opts()
420                                          .WithName("A")
421                                          .WithAttr("dtype", DT_FLOAT)
422                                          .WithAttr("value", Tensor())
423                                          .WithAttr(kXlaScopeAttr, "ScopeA"));
424     Node* b = ops::UnaryOp(
425         "Relu", a,
426         builder.opts().WithName("B").WithAttr(kXlaScopeAttr, "ScopeB"));
427     ops::BinaryOp(
428         "MatMul", a, b,
429         builder.opts().WithName("C").WithAttr(kXlaScopeAttr, "ScopeC"));
430     TF_CHECK_OK(GraphDefBuilderToGraph(builder, graph.get()));
431   }
432 
433   TF_ASSERT_OK(
434       MarkForCompilationPassTestHelper::MarkForCompilation(&graph, false));
435   auto clusters = GetClusters(*graph);
436 
437   // The computation is: C = A + relu(A)
438   // where A sits in ScopeA, relu(A) sits in ScopeB, and C sits in ScopeC.
439   // In this case, we cannot fuse anything, and there are no clusters.
440   EXPECT_EQ(0, clusters.size());
441 }
442 
TEST(XlaCompilationTest,CyclesWithSplittingScopes)443 TEST(XlaCompilationTest, CyclesWithSplittingScopes) {
444   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
445   GraphDef graphdef;
446   {
447     GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
448     Node* a = ops::SourceOp("Const", builder.opts()
449                                          .WithName("A")
450                                          .WithAttr("dtype", DT_FLOAT)
451                                          .WithAttr("value", Tensor())
452                                          .WithAttr(kXlaCompileAttr, true)
453                                          .WithAttr(kXlaScopeAttr, "Scope1"));
454     Node* b = ops::UnaryOp("Relu", a,
455                            builder.opts()
456                                .WithName("B")
457                                .WithAttr(kXlaCompileAttr, true)
458                                .WithAttr(kXlaScopeAttr, "Scope1"));
459     Node* c = ops::BinaryOp("MatMul", a, b,
460                             builder.opts()
461                                 .WithName("C")
462                                 .WithAttr(kXlaCompileAttr, true)
463                                 .WithAttr(kXlaScopeAttr, "Scope2"));
464     ops::BinaryOp("Add", b, c,
465                   builder.opts()
466                       .WithName("D")
467                       .WithAttr(kXlaCompileAttr, true)
468                       .WithAttr(kXlaScopeAttr, "Scope2"));
469     TF_CHECK_OK(GraphDefBuilderToGraph(builder, graph.get()));
470   }
471 
472   TF_ASSERT_OK(
473       MarkForCompilationPassTestHelper::MarkForCompilation(&graph, false));
474   auto clusters = GetClusters(*graph);
475 
476   // The computation is: D = relu(A) + (A @ relu(A))
477   // where A and relu(A) are in Scope1, and the @, + ops are in Scope2.
478   // In this case, we can fuse the A and relu(A), and we can fuse the
479   // second half of the operations; there are two clusters.
480   EXPECT_EQ(4, clusters.size());
481   EXPECT_EQ(clusters["A"], clusters["B"]);
482   EXPECT_NE(clusters["A"], clusters["C"]);
483   EXPECT_EQ(clusters["C"], clusters["D"]);
484 }
485 
TEST(XlaCompilationTest,CyclesWithDifferentScopesAndBridge)486 TEST(XlaCompilationTest, CyclesWithDifferentScopesAndBridge) {
487   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
488   GraphDef graphdef;
489   {
490     GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
491     Node* a = ops::SourceOp("Const", builder.opts()
492                                          .WithName("A")
493                                          .WithAttr("dtype", DT_FLOAT)
494                                          .WithAttr("value", Tensor())
495                                          .WithAttr(kXlaCompileAttr, true)
496                                          .WithAttr(kXlaScopeAttr, "ScopeA"));
497     Node* b = ops::UnaryOp("Relu", a,
498                            builder.opts()
499                                .WithName("B")
500                                .WithAttr(kXlaCompileAttr, true)
501                                .WithAttr(kXlaScopeAttr, "ScopeB"));
502     ops::BinaryOp("MatMul", a, b, builder.opts().WithName("C"));
503     TF_CHECK_OK(GraphDefBuilderToGraph(builder, graph.get()));
504   }
505 
506   TF_ASSERT_OK(
507       MarkForCompilationPassTestHelper::MarkForCompilation(&graph, false));
508   auto clusters = GetClusters(*graph);
509 
510   // The computation is: C = A @ relu(A)
511   // where A sits in ScopeA, relu(A) sits in ScopeB, and C sits in ScopeC.
512   // In this case, we cannot fuse anything.
513   EXPECT_EQ(3, clusters.size());
514   EXPECT_NE(clusters["A"], clusters["B"]);
515   EXPECT_EQ(clusters["B"], clusters["C"]);
516 }
517 
518 namespace {
MakeRead(const Scope & scope,const string & id,Node ** var_handle_op=nullptr)519 Node* MakeRead(const Scope& scope, const string& id,
520                Node** var_handle_op = nullptr) {
521   Output var_handle =
522       ops::VarHandleOp(scope.WithOpName("Var" + id), DT_FLOAT, TensorShape({}));
523   Output read =
524       ops::ReadVariableOp(scope.WithOpName("Read" + id), var_handle, DT_FLOAT);
525   if (var_handle_op) {
526     *var_handle_op = var_handle.node();
527   }
528   return read.node();
529 }
530 
MakeWrite(const Scope & scope,const string & id)531 Node* MakeWrite(const Scope& scope, const string& id) {
532   Output var_handle =
533       ops::VarHandleOp(scope.WithOpName("Var" + id), DT_FLOAT, TensorShape({}));
534   Output value_to_write =
535       ops::Const(scope.WithOpName("ValueToAssign" + id), 1.0f);
536   ops::AssignVariableOp assign_op(scope.WithOpName("Assignment" + id),
537                                   var_handle, value_to_write);
538   return assign_op.operation.node();
539 }
540 
MakeNeutral(const Scope & scope,const string & id)541 Node* MakeNeutral(const Scope& scope, const string& id) {
542   return ops::Const(scope.WithOpName("Const" + id), 42.0f).node();
543 }
544 }  // namespace
545 
TEST(XlaCompilationTest,ResourcesClusteringAllowed)546 TEST(XlaCompilationTest, ResourcesClusteringAllowed) {
547   Scope root = Scope::NewRootScope().ExitOnError();
548 
549   Node* read = MakeRead(root, "R");
550   Node* write = MakeWrite(root, "W");
551 
552   root.graph()->AddControlEdge(read, write);
553 
554   FixupSourceAndSinkEdges(root.graph());
555   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
556   TF_EXPECT_OK(root.ToGraph(graph.get()));
557   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
558   absl::flat_hash_map<string, std::vector<string>> cluster_sets =
559       GetClusterSets(*graph);
560   ASSERT_EQ(cluster_sets.size(), 1);
561   std::vector<string> expected_clustered_nodes = {"AssignmentW", "ReadR",
562                                                   "ValueToAssignW"};
563   ASSERT_EQ(cluster_sets.begin()->second, expected_clustered_nodes);
564 }
565 
TEST(XlaCompilationTest,ResourcesClusteringDisallowed)566 TEST(XlaCompilationTest, ResourcesClusteringDisallowed) {
567   Scope root = Scope::NewRootScope().ExitOnError();
568 
569   Node* read = MakeRead(root, "R");
570   Node* write = MakeWrite(root, "W");
571 
572   root.graph()->AddControlEdge(write, read);
573 
574   FixupSourceAndSinkEdges(root.graph());
575   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
576   TF_EXPECT_OK(root.ToGraph(graph.get()));
577   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
578   absl::flat_hash_map<string, std::vector<string>> cluster_sets =
579       GetClusterSets(*graph);
580   ASSERT_EQ(cluster_sets.size(), 0);
581 }
582 
TEST(XlaCompilationTest,ChainOfOps)583 TEST(XlaCompilationTest, ChainOfOps) {
584   Scope root = Scope::NewRootScope().ExitOnError();
585 
586   Node* write_0 = MakeWrite(root, "W0");
587   Node* neutral_0 = MakeNeutral(root, "N0");
588   Node* read_0 = MakeRead(root, "R0");
589   Node* write_1 = MakeWrite(root, "W1");
590   Node* neutral_1 = MakeNeutral(root, "N1");
591   Node* read_1 = MakeRead(root, "R1");
592 
593   root.graph()->AddControlEdge(write_0, neutral_0);
594   root.graph()->AddControlEdge(neutral_0, read_0);
595   root.graph()->AddControlEdge(read_0, write_1);
596   root.graph()->AddControlEdge(write_1, neutral_1);
597   root.graph()->AddControlEdge(neutral_1, read_1);
598 
599   FixupSourceAndSinkEdges(root.graph());
600   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
601   TF_EXPECT_OK(root.ToGraph(graph.get()));
602   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
603 
604   std::vector<string> cluster_names;
605   absl::flat_hash_map<string, std::vector<string>> cluster_sets =
606       GetClusterSets(*graph, &cluster_names);
607 
608   ASSERT_EQ(cluster_sets.size(), 1);
609 
610   std::vector<string> expected_clustered_nodes_a = {
611       "AssignmentW1", "ConstN1", "ReadR0", "ValueToAssignW1"};
612   ASSERT_EQ(cluster_sets[cluster_names[0]], expected_clustered_nodes_a);
613 }
614 
TEST(XlaCompilationTest,IllegalCycle_UsefulErrorMessage)615 TEST(XlaCompilationTest, IllegalCycle_UsefulErrorMessage) {
616   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
617   Scope root = Scope::NewRootScope().ExitOnError();
618   {
619     auto BuildNoopNode = [](absl::string_view name, Graph* graph) {
620       NodeDefBuilder builder(name, "NoOp");
621       NodeDef def;
622       TF_CHECK_OK(builder.Finalize(&def));
623 
624       Status status;
625       Node* node = graph->AddNode(def, &status);
626       TF_CHECK_OK(status);
627       return node;
628     };
629 
630     Node* a = BuildNoopNode("a", graph.get());
631     Node* b = BuildNoopNode("b", graph.get());
632     Node* c = BuildNoopNode("c", graph.get());
633     graph->AddControlEdge(a, b);
634     graph->AddControlEdge(b, c);
635     graph->AddControlEdge(c, a);
636   }
637 
638   TF_EXPECT_OK(root.ToGraph(graph.get()));
639 
640   Status status = MarkForCompilationPassTestHelper::MarkForCompilation(&graph);
641   EXPECT_FALSE(status.ok());
642   EXPECT_TRUE(absl::StrContains(status.ToString(),
643                                 "Edge from c to a would create a cycle.\n"
644                                 "+-> a\n"
645                                 "|   b\n"
646                                 "+-- c\n"));
647 }
648 
TEST(XlaCompilationTest,Retval)649 TEST(XlaCompilationTest, Retval) {
650   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
651   GraphDef graphdef;
652   {
653     GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
654     Node* a = ops::SourceOp("Const", builder.opts()
655                                          .WithName("A")
656                                          .WithAttr("dtype", DT_FLOAT)
657                                          .WithAttr("value", Tensor()));
658     Node* b = ops::UnaryOp("Relu", a, builder.opts().WithName("B"));
659     ops::UnaryOp("_Retval", b,
660                  builder.opts()
661                      .WithName("R")
662                      .WithAttr("T", DT_FLOAT)
663                      .WithAttr("index", 0));
664 
665     TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
666   }
667 
668   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
669   auto clusters = GetClusters(*graph);
670 
671   EXPECT_TRUE(clusters.empty());
672 }
673 
TEST(XlaCompilationTest,DontCountIdentityOps)674 TEST(XlaCompilationTest, DontCountIdentityOps) {
675   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
676   Scope root = Scope::NewRootScope().ExitOnError();
677   {
678     auto a = ops::_Arg(root.WithOpName("A"), DT_INT32, 0);
679     auto b = ops::Identity(root.WithOpName("B"), a);
680     auto c = ops::Identity(root.WithOpName("C"), b);
681     auto r = ops::_Retval(root.WithOpName("R"), c, 0);
682   }
683   TF_ASSERT_OK(root.ToGraph(graph.get()));
684   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
685   auto clusters = GetClusters(*graph);
686 
687   EXPECT_TRUE(clusters.empty());
688 }
689 
TEST(XlaCompilationTest,ConstOp)690 TEST(XlaCompilationTest, ConstOp) {
691   // valid data type
692   {
693     std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
694     Scope root = Scope::NewRootScope().ExitOnError();
695     auto c = ops::Const(root.WithOpName("const"), 0.5f);
696     c.node()->AddAttr(kXlaCompileAttr, true);
697     TF_ASSERT_OK(root.ToGraph(graph.get()));
698     TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
699     EXPECT_EQ(1, GetClusters(*graph).size());
700   }
701 
702   // invalid data type
703   {
704     std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
705     Scope root = Scope::NewRootScope().ExitOnError();
706     auto c = ops::Const(root.WithOpName("const"), string("string"));
707     c.node()->AddAttr(kXlaCompileAttr, true);
708     TF_ASSERT_OK(root.ToGraph(graph.get()));
709     TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
710     EXPECT_TRUE(GetClusters(*graph).empty());
711   }
712 }
713 
TEST(XlaCompilationTest,DontClusterIdentityWithRefInput)714 TEST(XlaCompilationTest, DontClusterIdentityWithRefInput) {
715   Scope root = Scope::NewRootScope().ExitOnError();
716   Output variable = ops::Variable(root.WithOpName("variable"),
717                                   PartialTensorShape{}, DT_FLOAT);
718   Output read = ops::Identity(root.WithOpName("read"), variable);
719   Output neg = ops::Negate(root.WithOpName("negate"), read);
720   Output add = ops::Add(root.WithOpName("add"), neg, neg);
721   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
722 
723   TF_ASSERT_OK(root.ToGraph(graph.get()));
724   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
725 
726   std::unordered_map<string, string> clusters = GetClusters(*graph);
727 
728   ASSERT_FALSE(clusters.empty());
729   string cluster_name = clusters.begin()->second;
730 
731   std::unordered_map<string, string> expected_clusters(
732       {{"negate", cluster_name}, {"add", cluster_name}});
733   EXPECT_EQ(clusters, expected_clusters);
734 }
735 
TEST(XlaCompilationTest,ClusterIdentityWithNonRefInput)736 TEST(XlaCompilationTest, ClusterIdentityWithNonRefInput) {
737   Scope root = Scope::NewRootScope().ExitOnError();
738   Output variable = ops::Variable(root.WithOpName("variable"),
739                                   PartialTensorShape{}, DT_FLOAT);
740   Output read = ops::Identity(root.WithOpName("read"), variable);
741   Output neg = ops::Negate(root.WithOpName("negate"), read);
742   Output identity = ops::Negate(root.WithOpName("identity"), neg);
743   Output add = ops::Add(root.WithOpName("add"), identity, neg);
744   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
745 
746   TF_ASSERT_OK(root.ToGraph(graph.get()));
747   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
748 
749   std::unordered_map<string, string> clusters = GetClusters(*graph);
750 
751   ASSERT_FALSE(clusters.empty());
752   string cluster_name = clusters.begin()->second;
753 
754   std::unordered_map<string, string> expected_clusters(
755       {{"negate", cluster_name},
756        {"identity", cluster_name},
757        {"add", cluster_name}});
758   EXPECT_EQ(clusters, expected_clusters);
759 }
760 
TEST(XlaCompilationTest,ClusterControlTrigger)761 TEST(XlaCompilationTest, ClusterControlTrigger) {
762   Scope root = Scope::NewRootScope().ExitOnError();
763 
764   Output recv_a = ops::_Recv(root.WithOpName("recv_a"), DT_BOOL, "tensor_a",
765                              "sender", 0, "receiver");
766   Output recv_b = ops::_Recv(root.WithOpName("recv_b"), DT_BOOL, "tensor_b",
767                              "sender", 0, "receiver");
768   Output const_a = ops::Const(root.WithOpName("const_a"), 42);
769 
770   ops::ControlTrigger ctrl_trigger_a(root.WithOpName("ctrl_trigger_a"));
771   ops::ControlTrigger ctrl_trigger_b(root.WithOpName("ctrl_trigger_b"));
772   root.graph()->AddControlEdge(recv_a.node(), ctrl_trigger_a.operation.node());
773   root.graph()->AddControlEdge(recv_b.node(), ctrl_trigger_a.operation.node());
774   root.graph()->AddControlEdge(ctrl_trigger_b.operation.node(), const_a.node());
775 
776   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
777 
778   TF_ASSERT_OK(root.ToGraph(graph.get()));
779   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
780 
781   std::unordered_map<string, string> clusters = GetClusters(*graph);
782 
783   // TODO(b/118970344): ctrl_trigger_a has inputs with mismatching deadness so
784   // it won't be clustered.  ctrl_trigger_b is okay to cluster but we don't
785   // cluster it because of b/118970344.
786   EXPECT_TRUE(clusters.empty());
787 }
788 
TEST(XlaCompilationTest,RandomShape)789 TEST(XlaCompilationTest, RandomShape) {
790   Scope root = Scope::NewRootScope().ExitOnError();
791   Output shape_shape = ops::Const(root.WithOpName("shape_shape"), {2}, {1});
792   Output shape =
793       ops::RandomUniformInt(root.WithOpName("shape"), shape_shape,
794                             ops::Const(root.WithOpName("minval"), 1),
795                             ops::Const(root.WithOpName("maxval"), 20));
796   Output reshape_input =
797       ops::Placeholder(root.WithOpName("reshape_input"), DT_FLOAT,
798                        ops::Placeholder::Shape(TensorShape({500, 500})));
799   Output reshape =
800       ops::Reshape(root.WithOpName("reshape"), reshape_input, shape);
801 
802   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
803 
804   TF_ASSERT_OK(root.ToGraph(graph.get()));
805   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
806 
807   std::unordered_map<string, string> clusters = GetClusters(*graph);
808   EXPECT_EQ(clusters["shape"], "");
809 }
810 
TEST(XlaCompilationTest,RandomShapeWithFunc)811 TEST(XlaCompilationTest, RandomShapeWithFunc) {
812   Scope root = Scope::DisabledShapeInferenceScope().ExitOnError();
813 
814   FunctionDefLibrary flib_def;
815   FunctionDef func = FunctionDefHelper::Create(
816       /*function_name=*/"Stateful_func", /*in_def=*/{},
817       /*out_def=*/{"out: int32"},
818       /*attr_def*/
819       {}, /*node_def=*/
820       {FunctionDefHelper::Const("shape_shape", 2),
821        FunctionDefHelper::Const("minval", 1),
822        FunctionDefHelper::Const("maxval", 20),
823        {{"shape"},
824         "RandomUniformInt",
825         {"shape_shape:output:0", "minval:output:0", "maxval:output:0"},
826         {{"Tout", DataType::DT_INT32}, {"T", DataType::DT_INT32}}}},
827       /*ret_def=*/{{"out", "shape:output:0"}});
828 
829   func.mutable_signature()->set_is_stateful(true);
830   *flib_def.add_function() = std::move(func);
831   TF_ASSERT_OK(root.graph()->AddFunctionLibrary(flib_def));
832   NodeDef call_node;
833   call_node.set_name("fn_call");
834   call_node.set_op("Stateful_func");
835   Status status;
836   Node* call = root.graph()->AddNode(call_node, &status);
837   TF_ASSERT_OK(status);
838 
839   Output shape = Output(call, 0);
840   Output reshape_input =
841       ops::Placeholder(root.WithOpName("reshape_input"), DT_FLOAT,
842                        ops::Placeholder::Shape(TensorShape({500, 500})));
843   Output reshape =
844       ops::Reshape(root.WithOpName("reshape"), reshape_input, shape);
845 
846   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
847   TF_ASSERT_OK(root.ToGraph(graph.get()));
848   auto fld = absl::make_unique<FunctionLibraryDefinition>(OpRegistry::Global(),
849                                                           flib_def);
850   TF_ASSERT_OK(
851       MarkForCompilationPassTestHelper::MarkForCompilation(&graph, fld.get()));
852 
853   std::unordered_map<string, string> clusters = GetClusters(*graph);
854   EXPECT_EQ(clusters["fn_call"], "");
855 }
856 
TEST(XlaCompilationTest,RandomShapeOnXlaDevice)857 TEST(XlaCompilationTest, RandomShapeOnXlaDevice) {
858   absl::string_view xla_gpu_device =
859       "/job:worker/replica:0/task:0/device:XLA_GPU:0";
860 
861   Scope root = Scope::NewRootScope().ExitOnError();
862   Output shape_shape =
863       ops::Const(root.WithOpName("test/shape_shape"), {2}, {1});
864   Output shape =
865       ops::RandomUniformInt(root.WithOpName("test/shape_rng"), shape_shape,
866                             ops::Const(root.WithOpName("test/minval"), 1),
867                             ops::Const(root.WithOpName("test/maxval"), 20));
868   Output reshape_input =
869       ops::Placeholder(root.WithOpName("test/reshape_input"), DT_FLOAT,
870                        ops::Placeholder::Shape(TensorShape({500, 500})));
871   Output reshape =
872       ops::Reshape(root.WithOpName("test/reshape"), reshape_input, shape);
873 
874   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
875   TF_ASSERT_OK(root.ToGraph(graph.get()));
876 
877   for (Node* n : graph->nodes()) {
878     if (absl::StartsWith(n->name(), /*prefix=*/"test/")) {
879       n->set_assigned_device_name(string(xla_gpu_device));
880     }
881   }
882   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
883 
884   std::unordered_map<string, string> clusters = GetClusters(*graph);
885   EXPECT_EQ(clusters["test/shape_rng"], "");
886   EXPECT_EQ(clusters["test/reshape"], "");
887 }
888 
TEST(XlaCompilationTest,TensorArrayShapeOnXlaDevice)889 TEST(XlaCompilationTest, TensorArrayShapeOnXlaDevice) {
890   absl::string_view xla_gpu_device =
891       "/job:worker/replica:0/task:0/device:XLA_GPU:0";
892   Scope root = Scope::NewRootScope().ExitOnError();
893   ops::TensorArray tensor_array(root.WithOpName("test/tensor_array"), 1,
894                                 DT_INT32);
895   Output zero = ops::Const(root.WithOpName("test/zero"), 0);
896   ops::TensorArrayWrite tensor_array_write(
897       root.WithOpName("test/write"), tensor_array.handle, zero,
898       ops::Const(root.WithOpName("test/forty_two"), 42.0f), tensor_array.flow);
899   Output tensor_array_read =
900       ops::TensorArrayRead(root.WithOpName("test/read"), tensor_array.handle,
901                            zero, tensor_array_write.flow_out, DT_INT32);
902   Output reshape =
903       ops::Reshape(root.WithOpName("test/reshape"),
904                    ops::Placeholder(root.WithOpName("placeholder"), DT_FLOAT),
905                    tensor_array_read);
906 
907   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
908   TF_ASSERT_OK(root.ToGraph(graph.get()));
909 
910   for (Node* n : graph->nodes()) {
911     if (absl::StartsWith(n->name(), /*prefix=*/"test/")) {
912       n->set_assigned_device_name(string(xla_gpu_device));
913     }
914   }
915   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
916 
917   std::unordered_map<string, string> clusters = GetClusters(*graph);
918   EXPECT_NE(clusters["test/read"], "");
919   EXPECT_EQ(clusters["test/read"], clusters["test/reshape"]);
920 }
921 
TEST(XlaCompilationTest,DontClusterMergingNodes)922 TEST(XlaCompilationTest, DontClusterMergingNodes) {
923   // MatMulCombined below takes data from nodes on GPU0 and GPU1 and is placed
924   // on GPU1. However, it should not be clustered with the previous node on
925   // GPU1, because that will serialize production of its inputs that should be
926   // done in parallel.
927   //
928   // This graph is:
929   // (Const0, Const0) -> MatMul0
930   // (Const1, Const1) -> MatMul1
931   // (MatMul0, MatMul1) -> MatMulCombined
932   //
933   // Device0: [Const0, Const0, MatMul0]
934   // Device1: [Const1, Const1, MatMul1, MatMulCombined]
935   //
936   // Cluster0: [Const0, Const0, MatMul0]
937   // Cluster1: [Const1, Const1, MatMul1]
938   // Cluster2: [MatMulCombined]
939   Scope root = Scope::NewRootScope().ExitOnError();
940   absl::string_view xla_gpu_dev0 =
941       "/job:worker/replica:0/task:0/device:XLA_GPU:0";
942   absl::string_view xla_gpu_dev1 =
943       "/job:worker/replica:0/task:0/device:XLA_GPU:1";
944   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
945   Output a = ops::Tanh(root.WithOpName("tanh_A_dev0"),
946                        ops::Const(root.WithOpName("A_dev0"), 1.0f, {2, 2}));
947   Output b = ops::Tanh(root.WithOpName("tanh_B_dev1"),
948                        ops::Const(root.WithOpName("B_dev1"), 1.0f, {2, 2}));
949   Output matmul0 = ops::MatMul(root.WithOpName("MatMul0_dev0"), a, a);
950   Output matmul1 = ops::MatMul(root.WithOpName("MatMul1_dev1"), b, b);
951 
952   Output combined =
953       ops::MatMul(root.WithOpName("MatMulCombined_dev1"), matmul0, matmul1);
954   TF_ASSERT_OK(root.ToGraph(graph.get()));
955 
956   for (Node* n : graph->nodes()) {
957     if (absl::EndsWith(n->name(), /*suffix=*/"dev0")) {
958       n->set_assigned_device_name(string(xla_gpu_dev0));
959     } else if (absl::EndsWith(n->name(), /*suffix=*/"dev1")) {
960       n->set_assigned_device_name(string(xla_gpu_dev1));
961     }
962   }
963   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
964 
965   // Each of the MatMuls should be in a separate cluster.
966   std::unordered_map<string, string> clusters = GetClusters(*graph);
967   EXPECT_NE(clusters["MatMul0_dev0"], clusters["MatMul1_dev1"]);
968   EXPECT_NE(clusters["MatMulCombined_dev1"], clusters["MatMul0_dev0"]);
969   EXPECT_NE(clusters["MatMulCombined_dev1"], clusters["MatMul1_dev1"]);
970   EXPECT_EQ(clusters["A_dev0"], clusters["MatMul0_dev0"]);
971   EXPECT_EQ(clusters["B_dev1"], clusters["MatMul1_dev1"]);
972 }
973 
974 // TODO(b/117085735): This form of clustering should be prevented.
TEST(XlaCompilationTest,NOT_DontClusterSpreadingNodes)975 TEST(XlaCompilationTest, NOT_DontClusterSpreadingNodes) {
976   // MatMulSource below creates data for nodes on GPU0 and GPU1 and is placed
977   // on GPU0. However, it should not be clustered with the next node on
978   // GPU0, because that will prevent the node on GPU1 from beginning its work as
979   // soon as the data has been produced.
980   //
981   // This graph is:
982   // (Const0, Const0) -> MatMulSource
983   // MatMulSource -> (MatMul0, MatMul1)
984   //
985   // Device0: [Const0, Const1, MatMulSource, MatMul0]
986   // Device1: [MatMul1]
987   //
988   // Cluster0: [Const0, Const1, MatMulSource]
989   // Cluster1: [MatMul0]
990   // Cluster2: [MatMul1]
991   Scope root = Scope::NewRootScope().ExitOnError();
992   absl::string_view xla_gpu_dev0 =
993       "/job:worker/replica:0/task:0/device:XLA_GPU:0";
994   absl::string_view xla_gpu_dev1 =
995       "/job:worker/replica:0/task:0/device:XLA_GPU:1";
996   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
997   Output a = ops::Const(root.WithOpName("A_dev0"), 1.0f, {2, 2});
998   Output matmul_source =
999       ops::MatMul(root.WithOpName("MatMulSource_dev0"), a, a);
1000 
1001   Output matmul0 = ops::MatMul(root.WithOpName("MatMul0_dev0"), matmul_source,
1002                                matmul_source);
1003   Output matmul1 = ops::MatMul(root.WithOpName("MatMul1_dev1"), matmul_source,
1004                                matmul_source);
1005 
1006   TF_ASSERT_OK(root.ToGraph(graph.get()));
1007   for (Node* n : graph->nodes()) {
1008     if (absl::EndsWith(n->name(), /*suffix=*/"dev0")) {
1009       n->set_assigned_device_name(string(xla_gpu_dev0));
1010     } else if (absl::EndsWith(n->name(), /*suffix=*/"dev1")) {
1011       n->set_assigned_device_name(string(xla_gpu_dev1));
1012     }
1013   }
1014   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
1015 
1016   std::unordered_map<string, string> clusters = GetClusters(*graph);
1017   EXPECT_EQ(clusters["A_dev0"], clusters["MatMulSource_dev0"]);
1018   EXPECT_NE(clusters["MatMul0_dev0"], clusters["MatMul1_dev1"]);
1019   EXPECT_NE(clusters["MatMulSource_dev0"], clusters["MatMul1_dev1"]);
1020 
1021   // Improved Heuristics should prevent this probably.
1022   EXPECT_EQ(clusters["MatMulSource_dev0"], clusters["MatMul0_dev0"]);
1023 }
1024 
TEST(XlaCompilationTest,ClusterStatefulRandomOpOnXlaDevice)1025 TEST(XlaCompilationTest, ClusterStatefulRandomOpOnXlaDevice) {
1026   absl::string_view xla_cpu_device =
1027       "/job:worker/replica:0/task:0/device:XLA_CPU:0";
1028 
1029   Scope root = Scope::NewRootScope().ExitOnError();
1030   Output shape = ops::Const(root.WithOpName("test/shape_shape"), {200, 200});
1031   Output a = ops::RandomUniform(root.WithOpName("test/a"), shape, DT_FLOAT);
1032   Output b = ops::RandomUniform(root.WithOpName("test/b"), shape, DT_FLOAT);
1033   Output c = ops::Add(root.WithOpName("test/c"), a, b);
1034 
1035   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1036   TF_ASSERT_OK(root.ToGraph(graph.get()));
1037 
1038   for (Node* n : graph->nodes()) {
1039     if (absl::StartsWith(n->name(), /*prefix=*/"test/")) {
1040       n->set_assigned_device_name(string(xla_cpu_device));
1041     }
1042   }
1043   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
1044 
1045   std::unordered_map<string, string> clusters = GetClusters(*graph);
1046   EXPECT_NE(clusters["test/a"], "");
1047   EXPECT_NE(clusters["test/b"], "");
1048   EXPECT_NE(clusters["test/c"], "");
1049 }
1050 
TEST(XlaCompilationTest,DontAutoClusterStatefulRandomOp)1051 TEST(XlaCompilationTest, DontAutoClusterStatefulRandomOp) {
1052   Scope root = Scope::NewRootScope().ExitOnError();
1053   Output shape = ops::Const(root.WithOpName("test/shape_shape"), {200, 200});
1054   Output a = ops::RandomUniform(root.WithOpName("test/a"), shape, DT_FLOAT);
1055   Output b = ops::RandomUniform(root.WithOpName("test/b"), shape, DT_FLOAT);
1056   Output c = ops::Add(root.WithOpName("test/c"), a, b);
1057 
1058   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1059   TF_ASSERT_OK(root.ToGraph(graph.get()));
1060 
1061   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
1062 
1063   std::unordered_map<string, string> clusters = GetClusters(*graph);
1064   EXPECT_EQ(clusters["test/a"], "");
1065   EXPECT_EQ(clusters["test/b"], "");
1066 }
1067 
TEST(XlaCompilationTest,ClusterDummyOpsOnXlaDevice)1068 TEST(XlaCompilationTest, ClusterDummyOpsOnXlaDevice) {
1069   absl::string_view xla_cpu_device =
1070       "/job:worker/replica:0/task:0/device:XLA_CPU:0";
1071 
1072   Scope root = Scope::NewRootScope().ExitOnError();
1073   Output a = ops::Placeholder(root.WithOpName("test/a"), DT_FLOAT);
1074   Output b = ops::Placeholder(root.WithOpName("test/b"), DT_FLOAT);
1075   Output check =
1076       ops::CheckNumerics(root.WithOpName("test/check"), a, "test/check");
1077   Output ge = ops::GreaterEqual(root.WithOpName("test/greaterequal"), check, b);
1078   Operation assert = ops::Assert(root.WithOpName("test/assert"), ge, {a, b});
1079 
1080   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1081   TF_ASSERT_OK(root.ToGraph(graph.get()));
1082 
1083   for (Node* n : graph->nodes()) {
1084     if (absl::StartsWith(n->name(), /*prefix=*/"test/")) {
1085       n->set_assigned_device_name(string(xla_cpu_device));
1086     }
1087   }
1088   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
1089 
1090   std::unordered_map<string, string> clusters = GetClusters(*graph);
1091   EXPECT_NE(clusters["test/check"], "");
1092   EXPECT_NE(clusters["test/greaterequal"], "");
1093   EXPECT_NE(clusters["test/assert"], "");
1094 }
1095 
TEST(XlaCompilationTest,DontAutoClusterDummyOps)1096 TEST(XlaCompilationTest, DontAutoClusterDummyOps) {
1097   Scope root = Scope::NewRootScope().ExitOnError();
1098   Output a = ops::Placeholder(root.WithOpName("test/a"), DT_FLOAT);
1099   Output b = ops::Placeholder(root.WithOpName("test/b"), DT_FLOAT);
1100   Output check =
1101       ops::CheckNumerics(root.WithOpName("test/check"), a, "test/check");
1102   Output ge = ops::GreaterEqual(root.WithOpName("test/greaterequal"), check, b);
1103   Operation assert = ops::Assert(root.WithOpName("test/assert"), ge, {a, b});
1104 
1105   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1106   TF_ASSERT_OK(root.ToGraph(graph.get()));
1107 
1108   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
1109 
1110   std::unordered_map<string, string> clusters = GetClusters(*graph);
1111   EXPECT_EQ(clusters["test/assert"], "");
1112   EXPECT_EQ(clusters["test/check"], "");
1113 }
1114 
TEST(XlaCompilationTest,DontAutoClusterOpsProducingVariant)1115 TEST(XlaCompilationTest, DontAutoClusterOpsProducingVariant) {
1116   Scope root = Scope::NewRootScope().ExitOnError();
1117   Output a = ops::Placeholder(root.WithOpName("test/a"), DT_INT64);
1118   Output b = ops::Placeholder(root.WithOpName("test/b"), DT_INT64);
1119 
1120   Output cast_a = ops::Cast(root.WithOpName("test/cast_a"), a, DT_INT32);
1121   Output cast_b = ops::Cast(root.WithOpName("test/cast_b"), b, DT_INT32);
1122 
1123   Output tensor_list_reserve = ops::TensorListReserve(
1124       root.WithOpName("test/tensor_list_reserve"), cast_a, cast_b, DT_FLOAT);
1125 
1126   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1127   TF_ASSERT_OK(root.ToGraph(graph.get()));
1128 
1129   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
1130 
1131   std::unordered_map<string, string> clusters = GetClusters(*graph);
1132   EXPECT_EQ(clusters["test/tensor_list_reserve"], "");
1133 }
1134 
TEST(XlaCompilationTest,DontAutoClusterOpsConsumingVariant)1135 TEST(XlaCompilationTest, DontAutoClusterOpsConsumingVariant) {
1136   Scope root = Scope::NewRootScope().ExitOnError();
1137   Output dummy_input =
1138       ops::Placeholder(root.WithOpName("test/dummy_input"), DT_INT64);
1139   Output variant_input =
1140       ops::Placeholder(root.WithOpName("test/variant_input"), DT_VARIANT);
1141 
1142   // Create one more node so that we don't avoid creating a cluster solely
1143   // because it would be trivial.
1144   Output dummy_cast =
1145       ops::Cast(root.WithOpName("test/dummy_cast"), dummy_input, DT_INT32);
1146 
1147   Output tensor_list_element_shape = ops::TensorListElementShape(
1148       root.WithOpName("test/tensor_list_element_shape"), variant_input,
1149       DT_INT32);
1150 
1151   root.graph()->AddControlEdge(dummy_cast.node(),
1152                                tensor_list_element_shape.node());
1153 
1154   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1155   TF_ASSERT_OK(root.ToGraph(graph.get()));
1156 
1157   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
1158 
1159   std::unordered_map<string, string> clusters = GetClusters(*graph);
1160   EXPECT_EQ(clusters["test/tensor_list_element_shape"], "");
1161 }
1162 
TEST(XlaCompilationTest,ClusterOpsProducingVariantIfOnXlaDevice)1163 TEST(XlaCompilationTest, ClusterOpsProducingVariantIfOnXlaDevice) {
1164   Scope root = Scope::NewRootScope().ExitOnError();
1165   Output a = ops::Placeholder(root.WithOpName("test/a"), DT_INT64);
1166   Output b = ops::Placeholder(root.WithOpName("test/b"), DT_INT64);
1167 
1168   Output cast_a = ops::Cast(root.WithOpName("test/cast_a"), a, DT_INT32);
1169   Output cast_b = ops::Cast(root.WithOpName("test/cast_b"), b, DT_INT32);
1170 
1171   Output tensor_list_reserve = ops::TensorListReserve(
1172       root.WithOpName("test/tensor_list_reserve"), cast_a, cast_b, DT_FLOAT);
1173 
1174   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1175   TF_ASSERT_OK(root.ToGraph(graph.get()));
1176 
1177   string xla_cpu_device = "/job:worker/replica:0/task:0/device:XLA_CPU:0";
1178   for (Node* n : graph->nodes()) {
1179     if (absl::StartsWith(n->name(), /*prefix=*/"test/")) {
1180       n->set_assigned_device_name(xla_cpu_device);
1181     }
1182   }
1183 
1184   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
1185 
1186   std::unordered_map<string, string> clusters = GetClusters(*graph);
1187   EXPECT_NE(clusters["test/tensor_list_reserve"], "");
1188 }
1189 
1190 const char* kCPU0 = "/job:worker/replica:0/task:0/device:CPU:0";
1191 const char* kGPU0 = "/job:worker/replica:0/task:0/device:GPU:0";
1192 const char* kXLA_GPU0 = "/job:worker/replica:0/task:0/device:XLA_GPU:0";
1193 const char* kGPU1 = "/job:worker/replica:0/task:0/device:GPU:1";
1194 
TEST(XlaCompilationTest,CreateCombinedCpuGpuClusters)1195 TEST(XlaCompilationTest, CreateCombinedCpuGpuClusters) {
1196   Scope root = Scope::NewRootScope().ExitOnError();
1197   Output a = ops::Placeholder(root.WithOpName("test/a"), DT_FLOAT);
1198   Output b = ops::Placeholder(root.WithOpName("test/b"), DT_FLOAT);
1199 
1200   Output x = ops::Add(root.WithOpName("test/x"), a, b);
1201   Output y = ops::MatMul(root.WithOpName("test/y"), a, b);
1202   Output z = ops::Add(root.WithOpName("test/z"), x, y);
1203 
1204   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1205   TF_ASSERT_OK(root.ToGraph(graph.get()));
1206 
1207   FindNodeByName(graph.get(), "test/x")->set_assigned_device_name(kGPU0);
1208   FindNodeByName(graph.get(), "test/y")->set_assigned_device_name(kCPU0);
1209   FindNodeByName(graph.get(), "test/z")->set_assigned_device_name(kGPU0);
1210 
1211   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
1212 
1213   std::unordered_map<string, string> clusters = GetClusters(*graph);
1214 
1215   EXPECT_NE(clusters["test/x"], "");
1216 
1217   EXPECT_EQ(clusters["test/x"], clusters["test/y"]);
1218   EXPECT_EQ(clusters["test/y"], clusters["test/z"]);
1219 }
1220 
TEST(XlaCompilationTest,DontCreateGpu0AndGpu1Clusters)1221 TEST(XlaCompilationTest, DontCreateGpu0AndGpu1Clusters) {
1222   Scope root = Scope::NewRootScope().ExitOnError();
1223   Output a = ops::Placeholder(root.WithOpName("test/a"), DT_FLOAT);
1224   Output b = ops::Placeholder(root.WithOpName("test/b"), DT_FLOAT);
1225 
1226   Output x = ops::Add(root.WithOpName("test/x"), a, b);
1227   Output y = ops::Add(root.WithOpName("test/y"), x, x);
1228 
1229   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1230   TF_ASSERT_OK(root.ToGraph(graph.get()));
1231 
1232   FindNodeByName(graph.get(), "test/x")->set_assigned_device_name(kGPU0);
1233   FindNodeByName(graph.get(), "test/y")->set_assigned_device_name(kGPU1);
1234 
1235   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
1236 
1237   std::unordered_map<string, string> clusters = GetClusters(*graph);
1238 
1239   EXPECT_EQ(clusters["test/x"], "");
1240   EXPECT_EQ(clusters["test/y"], "");
1241 }
1242 
TEST(XlaCompilationTest,DontCreateCombinedCpuUnknownClusters)1243 TEST(XlaCompilationTest, DontCreateCombinedCpuUnknownClusters) {
1244   Scope root = Scope::NewRootScope().ExitOnError();
1245   Output a = ops::Placeholder(root.WithOpName("test/a"), DT_FLOAT);
1246   Output b = ops::Placeholder(root.WithOpName("test/b"), DT_FLOAT);
1247 
1248   Output x = ops::Add(root.WithOpName("test/x"), a, b);
1249   Output y = ops::Add(root.WithOpName("test/y"), x, x);
1250 
1251   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1252   TF_ASSERT_OK(root.ToGraph(graph.get()));
1253 
1254   FindNodeByName(graph.get(), "test/x")->set_assigned_device_name(kCPU0);
1255   FindNodeByName(graph.get(), "test/y")->set_assigned_device_name(kXLA_GPU0);
1256 
1257   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
1258 
1259   std::unordered_map<string, string> clusters = GetClusters(*graph);
1260 
1261   EXPECT_EQ(clusters["test/x"], "");
1262   EXPECT_EQ(clusters["test/y"], "");
1263 }
1264 
TEST(XlaCompilationTest,ClusterResourceOpsWhenSafe)1265 TEST(XlaCompilationTest, ClusterResourceOpsWhenSafe) {
1266   Scope root = Scope::NewRootScope().ExitOnError();
1267   Output a = ops::Placeholder(root.WithOpName("test/a"), DT_FLOAT);
1268   Node* var_handle;
1269   Node* resource_read = MakeRead(root, "read", &var_handle);
1270   Output b = ops::Add(root.WithOpName("test/b"), Output(resource_read, 0), a);
1271 
1272   string resource_read_name = resource_read->name();
1273   string var_handle_name = var_handle->name();
1274 
1275   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1276   TF_ASSERT_OK(root.ToGraph(graph.get()));
1277 
1278   FindNodeByName(graph.get(), "test/b")->set_assigned_device_name(kCPU0);
1279   FindNodeByName(graph.get(), resource_read_name)
1280       ->set_assigned_device_name(kGPU0);
1281   FindNodeByName(graph.get(), var_handle_name)->set_assigned_device_name(kGPU0);
1282 
1283   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
1284 
1285   std::unordered_map<string, string> clusters = GetClusters(*graph);
1286 
1287   EXPECT_NE(clusters["test/b"], "");
1288   EXPECT_EQ(clusters["test/b"], clusters[resource_read_name]);
1289 }
1290 
TEST(XlaCompilationTest,DontClusterResourceOpsWhenUnsafe)1291 TEST(XlaCompilationTest, DontClusterResourceOpsWhenUnsafe) {
1292   Scope root = Scope::NewRootScope().ExitOnError();
1293   Output a = ops::Placeholder(root.WithOpName("test/a"), DT_FLOAT);
1294   Node* var_handle;
1295   Node* resource_read = MakeRead(root, "read", &var_handle);
1296   Output b = ops::Add(root.WithOpName("test/b"), Output(resource_read, 0), a);
1297 
1298   string resource_read_name = resource_read->name();
1299   string var_handle_name = var_handle->name();
1300 
1301   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1302   TF_ASSERT_OK(root.ToGraph(graph.get()));
1303 
1304   FindNodeByName(graph.get(), "test/b")->set_assigned_device_name(kGPU0);
1305   FindNodeByName(graph.get(), resource_read_name)
1306       ->set_assigned_device_name(kCPU0);
1307   FindNodeByName(graph.get(), var_handle_name)->set_assigned_device_name(kCPU0);
1308 
1309   TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
1310 
1311   std::unordered_map<string, string> clusters = GetClusters(*graph);
1312 
1313   EXPECT_EQ(clusters["test/b"], "");
1314   EXPECT_EQ(clusters[resource_read_name], "");
1315 }
1316 
1317 }  // namespace
1318 }  // namespace tensorflow
1319