1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "absl/container/flat_hash_map.h"
17 #include "absl/memory/memory.h"
18 #include "absl/strings/match.h"
19 #include "tensorflow/cc/framework/ops.h"
20 #include "tensorflow/cc/ops/array_ops.h"
21 #include "tensorflow/cc/ops/control_flow_ops_internal.h"
22 #include "tensorflow/cc/ops/function_ops.h"
23 #include "tensorflow/cc/ops/functional_ops.h"
24 #include "tensorflow/cc/ops/list_ops.h"
25 #include "tensorflow/cc/ops/resource_variable_ops.h"
26 #include "tensorflow/cc/ops/sendrecv_ops.h"
27 #include "tensorflow/cc/ops/standard_ops.h"
28 #include "tensorflow/compiler/jit/defs.h"
29 #include "tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.h"
30 #include "tensorflow/compiler/jit/node_matchers.h"
31 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
32 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
33 #include "tensorflow/core/common_runtime/graph_constructor.h"
34 #include "tensorflow/core/common_runtime/graph_def_builder_util.h"
35 #include "tensorflow/core/framework/node_def_util.h"
36 #include "tensorflow/core/framework/op.h"
37 #include "tensorflow/core/graph/algorithm.h"
38 #include "tensorflow/core/graph/graph_def_builder.h"
39 #include "tensorflow/core/lib/core/status_test_util.h"
40 #include "tensorflow/core/platform/test.h"
41
42 using ::tensorflow::testing::FindNodeByName;
43
44 namespace tensorflow {
45 namespace {
46
__anon1ff13d790202null47 static bool Initialized = [] {
48 tensorflow::GetXlaDeviceFlags()->tf_xla_enable_xla_devices = true;
49 return true;
50 }();
51
52 REGISTER_OP("UncompilableNullary").Output("o: float");
53 REGISTER_OP("UncompilableUnary").Input("a: float").Output("o: float");
54
GetClusters(const Graph & graph)55 std::unordered_map<string, string> GetClusters(const Graph& graph) {
56 std::unordered_map<string, string> ids;
57 for (Node* node : graph.nodes()) {
58 string cluster;
59 if (TryGetNodeAttr(node->attrs(), kXlaClusterAttr, &cluster)) {
60 CHECK(!cluster.empty());
61 ids[node->name()] = cluster;
62 }
63 }
64
65 if (VLOG_IS_ON(2)) {
66 VLOG(2) << "Clusters:";
67 for (const auto& p : ids) {
68 VLOG(2) << " " << p.first << " -> " << p.second;
69 }
70 }
71 return ids;
72 }
73
GetClusterSets(const Graph & g,std::vector<string> * cluster_names=nullptr)74 absl::flat_hash_map<string, std::vector<string>> GetClusterSets(
75 const Graph& g, std::vector<string>* cluster_names = nullptr) {
76 CHECK(cluster_names == nullptr || cluster_names->empty());
77 absl::flat_hash_map<string, std::vector<string>> cluster_sets;
78 for (const auto& p : GetClusters(g)) {
79 cluster_sets[p.second].push_back(p.first);
80 }
81 for (auto& p : cluster_sets) {
82 if (cluster_names != nullptr) {
83 cluster_names->push_back(p.first);
84 }
85 std::sort(p.second.begin(), p.second.end());
86 }
87 if (cluster_names != nullptr) {
88 std::sort(cluster_names->begin(), cluster_names->end());
89 }
90 return cluster_sets;
91 }
92
TEST(XlaCompilationTest,Chains)93 TEST(XlaCompilationTest, Chains) {
94 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
95 {
96 GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
97 Node* a =
98 ops::SourceOp("UncompilableNullary", builder.opts().WithName("A"));
99 Node* b = ops::UnaryOp("Relu", a, builder.opts().WithName("B"));
100 Node* c = ops::UnaryOp("Relu", b, builder.opts().WithName("C"));
101 Node* d =
102 ops::UnaryOp("UncompilableUnary", c, builder.opts().WithName("D"));
103 Node* e = ops::UnaryOp("Relu", d, builder.opts().WithName("E"));
104 ops::UnaryOp("Relu", e, builder.opts().WithName("F"));
105 TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
106 }
107
108 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
109 auto clusters = GetClusters(*graph);
110 EXPECT_EQ(4, clusters.size());
111 EXPECT_EQ(clusters["B"], clusters["C"]);
112 EXPECT_EQ(clusters["E"], clusters["F"]);
113 EXPECT_NE(clusters["B"], clusters["E"]);
114 EXPECT_TRUE(clusters.find("A") == clusters.cend());
115 EXPECT_TRUE(clusters.find("D") == clusters.cend());
116 }
117
TEST(XlaCompilationTest,UncompilableCycles)118 TEST(XlaCompilationTest, UncompilableCycles) {
119 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
120 {
121 GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
122 Node* a = ops::SourceOp("Const", builder.opts()
123 .WithName("A")
124 .WithAttr("dtype", DT_FLOAT)
125 .WithAttr("value", Tensor()));
126 Node* b =
127 ops::UnaryOp("UncompilableUnary", a, builder.opts().WithName("B"));
128 ops::BinaryOp("MatMul", a, b, builder.opts().WithName("C"));
129 TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
130 }
131
132 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
133 auto clusters = GetClusters(*graph);
134
135 EXPECT_TRUE(clusters.empty());
136 }
137
TEST(XlaCompilationTest,CompilableCycles)138 TEST(XlaCompilationTest, CompilableCycles) {
139 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
140 {
141 GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
142 Node* a = ops::SourceOp("Const", builder.opts()
143 .WithName("A")
144 .WithAttr("dtype", DT_FLOAT)
145 .WithAttr("value", Tensor()));
146 Node* b = ops::UnaryOp("Relu", a, builder.opts().WithName("B"));
147 ops::BinaryOp("MatMul", a, b, builder.opts().WithName("C"));
148 TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
149 }
150
151 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
152 auto clusters = GetClusters(*graph);
153
154 EXPECT_EQ(3, clusters.size());
155 EXPECT_EQ(clusters["A"], clusters["B"]);
156 EXPECT_EQ(clusters["A"], clusters["C"]);
157 }
158
TEST(XlaCompilationTest,StringUnsupported)159 TEST(XlaCompilationTest, StringUnsupported) {
160 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
161 {
162 GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
163 Node* a = ops::SourceOp(
164 "Const", builder.opts()
165 .WithName("A")
166 .WithAttr("dtype", DT_STRING)
167 .WithAttr("value", Tensor(DT_STRING, TensorShape())));
168 Node* b = ops::UnaryOp("EncodeBase64", a, builder.opts().WithName("B"));
169 ops::BinaryOp("StringSplit", a, b, builder.opts().WithName("C"));
170 TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
171 }
172
173 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
174 auto clusters = GetClusters(*graph);
175 EXPECT_TRUE(clusters.empty());
176 }
177
TEST(XlaCompilationTest,HalfSupported)178 TEST(XlaCompilationTest, HalfSupported) {
179 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
180 {
181 GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
182 Tensor t(DT_HALF, TensorShape());
183 t.scalar<Eigen::half>()() = static_cast<Eigen::half>(0.0f);
184 Node* a = ops::SourceOp("Const", builder.opts()
185 .WithName("A")
186 .WithAttr("dtype", DT_HALF)
187 .WithAttr("value", t));
188 Node* b = ops::UnaryOp("Neg", a, builder.opts().WithName("B"));
189 ops::BinaryOp("MatMul", a, b, builder.opts().WithName("C"));
190 TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
191 }
192
193 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
194 auto clusters = GetClusters(*graph);
195 EXPECT_FALSE(clusters.empty());
196 }
197
198 // Tests that PartitionedCalls are only marked for compilation if every node
199 // inside the function can be compiled.
TEST(XlaCompilationTest,PartitionedCallUnsupported)200 TEST(XlaCompilationTest, PartitionedCallUnsupported) {
201 FunctionDef compilable = FunctionDefHelper::Define(
202 "CompilableFn", {"n_a:float", "n_b:float"}, {"n_c:float"}, {},
203 {{{"n_c"}, "Add", {"n_a", "n_b"}, {{"T", DT_FLOAT}}}});
204 FunctionDef uncompilable =
205 FunctionDefHelper::Define("UncompilableFn", {"n_a:float"}, {"n_c:float"},
206 {}, {{{"n_c"}, "UncompilableUnary", {"n_a"}}});
207
208 FunctionDefLibrary flib;
209 *flib.add_function() = compilable;
210 *flib.add_function() = uncompilable;
211 FunctionLibraryDefinition flib_def(OpRegistry::Global(), flib);
212
213 std::unique_ptr<Graph> graph(new Graph(&flib_def));
214 Scope root = Scope::NewRootScope().ExitOnError();
215 Output a = ops::Placeholder(root.WithOpName("A"), DT_FLOAT);
216
217 NameAttrList b_name_attr;
218 b_name_attr.set_name("CompilableFn");
219 ops::PartitionedCall b(root.WithOpName("B"), {a, a}, {DT_FLOAT}, b_name_attr);
220 NameAttrList c_name_attr;
221 c_name_attr.set_name("UncompilableFn");
222
223 ops::PartitionedCall c(root.WithOpName("C"), {a}, {DT_FLOAT}, c_name_attr);
224 Output d = ops::Add(root.WithOpName("D"), b.output.front(), c.output.front());
225
226 TF_ASSERT_OK(root.ToGraph(graph.get()));
227 TF_ASSERT_OK(
228 MarkForCompilationPassTestHelper::MarkForCompilation(&graph, &flib_def));
229 auto clusters = GetClusters(*graph);
230
231 EXPECT_EQ(2, clusters.size());
232 EXPECT_FALSE(clusters["B"].empty());
233 EXPECT_TRUE(clusters["C"].empty());
234 EXPECT_EQ(clusters["B"], clusters["D"]);
235 }
236
TEST(XlaCompilationTest,FunctionCalls)237 TEST(XlaCompilationTest, FunctionCalls) {
238 FunctionDef compilable = FunctionDefHelper::Define(
239 "CompilableFn", {"n_a:float", "n_b:float"}, {"n_c:float"}, {},
240 {{{"n_c"}, "Add", {"n_a", "n_b"}, {{"T", DT_FLOAT}}}});
241 FunctionDef uncompilable =
242 FunctionDefHelper::Define("UncompilableFn", {"n_a:float"}, {"n_c:float"},
243 {}, {{{"n_c"}, "UncompilableUnary", {"n_a"}}});
244 FunctionDef noinline = compilable;
245 noinline.mutable_signature()->set_name("NoInlineFn");
246 AddAttr("_noinline", static_cast<bool>(true), noinline.mutable_attr());
247
248 FunctionDefLibrary flib;
249 *flib.add_function() = compilable;
250 *flib.add_function() = uncompilable;
251 *flib.add_function() = noinline;
252 FunctionLibraryDefinition flib_def(OpRegistry::Global(), flib);
253
254 std::unique_ptr<Graph> graph(new Graph(&flib_def));
255 {
256 GraphDefBuilder builder(GraphDefBuilder::kFailImmediately, &flib_def);
257 Node* a =
258 ops::SourceOp("UncompilableNullary", builder.opts().WithName("A"));
259 Node* b = ops::BinaryOp("CompilableFn", a, a, builder.opts().WithName("B"));
260 Node* c = ops::UnaryOp("Relu", b, builder.opts().WithName("C"));
261 ops::UnaryOp("UncompilableFn", c, builder.opts().WithName("D"));
262 ops::BinaryOp("NoInlineFn", c, c, builder.opts().WithName("E"));
263 TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
264 }
265
266 TF_ASSERT_OK(
267 MarkForCompilationPassTestHelper::MarkForCompilation(&graph, &flib_def));
268 auto clusters = GetClusters(*graph);
269
270 EXPECT_EQ(2, clusters.size());
271 EXPECT_FALSE(clusters["C"].empty());
272 EXPECT_EQ(clusters["C"], clusters["E"]);
273 EXPECT_TRUE(clusters.find("A") == clusters.cend());
274 EXPECT_TRUE(clusters.find("B") == clusters.cend());
275 EXPECT_TRUE(clusters.find("D") == clusters.cend());
276 }
277
TEST(XlaCompilationTest,CallXlaDeviceFuncWithResourceOp)278 TEST(XlaCompilationTest, CallXlaDeviceFuncWithResourceOp) {
279 FunctionDef compilable = FunctionDefHelper::Define(
280 "FnWithResourceOp", {"var:resource", "val:float"}, {"retval:float"}, {},
281 {{{"assign_op"},
282 "AssignVariableOp",
283 {"var", "val"},
284 {{"dtype", DT_FLOAT}}},
285 {{"retval"}, "Identity", {"val"}, {{"T", DT_FLOAT}}, {"assign_op"}}});
286
287 FunctionDefLibrary flib;
288 *flib.add_function() = compilable;
289 FunctionLibraryDefinition flib_def(OpRegistry::Global(), flib);
290
291 std::unique_ptr<Graph> graph(new Graph(&flib_def));
292 {
293 GraphDefBuilder builder(GraphDefBuilder::kFailImmediately, &flib_def);
294 Node* resource =
295 ops::SourceOp("VarHandleOp", builder.opts()
296 .WithName("varhandle")
297 .WithAttr("dtype", DT_FLOAT)
298 .WithAttr("shape", TensorShape({})));
299
300 Tensor const_tensor(DT_FLOAT, TensorShape({}));
301 const_tensor.scalar<float>()() = 42.0f;
302 Node* value = ops::SourceOp("Const", builder.opts()
303 .WithName("const")
304 .WithAttr("value", const_tensor)
305 .WithAttr("dtype", DT_FLOAT));
306
307 Node* call = ops::BinaryOp("FnWithResourceOp", resource, value,
308 builder.opts().WithName("A"));
309 Node* tanh0 = ops::UnaryOp("Tanh", call, builder.opts().WithName("tanh0"));
310 Node* tanh1 = ops::UnaryOp("Tanh", tanh0, builder.opts().WithName("tanh1"));
311 ops::UnaryOp("Tanh", tanh1, builder.opts().WithName("tanh2"));
312 TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
313 }
314
315 string xla_cpu_device = "/job:worker/replica:0/task:0/device:XLA_CPU:0";
316 testing::FindNodeByName(graph.get(), "A")
317 ->set_assigned_device_name(xla_cpu_device);
318 testing::FindNodeByName(graph.get(), "tanh0")
319 ->set_assigned_device_name(xla_cpu_device);
320 testing::FindNodeByName(graph.get(), "tanh1")
321 ->set_assigned_device_name(xla_cpu_device);
322 testing::FindNodeByName(graph.get(), "tanh2")
323 ->set_assigned_device_name(xla_cpu_device);
324
325 TF_ASSERT_OK(
326 MarkForCompilationPassTestHelper::MarkForCompilation(&graph, &flib_def));
327 auto clusters = GetClusters(*graph);
328
329 EXPECT_NE(clusters["A"], "");
330 }
331
GradForUnaryCwise(FunctionDef * g,std::vector<FunctionDefHelper::Node> nodes)332 static Status GradForUnaryCwise(FunctionDef* g,
333 std::vector<FunctionDefHelper::Node> nodes) {
334 for (auto& n : nodes) {
335 if (n.attr.empty()) {
336 n.attr = {{"T", DT_FLOAT}};
337 }
338 }
339 *g = FunctionDefHelper::Define(
340 // Arg defs
341 {"x: float", "dy: float"},
342 // Ret val defs
343 {"dx: float"},
344 // Attr defs
345 {},
346 // Nodes
347 nodes);
348 return Status::OK();
349 }
350
351 // A gradient containing only supported operators
SupportedGrad(const AttrSlice & attrs,FunctionDef * g)352 Status SupportedGrad(const AttrSlice& attrs, FunctionDef* g) {
353 // clang-format off
354 return GradForUnaryCwise(g, {
355 {{"y"}, "Tanh", {"x"}},
356 {{"y2"}, "Square", {"y"}, {}, {"dy"}},
357 FunctionDefHelper::Const("one", 1.0f),
358 {{"a"}, "Sub", {"one", "y2"}},
359 {{"dx"}, "Mul", {"dy", "a"}},
360 });
361 // clang-format on
362 }
363 REGISTER_OP_GRADIENT("Supported", SupportedGrad);
364
365 // A gradient containing an unsupported operator.
UnsupportedGrad(const AttrSlice & attrs,FunctionDef * g)366 Status UnsupportedGrad(const AttrSlice& attrs, FunctionDef* g) {
367 // clang-format off
368 return GradForUnaryCwise(g, {
369 {{"y"}, "Tanh", {"x"}},
370 {{"y2"}, "UncompilableUnary", {"y"}, {}, {"dy"}},
371 FunctionDefHelper::Const("one", 1.0f),
372 {{"a"}, "Sub", {"one", "y2"}},
373 {{"dx"}, "Mul", {"dy", "a"}},
374 });
375 // clang-format on
376 }
377 REGISTER_OP_GRADIENT("Unsupported", UnsupportedGrad);
378
TEST(XlaCompilationTest,SymbolicGradients)379 TEST(XlaCompilationTest, SymbolicGradients) {
380 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
381 {
382 GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
383 Node* a =
384 ops::SourceOp("UncompilableNullary", builder.opts().WithName("A"));
385
386 // Builds a Symbolic gradient for Supported
387 NodeBuilder b_builder("B", "SymbolicGradient",
388 builder.opts().op_registry());
389 NameAttrList b_name_attr;
390 b_name_attr.set_name("Supported");
391 b_builder.Attr("f", b_name_attr);
392 b_builder.Attr("Tin", {DT_FLOAT, DT_FLOAT});
393 b_builder.Attr("Tout", {DT_FLOAT});
394 b_builder.Input({a, a});
395 Node* b = builder.opts().FinalizeBuilder(&b_builder);
396
397 Node* c = ops::UnaryOp("Relu", b, builder.opts().WithName("C"));
398
399 // Builds a Symbolic gradient for Unsupported
400 NodeBuilder d_builder("D", "SymbolicGradient",
401 builder.opts().op_registry());
402 NameAttrList d_name_attr;
403 d_name_attr.set_name("Unsupported");
404 d_builder.Attr("f", d_name_attr);
405 d_builder.Attr("Tin", {DT_FLOAT, DT_FLOAT});
406 d_builder.Attr("Tout", {DT_FLOAT});
407 d_builder.Input({c, c});
408 builder.opts().FinalizeBuilder(&d_builder);
409
410 TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
411 }
412
413 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
414 auto clusters = GetClusters(*graph);
415
416 EXPECT_EQ(2, clusters.size());
417 EXPECT_FALSE(clusters["B"].empty());
418 EXPECT_EQ(clusters["B"], clusters["C"]);
419 EXPECT_TRUE(clusters.find("A") == clusters.cend());
420 EXPECT_TRUE(clusters.find("D") == clusters.cend());
421 }
422
TEST(XlaCompilationTest,Loops)423 TEST(XlaCompilationTest, Loops) {
424 // Regression test for b/32350199, where the autoclustering code introduced a
425 // deadlock in a graph containing a while loop.
426 Scope root = Scope::NewRootScope().ExitOnError();
427 auto a = ops::Placeholder(root.WithOpName("A"), DT_FLOAT);
428 auto b = ops::Placeholder(root.WithOpName("B"), DT_FLOAT);
429 auto c = ops::Add(root.WithOpName("C"), a, b);
430 auto enter = ops::internal::Enter(root, c, "aframe");
431 auto next_iter = ops::NextIteration(root, enter);
432 auto exit = ops::internal::Exit(root, next_iter);
433 auto d = ops::Add(root.WithOpName("D"), c, exit);
434
435 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
436 TF_EXPECT_OK(root.ToGraph(graph.get()));
437
438 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
439 auto clusters = GetClusters(*graph);
440
441 // Nothing should be compiled. In particular, 'd' and 'c' must not be
442 // compiled.
443 EXPECT_EQ(0, clusters.size());
444 }
445
TEST(XlaCompilationTest,CyclesWithAllDifferentScopesGlobalJitOverridden)446 TEST(XlaCompilationTest, CyclesWithAllDifferentScopesGlobalJitOverridden) {
447 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
448 {
449 GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
450 Node* a = ops::SourceOp("Const", builder.opts()
451 .WithName("A")
452 .WithAttr("dtype", DT_FLOAT)
453 .WithAttr("value", Tensor())
454 .WithAttr(kXlaScopeAttr, "ScopeA"));
455 Node* b = ops::UnaryOp(
456 "Relu", a,
457 builder.opts().WithName("B").WithAttr(kXlaScopeAttr, "ScopeB"));
458 ops::BinaryOp(
459 "MatMul", a, b,
460 builder.opts().WithName("C").WithAttr(kXlaScopeAttr, "ScopeC"));
461 TF_CHECK_OK(GraphDefBuilderToGraph(builder, graph.get()));
462 }
463
464 FunctionDefLibrary flib;
465 FunctionLibraryDefinition flib_def(graph->op_registry(), flib);
466 TF_ASSERT_OK(
467 MarkForCompilationPassTestHelper::MarkForCompilation(&graph, &flib_def));
468 auto clusters = GetClusters(*graph);
469
470 // The computation is: C = A + relu(A)
471 // where A sits in ScopeA, relu(A) sits in ScopeB, and C sits in ScopeC.
472 // In this case, the GlobalJitLevel overrides the scopes to cluster while
473 // ignoring scopes.
474 EXPECT_EQ(3, clusters.size());
475 EXPECT_EQ(clusters["A"], clusters["B"]);
476 EXPECT_EQ(clusters["A"], clusters["C"]);
477 }
478
TEST(XlaCompilationTest,CyclesWithAllDifferentScopes)479 TEST(XlaCompilationTest, CyclesWithAllDifferentScopes) {
480 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
481 {
482 GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
483 Node* a = ops::SourceOp("Const", builder.opts()
484 .WithName("A")
485 .WithAttr("dtype", DT_FLOAT)
486 .WithAttr("value", Tensor())
487 .WithAttr(kXlaScopeAttr, "ScopeA"));
488 Node* b = ops::UnaryOp(
489 "Relu", a,
490 builder.opts().WithName("B").WithAttr(kXlaScopeAttr, "ScopeB"));
491 ops::BinaryOp(
492 "MatMul", a, b,
493 builder.opts().WithName("C").WithAttr(kXlaScopeAttr, "ScopeC"));
494 TF_CHECK_OK(GraphDefBuilderToGraph(builder, graph.get()));
495 }
496
497 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(
498 &graph, MarkForCompilationPassTestHelper::Options().WithNoGlobalJit()));
499 auto clusters = GetClusters(*graph);
500
501 // The computation is: C = A + relu(A)
502 // where A sits in ScopeA, relu(A) sits in ScopeB, and C sits in ScopeC.
503 // In this case, we cannot fuse anything, and there are no clusters.
504 EXPECT_EQ(0, clusters.size());
505 }
506
TEST(XlaCompilationTest,CyclesWithSplittingScopes)507 TEST(XlaCompilationTest, CyclesWithSplittingScopes) {
508 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
509 {
510 GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
511 Node* a = ops::SourceOp("Const", builder.opts()
512 .WithName("A")
513 .WithAttr("dtype", DT_FLOAT)
514 .WithAttr("value", Tensor())
515 .WithAttr(kXlaCompileAttr, true)
516 .WithAttr(kXlaScopeAttr, "Scope1"));
517 Node* b = ops::UnaryOp("Relu", a,
518 builder.opts()
519 .WithName("B")
520 .WithAttr(kXlaCompileAttr, true)
521 .WithAttr(kXlaScopeAttr, "Scope1"));
522 Node* c = ops::BinaryOp("MatMul", a, b,
523 builder.opts()
524 .WithName("C")
525 .WithAttr(kXlaCompileAttr, true)
526 .WithAttr(kXlaScopeAttr, "Scope2"));
527 ops::BinaryOp("Add", b, c,
528 builder.opts()
529 .WithName("D")
530 .WithAttr(kXlaCompileAttr, true)
531 .WithAttr(kXlaScopeAttr, "Scope2"));
532 TF_CHECK_OK(GraphDefBuilderToGraph(builder, graph.get()));
533 }
534
535 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(
536 &graph, MarkForCompilationPassTestHelper::Options().WithNoGlobalJit()));
537 auto clusters = GetClusters(*graph);
538
539 // The computation is: D = relu(A) + (A @ relu(A))
540 // where A and relu(A) are in Scope1, and the @, + ops are in Scope2.
541 // In this case, we can fuse the A and relu(A), and we can fuse the
542 // second half of the operations; there are two clusters.
543 EXPECT_EQ(4, clusters.size());
544 EXPECT_EQ(clusters["A"], clusters["B"]);
545 EXPECT_NE(clusters["A"], clusters["C"]);
546 EXPECT_EQ(clusters["C"], clusters["D"]);
547 }
548
TEST(XlaCompilationTest,CyclesWithDifferentScopesAndBridge)549 TEST(XlaCompilationTest, CyclesWithDifferentScopesAndBridge) {
550 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
551 {
552 GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
553 Node* a = ops::SourceOp("Const", builder.opts()
554 .WithName("A")
555 .WithAttr("dtype", DT_FLOAT)
556 .WithAttr("value", Tensor())
557 .WithAttr(kXlaCompileAttr, true)
558 .WithAttr(kXlaScopeAttr, "ScopeA"));
559 Node* b = ops::UnaryOp("Relu", a,
560 builder.opts()
561 .WithName("B")
562 .WithAttr(kXlaCompileAttr, true)
563 .WithAttr(kXlaScopeAttr, "ScopeB"));
564 ops::BinaryOp("MatMul", a, b, builder.opts().WithName("C"));
565 TF_CHECK_OK(GraphDefBuilderToGraph(builder, graph.get()));
566 }
567
568 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(
569 &graph, MarkForCompilationPassTestHelper::Options().WithNoGlobalJit()));
570 auto clusters = GetClusters(*graph);
571
572 // The computation is: C = A @ relu(A)
573 // where A sits in ScopeA, relu(A) sits in ScopeB, and C sits in ScopeC.
574 // In this case, we cannot fuse anything.
575 EXPECT_EQ(3, clusters.size());
576 EXPECT_NE(clusters["A"], clusters["B"]);
577 EXPECT_EQ(clusters["B"], clusters["C"]);
578 }
579
TEST(XlaCompilationTest,DontClusterNodesWithMismatchingDeadness)580 TEST(XlaCompilationTest, DontClusterNodesWithMismatchingDeadness) {
581 Scope root = Scope::NewRootScope().ExitOnError();
582
583 Output cond_a = ops::Placeholder(root.WithOpName("cond_a"), DT_BOOL);
584 Output cond_b = ops::Placeholder(root.WithOpName("cond_b"), DT_BOOL);
585
586 Output value = ops::Placeholder(root.WithOpName("value"), DT_FLOAT);
587
588 ops::Switch switch_a(root.WithOpName("switch_a"), value, cond_a);
589 ops::Switch switch_b(root.WithOpName("switch_b"), value, cond_b);
590
591 Output tanh_a0 = ops::Tanh(root.WithOpName("tan_a0"), switch_a.output_true);
592 Output tanh_a1 = ops::Tanh(root.WithOpName("tan_a1"), tanh_a0);
593
594 Output tanh_b0 = ops::Tanh(root.WithOpName("tan_b0"), switch_b.output_true);
595 Output tanh_b1 = ops::Tanh(root.WithOpName("tan_b1"), tanh_b0);
596
597 Output add = ops::Add(root.WithOpName("add"), tanh_a1, tanh_b1);
598
599 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
600 TF_EXPECT_OK(root.ToGraph(graph.get()));
601
602 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(
603 &graph,
604 MarkForCompilationPassTestHelper::Options().WithDeadnessAnalysis()));
605 auto clusters = GetClusters(*graph);
606
607 EXPECT_NE(clusters["tan_a0"], "");
608 EXPECT_NE(clusters["tan_a1"], "");
609 EXPECT_NE(clusters["tan_b0"], "");
610 EXPECT_NE(clusters["tan_b1"], "");
611
612 EXPECT_EQ(clusters["tan_a0"], clusters["tan_a1"]);
613 EXPECT_EQ(clusters["tan_b0"], clusters["tan_b1"]);
614
615 EXPECT_NE(clusters["tan_a0"], clusters["tan_b0"]);
616 }
617
TEST(XlaCompilationTest,ClusterNodesWithMismatchingInputDeadness)618 TEST(XlaCompilationTest, ClusterNodesWithMismatchingInputDeadness) {
619 Scope root = Scope::NewRootScope().ExitOnError();
620
621 Output cond_a = ops::Placeholder(root.WithOpName("cond_a"), DT_BOOL);
622 Output cond_b = ops::Placeholder(root.WithOpName("cond_b"), DT_BOOL);
623
624 Output value = ops::Placeholder(root.WithOpName("value"), DT_FLOAT);
625
626 ops::Switch switch_a(root.WithOpName("switch_a"), value, cond_a);
627 ops::Switch switch_b(root.WithOpName("switch_b"), value, cond_b);
628
629 Output add_a = ops::Add(root.WithOpName("add_a"), switch_a.output_true,
630 switch_b.output_true);
631 Output add_b = ops::Add(root.WithOpName("add_b"), switch_a.output_true,
632 switch_b.output_true);
633 Output add = ops::Add(root.WithOpName("add_c"), add_a, add_b);
634
635 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
636 TF_EXPECT_OK(root.ToGraph(graph.get()));
637
638 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(
639 &graph,
640 MarkForCompilationPassTestHelper::Options().WithDeadnessAnalysis()));
641 auto clusters = GetClusters(*graph);
642
643 EXPECT_NE(clusters["add_a"], "");
644 EXPECT_NE(clusters["add_b"], "");
645 EXPECT_NE(clusters["add_c"], "");
646
647 EXPECT_EQ(clusters["add_a"], clusters["add_b"]);
648 EXPECT_EQ(clusters["add_b"], clusters["add_c"]);
649 }
650
651 namespace {
MakeRead(const Scope & scope,const string & id,Node ** var_handle_op=nullptr)652 Node* MakeRead(const Scope& scope, const string& id,
653 Node** var_handle_op = nullptr) {
654 Output var_handle =
655 ops::VarHandleOp(scope.WithOpName("Var" + id), DT_FLOAT, TensorShape({}));
656 Output read =
657 ops::ReadVariableOp(scope.WithOpName("Read" + id), var_handle, DT_FLOAT);
658 if (var_handle_op) {
659 *var_handle_op = var_handle.node();
660 }
661 return read.node();
662 }
663
MakeWrite(const Scope & scope,const string & id)664 Node* MakeWrite(const Scope& scope, const string& id) {
665 Output var_handle =
666 ops::VarHandleOp(scope.WithOpName("Var" + id), DT_FLOAT, TensorShape({}));
667 Output value_to_write =
668 ops::Const(scope.WithOpName("ValueToAssign" + id), 1.0f);
669 ops::AssignVariableOp assign_op(scope.WithOpName("Assignment" + id),
670 var_handle, value_to_write);
671 return assign_op.operation.node();
672 }
673
MakeNeutral(const Scope & scope,const string & id)674 Node* MakeNeutral(const Scope& scope, const string& id) {
675 return ops::Const(scope.WithOpName("Const" + id), 42.0f).node();
676 }
677 } // namespace
678
TEST(XlaCompilationTest,ResourcesClusteringAllowed)679 TEST(XlaCompilationTest, ResourcesClusteringAllowed) {
680 Scope root = Scope::NewRootScope().ExitOnError();
681
682 Node* read = MakeRead(root, "R");
683 Node* write = MakeWrite(root, "W");
684
685 root.graph()->AddControlEdge(read, write);
686
687 FixupSourceAndSinkEdges(root.graph());
688 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
689 TF_EXPECT_OK(root.ToGraph(graph.get()));
690 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
691 absl::flat_hash_map<string, std::vector<string>> cluster_sets =
692 GetClusterSets(*graph);
693 ASSERT_EQ(cluster_sets.size(), 1);
694 std::vector<string> expected_clustered_nodes = {"AssignmentW", "ReadR",
695 "ValueToAssignW"};
696 ASSERT_EQ(cluster_sets.begin()->second, expected_clustered_nodes);
697 }
698
TEST(XlaCompilationTest,ResourcesClusteringDisallowed)699 TEST(XlaCompilationTest, ResourcesClusteringDisallowed) {
700 Scope root = Scope::NewRootScope().ExitOnError();
701
702 Node* read = MakeRead(root, "R");
703 Node* write = MakeWrite(root, "W");
704
705 root.graph()->AddControlEdge(write, read);
706
707 FixupSourceAndSinkEdges(root.graph());
708 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
709 TF_EXPECT_OK(root.ToGraph(graph.get()));
710 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
711 absl::flat_hash_map<string, std::vector<string>> cluster_sets =
712 GetClusterSets(*graph);
713 ASSERT_EQ(cluster_sets.size(), 0);
714 }
715
TEST(XlaCompilationTest,ChainOfOps)716 TEST(XlaCompilationTest, ChainOfOps) {
717 Scope root = Scope::NewRootScope().ExitOnError();
718
719 Node* write_0 = MakeWrite(root, "W0");
720 Node* neutral_0 = MakeNeutral(root, "N0");
721 Node* read_0 = MakeRead(root, "R0");
722 Node* write_1 = MakeWrite(root, "W1");
723 Node* neutral_1 = MakeNeutral(root, "N1");
724 Node* read_1 = MakeRead(root, "R1");
725
726 root.graph()->AddControlEdge(write_0, neutral_0);
727 root.graph()->AddControlEdge(neutral_0, read_0);
728 root.graph()->AddControlEdge(read_0, write_1);
729 root.graph()->AddControlEdge(write_1, neutral_1);
730 root.graph()->AddControlEdge(neutral_1, read_1);
731
732 FixupSourceAndSinkEdges(root.graph());
733 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
734 TF_EXPECT_OK(root.ToGraph(graph.get()));
735 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
736
737 std::vector<string> cluster_names;
738 absl::flat_hash_map<string, std::vector<string>> cluster_sets =
739 GetClusterSets(*graph, &cluster_names);
740
741 ASSERT_EQ(cluster_sets.size(), 1);
742
743 std::vector<string> expected_clustered_nodes_a = {
744 "AssignmentW1", "ConstN0", "ReadR0", "ValueToAssignW1"};
745 ASSERT_EQ(cluster_sets[cluster_names[0]], expected_clustered_nodes_a);
746 }
747
TEST(XlaCompilationTest,IllegalCycle_UsefulErrorMessage)748 TEST(XlaCompilationTest, IllegalCycle_UsefulErrorMessage) {
749 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
750 Scope root = Scope::NewRootScope().ExitOnError();
751 {
752 auto BuildNoopNode = [](absl::string_view name, Graph* graph) {
753 NodeDefBuilder builder(name, "NoOp");
754 NodeDef def;
755 TF_CHECK_OK(builder.Finalize(&def));
756
757 Status status;
758 Node* node = graph->AddNode(def, &status);
759 TF_CHECK_OK(status);
760 return node;
761 };
762
763 Node* a = BuildNoopNode("a", graph.get());
764 Node* b = BuildNoopNode("b", graph.get());
765 Node* c = BuildNoopNode("c", graph.get());
766 graph->AddControlEdge(a, b);
767 graph->AddControlEdge(b, c);
768 graph->AddControlEdge(c, a);
769 }
770
771 TF_EXPECT_OK(root.ToGraph(graph.get()));
772
773 Status status = MarkForCompilationPassTestHelper::MarkForCompilation(&graph);
774 EXPECT_FALSE(status.ok());
775 EXPECT_TRUE(absl::StrContains(status.ToString(),
776 "Edge from c to a would create a cycle.\n"
777 "+-> a\n"
778 "| b\n"
779 "+-- c\n"));
780 }
781
TEST(XlaCompilationTest,Retval)782 TEST(XlaCompilationTest, Retval) {
783 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
784 {
785 GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
786 Node* a = ops::SourceOp("Const", builder.opts()
787 .WithName("A")
788 .WithAttr("dtype", DT_FLOAT)
789 .WithAttr("value", Tensor()));
790 Node* b = ops::UnaryOp("Relu", a, builder.opts().WithName("B"));
791 ops::UnaryOp("_Retval", b,
792 builder.opts()
793 .WithName("R")
794 .WithAttr("T", DT_FLOAT)
795 .WithAttr("index", 0));
796
797 TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
798 }
799
800 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
801 auto clusters = GetClusters(*graph);
802
803 EXPECT_TRUE(clusters.empty());
804 }
805
TEST(XlaCompilationTest,DontCountIdentityOps)806 TEST(XlaCompilationTest, DontCountIdentityOps) {
807 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
808 Scope root = Scope::NewRootScope().ExitOnError();
809 {
810 auto a = ops::_Arg(root.WithOpName("A"), DT_INT32, 0);
811 auto b = ops::Identity(root.WithOpName("B"), a);
812 auto c = ops::Identity(root.WithOpName("C"), b);
813 auto r = ops::_Retval(root.WithOpName("R"), c, 0);
814 }
815 TF_ASSERT_OK(root.ToGraph(graph.get()));
816 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
817 auto clusters = GetClusters(*graph);
818
819 EXPECT_TRUE(clusters.empty());
820 }
821
TEST(XlaCompilationTest,ConstOp)822 TEST(XlaCompilationTest, ConstOp) {
823 // valid data type
824 {
825 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
826 Scope root = Scope::NewRootScope().ExitOnError();
827 auto c = ops::Const(root.WithOpName("const"), 0.5f);
828 c.node()->AddAttr(kXlaCompileAttr, true);
829 TF_ASSERT_OK(root.ToGraph(graph.get()));
830 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
831 EXPECT_EQ(1, GetClusters(*graph).size());
832 }
833
834 // invalid data type
835 {
836 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
837 Scope root = Scope::NewRootScope().ExitOnError();
838 auto c = ops::Const(root.WithOpName("const"), string("string"));
839 c.node()->AddAttr(kXlaCompileAttr, true);
840 TF_ASSERT_OK(root.ToGraph(graph.get()));
841 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
842 EXPECT_TRUE(GetClusters(*graph).empty());
843 }
844 }
845
TEST(XlaCompilationTest,DontClusterIdentityWithRefInput)846 TEST(XlaCompilationTest, DontClusterIdentityWithRefInput) {
847 Scope root = Scope::NewRootScope().ExitOnError();
848 Output variable = ops::Variable(root.WithOpName("variable"),
849 PartialTensorShape{}, DT_FLOAT);
850 Output read = ops::Identity(root.WithOpName("read"), variable);
851 Output neg = ops::Negate(root.WithOpName("negate"), read);
852 Output add = ops::Add(root.WithOpName("add"), neg, neg);
853 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
854
855 TF_ASSERT_OK(root.ToGraph(graph.get()));
856 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
857
858 std::unordered_map<string, string> clusters = GetClusters(*graph);
859
860 ASSERT_FALSE(clusters.empty());
861 string cluster_name = clusters.begin()->second;
862
863 std::unordered_map<string, string> expected_clusters(
864 {{"negate", cluster_name}, {"add", cluster_name}});
865 EXPECT_EQ(clusters, expected_clusters);
866 }
867
TEST(XlaCompilationTest,ClusterIdentityWithNonRefInput)868 TEST(XlaCompilationTest, ClusterIdentityWithNonRefInput) {
869 Scope root = Scope::NewRootScope().ExitOnError();
870 Output variable = ops::Variable(root.WithOpName("variable"),
871 PartialTensorShape{}, DT_FLOAT);
872 Output read = ops::Identity(root.WithOpName("read"), variable);
873 Output neg = ops::Negate(root.WithOpName("negate"), read);
874 Output identity = ops::Negate(root.WithOpName("identity"), neg);
875 Output add = ops::Add(root.WithOpName("add"), identity, neg);
876 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
877
878 TF_ASSERT_OK(root.ToGraph(graph.get()));
879 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
880
881 std::unordered_map<string, string> clusters = GetClusters(*graph);
882
883 ASSERT_FALSE(clusters.empty());
884 string cluster_name = clusters.begin()->second;
885
886 std::unordered_map<string, string> expected_clusters(
887 {{"negate", cluster_name},
888 {"identity", cluster_name},
889 {"add", cluster_name}});
890 EXPECT_EQ(clusters, expected_clusters);
891 }
892
TEST(XlaCompilationTest,ClusterControlTrigger)893 TEST(XlaCompilationTest, ClusterControlTrigger) {
894 Scope root = Scope::NewRootScope().ExitOnError();
895
896 Output recv_a = ops::_Recv(root.WithOpName("recv_a"), DT_BOOL, "tensor_a",
897 "sender", 0, "receiver");
898 Output recv_b = ops::_Recv(root.WithOpName("recv_b"), DT_BOOL, "tensor_b",
899 "sender", 0, "receiver");
900 Output const_a = ops::Const(root.WithOpName("const_a"), 42);
901
902 ops::ControlTrigger ctrl_trigger_a(root.WithOpName("ctrl_trigger_a"));
903 ops::ControlTrigger ctrl_trigger_b(root.WithOpName("ctrl_trigger_b"));
904 root.graph()->AddControlEdge(recv_a.node(), ctrl_trigger_a.operation.node());
905 root.graph()->AddControlEdge(recv_b.node(), ctrl_trigger_a.operation.node());
906 root.graph()->AddControlEdge(ctrl_trigger_b.operation.node(), const_a.node());
907
908 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
909
910 TF_ASSERT_OK(root.ToGraph(graph.get()));
911 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
912
913 std::unordered_map<string, string> clusters = GetClusters(*graph);
914
915 // TODO(b/118970344): ctrl_trigger_a has inputs with mismatching deadness so
916 // it won't be clustered. ctrl_trigger_b is okay to cluster but we don't
917 // cluster it because of b/118970344.
918 EXPECT_TRUE(clusters.empty());
919 }
920
TEST(XlaCompilationTest,RandomShape)921 TEST(XlaCompilationTest, RandomShape) {
922 Scope root = Scope::NewRootScope().ExitOnError();
923 Output shape_shape = ops::Const(root.WithOpName("shape_shape"), {2}, {1});
924 Output shape =
925 ops::RandomUniformInt(root.WithOpName("shape"), shape_shape,
926 ops::Const(root.WithOpName("minval"), 1),
927 ops::Const(root.WithOpName("maxval"), 20));
928 Output reshape_input =
929 ops::Placeholder(root.WithOpName("reshape_input"), DT_FLOAT,
930 ops::Placeholder::Shape(TensorShape({500, 500})));
931 Output reshape =
932 ops::Reshape(root.WithOpName("reshape"), reshape_input, shape);
933
934 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
935
936 TF_ASSERT_OK(root.ToGraph(graph.get()));
937 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
938
939 std::unordered_map<string, string> clusters = GetClusters(*graph);
940 EXPECT_EQ(clusters["shape"], "");
941 }
942
TEST(XlaCompilationTest,RandomShapeWithFunc)943 TEST(XlaCompilationTest, RandomShapeWithFunc) {
944 Scope root = Scope::DisabledShapeInferenceScope().ExitOnError();
945
946 FunctionDefLibrary flib_def;
947 FunctionDef func = FunctionDefHelper::Create(
948 /*function_name=*/"Stateful_func", /*in_def=*/{},
949 /*out_def=*/{"out: int32"},
950 /*attr_def*/
951 {}, /*node_def=*/
952 {FunctionDefHelper::Const("shape_shape", 2),
953 FunctionDefHelper::Const("minval", 1),
954 FunctionDefHelper::Const("maxval", 20),
955 {{"shape"},
956 "RandomUniformInt",
957 {"shape_shape:output:0", "minval:output:0", "maxval:output:0"},
958 {{"Tout", DataType::DT_INT32}, {"T", DataType::DT_INT32}}}},
959 /*ret_def=*/{{"out", "shape:output:0"}});
960
961 func.mutable_signature()->set_is_stateful(true);
962 *flib_def.add_function() = std::move(func);
963 TF_ASSERT_OK(root.graph()->AddFunctionLibrary(flib_def));
964 NodeDef call_node;
965 call_node.set_name("fn_call");
966 call_node.set_op("Stateful_func");
967 Status status;
968 Node* call = root.graph()->AddNode(call_node, &status);
969 TF_ASSERT_OK(status);
970
971 Output shape = Output(call, 0);
972 Output reshape_input =
973 ops::Placeholder(root.WithOpName("reshape_input"), DT_FLOAT,
974 ops::Placeholder::Shape(TensorShape({500, 500})));
975 Output reshape =
976 ops::Reshape(root.WithOpName("reshape"), reshape_input, shape);
977
978 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
979 TF_ASSERT_OK(root.ToGraph(graph.get()));
980 auto fld = absl::make_unique<FunctionLibraryDefinition>(OpRegistry::Global(),
981 flib_def);
982 TF_ASSERT_OK(
983 MarkForCompilationPassTestHelper::MarkForCompilation(&graph, fld.get()));
984
985 std::unordered_map<string, string> clusters = GetClusters(*graph);
986 EXPECT_EQ(clusters["fn_call"], "");
987 }
988
TEST(XlaCompilationTest,RandomShapeOnXlaDevice)989 TEST(XlaCompilationTest, RandomShapeOnXlaDevice) {
990 absl::string_view xla_gpu_device =
991 "/job:worker/replica:0/task:0/device:XLA_GPU:0";
992
993 Scope root = Scope::NewRootScope().ExitOnError();
994 Output shape_shape =
995 ops::Const(root.WithOpName("test/shape_shape"), {2}, {1});
996 Output shape =
997 ops::RandomUniformInt(root.WithOpName("test/shape_rng"), shape_shape,
998 ops::Const(root.WithOpName("test/minval"), 1),
999 ops::Const(root.WithOpName("test/maxval"), 20));
1000 Output reshape_input =
1001 ops::Placeholder(root.WithOpName("test/reshape_input"), DT_FLOAT,
1002 ops::Placeholder::Shape(TensorShape({500, 500})));
1003 Output reshape =
1004 ops::Reshape(root.WithOpName("test/reshape"), reshape_input, shape);
1005
1006 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1007 TF_ASSERT_OK(root.ToGraph(graph.get()));
1008
1009 for (Node* n : graph->nodes()) {
1010 if (absl::StartsWith(n->name(), /*prefix=*/"test/")) {
1011 n->set_assigned_device_name(string(xla_gpu_device));
1012 }
1013 }
1014 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
1015
1016 std::unordered_map<string, string> clusters = GetClusters(*graph);
1017 EXPECT_EQ(clusters["test/shape_rng"], "");
1018 EXPECT_EQ(clusters["test/reshape"], "");
1019 }
1020
TEST(XlaCompilationTest,TensorArrayShapeOnXlaDevice)1021 TEST(XlaCompilationTest, TensorArrayShapeOnXlaDevice) {
1022 absl::string_view xla_gpu_device =
1023 "/job:worker/replica:0/task:0/device:XLA_GPU:0";
1024 Scope root = Scope::NewRootScope().ExitOnError();
1025 ops::TensorArray tensor_array(root.WithOpName("test/tensor_array"), 1,
1026 DT_INT32);
1027 Output zero = ops::Const(root.WithOpName("test/zero"), 0);
1028 ops::TensorArrayWrite tensor_array_write(
1029 root.WithOpName("test/write"), tensor_array.handle, zero,
1030 ops::Const(root.WithOpName("test/forty_two"), 42.0f), tensor_array.flow);
1031 Output tensor_array_read =
1032 ops::TensorArrayRead(root.WithOpName("test/read"), tensor_array.handle,
1033 zero, tensor_array_write.flow_out, DT_INT32);
1034 Output reshape =
1035 ops::Reshape(root.WithOpName("test/reshape"),
1036 ops::Placeholder(root.WithOpName("placeholder"), DT_FLOAT),
1037 tensor_array_read);
1038
1039 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1040 TF_ASSERT_OK(root.ToGraph(graph.get()));
1041
1042 for (Node* n : graph->nodes()) {
1043 if (absl::StartsWith(n->name(), /*prefix=*/"test/")) {
1044 n->set_assigned_device_name(string(xla_gpu_device));
1045 }
1046 }
1047 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
1048
1049 std::unordered_map<string, string> clusters = GetClusters(*graph);
1050 EXPECT_NE(clusters["test/read"], "");
1051 EXPECT_EQ(clusters["test/read"], clusters["test/reshape"]);
1052 }
1053
TEST(XlaCompilationTest,DontClusterMergingNodes)1054 TEST(XlaCompilationTest, DontClusterMergingNodes) {
1055 // MatMulCombined below takes data from nodes on GPU0 and GPU1 and is placed
1056 // on GPU1. However, it should not be clustered with the previous node on
1057 // GPU1, because that will serialize production of its inputs that should be
1058 // done in parallel.
1059 //
1060 // This graph is:
1061 // (Const0, Const0) -> MatMul0
1062 // (Const1, Const1) -> MatMul1
1063 // (MatMul0, MatMul1) -> MatMulCombined
1064 //
1065 // Device0: [Const0, Const0, MatMul0]
1066 // Device1: [Const1, Const1, MatMul1, MatMulCombined]
1067 //
1068 // Cluster0: [Const0, Const0, MatMul0]
1069 // Cluster1: [Const1, Const1, MatMul1]
1070 // Cluster2: [MatMulCombined]
1071 Scope root = Scope::NewRootScope().ExitOnError();
1072 absl::string_view xla_gpu_dev0 =
1073 "/job:worker/replica:0/task:0/device:XLA_GPU:0";
1074 absl::string_view xla_gpu_dev1 =
1075 "/job:worker/replica:0/task:0/device:XLA_GPU:1";
1076 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1077 Output a = ops::Tanh(root.WithOpName("tanh_A_dev0"),
1078 ops::Const(root.WithOpName("A_dev0"), 1.0f, {2, 2}));
1079 Output b = ops::Tanh(root.WithOpName("tanh_B_dev1"),
1080 ops::Const(root.WithOpName("B_dev1"), 1.0f, {2, 2}));
1081 Output matmul0 = ops::MatMul(root.WithOpName("MatMul0_dev0"), a, a);
1082 Output matmul1 = ops::MatMul(root.WithOpName("MatMul1_dev1"), b, b);
1083
1084 Output combined =
1085 ops::MatMul(root.WithOpName("MatMulCombined_dev1"), matmul0, matmul1);
1086 TF_ASSERT_OK(root.ToGraph(graph.get()));
1087
1088 for (Node* n : graph->nodes()) {
1089 if (absl::EndsWith(n->name(), /*suffix=*/"dev0")) {
1090 n->set_assigned_device_name(string(xla_gpu_dev0));
1091 } else if (absl::EndsWith(n->name(), /*suffix=*/"dev1")) {
1092 n->set_assigned_device_name(string(xla_gpu_dev1));
1093 }
1094 }
1095 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
1096
1097 // Each of the MatMuls should be in a separate cluster.
1098 std::unordered_map<string, string> clusters = GetClusters(*graph);
1099 EXPECT_NE(clusters["MatMul0_dev0"], clusters["MatMul1_dev1"]);
1100 EXPECT_NE(clusters["MatMulCombined_dev1"], clusters["MatMul0_dev0"]);
1101 EXPECT_NE(clusters["MatMulCombined_dev1"], clusters["MatMul1_dev1"]);
1102 EXPECT_EQ(clusters["A_dev0"], clusters["MatMul0_dev0"]);
1103 EXPECT_EQ(clusters["B_dev1"], clusters["MatMul1_dev1"]);
1104 }
1105
TEST(XlaCompilationTest,DontClusterMergingNodesOnCPU)1106 TEST(XlaCompilationTest, DontClusterMergingNodesOnCPU) {
1107 // This is similar to the 'DontClusterMergingNodes' above, except
1108 // MatMulCombined is placed on the CPU.
1109 Scope root = Scope::NewRootScope().ExitOnError();
1110 absl::string_view xla_gpu_dev0 = "/job:worker/replica:0/task:0/device:GPU:0";
1111 absl::string_view xla_gpu_dev1 = "/job:worker/replica:0/task:0/device:GPU:1";
1112 absl::string_view xla_cpu_dev0 = "/job:worker/replica:0/task:0/device:CPU:0";
1113 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1114 Output a = ops::Tanh(root.WithOpName("tanh_A_dev0"),
1115 ops::Const(root.WithOpName("A_dev0"), 1.0f, {2, 2}));
1116 Output b = ops::Tanh(root.WithOpName("tanh_B_dev1"),
1117 ops::Const(root.WithOpName("B_dev1"), 1.0f, {2, 2}));
1118 Output matmul0 = ops::MatMul(root.WithOpName("MatMul0_dev0"), a, a);
1119 Output matmul1 = ops::MatMul(root.WithOpName("MatMul1_dev1"), b, b);
1120
1121 Output combined =
1122 ops::MatMul(root.WithOpName("MatMulCombined_cpu"), matmul0, matmul1);
1123 TF_ASSERT_OK(root.ToGraph(graph.get()));
1124
1125 for (Node* n : graph->nodes()) {
1126 if (absl::EndsWith(n->name(), /*suffix=*/"cpu")) {
1127 n->set_assigned_device_name(string(xla_cpu_dev0));
1128 } else if (absl::EndsWith(n->name(), /*suffix=*/"dev0")) {
1129 n->set_assigned_device_name(string(xla_gpu_dev0));
1130 } else if (absl::EndsWith(n->name(), /*suffix=*/"dev1")) {
1131 n->set_assigned_device_name(string(xla_gpu_dev1));
1132 }
1133 }
1134 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
1135
1136 // Each of the MatMuls should be in a separate cluster.
1137 std::unordered_map<string, string> clusters = GetClusters(*graph);
1138 EXPECT_NE(clusters["MatMul0_dev0"], clusters["MatMul1_dev1"]);
1139 EXPECT_NE(clusters["MatMulCombined_cpu"], clusters["MatMul0_dev0"]);
1140 EXPECT_NE(clusters["MatMulCombined_cpu"], clusters["MatMul1_dev1"]);
1141 EXPECT_EQ(clusters["A_dev0"], clusters["MatMul0_dev0"]);
1142 EXPECT_EQ(clusters["B_dev1"], clusters["MatMul1_dev1"]);
1143 }
1144
1145 // TODO(b/117085735): This form of clustering should be prevented.
TEST(XlaCompilationTest,NOT_DontClusterSpreadingNodes)1146 TEST(XlaCompilationTest, NOT_DontClusterSpreadingNodes) {
1147 // MatMulSource below creates data for nodes on GPU0 and GPU1 and is placed
1148 // on GPU0. However, it should not be clustered with the next node on
1149 // GPU0, because that will prevent the node on GPU1 from beginning its work as
1150 // soon as the data has been produced.
1151 //
1152 // This graph is:
1153 // (Const0, Const0) -> MatMulSource
1154 // MatMulSource -> (MatMul0, MatMul1)
1155 //
1156 // Device0: [Const0, Const1, MatMulSource, MatMul0]
1157 // Device1: [MatMul1]
1158 //
1159 // Cluster0: [Const0, Const1, MatMulSource]
1160 // Cluster1: [MatMul0]
1161 // Cluster2: [MatMul1]
1162 Scope root = Scope::NewRootScope().ExitOnError();
1163 absl::string_view xla_gpu_dev0 =
1164 "/job:worker/replica:0/task:0/device:XLA_GPU:0";
1165 absl::string_view xla_gpu_dev1 =
1166 "/job:worker/replica:0/task:0/device:XLA_GPU:1";
1167 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1168 Output a = ops::Const(root.WithOpName("A_dev0"), 1.0f, {2, 2});
1169 Output matmul_source =
1170 ops::MatMul(root.WithOpName("MatMulSource_dev0"), a, a);
1171
1172 Output matmul0 = ops::MatMul(root.WithOpName("MatMul0_dev0"), matmul_source,
1173 matmul_source);
1174 Output matmul1 = ops::MatMul(root.WithOpName("MatMul1_dev1"), matmul_source,
1175 matmul_source);
1176
1177 TF_ASSERT_OK(root.ToGraph(graph.get()));
1178 for (Node* n : graph->nodes()) {
1179 if (absl::EndsWith(n->name(), /*suffix=*/"dev0")) {
1180 n->set_assigned_device_name(string(xla_gpu_dev0));
1181 } else if (absl::EndsWith(n->name(), /*suffix=*/"dev1")) {
1182 n->set_assigned_device_name(string(xla_gpu_dev1));
1183 }
1184 }
1185 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
1186
1187 std::unordered_map<string, string> clusters = GetClusters(*graph);
1188 EXPECT_EQ(clusters["A_dev0"], clusters["MatMulSource_dev0"]);
1189 EXPECT_NE(clusters["MatMul0_dev0"], clusters["MatMul1_dev1"]);
1190 EXPECT_NE(clusters["MatMulSource_dev0"], clusters["MatMul1_dev1"]);
1191
1192 // Improved Heuristics should prevent this probably.
1193 EXPECT_EQ(clusters["MatMulSource_dev0"], clusters["MatMul0_dev0"]);
1194 }
1195
TEST(XlaCompilationTest,ClusterStatefulRandomOpOnXlaDevice)1196 TEST(XlaCompilationTest, ClusterStatefulRandomOpOnXlaDevice) {
1197 absl::string_view xla_cpu_device =
1198 "/job:worker/replica:0/task:0/device:XLA_CPU:0";
1199
1200 Scope root = Scope::NewRootScope().ExitOnError();
1201 Output shape = ops::Const(root.WithOpName("test/shape_shape"), {200, 200});
1202 Output a = ops::RandomUniform(root.WithOpName("test/a"), shape, DT_FLOAT);
1203 Output b = ops::RandomUniform(root.WithOpName("test/b"), shape, DT_FLOAT);
1204 Output c = ops::Add(root.WithOpName("test/c"), a, b);
1205
1206 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1207 TF_ASSERT_OK(root.ToGraph(graph.get()));
1208
1209 for (Node* n : graph->nodes()) {
1210 if (absl::StartsWith(n->name(), /*prefix=*/"test/")) {
1211 n->set_assigned_device_name(string(xla_cpu_device));
1212 }
1213 }
1214 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
1215
1216 std::unordered_map<string, string> clusters = GetClusters(*graph);
1217 EXPECT_NE(clusters["test/a"], "");
1218 EXPECT_NE(clusters["test/b"], "");
1219 EXPECT_NE(clusters["test/c"], "");
1220 }
1221
TEST(XlaCompilationTest,DontAutoClusterStatefulRandomOp)1222 TEST(XlaCompilationTest, DontAutoClusterStatefulRandomOp) {
1223 Scope root = Scope::NewRootScope().ExitOnError();
1224 Output shape = ops::Const(root.WithOpName("test/shape_shape"), {200, 200});
1225 Output a = ops::RandomUniform(root.WithOpName("test/a"), shape, DT_FLOAT);
1226 Output b = ops::RandomUniform(root.WithOpName("test/b"), shape, DT_FLOAT);
1227 Output c = ops::Add(root.WithOpName("test/c"), a, b);
1228
1229 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1230 TF_ASSERT_OK(root.ToGraph(graph.get()));
1231
1232 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
1233
1234 std::unordered_map<string, string> clusters = GetClusters(*graph);
1235 EXPECT_EQ(clusters["test/a"], "");
1236 EXPECT_EQ(clusters["test/b"], "");
1237 }
1238
TEST(XlaCompilationTest,ClusterDummyOpsOnXlaDevice)1239 TEST(XlaCompilationTest, ClusterDummyOpsOnXlaDevice) {
1240 absl::string_view xla_cpu_device =
1241 "/job:worker/replica:0/task:0/device:XLA_CPU:0";
1242
1243 Scope root = Scope::NewRootScope().ExitOnError();
1244 Output a = ops::Placeholder(root.WithOpName("test/a"), DT_FLOAT);
1245 Output b = ops::Placeholder(root.WithOpName("test/b"), DT_FLOAT);
1246 Output check =
1247 ops::CheckNumerics(root.WithOpName("test/check"), a, "test/check");
1248 Output ge = ops::GreaterEqual(root.WithOpName("test/greaterequal"), check, b);
1249 Operation assert = ops::Assert(root.WithOpName("test/assert"), ge, {a, b});
1250
1251 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1252 TF_ASSERT_OK(root.ToGraph(graph.get()));
1253
1254 for (Node* n : graph->nodes()) {
1255 if (absl::StartsWith(n->name(), /*prefix=*/"test/")) {
1256 n->set_assigned_device_name(string(xla_cpu_device));
1257 }
1258 }
1259 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
1260
1261 std::unordered_map<string, string> clusters = GetClusters(*graph);
1262 EXPECT_NE(clusters["test/check"], "");
1263 EXPECT_NE(clusters["test/greaterequal"], "");
1264 EXPECT_NE(clusters["test/assert"], "");
1265 }
1266
TEST(XlaCompilationTest,DontAutoClusterDummyOps)1267 TEST(XlaCompilationTest, DontAutoClusterDummyOps) {
1268 Scope root = Scope::NewRootScope().ExitOnError();
1269 Output a = ops::Placeholder(root.WithOpName("test/a"), DT_FLOAT);
1270 Output b = ops::Placeholder(root.WithOpName("test/b"), DT_FLOAT);
1271 Output check =
1272 ops::CheckNumerics(root.WithOpName("test/check"), a, "test/check");
1273 Output ge = ops::GreaterEqual(root.WithOpName("test/greaterequal"), check, b);
1274 Operation assert = ops::Assert(root.WithOpName("test/assert"), ge, {a, b});
1275
1276 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1277 TF_ASSERT_OK(root.ToGraph(graph.get()));
1278
1279 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
1280
1281 std::unordered_map<string, string> clusters = GetClusters(*graph);
1282 EXPECT_EQ(clusters["test/assert"], "");
1283 EXPECT_EQ(clusters["test/check"], "");
1284 }
1285
TEST(XlaCompilationTest,DontAutoClusterOpsProducingVariant)1286 TEST(XlaCompilationTest, DontAutoClusterOpsProducingVariant) {
1287 Scope root = Scope::NewRootScope().ExitOnError();
1288 Output a = ops::Placeholder(root.WithOpName("test/a"), DT_INT64);
1289 Output b = ops::Placeholder(root.WithOpName("test/b"), DT_INT64);
1290
1291 Output cast_a = ops::Cast(root.WithOpName("test/cast_a"), a, DT_INT32);
1292 Output cast_b = ops::Cast(root.WithOpName("test/cast_b"), b, DT_INT32);
1293
1294 Output tensor_list_reserve = ops::TensorListReserve(
1295 root.WithOpName("test/tensor_list_reserve"), cast_a, cast_b, DT_FLOAT);
1296
1297 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1298 TF_ASSERT_OK(root.ToGraph(graph.get()));
1299
1300 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
1301
1302 std::unordered_map<string, string> clusters = GetClusters(*graph);
1303 EXPECT_EQ(clusters["test/tensor_list_reserve"], "");
1304 }
1305
TEST(XlaCompilationTest,DontAutoClusterOpsConsumingVariant)1306 TEST(XlaCompilationTest, DontAutoClusterOpsConsumingVariant) {
1307 Scope root = Scope::NewRootScope().ExitOnError();
1308 Output dummy_input =
1309 ops::Placeholder(root.WithOpName("test/dummy_input"), DT_INT64);
1310 Output variant_input =
1311 ops::Placeholder(root.WithOpName("test/variant_input"), DT_VARIANT);
1312
1313 // Create one more node so that we don't avoid creating a cluster solely
1314 // because it would be trivial.
1315 Output dummy_cast =
1316 ops::Cast(root.WithOpName("test/dummy_cast"), dummy_input, DT_INT32);
1317
1318 Output tensor_list_element_shape = ops::TensorListElementShape(
1319 root.WithOpName("test/tensor_list_element_shape"), variant_input,
1320 DT_INT32);
1321
1322 root.graph()->AddControlEdge(dummy_cast.node(),
1323 tensor_list_element_shape.node());
1324
1325 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1326 TF_ASSERT_OK(root.ToGraph(graph.get()));
1327
1328 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
1329
1330 std::unordered_map<string, string> clusters = GetClusters(*graph);
1331 EXPECT_EQ(clusters["test/tensor_list_element_shape"], "");
1332 }
1333
TEST(XlaCompilationTest,ClusterOpsProducingVariantIfOnXlaDevice)1334 TEST(XlaCompilationTest, ClusterOpsProducingVariantIfOnXlaDevice) {
1335 Scope root = Scope::NewRootScope().ExitOnError();
1336 Output a = ops::Placeholder(root.WithOpName("test/a"), DT_INT64);
1337 Output b = ops::Placeholder(root.WithOpName("test/b"), DT_INT64);
1338
1339 Output cast_a = ops::Cast(root.WithOpName("test/cast_a"), a, DT_INT32);
1340 Output cast_b = ops::Cast(root.WithOpName("test/cast_b"), b, DT_INT32);
1341
1342 Output tensor_list_reserve = ops::TensorListReserve(
1343 root.WithOpName("test/tensor_list_reserve"), cast_a, cast_b, DT_FLOAT);
1344
1345 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1346 TF_ASSERT_OK(root.ToGraph(graph.get()));
1347
1348 string xla_cpu_device = "/job:worker/replica:0/task:0/device:XLA_CPU:0";
1349 for (Node* n : graph->nodes()) {
1350 if (absl::StartsWith(n->name(), /*prefix=*/"test/")) {
1351 n->set_assigned_device_name(xla_cpu_device);
1352 }
1353 }
1354
1355 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
1356
1357 std::unordered_map<string, string> clusters = GetClusters(*graph);
1358 EXPECT_NE(clusters["test/tensor_list_reserve"], "");
1359 }
1360
1361 const char* kCPU0 = "/job:worker/replica:0/task:0/device:CPU:0";
1362 const char* kGPU0 = "/job:worker/replica:0/task:0/device:GPU:0";
1363 const char* kXLA_GPU0 = "/job:worker/replica:0/task:0/device:XLA_GPU:0";
1364 const char* kGPU1 = "/job:worker/replica:0/task:0/device:GPU:1";
1365
TEST(XlaCompilationTest,CreateCombinedCpuGpuClusters)1366 TEST(XlaCompilationTest, CreateCombinedCpuGpuClusters) {
1367 Scope root = Scope::NewRootScope().ExitOnError();
1368 Output a = ops::Placeholder(root.WithOpName("test/a"), DT_FLOAT);
1369 Output b = ops::Placeholder(root.WithOpName("test/b"), DT_FLOAT);
1370
1371 Output x = ops::Add(root.WithOpName("test/x"), a, b);
1372 Output y = ops::MatMul(root.WithOpName("test/y"), a, b);
1373 Output z = ops::Add(root.WithOpName("test/z"), x, y);
1374
1375 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1376 TF_ASSERT_OK(root.ToGraph(graph.get()));
1377
1378 FindNodeByName(graph.get(), "test/x")->set_assigned_device_name(kGPU0);
1379 FindNodeByName(graph.get(), "test/y")->set_assigned_device_name(kCPU0);
1380 FindNodeByName(graph.get(), "test/z")->set_assigned_device_name(kGPU0);
1381
1382 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
1383
1384 std::unordered_map<string, string> clusters = GetClusters(*graph);
1385
1386 EXPECT_NE(clusters["test/x"], "");
1387
1388 EXPECT_EQ(clusters["test/x"], clusters["test/y"]);
1389 EXPECT_EQ(clusters["test/y"], clusters["test/z"]);
1390 }
1391
TEST(XlaCompilationTest,DontCreateGpu0AndGpu1Clusters)1392 TEST(XlaCompilationTest, DontCreateGpu0AndGpu1Clusters) {
1393 Scope root = Scope::NewRootScope().ExitOnError();
1394 Output a = ops::Placeholder(root.WithOpName("test/a"), DT_FLOAT);
1395 Output b = ops::Placeholder(root.WithOpName("test/b"), DT_FLOAT);
1396
1397 Output x = ops::Add(root.WithOpName("test/x"), a, b);
1398 Output y = ops::Add(root.WithOpName("test/y"), x, x);
1399
1400 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1401 TF_ASSERT_OK(root.ToGraph(graph.get()));
1402
1403 FindNodeByName(graph.get(), "test/x")->set_assigned_device_name(kGPU0);
1404 FindNodeByName(graph.get(), "test/y")->set_assigned_device_name(kGPU1);
1405
1406 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
1407
1408 std::unordered_map<string, string> clusters = GetClusters(*graph);
1409
1410 EXPECT_EQ(clusters["test/x"], "");
1411 EXPECT_EQ(clusters["test/y"], "");
1412 }
1413
TEST(XlaCompilationTest,DontCreateCombinedCpuUnknownClusters)1414 TEST(XlaCompilationTest, DontCreateCombinedCpuUnknownClusters) {
1415 Scope root = Scope::NewRootScope().ExitOnError();
1416 Output a = ops::Placeholder(root.WithOpName("test/a"), DT_FLOAT);
1417 Output b = ops::Placeholder(root.WithOpName("test/b"), DT_FLOAT);
1418
1419 Output x = ops::Add(root.WithOpName("test/x"), a, b);
1420 Output y = ops::Add(root.WithOpName("test/y"), x, x);
1421
1422 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1423 TF_ASSERT_OK(root.ToGraph(graph.get()));
1424
1425 FindNodeByName(graph.get(), "test/x")->set_assigned_device_name(kCPU0);
1426 FindNodeByName(graph.get(), "test/y")->set_assigned_device_name(kXLA_GPU0);
1427
1428 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
1429
1430 std::unordered_map<string, string> clusters = GetClusters(*graph);
1431
1432 EXPECT_EQ(clusters["test/x"], "");
1433 EXPECT_EQ(clusters["test/y"], "");
1434 }
1435
TEST(XlaCompilationTest,ClusterResourceOpsWhenSafe)1436 TEST(XlaCompilationTest, ClusterResourceOpsWhenSafe) {
1437 Scope root = Scope::NewRootScope().ExitOnError();
1438 Output a = ops::Placeholder(root.WithOpName("test/a"), DT_FLOAT);
1439 Node* var_handle;
1440 Node* resource_read = MakeRead(root, "read", &var_handle);
1441 Output b = ops::Add(root.WithOpName("test/b"), Output(resource_read, 0), a);
1442
1443 string resource_read_name = resource_read->name();
1444 string var_handle_name = var_handle->name();
1445
1446 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1447 TF_ASSERT_OK(root.ToGraph(graph.get()));
1448
1449 FindNodeByName(graph.get(), "test/b")->set_assigned_device_name(kCPU0);
1450 FindNodeByName(graph.get(), resource_read_name)
1451 ->set_assigned_device_name(kGPU0);
1452 FindNodeByName(graph.get(), var_handle_name)->set_assigned_device_name(kGPU0);
1453
1454 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
1455
1456 std::unordered_map<string, string> clusters = GetClusters(*graph);
1457
1458 EXPECT_NE(clusters["test/b"], "");
1459 EXPECT_EQ(clusters["test/b"], clusters[resource_read_name]);
1460 }
1461
TEST(XlaCompilationTest,DontClusterResourceOpsWhenUnsafe)1462 TEST(XlaCompilationTest, DontClusterResourceOpsWhenUnsafe) {
1463 Scope root = Scope::NewRootScope().ExitOnError();
1464 Output a = ops::Placeholder(root.WithOpName("test/a"), DT_FLOAT);
1465 Node* var_handle;
1466 Node* resource_read = MakeRead(root, "read", &var_handle);
1467 Output b = ops::Add(root.WithOpName("test/b"), Output(resource_read, 0), a);
1468
1469 string resource_read_name = resource_read->name();
1470 string var_handle_name = var_handle->name();
1471
1472 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1473 TF_ASSERT_OK(root.ToGraph(graph.get()));
1474
1475 FindNodeByName(graph.get(), "test/b")->set_assigned_device_name(kGPU0);
1476 FindNodeByName(graph.get(), resource_read_name)
1477 ->set_assigned_device_name(kCPU0);
1478 FindNodeByName(graph.get(), var_handle_name)->set_assigned_device_name(kCPU0);
1479
1480 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
1481
1482 std::unordered_map<string, string> clusters = GetClusters(*graph);
1483
1484 EXPECT_EQ(clusters["test/b"], "");
1485 EXPECT_EQ(clusters[resource_read_name], "");
1486 }
1487
TEST(XlaCompilationTest,DontClusterNodesWithScopedAllocatorAttr)1488 TEST(XlaCompilationTest, DontClusterNodesWithScopedAllocatorAttr) {
1489 Scope root = Scope::NewRootScope().ExitOnError();
1490 Output a = ops::Placeholder(root.WithOpName("test/a"), DT_FLOAT);
1491 Output b = ops::Placeholder(root.WithOpName("test/b"), DT_FLOAT);
1492
1493 Output x = ops::Add(root.WithOpName("test/x"), a, b);
1494 Output y = ops::MatMul(root.WithOpName("test/y"), a, b);
1495 Output z = ops::Add(root.WithOpName("test/z"), x, y);
1496
1497 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1498 TF_ASSERT_OK(root.ToGraph(graph.get()));
1499
1500 FindNodeByName(graph.get(), "test/x")->set_assigned_device_name(kGPU0);
1501 FindNodeByName(graph.get(), "test/y")->set_assigned_device_name(kGPU0);
1502 FindNodeByName(graph.get(), "test/z")->set_assigned_device_name(kGPU0);
1503
1504 std::vector<int> scoped_allocator_value;
1505 scoped_allocator_value.push_back(0);
1506 scoped_allocator_value.push_back(155);
1507 FindNodeByName(graph.get(), "test/z")
1508 ->AddAttr("_scoped_allocator", scoped_allocator_value);
1509
1510 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
1511
1512 std::unordered_map<string, string> clusters = GetClusters(*graph);
1513
1514 EXPECT_EQ(clusters["test/z"], "");
1515 }
1516
TEST(XlaCompilationTest,DontClusterNodesWithForwardFromAttr)1517 TEST(XlaCompilationTest, DontClusterNodesWithForwardFromAttr) {
1518 Scope root = Scope::NewRootScope().ExitOnError();
1519 Output a = ops::Placeholder(root.WithOpName("test/a"), DT_FLOAT);
1520 Output b = ops::Placeholder(root.WithOpName("test/b"), DT_FLOAT);
1521
1522 Output x = ops::Add(root.WithOpName("test/x"), a, b);
1523 Output y = ops::MatMul(root.WithOpName("test/y"), a, b);
1524 Output z = ops::Add(root.WithOpName("test/z"), x, y);
1525
1526 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1527 TF_ASSERT_OK(root.ToGraph(graph.get()));
1528
1529 FindNodeByName(graph.get(), "test/x")->set_assigned_device_name(kGPU0);
1530 FindNodeByName(graph.get(), "test/y")->set_assigned_device_name(kGPU0);
1531 FindNodeByName(graph.get(), "test/z")->set_assigned_device_name(kGPU0);
1532
1533 FindNodeByName(graph.get(), "test/z")->AddAttr("_forward_from", 0);
1534
1535 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
1536
1537 std::unordered_map<string, string> clusters = GetClusters(*graph);
1538
1539 EXPECT_EQ(clusters["test/z"], "");
1540 }
1541
1542 // Note, this relies on other implementation details to test the
1543 // specific heuristic we care about here, so other changes might be at fault if
1544 // this CL breaks. What we care about is that if a ShapeConsumingOp can be
1545 // connected with a producer or consumer and cannot be clustered with both, it
1546 // should be clustered with the producer.
TEST(XlaCompilationTest,ClusterShapeConsumerWithProducer)1547 TEST(XlaCompilationTest, ClusterShapeConsumerWithProducer) {
1548 Scope root = Scope::NewRootScope().ExitOnError();
1549 Output a = ops::Placeholder(root.WithOpName("test/a"), DT_FLOAT);
1550 Output b = ops::Placeholder(root.WithOpName("test/b"), DT_FLOAT);
1551
1552 Output x = ops::MatMul(root.WithOpName("test/x"), a, b);
1553 Output y = ops::Size(root.WithOpName("test/y"), x);
1554 Output z = ops::Add(root.WithOpName("test/z"), y, y);
1555
1556 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1557 TF_ASSERT_OK(root.ToGraph(graph.get()));
1558
1559 // Ensure that the "Size" op can only be clustered with either the producer or
1560 // consumer by putting them on different devices.
1561 FindNodeByName(graph.get(), "test/x")->set_assigned_device_name(kGPU0);
1562 FindNodeByName(graph.get(), "test/y")->set_assigned_device_name(kCPU0);
1563 FindNodeByName(graph.get(), "test/z")->set_assigned_device_name(kGPU1);
1564
1565 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
1566
1567 std::unordered_map<string, string> clusters = GetClusters(*graph);
1568
1569 EXPECT_NE(clusters["test/y"], "");
1570 EXPECT_EQ(clusters["test/x"], clusters["test/y"]);
1571 EXPECT_NE(clusters["test/z"], clusters["test/y"]);
1572 }
1573
1574 // Test that ShapeConsuming ops are still fully clustered whenever possible.
TEST(XlaCompilationTest,ClusterShapeConsumerWithProducerAndConsumer)1575 TEST(XlaCompilationTest, ClusterShapeConsumerWithProducerAndConsumer) {
1576 Scope root = Scope::NewRootScope().ExitOnError();
1577 Output a = ops::Placeholder(root.WithOpName("test/a"), DT_FLOAT);
1578 Output b = ops::Placeholder(root.WithOpName("test/b"), DT_FLOAT);
1579
1580 Output x = ops::MatMul(root.WithOpName("test/x"), a, b);
1581 Output y = ops::Size(root.WithOpName("test/y"), x);
1582 Output z = ops::Add(root.WithOpName("test/z"), y, y);
1583
1584 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1585 TF_ASSERT_OK(root.ToGraph(graph.get()));
1586
1587 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
1588
1589 std::unordered_map<string, string> clusters = GetClusters(*graph);
1590
1591 EXPECT_NE(clusters["test/y"], "");
1592 EXPECT_EQ(clusters["test/y"], clusters["test/x"]);
1593 EXPECT_EQ(clusters["test/y"], clusters["test/z"]);
1594 }
1595
AddCtrlEdge(const Scope & scope,Operation a,Operation b)1596 void AddCtrlEdge(const Scope& scope, Operation a, Operation b) {
1597 scope.graph()->AddControlEdge(a.node(), b.node());
1598 }
1599
AddCtrlEdge(const Scope & scope,Output a,Operation b)1600 void AddCtrlEdge(const Scope& scope, Output a, Operation b) {
1601 AddCtrlEdge(scope, a.op(), b);
1602 }
1603
AddCtrlEdge(const Scope & scope,Operation a,Output b)1604 void AddCtrlEdge(const Scope& scope, Operation a, Output b) {
1605 AddCtrlEdge(scope, a, b.op());
1606 }
1607
1608 // Tests that we pick a good clustering for graphs that have an integer
1609 // increment operation control dependent on gradient update operations.
TEST(XlaCompilationTest,IterationIncrementAndGroupDeps)1610 TEST(XlaCompilationTest, IterationIncrementAndGroupDeps) {
1611 Scope scope = Scope::NewRootScope().ExitOnError();
1612
1613 Output iter =
1614 ops::VarHandleOp(scope.WithOpName("iter"), DT_INT64, TensorShape({}));
1615 Output weights_0 = ops::VarHandleOp(scope.WithOpName("weights_0"), DT_FLOAT,
1616 TensorShape({1000}));
1617 Output weights_1 = ops::VarHandleOp(scope.WithOpName("weights_1"), DT_FLOAT,
1618 TensorShape({1000}));
1619
1620 // We update the weights by adding delta to them (to "simulate" a
1621 // ResourceApplyGradientDescent and similar things).
1622 Output delta = ops::Placeholder(scope.WithOpName("delta"), DT_FLOAT);
1623
1624 ops::AssignAddVariableOp increment_op(
1625 scope.WithOpName("IncrementIteration"), iter,
1626 ops::Const(scope.WithOpName("one"), static_cast<int64>(1)));
1627
1628 ops::AssignAddVariableOp weights_0_update_op(
1629 scope.WithOpName("weights_0_update"), weights_0, delta);
1630 ops::AssignAddVariableOp weights_1_update_op(
1631 scope.WithOpName("weights_1_update"), weights_1, delta);
1632
1633 ops::NoOp group_deps(scope.WithOpName("group_deps"));
1634
1635 ops::NoOp some_ctrl_input(scope.WithOpName("some_ctrl_input"));
1636
1637 Output matmul_input =
1638 ops::Placeholder(scope.WithOpName("matmul_input"), DT_FLOAT);
1639 Output matmul_0 =
1640 ops::MatMul(scope.WithOpName("matmul_0"), matmul_input, matmul_input);
1641 Output matmul_1 =
1642 ops::MatMul(scope.WithOpName("matmul_1"), matmul_input, matmul_input);
1643
1644 AddCtrlEdge(scope, increment_op, group_deps);
1645 AddCtrlEdge(scope, weights_0_update_op, increment_op);
1646 AddCtrlEdge(scope, weights_1_update_op, increment_op);
1647
1648 AddCtrlEdge(scope, some_ctrl_input, weights_0_update_op);
1649 AddCtrlEdge(scope, some_ctrl_input, weights_1_update_op);
1650
1651 AddCtrlEdge(scope, matmul_0, group_deps);
1652 AddCtrlEdge(scope, matmul_1, group_deps);
1653
1654 AddCtrlEdge(scope, weights_0_update_op, matmul_0);
1655 AddCtrlEdge(scope, weights_1_update_op, matmul_1);
1656
1657 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1658 TF_ASSERT_OK(scope.ToGraph(graph.get()));
1659
1660 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
1661
1662 std::unordered_map<string, string> clusters = GetClusters(*graph);
1663
1664 EXPECT_NE(clusters["some_ctrl_input"], "");
1665 EXPECT_EQ(clusters["some_ctrl_input"], clusters["weights_0_update"]);
1666 EXPECT_EQ(clusters["some_ctrl_input"], clusters["weights_1_update"]);
1667 EXPECT_EQ(clusters["some_ctrl_input"], clusters["matmul_0"]);
1668 EXPECT_EQ(clusters["some_ctrl_input"], clusters["matmul_0"]);
1669 }
1670
1671 // Test a pattern where a special Identity node is driving consts in a loop.
1672 // Expect that the Identity node will not go into any clusters. Note that we
1673 // create an incomplete graph here (e.g., lacking Enter/Exit/NextIteration,
1674 // etc.) just enough to test the pattern, as a complete graph may be too
1675 // cumbersome and unnecessary.
TEST(XlaCompilationTest,DontClusterTheSpecialIdentityDrivingConstsInLoop)1676 TEST(XlaCompilationTest, DontClusterTheSpecialIdentityDrivingConstsInLoop) {
1677 Scope root = Scope::NewRootScope().ExitOnError();
1678
1679 Output cond = ops::Placeholder(root.WithOpName("cond"), DT_BOOL);
1680 Output value = ops::Placeholder(root.WithOpName("value"), DT_FLOAT);
1681 Output loop_cond = ops::LoopCond(root.WithOpName("loop_cond"), cond);
1682 ops::Switch switch_node(root.WithOpName("switch"), value, loop_cond);
1683
1684 Output identity =
1685 ops::Identity(root.WithOpName("identity"), switch_node.output_true);
1686 Output const_node = ops::Const(root.WithOpName("const"), 1.0f);
1687 root.graph()->AddControlEdge(identity.node(), const_node.node());
1688 Output tanh0 = ops::Tanh(root.WithOpName("tanh0"), const_node);
1689 Output tanh1 = ops::Tanh(root.WithOpName("tanh1"), tanh0);
1690 Output add = ops::Add(root.WithOpName("add"), const_node, tanh1);
1691
1692 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1693 TF_EXPECT_OK(root.ToGraph(graph.get()));
1694
1695 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(
1696 &graph,
1697 MarkForCompilationPassTestHelper::Options().WithDeadnessAnalysis()));
1698 auto clusters = GetClusters(*graph);
1699
1700 EXPECT_EQ(clusters["identity"], "");
1701 }
1702
TEST(XlaCompilationTest,UnsupportedEnterExitPattern)1703 TEST(XlaCompilationTest, UnsupportedEnterExitPattern) {
1704 // Regression test for b/32350199, where the autoclustering code introduced a
1705 // deadlock in a graph containing a while loop.
1706 Scope root = Scope::NewRootScope().ExitOnError();
1707 auto a = ops::Placeholder(root.WithOpName("A"), DT_FLOAT);
1708 auto enter_0 = ops::internal::Enter(root.WithOpName("enter_a"), a, "frame");
1709 auto exit_0 = ops::internal::Exit(root.WithOpName("exit_a"), enter_0);
1710 auto tanh = ops::Tanh(root.WithOpName("tanh"), exit_0);
1711 auto enter_1 =
1712 ops::internal::Enter(root.WithOpName("enter_1"), tanh, "frame");
1713 auto exit_1 = ops::internal::Exit(root.WithOpName("exit_1"), enter_1);
1714
1715 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1716 TF_EXPECT_OK(root.ToGraph(graph.get()));
1717
1718 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
1719 auto clusters = GetClusters(*graph);
1720
1721 // Nothing should be compiled.
1722 EXPECT_EQ(0, clusters.size());
1723 }
1724
1725 namespace {
MakeStageNode(GraphDefBuilder & builder,string name,std::initializer_list<DataType> dtypes,absl::Span<const ops::NodeOut> values)1726 Node* MakeStageNode(GraphDefBuilder& builder, string name,
1727 std::initializer_list<DataType> dtypes,
1728 absl::Span<const ops::NodeOut> values) {
1729 auto opts = builder.opts()
1730 .WithName(std::move(name))
1731 .WithAttr("dtypes", std::move(dtypes));
1732 if (opts.HaveError()) {
1733 return nullptr;
1734 }
1735
1736 NodeBuilder node_builder(name, "Stage", opts.op_registry());
1737 node_builder.Input(values);
1738 return opts.FinalizeBuilder(&node_builder);
1739 }
1740 } // namespace
1741
TEST(XlaCompilationTest,StagePipelinePreservedByClusterScopingPass)1742 TEST(XlaCompilationTest, StagePipelinePreservedByClusterScopingPass) {
1743 auto build_staged_graph = [](std::unique_ptr<Graph>* graph) -> Status {
1744 // Construct a graph as below with two pipeline stages and test that nodes
1745 // in different stages will not be merged if ClusterScopingPass is on.
1746 //
1747 // b
1748 // |
1749 // v
1750 // a -> add0 -> relu0 -> stage
1751 //
1752 // b
1753 // |
1754 // v
1755 // unstage -> add1 -> relu1
1756 GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
1757 Node* a = ops::SourceOp("Const", builder.opts()
1758 .WithName("a")
1759 .WithAttr("dtype", DT_FLOAT)
1760 .WithAttr("value", Tensor()));
1761 Node* b = ops::SourceOp("Const", builder.opts()
1762 .WithName("b")
1763 .WithAttr("dtype", DT_FLOAT)
1764 .WithAttr("value", Tensor()));
1765 Node* unstage = ops::SourceOp(
1766 "Unstage",
1767 builder.opts().WithName("unstage").WithAttr("dtypes", {DT_FLOAT}));
1768
1769 Node* add0 = ops::BinaryOp("Add", a, b, builder.opts().WithName("add0"));
1770 Node* add1 =
1771 ops::BinaryOp("Add", unstage, b, builder.opts().WithName("add1"));
1772 Node* relu0 = ops::UnaryOp("Relu", add0, builder.opts().WithName("relu0"));
1773 ops::UnaryOp("Relu", add1, builder.opts().WithName("relu1"));
1774 MakeStageNode(builder, "stage", {DT_FLOAT}, {relu0});
1775
1776 return GraphDefBuilderToGraph(builder, graph->get());
1777 };
1778
1779 // All nodes go into the same cluster if ClusterScopingPass is off.
1780 {
1781 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1782 TF_ASSERT_OK(build_staged_graph(&graph));
1783
1784 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(
1785 &graph,
1786 MarkForCompilationPassTestHelper::Options().WithNoClusterScoping()));
1787
1788 std::unordered_map<string, string> clusters = GetClusters(*graph);
1789 EXPECT_EQ(clusters["add0"], clusters["add1"]);
1790 EXPECT_EQ(clusters["add0"], clusters["relu1"]);
1791 EXPECT_EQ(clusters["relu0"], clusters["add1"]);
1792 EXPECT_EQ(clusters["relu0"], clusters["relu1"]);
1793 }
1794
1795 // By default, ClusterScopingPass is on and different pipeline stages should
1796 // not be merged.
1797 {
1798 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1799 TF_ASSERT_OK(build_staged_graph(&graph));
1800
1801 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
1802
1803 std::unordered_map<string, string> clusters = GetClusters(*graph);
1804 EXPECT_NE(clusters["add0"], clusters["add1"]);
1805 EXPECT_NE(clusters["add0"], clusters["relu1"]);
1806 EXPECT_NE(clusters["relu0"], clusters["add1"]);
1807 EXPECT_NE(clusters["relu0"], clusters["relu1"]);
1808 }
1809 }
TEST(XlaCompilationTest,XLALiteAllowlist)1810 TEST(XlaCompilationTest, XLALiteAllowlist) {
1811 auto* allowlist_table = tensorflow::GetAllowlistTable();
1812 absl::flat_hash_set<string> hallowlist;
1813 std::vector<string> vall_ops = XlaOpRegistry::GetAllRegisteredOps();
1814 absl::flat_hash_set<string> all_ops(vall_ops.begin(), vall_ops.end());
1815
1816 // Check that all the operations in the table are existing TF operations
1817 for (auto pair : *allowlist_table) {
1818 hallowlist.insert(pair.second.begin(), pair.second.end());
1819 for (auto op : pair.second) {
1820 ASSERT_TRUE(all_ops.contains(op));
1821 }
1822 }
1823
1824 // Check that all registered XLA operation are in the allowlist
1825 // table or are known to not be in it.
1826
1827 absl::flat_hash_set<string> known_not_in_list =
1828 tensorflow::testing::GetKnownXLAAllowlistOp();
1829 std::vector<string> unknow_op;
1830 for (string op : vall_ops) {
1831 if (!hallowlist.contains(op) && !known_not_in_list.contains(op)) {
1832 unknow_op.push_back(op);
1833 }
1834 }
1835 EXPECT_TRUE(unknow_op.empty())
1836 << "Someone added support for a new TF opeations inside XLA. They must "
1837 "be included in the XLALite allowlist or denylist:\n"
1838 << absl::StrJoin(unknow_op, "\n");
1839 }
1840 } // namespace
1841 } // namespace tensorflow
1842