1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.h"
17
18 #include "absl/container/flat_hash_map.h"
19 #include "absl/memory/memory.h"
20 #include "absl/strings/match.h"
21 #include "tensorflow/cc/framework/ops.h"
22 #include "tensorflow/cc/ops/array_ops.h"
23 #include "tensorflow/cc/ops/control_flow_ops_internal.h"
24 #include "tensorflow/cc/ops/function_ops.h"
25 #include "tensorflow/cc/ops/list_ops.h"
26 #include "tensorflow/cc/ops/resource_variable_ops.h"
27 #include "tensorflow/cc/ops/sendrecv_ops.h"
28 #include "tensorflow/cc/ops/standard_ops.h"
29 #include "tensorflow/compiler/jit/defs.h"
30 #include "tensorflow/compiler/jit/node_matchers.h"
31 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
32 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
33 #include "tensorflow/core/framework/node_def_util.h"
34 #include "tensorflow/core/framework/op.h"
35 #include "tensorflow/core/graph/algorithm.h"
36 #include "tensorflow/core/graph/graph_constructor.h"
37 #include "tensorflow/core/graph/graph_def_builder.h"
38 #include "tensorflow/core/graph/graph_def_builder_util.h"
39 #include "tensorflow/core/lib/core/status_test_util.h"
40 #include "tensorflow/core/platform/test.h"
41
42 using ::tensorflow::testing::FindNodeByName;
43
44 namespace tensorflow {
45 namespace {
46
47 REGISTER_OP("UncompilableNullary").Output("o: float");
48 REGISTER_OP("UncompilableUnary").Input("a: float").Output("o: float");
49
GetClusters(const Graph & graph)50 std::unordered_map<string, string> GetClusters(const Graph& graph) {
51 std::unordered_map<string, string> ids;
52 for (Node* node : graph.nodes()) {
53 string cluster;
54 if (GetNodeAttr(node->attrs(), kXlaClusterAttr, &cluster).ok()) {
55 CHECK(!cluster.empty());
56 ids[node->name()] = cluster;
57 }
58 }
59
60 if (VLOG_IS_ON(2)) {
61 VLOG(2) << "Clusters:";
62 for (const auto& p : ids) {
63 VLOG(2) << " " << p.first << " -> " << p.second;
64 }
65 }
66 return ids;
67 }
68
GetClusterSets(const Graph & g,std::vector<string> * cluster_names=nullptr)69 absl::flat_hash_map<string, std::vector<string>> GetClusterSets(
70 const Graph& g, std::vector<string>* cluster_names = nullptr) {
71 CHECK(cluster_names == nullptr || cluster_names->empty());
72 absl::flat_hash_map<string, std::vector<string>> cluster_sets;
73 for (const auto& p : GetClusters(g)) {
74 cluster_sets[p.second].push_back(p.first);
75 }
76 for (auto& p : cluster_sets) {
77 if (cluster_names != nullptr) {
78 cluster_names->push_back(p.first);
79 }
80 std::sort(p.second.begin(), p.second.end());
81 }
82 if (cluster_names != nullptr) {
83 std::sort(cluster_names->begin(), cluster_names->end());
84 }
85 return cluster_sets;
86 }
87
TEST(XlaCompilationTest,Chains)88 TEST(XlaCompilationTest, Chains) {
89 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
90 GraphDef graphdef;
91 {
92 GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
93 Node* a =
94 ops::SourceOp("UncompilableNullary", builder.opts().WithName("A"));
95 Node* b = ops::UnaryOp("Relu", a, builder.opts().WithName("B"));
96 Node* c = ops::UnaryOp("Relu", b, builder.opts().WithName("C"));
97 Node* d =
98 ops::UnaryOp("UncompilableUnary", c, builder.opts().WithName("D"));
99 Node* e = ops::UnaryOp("Relu", d, builder.opts().WithName("E"));
100 ops::UnaryOp("Relu", e, builder.opts().WithName("F"));
101 TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
102 }
103
104 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
105 auto clusters = GetClusters(*graph);
106 EXPECT_EQ(4, clusters.size());
107 EXPECT_EQ(clusters["B"], clusters["C"]);
108 EXPECT_EQ(clusters["E"], clusters["F"]);
109 EXPECT_NE(clusters["B"], clusters["E"]);
110 EXPECT_TRUE(clusters.find("A") == clusters.cend());
111 EXPECT_TRUE(clusters.find("D") == clusters.cend());
112 }
113
TEST(XlaCompilationTest,UncompilableCycles)114 TEST(XlaCompilationTest, UncompilableCycles) {
115 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
116 GraphDef graphdef;
117 {
118 GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
119 Node* a = ops::SourceOp("Const", builder.opts()
120 .WithName("A")
121 .WithAttr("dtype", DT_FLOAT)
122 .WithAttr("value", Tensor()));
123 Node* b =
124 ops::UnaryOp("UncompilableUnary", a, builder.opts().WithName("B"));
125 ops::BinaryOp("MatMul", a, b, builder.opts().WithName("C"));
126 TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
127 }
128
129 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
130 auto clusters = GetClusters(*graph);
131
132 EXPECT_TRUE(clusters.empty());
133 }
134
TEST(XlaCompilationTest,CompilableCycles)135 TEST(XlaCompilationTest, CompilableCycles) {
136 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
137 GraphDef graphdef;
138 {
139 GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
140 Node* a = ops::SourceOp("Const", builder.opts()
141 .WithName("A")
142 .WithAttr("dtype", DT_FLOAT)
143 .WithAttr("value", Tensor()));
144 Node* b = ops::UnaryOp("Relu", a, builder.opts().WithName("B"));
145 ops::BinaryOp("MatMul", a, b, builder.opts().WithName("C"));
146 TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
147 }
148
149 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
150 auto clusters = GetClusters(*graph);
151
152 EXPECT_EQ(3, clusters.size());
153 EXPECT_EQ(clusters["A"], clusters["B"]);
154 EXPECT_EQ(clusters["A"], clusters["C"]);
155 }
156
TEST(XlaCompilationTest,StringUnsupported)157 TEST(XlaCompilationTest, StringUnsupported) {
158 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
159 GraphDef graphdef;
160 {
161 GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
162 Node* a = ops::SourceOp(
163 "Const", builder.opts()
164 .WithName("A")
165 .WithAttr("dtype", DT_STRING)
166 .WithAttr("value", Tensor(DT_STRING, TensorShape())));
167 Node* b = ops::UnaryOp("EncodeBase64", a, builder.opts().WithName("B"));
168 ops::BinaryOp("StringSplit", a, b, builder.opts().WithName("C"));
169 TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
170 }
171
172 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
173 auto clusters = GetClusters(*graph);
174 EXPECT_TRUE(clusters.empty());
175 }
176
TEST(XlaCompilationTest,HalfSupported)177 TEST(XlaCompilationTest, HalfSupported) {
178 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
179 GraphDef graphdef;
180 {
181 GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
182 Tensor t(DT_HALF, TensorShape());
183 t.scalar<Eigen::half>()() = static_cast<Eigen::half>(0.0f);
184 Node* a = ops::SourceOp("Const", builder.opts()
185 .WithName("A")
186 .WithAttr("dtype", DT_HALF)
187 .WithAttr("value", t));
188 Node* b = ops::UnaryOp("Neg", a, builder.opts().WithName("B"));
189 ops::BinaryOp("MatMul", a, b, builder.opts().WithName("C"));
190 TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
191 }
192
193 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
194 auto clusters = GetClusters(*graph);
195 EXPECT_FALSE(clusters.empty());
196 }
197
TEST(XlaCompilationTest,FunctionCalls)198 TEST(XlaCompilationTest, FunctionCalls) {
199 FunctionDef compilable = FunctionDefHelper::Define(
200 "CompilableFn", {"n_a:float", "n_b:float"}, {"n_c:float"}, {},
201 {{{"n_c"}, "Add", {"n_a", "n_b"}, {{"T", DT_FLOAT}}}});
202 FunctionDef uncompilable =
203 FunctionDefHelper::Define("UncompilableFn", {"n_a:float"}, {"n_c:float"},
204 {}, {{{"n_c"}, "UncompilableUnary", {"n_a"}}});
205 FunctionDef noinline = compilable;
206 noinline.mutable_signature()->set_name("NoInlineFn");
207 AddAttr("_noinline", static_cast<bool>(true), noinline.mutable_attr());
208
209 FunctionDefLibrary flib;
210 *flib.add_function() = compilable;
211 *flib.add_function() = uncompilable;
212 *flib.add_function() = noinline;
213 FunctionLibraryDefinition flib_def(OpRegistry::Global(), flib);
214
215 std::unique_ptr<Graph> graph(new Graph(&flib_def));
216 GraphDef graphdef;
217 {
218 GraphDefBuilder builder(GraphDefBuilder::kFailImmediately, &flib_def);
219 Node* a =
220 ops::SourceOp("UncompilableNullary", builder.opts().WithName("A"));
221 Node* b = ops::BinaryOp("CompilableFn", a, a, builder.opts().WithName("B"));
222 Node* c = ops::UnaryOp("Relu", b, builder.opts().WithName("C"));
223 ops::UnaryOp("UncompilableFn", c, builder.opts().WithName("D"));
224 ops::BinaryOp("NoInlineFn", c, c, builder.opts().WithName("E"));
225 TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
226 }
227
228 TF_ASSERT_OK(
229 MarkForCompilationPassTestHelper::MarkForCompilation(&graph, &flib_def));
230 auto clusters = GetClusters(*graph);
231
232 EXPECT_EQ(2, clusters.size());
233 EXPECT_FALSE(clusters["B"].empty());
234 EXPECT_EQ(clusters["B"], clusters["C"]);
235 EXPECT_TRUE(clusters.find("A") == clusters.cend());
236 EXPECT_TRUE(clusters.find("D") == clusters.cend());
237 EXPECT_TRUE(clusters.find("E") == clusters.cend());
238 }
239
240 // Metadata-only operators such as Shape/Rank/Size may not be the root of a
241 // cluster. This is partially to work around b/26800664, and partially because
242 // we should probably prefer to compile metadata operators with their producers
243 // wherever possible, rather than their consumers.
TEST(XlaCompilationTest,MetadataOpsDontStartClusters)244 TEST(XlaCompilationTest, MetadataOpsDontStartClusters) {
245 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
246 GraphDef graphdef;
247 {
248 GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
249 Node* a =
250 ops::SourceOp("UncompilableNullary", builder.opts().WithName("A"));
251 // While all of the following ops are notionally compilable, none is
252 // permitted
253 // to start a cluster. So nothing should be compiled.
254 Node* b = ops::UnaryOp("Shape", a, builder.opts().WithName("B"));
255 Node* c = ops::UnaryOp("Rank", b, builder.opts().WithName("C"));
256 Node* d = ops::UnaryOp("Size", c, builder.opts().WithName("D"));
257 ops::UnaryOp("Shape", d, builder.opts().WithName("E"));
258 TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
259 }
260 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
261 auto clusters = GetClusters(*graph);
262 EXPECT_EQ(0, clusters.size()); // Nothing should be compiled.
263 }
264
GradForUnaryCwise(FunctionDef * g,std::vector<FunctionDefHelper::Node> nodes)265 static Status GradForUnaryCwise(FunctionDef* g,
266 std::vector<FunctionDefHelper::Node> nodes) {
267 for (auto& n : nodes) {
268 if (n.attr.empty()) {
269 n.attr = {{"T", DT_FLOAT}};
270 }
271 }
272 *g = FunctionDefHelper::Define(
273 // Arg defs
274 {"x: float", "dy: float"},
275 // Ret val defs
276 {"dx: float"},
277 // Attr defs
278 {},
279 // Nodes
280 nodes);
281 return Status::OK();
282 }
283
284 // A gradient containing only supported operators
SupportedGrad(const AttrSlice & attrs,FunctionDef * g)285 Status SupportedGrad(const AttrSlice& attrs, FunctionDef* g) {
286 // clang-format off
287 return GradForUnaryCwise(g, {
288 {{"y"}, "Tanh", {"x"}},
289 {{"y2"}, "Square", {"y"}, {}, {"dy"}},
290 FunctionDefHelper::Const("one", 1.0f),
291 {{"a"}, "Sub", {"one", "y2"}},
292 {{"dx"}, "Mul", {"dy", "a"}},
293 });
294 // clang-format on
295 }
296 REGISTER_OP_GRADIENT("Supported", SupportedGrad);
297
298 // A gradient containing an unsupported operator.
UnsupportedGrad(const AttrSlice & attrs,FunctionDef * g)299 Status UnsupportedGrad(const AttrSlice& attrs, FunctionDef* g) {
300 // clang-format off
301 return GradForUnaryCwise(g, {
302 {{"y"}, "Tanh", {"x"}},
303 {{"y2"}, "UncompilableUnary", {"y"}, {}, {"dy"}},
304 FunctionDefHelper::Const("one", 1.0f),
305 {{"a"}, "Sub", {"one", "y2"}},
306 {{"dx"}, "Mul", {"dy", "a"}},
307 });
308 // clang-format on
309 }
310 REGISTER_OP_GRADIENT("Unsupported", UnsupportedGrad);
311
TEST(XlaCompilationTest,SymbolicGradients)312 TEST(XlaCompilationTest, SymbolicGradients) {
313 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
314 GraphDef graphdef;
315 {
316 GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
317 Node* a =
318 ops::SourceOp("UncompilableNullary", builder.opts().WithName("A"));
319
320 // Builds a Symbolic gradient for Supported
321 NodeBuilder b_builder("B", "SymbolicGradient",
322 builder.opts().op_registry());
323 NameAttrList b_name_attr;
324 b_name_attr.set_name("Supported");
325 b_builder.Attr("f", b_name_attr);
326 b_builder.Attr("Tin", {DT_FLOAT, DT_FLOAT});
327 b_builder.Attr("Tout", {DT_FLOAT});
328 b_builder.Input({a, a});
329 Node* b = builder.opts().FinalizeBuilder(&b_builder);
330
331 Node* c = ops::UnaryOp("Relu", b, builder.opts().WithName("C"));
332
333 // Builds a Symbolic gradient for Unsupported
334 NodeBuilder d_builder("D", "SymbolicGradient",
335 builder.opts().op_registry());
336 NameAttrList d_name_attr;
337 d_name_attr.set_name("Unsupported");
338 d_builder.Attr("f", d_name_attr);
339 d_builder.Attr("Tin", {DT_FLOAT, DT_FLOAT});
340 d_builder.Attr("Tout", {DT_FLOAT});
341 d_builder.Input({c, c});
342 builder.opts().FinalizeBuilder(&d_builder);
343
344 TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
345 }
346
347 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
348 auto clusters = GetClusters(*graph);
349
350 EXPECT_EQ(2, clusters.size());
351 EXPECT_FALSE(clusters["B"].empty());
352 EXPECT_EQ(clusters["B"], clusters["C"]);
353 EXPECT_TRUE(clusters.find("A") == clusters.cend());
354 EXPECT_TRUE(clusters.find("D") == clusters.cend());
355 }
356
TEST(XlaCompilationTest,Loops)357 TEST(XlaCompilationTest, Loops) {
358 // Regression test for b/32350199, where the autoclustering code introduced a
359 // deadlock in a graph containing a while loop.
360 Scope root = Scope::NewRootScope().ExitOnError();
361 auto a = ops::Placeholder(root.WithOpName("A"), DT_FLOAT);
362 auto b = ops::Placeholder(root.WithOpName("B"), DT_FLOAT);
363 auto c = ops::Add(root.WithOpName("C"), a, b);
364 auto enter = ops::internal::Enter(root, c, "aframe");
365 auto next_iter = ops::NextIteration(root, enter);
366 auto exit = ops::internal::Exit(root, next_iter);
367 auto d = ops::Add(root.WithOpName("D"), c, exit);
368
369 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
370 TF_EXPECT_OK(root.ToGraph(graph.get()));
371
372 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
373 auto clusters = GetClusters(*graph);
374
375 // Nothing should be compiled. In particular, 'd' and 'c' must not be
376 // compiled.
377 EXPECT_EQ(0, clusters.size());
378 }
379
TEST(XlaCompilationTest,CyclesWithAllDifferentScopesGlobalJitOverridden)380 TEST(XlaCompilationTest, CyclesWithAllDifferentScopesGlobalJitOverridden) {
381 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
382 GraphDef graphdef;
383 {
384 GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
385 Node* a = ops::SourceOp("Const", builder.opts()
386 .WithName("A")
387 .WithAttr("dtype", DT_FLOAT)
388 .WithAttr("value", Tensor())
389 .WithAttr(kXlaScopeAttr, "ScopeA"));
390 Node* b = ops::UnaryOp(
391 "Relu", a,
392 builder.opts().WithName("B").WithAttr(kXlaScopeAttr, "ScopeB"));
393 ops::BinaryOp(
394 "MatMul", a, b,
395 builder.opts().WithName("C").WithAttr(kXlaScopeAttr, "ScopeC"));
396 TF_CHECK_OK(GraphDefBuilderToGraph(builder, graph.get()));
397 }
398
399 FunctionDefLibrary flib;
400 FunctionLibraryDefinition flib_def(graph->op_registry(), flib);
401 TF_ASSERT_OK(
402 MarkForCompilationPassTestHelper::MarkForCompilation(&graph, &flib_def));
403 auto clusters = GetClusters(*graph);
404
405 // The computation is: C = A + relu(A)
406 // where A sits in ScopeA, relu(A) sits in ScopeB, and C sits in ScopeC.
407 // In this case, the GlobalJitLevel overrides the scopes to cluster while
408 // ignoring scopes.
409 EXPECT_EQ(3, clusters.size());
410 EXPECT_EQ(clusters["A"], clusters["B"]);
411 EXPECT_EQ(clusters["A"], clusters["C"]);
412 }
413
TEST(XlaCompilationTest,CyclesWithAllDifferentScopes)414 TEST(XlaCompilationTest, CyclesWithAllDifferentScopes) {
415 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
416 GraphDef graphdef;
417 {
418 GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
419 Node* a = ops::SourceOp("Const", builder.opts()
420 .WithName("A")
421 .WithAttr("dtype", DT_FLOAT)
422 .WithAttr("value", Tensor())
423 .WithAttr(kXlaScopeAttr, "ScopeA"));
424 Node* b = ops::UnaryOp(
425 "Relu", a,
426 builder.opts().WithName("B").WithAttr(kXlaScopeAttr, "ScopeB"));
427 ops::BinaryOp(
428 "MatMul", a, b,
429 builder.opts().WithName("C").WithAttr(kXlaScopeAttr, "ScopeC"));
430 TF_CHECK_OK(GraphDefBuilderToGraph(builder, graph.get()));
431 }
432
433 TF_ASSERT_OK(
434 MarkForCompilationPassTestHelper::MarkForCompilation(&graph, false));
435 auto clusters = GetClusters(*graph);
436
437 // The computation is: C = A + relu(A)
438 // where A sits in ScopeA, relu(A) sits in ScopeB, and C sits in ScopeC.
439 // In this case, we cannot fuse anything, and there are no clusters.
440 EXPECT_EQ(0, clusters.size());
441 }
442
TEST(XlaCompilationTest,CyclesWithSplittingScopes)443 TEST(XlaCompilationTest, CyclesWithSplittingScopes) {
444 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
445 GraphDef graphdef;
446 {
447 GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
448 Node* a = ops::SourceOp("Const", builder.opts()
449 .WithName("A")
450 .WithAttr("dtype", DT_FLOAT)
451 .WithAttr("value", Tensor())
452 .WithAttr(kXlaCompileAttr, true)
453 .WithAttr(kXlaScopeAttr, "Scope1"));
454 Node* b = ops::UnaryOp("Relu", a,
455 builder.opts()
456 .WithName("B")
457 .WithAttr(kXlaCompileAttr, true)
458 .WithAttr(kXlaScopeAttr, "Scope1"));
459 Node* c = ops::BinaryOp("MatMul", a, b,
460 builder.opts()
461 .WithName("C")
462 .WithAttr(kXlaCompileAttr, true)
463 .WithAttr(kXlaScopeAttr, "Scope2"));
464 ops::BinaryOp("Add", b, c,
465 builder.opts()
466 .WithName("D")
467 .WithAttr(kXlaCompileAttr, true)
468 .WithAttr(kXlaScopeAttr, "Scope2"));
469 TF_CHECK_OK(GraphDefBuilderToGraph(builder, graph.get()));
470 }
471
472 TF_ASSERT_OK(
473 MarkForCompilationPassTestHelper::MarkForCompilation(&graph, false));
474 auto clusters = GetClusters(*graph);
475
476 // The computation is: D = relu(A) + (A @ relu(A))
477 // where A and relu(A) are in Scope1, and the @, + ops are in Scope2.
478 // In this case, we can fuse the A and relu(A), and we can fuse the
479 // second half of the operations; there are two clusters.
480 EXPECT_EQ(4, clusters.size());
481 EXPECT_EQ(clusters["A"], clusters["B"]);
482 EXPECT_NE(clusters["A"], clusters["C"]);
483 EXPECT_EQ(clusters["C"], clusters["D"]);
484 }
485
TEST(XlaCompilationTest,CyclesWithDifferentScopesAndBridge)486 TEST(XlaCompilationTest, CyclesWithDifferentScopesAndBridge) {
487 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
488 GraphDef graphdef;
489 {
490 GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
491 Node* a = ops::SourceOp("Const", builder.opts()
492 .WithName("A")
493 .WithAttr("dtype", DT_FLOAT)
494 .WithAttr("value", Tensor())
495 .WithAttr(kXlaCompileAttr, true)
496 .WithAttr(kXlaScopeAttr, "ScopeA"));
497 Node* b = ops::UnaryOp("Relu", a,
498 builder.opts()
499 .WithName("B")
500 .WithAttr(kXlaCompileAttr, true)
501 .WithAttr(kXlaScopeAttr, "ScopeB"));
502 ops::BinaryOp("MatMul", a, b, builder.opts().WithName("C"));
503 TF_CHECK_OK(GraphDefBuilderToGraph(builder, graph.get()));
504 }
505
506 TF_ASSERT_OK(
507 MarkForCompilationPassTestHelper::MarkForCompilation(&graph, false));
508 auto clusters = GetClusters(*graph);
509
510 // The computation is: C = A @ relu(A)
511 // where A sits in ScopeA, relu(A) sits in ScopeB, and C sits in ScopeC.
512 // In this case, we cannot fuse anything.
513 EXPECT_EQ(3, clusters.size());
514 EXPECT_NE(clusters["A"], clusters["B"]);
515 EXPECT_EQ(clusters["B"], clusters["C"]);
516 }
517
518 namespace {
MakeRead(const Scope & scope,const string & id,Node ** var_handle_op=nullptr)519 Node* MakeRead(const Scope& scope, const string& id,
520 Node** var_handle_op = nullptr) {
521 Output var_handle =
522 ops::VarHandleOp(scope.WithOpName("Var" + id), DT_FLOAT, TensorShape({}));
523 Output read =
524 ops::ReadVariableOp(scope.WithOpName("Read" + id), var_handle, DT_FLOAT);
525 if (var_handle_op) {
526 *var_handle_op = var_handle.node();
527 }
528 return read.node();
529 }
530
MakeWrite(const Scope & scope,const string & id)531 Node* MakeWrite(const Scope& scope, const string& id) {
532 Output var_handle =
533 ops::VarHandleOp(scope.WithOpName("Var" + id), DT_FLOAT, TensorShape({}));
534 Output value_to_write =
535 ops::Const(scope.WithOpName("ValueToAssign" + id), 1.0f);
536 ops::AssignVariableOp assign_op(scope.WithOpName("Assignment" + id),
537 var_handle, value_to_write);
538 return assign_op.operation.node();
539 }
540
MakeNeutral(const Scope & scope,const string & id)541 Node* MakeNeutral(const Scope& scope, const string& id) {
542 return ops::Const(scope.WithOpName("Const" + id), 42.0f).node();
543 }
544 } // namespace
545
TEST(XlaCompilationTest,ResourcesClusteringAllowed)546 TEST(XlaCompilationTest, ResourcesClusteringAllowed) {
547 Scope root = Scope::NewRootScope().ExitOnError();
548
549 Node* read = MakeRead(root, "R");
550 Node* write = MakeWrite(root, "W");
551
552 root.graph()->AddControlEdge(read, write);
553
554 FixupSourceAndSinkEdges(root.graph());
555 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
556 TF_EXPECT_OK(root.ToGraph(graph.get()));
557 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
558 absl::flat_hash_map<string, std::vector<string>> cluster_sets =
559 GetClusterSets(*graph);
560 ASSERT_EQ(cluster_sets.size(), 1);
561 std::vector<string> expected_clustered_nodes = {"AssignmentW", "ReadR",
562 "ValueToAssignW"};
563 ASSERT_EQ(cluster_sets.begin()->second, expected_clustered_nodes);
564 }
565
TEST(XlaCompilationTest,ResourcesClusteringDisallowed)566 TEST(XlaCompilationTest, ResourcesClusteringDisallowed) {
567 Scope root = Scope::NewRootScope().ExitOnError();
568
569 Node* read = MakeRead(root, "R");
570 Node* write = MakeWrite(root, "W");
571
572 root.graph()->AddControlEdge(write, read);
573
574 FixupSourceAndSinkEdges(root.graph());
575 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
576 TF_EXPECT_OK(root.ToGraph(graph.get()));
577 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
578 absl::flat_hash_map<string, std::vector<string>> cluster_sets =
579 GetClusterSets(*graph);
580 ASSERT_EQ(cluster_sets.size(), 0);
581 }
582
TEST(XlaCompilationTest,ChainOfOps)583 TEST(XlaCompilationTest, ChainOfOps) {
584 Scope root = Scope::NewRootScope().ExitOnError();
585
586 Node* write_0 = MakeWrite(root, "W0");
587 Node* neutral_0 = MakeNeutral(root, "N0");
588 Node* read_0 = MakeRead(root, "R0");
589 Node* write_1 = MakeWrite(root, "W1");
590 Node* neutral_1 = MakeNeutral(root, "N1");
591 Node* read_1 = MakeRead(root, "R1");
592
593 root.graph()->AddControlEdge(write_0, neutral_0);
594 root.graph()->AddControlEdge(neutral_0, read_0);
595 root.graph()->AddControlEdge(read_0, write_1);
596 root.graph()->AddControlEdge(write_1, neutral_1);
597 root.graph()->AddControlEdge(neutral_1, read_1);
598
599 FixupSourceAndSinkEdges(root.graph());
600 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
601 TF_EXPECT_OK(root.ToGraph(graph.get()));
602 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
603
604 std::vector<string> cluster_names;
605 absl::flat_hash_map<string, std::vector<string>> cluster_sets =
606 GetClusterSets(*graph, &cluster_names);
607
608 ASSERT_EQ(cluster_sets.size(), 1);
609
610 std::vector<string> expected_clustered_nodes_a = {
611 "AssignmentW1", "ConstN1", "ReadR0", "ValueToAssignW1"};
612 ASSERT_EQ(cluster_sets[cluster_names[0]], expected_clustered_nodes_a);
613 }
614
TEST(XlaCompilationTest,IllegalCycle_UsefulErrorMessage)615 TEST(XlaCompilationTest, IllegalCycle_UsefulErrorMessage) {
616 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
617 Scope root = Scope::NewRootScope().ExitOnError();
618 {
619 auto BuildNoopNode = [](absl::string_view name, Graph* graph) {
620 NodeDefBuilder builder(name, "NoOp");
621 NodeDef def;
622 TF_CHECK_OK(builder.Finalize(&def));
623
624 Status status;
625 Node* node = graph->AddNode(def, &status);
626 TF_CHECK_OK(status);
627 return node;
628 };
629
630 Node* a = BuildNoopNode("a", graph.get());
631 Node* b = BuildNoopNode("b", graph.get());
632 Node* c = BuildNoopNode("c", graph.get());
633 graph->AddControlEdge(a, b);
634 graph->AddControlEdge(b, c);
635 graph->AddControlEdge(c, a);
636 }
637
638 TF_EXPECT_OK(root.ToGraph(graph.get()));
639
640 Status status = MarkForCompilationPassTestHelper::MarkForCompilation(&graph);
641 EXPECT_FALSE(status.ok());
642 EXPECT_TRUE(absl::StrContains(status.ToString(),
643 "Edge from c to a would create a cycle.\n"
644 "+-> a\n"
645 "| b\n"
646 "+-- c\n"));
647 }
648
TEST(XlaCompilationTest,Retval)649 TEST(XlaCompilationTest, Retval) {
650 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
651 GraphDef graphdef;
652 {
653 GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
654 Node* a = ops::SourceOp("Const", builder.opts()
655 .WithName("A")
656 .WithAttr("dtype", DT_FLOAT)
657 .WithAttr("value", Tensor()));
658 Node* b = ops::UnaryOp("Relu", a, builder.opts().WithName("B"));
659 ops::UnaryOp("_Retval", b,
660 builder.opts()
661 .WithName("R")
662 .WithAttr("T", DT_FLOAT)
663 .WithAttr("index", 0));
664
665 TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
666 }
667
668 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
669 auto clusters = GetClusters(*graph);
670
671 EXPECT_TRUE(clusters.empty());
672 }
673
TEST(XlaCompilationTest,DontCountIdentityOps)674 TEST(XlaCompilationTest, DontCountIdentityOps) {
675 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
676 Scope root = Scope::NewRootScope().ExitOnError();
677 {
678 auto a = ops::_Arg(root.WithOpName("A"), DT_INT32, 0);
679 auto b = ops::Identity(root.WithOpName("B"), a);
680 auto c = ops::Identity(root.WithOpName("C"), b);
681 auto r = ops::_Retval(root.WithOpName("R"), c, 0);
682 }
683 TF_ASSERT_OK(root.ToGraph(graph.get()));
684 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
685 auto clusters = GetClusters(*graph);
686
687 EXPECT_TRUE(clusters.empty());
688 }
689
TEST(XlaCompilationTest,ConstOp)690 TEST(XlaCompilationTest, ConstOp) {
691 // valid data type
692 {
693 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
694 Scope root = Scope::NewRootScope().ExitOnError();
695 auto c = ops::Const(root.WithOpName("const"), 0.5f);
696 c.node()->AddAttr(kXlaCompileAttr, true);
697 TF_ASSERT_OK(root.ToGraph(graph.get()));
698 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
699 EXPECT_EQ(1, GetClusters(*graph).size());
700 }
701
702 // invalid data type
703 {
704 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
705 Scope root = Scope::NewRootScope().ExitOnError();
706 auto c = ops::Const(root.WithOpName("const"), string("string"));
707 c.node()->AddAttr(kXlaCompileAttr, true);
708 TF_ASSERT_OK(root.ToGraph(graph.get()));
709 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
710 EXPECT_TRUE(GetClusters(*graph).empty());
711 }
712 }
713
TEST(XlaCompilationTest,DontClusterIdentityWithRefInput)714 TEST(XlaCompilationTest, DontClusterIdentityWithRefInput) {
715 Scope root = Scope::NewRootScope().ExitOnError();
716 Output variable = ops::Variable(root.WithOpName("variable"),
717 PartialTensorShape{}, DT_FLOAT);
718 Output read = ops::Identity(root.WithOpName("read"), variable);
719 Output neg = ops::Negate(root.WithOpName("negate"), read);
720 Output add = ops::Add(root.WithOpName("add"), neg, neg);
721 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
722
723 TF_ASSERT_OK(root.ToGraph(graph.get()));
724 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
725
726 std::unordered_map<string, string> clusters = GetClusters(*graph);
727
728 ASSERT_FALSE(clusters.empty());
729 string cluster_name = clusters.begin()->second;
730
731 std::unordered_map<string, string> expected_clusters(
732 {{"negate", cluster_name}, {"add", cluster_name}});
733 EXPECT_EQ(clusters, expected_clusters);
734 }
735
TEST(XlaCompilationTest,ClusterIdentityWithNonRefInput)736 TEST(XlaCompilationTest, ClusterIdentityWithNonRefInput) {
737 Scope root = Scope::NewRootScope().ExitOnError();
738 Output variable = ops::Variable(root.WithOpName("variable"),
739 PartialTensorShape{}, DT_FLOAT);
740 Output read = ops::Identity(root.WithOpName("read"), variable);
741 Output neg = ops::Negate(root.WithOpName("negate"), read);
742 Output identity = ops::Negate(root.WithOpName("identity"), neg);
743 Output add = ops::Add(root.WithOpName("add"), identity, neg);
744 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
745
746 TF_ASSERT_OK(root.ToGraph(graph.get()));
747 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
748
749 std::unordered_map<string, string> clusters = GetClusters(*graph);
750
751 ASSERT_FALSE(clusters.empty());
752 string cluster_name = clusters.begin()->second;
753
754 std::unordered_map<string, string> expected_clusters(
755 {{"negate", cluster_name},
756 {"identity", cluster_name},
757 {"add", cluster_name}});
758 EXPECT_EQ(clusters, expected_clusters);
759 }
760
TEST(XlaCompilationTest,ClusterControlTrigger)761 TEST(XlaCompilationTest, ClusterControlTrigger) {
762 Scope root = Scope::NewRootScope().ExitOnError();
763
764 Output recv_a = ops::_Recv(root.WithOpName("recv_a"), DT_BOOL, "tensor_a",
765 "sender", 0, "receiver");
766 Output recv_b = ops::_Recv(root.WithOpName("recv_b"), DT_BOOL, "tensor_b",
767 "sender", 0, "receiver");
768 Output const_a = ops::Const(root.WithOpName("const_a"), 42);
769
770 ops::ControlTrigger ctrl_trigger_a(root.WithOpName("ctrl_trigger_a"));
771 ops::ControlTrigger ctrl_trigger_b(root.WithOpName("ctrl_trigger_b"));
772 root.graph()->AddControlEdge(recv_a.node(), ctrl_trigger_a.operation.node());
773 root.graph()->AddControlEdge(recv_b.node(), ctrl_trigger_a.operation.node());
774 root.graph()->AddControlEdge(ctrl_trigger_b.operation.node(), const_a.node());
775
776 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
777
778 TF_ASSERT_OK(root.ToGraph(graph.get()));
779 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
780
781 std::unordered_map<string, string> clusters = GetClusters(*graph);
782
783 // TODO(b/118970344): ctrl_trigger_a has inputs with mismatching deadness so
784 // it won't be clustered. ctrl_trigger_b is okay to cluster but we don't
785 // cluster it because of b/118970344.
786 EXPECT_TRUE(clusters.empty());
787 }
788
TEST(XlaCompilationTest,RandomShape)789 TEST(XlaCompilationTest, RandomShape) {
790 Scope root = Scope::NewRootScope().ExitOnError();
791 Output shape_shape = ops::Const(root.WithOpName("shape_shape"), {2}, {1});
792 Output shape =
793 ops::RandomUniformInt(root.WithOpName("shape"), shape_shape,
794 ops::Const(root.WithOpName("minval"), 1),
795 ops::Const(root.WithOpName("maxval"), 20));
796 Output reshape_input =
797 ops::Placeholder(root.WithOpName("reshape_input"), DT_FLOAT,
798 ops::Placeholder::Shape(TensorShape({500, 500})));
799 Output reshape =
800 ops::Reshape(root.WithOpName("reshape"), reshape_input, shape);
801
802 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
803
804 TF_ASSERT_OK(root.ToGraph(graph.get()));
805 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
806
807 std::unordered_map<string, string> clusters = GetClusters(*graph);
808 EXPECT_EQ(clusters["shape"], "");
809 }
810
TEST(XlaCompilationTest,RandomShapeWithFunc)811 TEST(XlaCompilationTest, RandomShapeWithFunc) {
812 Scope root = Scope::DisabledShapeInferenceScope().ExitOnError();
813
814 FunctionDefLibrary flib_def;
815 FunctionDef func = FunctionDefHelper::Create(
816 /*function_name=*/"Stateful_func", /*in_def=*/{},
817 /*out_def=*/{"out: int32"},
818 /*attr_def*/
819 {}, /*node_def=*/
820 {FunctionDefHelper::Const("shape_shape", 2),
821 FunctionDefHelper::Const("minval", 1),
822 FunctionDefHelper::Const("maxval", 20),
823 {{"shape"},
824 "RandomUniformInt",
825 {"shape_shape:output:0", "minval:output:0", "maxval:output:0"},
826 {{"Tout", DataType::DT_INT32}, {"T", DataType::DT_INT32}}}},
827 /*ret_def=*/{{"out", "shape:output:0"}});
828
829 func.mutable_signature()->set_is_stateful(true);
830 *flib_def.add_function() = std::move(func);
831 TF_ASSERT_OK(root.graph()->AddFunctionLibrary(flib_def));
832 NodeDef call_node;
833 call_node.set_name("fn_call");
834 call_node.set_op("Stateful_func");
835 Status status;
836 Node* call = root.graph()->AddNode(call_node, &status);
837 TF_ASSERT_OK(status);
838
839 Output shape = Output(call, 0);
840 Output reshape_input =
841 ops::Placeholder(root.WithOpName("reshape_input"), DT_FLOAT,
842 ops::Placeholder::Shape(TensorShape({500, 500})));
843 Output reshape =
844 ops::Reshape(root.WithOpName("reshape"), reshape_input, shape);
845
846 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
847 TF_ASSERT_OK(root.ToGraph(graph.get()));
848 auto fld = absl::make_unique<FunctionLibraryDefinition>(OpRegistry::Global(),
849 flib_def);
850 TF_ASSERT_OK(
851 MarkForCompilationPassTestHelper::MarkForCompilation(&graph, fld.get()));
852
853 std::unordered_map<string, string> clusters = GetClusters(*graph);
854 EXPECT_EQ(clusters["fn_call"], "");
855 }
856
TEST(XlaCompilationTest,RandomShapeOnXlaDevice)857 TEST(XlaCompilationTest, RandomShapeOnXlaDevice) {
858 absl::string_view xla_gpu_device =
859 "/job:worker/replica:0/task:0/device:XLA_GPU:0";
860
861 Scope root = Scope::NewRootScope().ExitOnError();
862 Output shape_shape =
863 ops::Const(root.WithOpName("test/shape_shape"), {2}, {1});
864 Output shape =
865 ops::RandomUniformInt(root.WithOpName("test/shape_rng"), shape_shape,
866 ops::Const(root.WithOpName("test/minval"), 1),
867 ops::Const(root.WithOpName("test/maxval"), 20));
868 Output reshape_input =
869 ops::Placeholder(root.WithOpName("test/reshape_input"), DT_FLOAT,
870 ops::Placeholder::Shape(TensorShape({500, 500})));
871 Output reshape =
872 ops::Reshape(root.WithOpName("test/reshape"), reshape_input, shape);
873
874 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
875 TF_ASSERT_OK(root.ToGraph(graph.get()));
876
877 for (Node* n : graph->nodes()) {
878 if (absl::StartsWith(n->name(), /*prefix=*/"test/")) {
879 n->set_assigned_device_name(string(xla_gpu_device));
880 }
881 }
882 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
883
884 std::unordered_map<string, string> clusters = GetClusters(*graph);
885 EXPECT_EQ(clusters["test/shape_rng"], "");
886 EXPECT_EQ(clusters["test/reshape"], "");
887 }
888
TEST(XlaCompilationTest,TensorArrayShapeOnXlaDevice)889 TEST(XlaCompilationTest, TensorArrayShapeOnXlaDevice) {
890 absl::string_view xla_gpu_device =
891 "/job:worker/replica:0/task:0/device:XLA_GPU:0";
892 Scope root = Scope::NewRootScope().ExitOnError();
893 ops::TensorArray tensor_array(root.WithOpName("test/tensor_array"), 1,
894 DT_INT32);
895 Output zero = ops::Const(root.WithOpName("test/zero"), 0);
896 ops::TensorArrayWrite tensor_array_write(
897 root.WithOpName("test/write"), tensor_array.handle, zero,
898 ops::Const(root.WithOpName("test/forty_two"), 42.0f), tensor_array.flow);
899 Output tensor_array_read =
900 ops::TensorArrayRead(root.WithOpName("test/read"), tensor_array.handle,
901 zero, tensor_array_write.flow_out, DT_INT32);
902 Output reshape =
903 ops::Reshape(root.WithOpName("test/reshape"),
904 ops::Placeholder(root.WithOpName("placeholder"), DT_FLOAT),
905 tensor_array_read);
906
907 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
908 TF_ASSERT_OK(root.ToGraph(graph.get()));
909
910 for (Node* n : graph->nodes()) {
911 if (absl::StartsWith(n->name(), /*prefix=*/"test/")) {
912 n->set_assigned_device_name(string(xla_gpu_device));
913 }
914 }
915 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
916
917 std::unordered_map<string, string> clusters = GetClusters(*graph);
918 EXPECT_NE(clusters["test/read"], "");
919 EXPECT_EQ(clusters["test/read"], clusters["test/reshape"]);
920 }
921
TEST(XlaCompilationTest,DontClusterMergingNodes)922 TEST(XlaCompilationTest, DontClusterMergingNodes) {
923 // MatMulCombined below takes data from nodes on GPU0 and GPU1 and is placed
924 // on GPU1. However, it should not be clustered with the previous node on
925 // GPU1, because that will serialize production of its inputs that should be
926 // done in parallel.
927 //
928 // This graph is:
929 // (Const0, Const0) -> MatMul0
930 // (Const1, Const1) -> MatMul1
931 // (MatMul0, MatMul1) -> MatMulCombined
932 //
933 // Device0: [Const0, Const0, MatMul0]
934 // Device1: [Const1, Const1, MatMul1, MatMulCombined]
935 //
936 // Cluster0: [Const0, Const0, MatMul0]
937 // Cluster1: [Const1, Const1, MatMul1]
938 // Cluster2: [MatMulCombined]
939 Scope root = Scope::NewRootScope().ExitOnError();
940 absl::string_view xla_gpu_dev0 =
941 "/job:worker/replica:0/task:0/device:XLA_GPU:0";
942 absl::string_view xla_gpu_dev1 =
943 "/job:worker/replica:0/task:0/device:XLA_GPU:1";
944 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
945 Output a = ops::Tanh(root.WithOpName("tanh_A_dev0"),
946 ops::Const(root.WithOpName("A_dev0"), 1.0f, {2, 2}));
947 Output b = ops::Tanh(root.WithOpName("tanh_B_dev1"),
948 ops::Const(root.WithOpName("B_dev1"), 1.0f, {2, 2}));
949 Output matmul0 = ops::MatMul(root.WithOpName("MatMul0_dev0"), a, a);
950 Output matmul1 = ops::MatMul(root.WithOpName("MatMul1_dev1"), b, b);
951
952 Output combined =
953 ops::MatMul(root.WithOpName("MatMulCombined_dev1"), matmul0, matmul1);
954 TF_ASSERT_OK(root.ToGraph(graph.get()));
955
956 for (Node* n : graph->nodes()) {
957 if (absl::EndsWith(n->name(), /*suffix=*/"dev0")) {
958 n->set_assigned_device_name(string(xla_gpu_dev0));
959 } else if (absl::EndsWith(n->name(), /*suffix=*/"dev1")) {
960 n->set_assigned_device_name(string(xla_gpu_dev1));
961 }
962 }
963 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
964
965 // Each of the MatMuls should be in a separate cluster.
966 std::unordered_map<string, string> clusters = GetClusters(*graph);
967 EXPECT_NE(clusters["MatMul0_dev0"], clusters["MatMul1_dev1"]);
968 EXPECT_NE(clusters["MatMulCombined_dev1"], clusters["MatMul0_dev0"]);
969 EXPECT_NE(clusters["MatMulCombined_dev1"], clusters["MatMul1_dev1"]);
970 EXPECT_EQ(clusters["A_dev0"], clusters["MatMul0_dev0"]);
971 EXPECT_EQ(clusters["B_dev1"], clusters["MatMul1_dev1"]);
972 }
973
974 // TODO(b/117085735): This form of clustering should be prevented.
TEST(XlaCompilationTest,NOT_DontClusterSpreadingNodes)975 TEST(XlaCompilationTest, NOT_DontClusterSpreadingNodes) {
976 // MatMulSource below creates data for nodes on GPU0 and GPU1 and is placed
977 // on GPU0. However, it should not be clustered with the next node on
978 // GPU0, because that will prevent the node on GPU1 from beginning its work as
979 // soon as the data has been produced.
980 //
981 // This graph is:
982 // (Const0, Const0) -> MatMulSource
983 // MatMulSource -> (MatMul0, MatMul1)
984 //
985 // Device0: [Const0, Const1, MatMulSource, MatMul0]
986 // Device1: [MatMul1]
987 //
988 // Cluster0: [Const0, Const1, MatMulSource]
989 // Cluster1: [MatMul0]
990 // Cluster2: [MatMul1]
991 Scope root = Scope::NewRootScope().ExitOnError();
992 absl::string_view xla_gpu_dev0 =
993 "/job:worker/replica:0/task:0/device:XLA_GPU:0";
994 absl::string_view xla_gpu_dev1 =
995 "/job:worker/replica:0/task:0/device:XLA_GPU:1";
996 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
997 Output a = ops::Const(root.WithOpName("A_dev0"), 1.0f, {2, 2});
998 Output matmul_source =
999 ops::MatMul(root.WithOpName("MatMulSource_dev0"), a, a);
1000
1001 Output matmul0 = ops::MatMul(root.WithOpName("MatMul0_dev0"), matmul_source,
1002 matmul_source);
1003 Output matmul1 = ops::MatMul(root.WithOpName("MatMul1_dev1"), matmul_source,
1004 matmul_source);
1005
1006 TF_ASSERT_OK(root.ToGraph(graph.get()));
1007 for (Node* n : graph->nodes()) {
1008 if (absl::EndsWith(n->name(), /*suffix=*/"dev0")) {
1009 n->set_assigned_device_name(string(xla_gpu_dev0));
1010 } else if (absl::EndsWith(n->name(), /*suffix=*/"dev1")) {
1011 n->set_assigned_device_name(string(xla_gpu_dev1));
1012 }
1013 }
1014 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
1015
1016 std::unordered_map<string, string> clusters = GetClusters(*graph);
1017 EXPECT_EQ(clusters["A_dev0"], clusters["MatMulSource_dev0"]);
1018 EXPECT_NE(clusters["MatMul0_dev0"], clusters["MatMul1_dev1"]);
1019 EXPECT_NE(clusters["MatMulSource_dev0"], clusters["MatMul1_dev1"]);
1020
1021 // Improved Heuristics should prevent this probably.
1022 EXPECT_EQ(clusters["MatMulSource_dev0"], clusters["MatMul0_dev0"]);
1023 }
1024
TEST(XlaCompilationTest,ClusterStatefulRandomOpOnXlaDevice)1025 TEST(XlaCompilationTest, ClusterStatefulRandomOpOnXlaDevice) {
1026 absl::string_view xla_cpu_device =
1027 "/job:worker/replica:0/task:0/device:XLA_CPU:0";
1028
1029 Scope root = Scope::NewRootScope().ExitOnError();
1030 Output shape = ops::Const(root.WithOpName("test/shape_shape"), {200, 200});
1031 Output a = ops::RandomUniform(root.WithOpName("test/a"), shape, DT_FLOAT);
1032 Output b = ops::RandomUniform(root.WithOpName("test/b"), shape, DT_FLOAT);
1033 Output c = ops::Add(root.WithOpName("test/c"), a, b);
1034
1035 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1036 TF_ASSERT_OK(root.ToGraph(graph.get()));
1037
1038 for (Node* n : graph->nodes()) {
1039 if (absl::StartsWith(n->name(), /*prefix=*/"test/")) {
1040 n->set_assigned_device_name(string(xla_cpu_device));
1041 }
1042 }
1043 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
1044
1045 std::unordered_map<string, string> clusters = GetClusters(*graph);
1046 EXPECT_NE(clusters["test/a"], "");
1047 EXPECT_NE(clusters["test/b"], "");
1048 EXPECT_NE(clusters["test/c"], "");
1049 }
1050
TEST(XlaCompilationTest,DontAutoClusterStatefulRandomOp)1051 TEST(XlaCompilationTest, DontAutoClusterStatefulRandomOp) {
1052 Scope root = Scope::NewRootScope().ExitOnError();
1053 Output shape = ops::Const(root.WithOpName("test/shape_shape"), {200, 200});
1054 Output a = ops::RandomUniform(root.WithOpName("test/a"), shape, DT_FLOAT);
1055 Output b = ops::RandomUniform(root.WithOpName("test/b"), shape, DT_FLOAT);
1056 Output c = ops::Add(root.WithOpName("test/c"), a, b);
1057
1058 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1059 TF_ASSERT_OK(root.ToGraph(graph.get()));
1060
1061 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
1062
1063 std::unordered_map<string, string> clusters = GetClusters(*graph);
1064 EXPECT_EQ(clusters["test/a"], "");
1065 EXPECT_EQ(clusters["test/b"], "");
1066 }
1067
TEST(XlaCompilationTest,ClusterDummyOpsOnXlaDevice)1068 TEST(XlaCompilationTest, ClusterDummyOpsOnXlaDevice) {
1069 absl::string_view xla_cpu_device =
1070 "/job:worker/replica:0/task:0/device:XLA_CPU:0";
1071
1072 Scope root = Scope::NewRootScope().ExitOnError();
1073 Output a = ops::Placeholder(root.WithOpName("test/a"), DT_FLOAT);
1074 Output b = ops::Placeholder(root.WithOpName("test/b"), DT_FLOAT);
1075 Output check =
1076 ops::CheckNumerics(root.WithOpName("test/check"), a, "test/check");
1077 Output ge = ops::GreaterEqual(root.WithOpName("test/greaterequal"), check, b);
1078 Operation assert = ops::Assert(root.WithOpName("test/assert"), ge, {a, b});
1079
1080 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1081 TF_ASSERT_OK(root.ToGraph(graph.get()));
1082
1083 for (Node* n : graph->nodes()) {
1084 if (absl::StartsWith(n->name(), /*prefix=*/"test/")) {
1085 n->set_assigned_device_name(string(xla_cpu_device));
1086 }
1087 }
1088 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
1089
1090 std::unordered_map<string, string> clusters = GetClusters(*graph);
1091 EXPECT_NE(clusters["test/check"], "");
1092 EXPECT_NE(clusters["test/greaterequal"], "");
1093 EXPECT_NE(clusters["test/assert"], "");
1094 }
1095
TEST(XlaCompilationTest,DontAutoClusterDummyOps)1096 TEST(XlaCompilationTest, DontAutoClusterDummyOps) {
1097 Scope root = Scope::NewRootScope().ExitOnError();
1098 Output a = ops::Placeholder(root.WithOpName("test/a"), DT_FLOAT);
1099 Output b = ops::Placeholder(root.WithOpName("test/b"), DT_FLOAT);
1100 Output check =
1101 ops::CheckNumerics(root.WithOpName("test/check"), a, "test/check");
1102 Output ge = ops::GreaterEqual(root.WithOpName("test/greaterequal"), check, b);
1103 Operation assert = ops::Assert(root.WithOpName("test/assert"), ge, {a, b});
1104
1105 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1106 TF_ASSERT_OK(root.ToGraph(graph.get()));
1107
1108 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
1109
1110 std::unordered_map<string, string> clusters = GetClusters(*graph);
1111 EXPECT_EQ(clusters["test/assert"], "");
1112 EXPECT_EQ(clusters["test/check"], "");
1113 }
1114
TEST(XlaCompilationTest,DontAutoClusterOpsProducingVariant)1115 TEST(XlaCompilationTest, DontAutoClusterOpsProducingVariant) {
1116 Scope root = Scope::NewRootScope().ExitOnError();
1117 Output a = ops::Placeholder(root.WithOpName("test/a"), DT_INT64);
1118 Output b = ops::Placeholder(root.WithOpName("test/b"), DT_INT64);
1119
1120 Output cast_a = ops::Cast(root.WithOpName("test/cast_a"), a, DT_INT32);
1121 Output cast_b = ops::Cast(root.WithOpName("test/cast_b"), b, DT_INT32);
1122
1123 Output tensor_list_reserve = ops::TensorListReserve(
1124 root.WithOpName("test/tensor_list_reserve"), cast_a, cast_b, DT_FLOAT);
1125
1126 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1127 TF_ASSERT_OK(root.ToGraph(graph.get()));
1128
1129 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
1130
1131 std::unordered_map<string, string> clusters = GetClusters(*graph);
1132 EXPECT_EQ(clusters["test/tensor_list_reserve"], "");
1133 }
1134
TEST(XlaCompilationTest,DontAutoClusterOpsConsumingVariant)1135 TEST(XlaCompilationTest, DontAutoClusterOpsConsumingVariant) {
1136 Scope root = Scope::NewRootScope().ExitOnError();
1137 Output dummy_input =
1138 ops::Placeholder(root.WithOpName("test/dummy_input"), DT_INT64);
1139 Output variant_input =
1140 ops::Placeholder(root.WithOpName("test/variant_input"), DT_VARIANT);
1141
1142 // Create one more node so that we don't avoid creating a cluster solely
1143 // because it would be trivial.
1144 Output dummy_cast =
1145 ops::Cast(root.WithOpName("test/dummy_cast"), dummy_input, DT_INT32);
1146
1147 Output tensor_list_element_shape = ops::TensorListElementShape(
1148 root.WithOpName("test/tensor_list_element_shape"), variant_input,
1149 DT_INT32);
1150
1151 root.graph()->AddControlEdge(dummy_cast.node(),
1152 tensor_list_element_shape.node());
1153
1154 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1155 TF_ASSERT_OK(root.ToGraph(graph.get()));
1156
1157 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
1158
1159 std::unordered_map<string, string> clusters = GetClusters(*graph);
1160 EXPECT_EQ(clusters["test/tensor_list_element_shape"], "");
1161 }
1162
TEST(XlaCompilationTest,ClusterOpsProducingVariantIfOnXlaDevice)1163 TEST(XlaCompilationTest, ClusterOpsProducingVariantIfOnXlaDevice) {
1164 Scope root = Scope::NewRootScope().ExitOnError();
1165 Output a = ops::Placeholder(root.WithOpName("test/a"), DT_INT64);
1166 Output b = ops::Placeholder(root.WithOpName("test/b"), DT_INT64);
1167
1168 Output cast_a = ops::Cast(root.WithOpName("test/cast_a"), a, DT_INT32);
1169 Output cast_b = ops::Cast(root.WithOpName("test/cast_b"), b, DT_INT32);
1170
1171 Output tensor_list_reserve = ops::TensorListReserve(
1172 root.WithOpName("test/tensor_list_reserve"), cast_a, cast_b, DT_FLOAT);
1173
1174 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1175 TF_ASSERT_OK(root.ToGraph(graph.get()));
1176
1177 string xla_cpu_device = "/job:worker/replica:0/task:0/device:XLA_CPU:0";
1178 for (Node* n : graph->nodes()) {
1179 if (absl::StartsWith(n->name(), /*prefix=*/"test/")) {
1180 n->set_assigned_device_name(xla_cpu_device);
1181 }
1182 }
1183
1184 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
1185
1186 std::unordered_map<string, string> clusters = GetClusters(*graph);
1187 EXPECT_NE(clusters["test/tensor_list_reserve"], "");
1188 }
1189
1190 const char* kCPU0 = "/job:worker/replica:0/task:0/device:CPU:0";
1191 const char* kGPU0 = "/job:worker/replica:0/task:0/device:GPU:0";
1192 const char* kXLA_GPU0 = "/job:worker/replica:0/task:0/device:XLA_GPU:0";
1193 const char* kGPU1 = "/job:worker/replica:0/task:0/device:GPU:1";
1194
TEST(XlaCompilationTest,CreateCombinedCpuGpuClusters)1195 TEST(XlaCompilationTest, CreateCombinedCpuGpuClusters) {
1196 Scope root = Scope::NewRootScope().ExitOnError();
1197 Output a = ops::Placeholder(root.WithOpName("test/a"), DT_FLOAT);
1198 Output b = ops::Placeholder(root.WithOpName("test/b"), DT_FLOAT);
1199
1200 Output x = ops::Add(root.WithOpName("test/x"), a, b);
1201 Output y = ops::MatMul(root.WithOpName("test/y"), a, b);
1202 Output z = ops::Add(root.WithOpName("test/z"), x, y);
1203
1204 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1205 TF_ASSERT_OK(root.ToGraph(graph.get()));
1206
1207 FindNodeByName(graph.get(), "test/x")->set_assigned_device_name(kGPU0);
1208 FindNodeByName(graph.get(), "test/y")->set_assigned_device_name(kCPU0);
1209 FindNodeByName(graph.get(), "test/z")->set_assigned_device_name(kGPU0);
1210
1211 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
1212
1213 std::unordered_map<string, string> clusters = GetClusters(*graph);
1214
1215 EXPECT_NE(clusters["test/x"], "");
1216
1217 EXPECT_EQ(clusters["test/x"], clusters["test/y"]);
1218 EXPECT_EQ(clusters["test/y"], clusters["test/z"]);
1219 }
1220
TEST(XlaCompilationTest,DontCreateGpu0AndGpu1Clusters)1221 TEST(XlaCompilationTest, DontCreateGpu0AndGpu1Clusters) {
1222 Scope root = Scope::NewRootScope().ExitOnError();
1223 Output a = ops::Placeholder(root.WithOpName("test/a"), DT_FLOAT);
1224 Output b = ops::Placeholder(root.WithOpName("test/b"), DT_FLOAT);
1225
1226 Output x = ops::Add(root.WithOpName("test/x"), a, b);
1227 Output y = ops::Add(root.WithOpName("test/y"), x, x);
1228
1229 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1230 TF_ASSERT_OK(root.ToGraph(graph.get()));
1231
1232 FindNodeByName(graph.get(), "test/x")->set_assigned_device_name(kGPU0);
1233 FindNodeByName(graph.get(), "test/y")->set_assigned_device_name(kGPU1);
1234
1235 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
1236
1237 std::unordered_map<string, string> clusters = GetClusters(*graph);
1238
1239 EXPECT_EQ(clusters["test/x"], "");
1240 EXPECT_EQ(clusters["test/y"], "");
1241 }
1242
TEST(XlaCompilationTest,DontCreateCombinedCpuUnknownClusters)1243 TEST(XlaCompilationTest, DontCreateCombinedCpuUnknownClusters) {
1244 Scope root = Scope::NewRootScope().ExitOnError();
1245 Output a = ops::Placeholder(root.WithOpName("test/a"), DT_FLOAT);
1246 Output b = ops::Placeholder(root.WithOpName("test/b"), DT_FLOAT);
1247
1248 Output x = ops::Add(root.WithOpName("test/x"), a, b);
1249 Output y = ops::Add(root.WithOpName("test/y"), x, x);
1250
1251 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1252 TF_ASSERT_OK(root.ToGraph(graph.get()));
1253
1254 FindNodeByName(graph.get(), "test/x")->set_assigned_device_name(kCPU0);
1255 FindNodeByName(graph.get(), "test/y")->set_assigned_device_name(kXLA_GPU0);
1256
1257 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
1258
1259 std::unordered_map<string, string> clusters = GetClusters(*graph);
1260
1261 EXPECT_EQ(clusters["test/x"], "");
1262 EXPECT_EQ(clusters["test/y"], "");
1263 }
1264
TEST(XlaCompilationTest,ClusterResourceOpsWhenSafe)1265 TEST(XlaCompilationTest, ClusterResourceOpsWhenSafe) {
1266 Scope root = Scope::NewRootScope().ExitOnError();
1267 Output a = ops::Placeholder(root.WithOpName("test/a"), DT_FLOAT);
1268 Node* var_handle;
1269 Node* resource_read = MakeRead(root, "read", &var_handle);
1270 Output b = ops::Add(root.WithOpName("test/b"), Output(resource_read, 0), a);
1271
1272 string resource_read_name = resource_read->name();
1273 string var_handle_name = var_handle->name();
1274
1275 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1276 TF_ASSERT_OK(root.ToGraph(graph.get()));
1277
1278 FindNodeByName(graph.get(), "test/b")->set_assigned_device_name(kCPU0);
1279 FindNodeByName(graph.get(), resource_read_name)
1280 ->set_assigned_device_name(kGPU0);
1281 FindNodeByName(graph.get(), var_handle_name)->set_assigned_device_name(kGPU0);
1282
1283 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
1284
1285 std::unordered_map<string, string> clusters = GetClusters(*graph);
1286
1287 EXPECT_NE(clusters["test/b"], "");
1288 EXPECT_EQ(clusters["test/b"], clusters[resource_read_name]);
1289 }
1290
TEST(XlaCompilationTest,DontClusterResourceOpsWhenUnsafe)1291 TEST(XlaCompilationTest, DontClusterResourceOpsWhenUnsafe) {
1292 Scope root = Scope::NewRootScope().ExitOnError();
1293 Output a = ops::Placeholder(root.WithOpName("test/a"), DT_FLOAT);
1294 Node* var_handle;
1295 Node* resource_read = MakeRead(root, "read", &var_handle);
1296 Output b = ops::Add(root.WithOpName("test/b"), Output(resource_read, 0), a);
1297
1298 string resource_read_name = resource_read->name();
1299 string var_handle_name = var_handle->name();
1300
1301 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1302 TF_ASSERT_OK(root.ToGraph(graph.get()));
1303
1304 FindNodeByName(graph.get(), "test/b")->set_assigned_device_name(kGPU0);
1305 FindNodeByName(graph.get(), resource_read_name)
1306 ->set_assigned_device_name(kCPU0);
1307 FindNodeByName(graph.get(), var_handle_name)->set_assigned_device_name(kCPU0);
1308
1309 TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
1310
1311 std::unordered_map<string, string> clusters = GetClusters(*graph);
1312
1313 EXPECT_EQ(clusters["test/b"], "");
1314 EXPECT_EQ(clusters[resource_read_name], "");
1315 }
1316
1317 } // namespace
1318 } // namespace tensorflow
1319