1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/jit/partially_decluster_pass.h"
17
18 #include "absl/memory/memory.h"
19 #include "tensorflow/cc/framework/ops.h"
20 #include "tensorflow/cc/ops/array_ops.h"
21 #include "tensorflow/cc/ops/const_op.h"
22 #include "tensorflow/cc/ops/control_flow_ops_internal.h"
23 #include "tensorflow/cc/ops/function_ops.h"
24 #include "tensorflow/cc/ops/sendrecv_ops.h"
25 #include "tensorflow/cc/ops/standard_ops.h"
26 #include "tensorflow/compiler/jit/defs.h"
27 #include "tensorflow/compiler/jit/test_util.h"
28 #include "tensorflow/compiler/jit/xla_cluster_util.h"
29 #include "tensorflow/compiler/tf2xla/cc/ops/xla_ops.h"
30 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
31 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
32 #include "tensorflow/core/common_runtime/graph_constructor.h"
33 #include "tensorflow/core/common_runtime/graph_def_builder_util.h"
34 #include "tensorflow/core/framework/function.h"
35 #include "tensorflow/core/framework/function.pb.h"
36 #include "tensorflow/core/framework/node_def_util.h"
37 #include "tensorflow/core/framework/op.h"
38 #include "tensorflow/core/graph/algorithm.h"
39 #include "tensorflow/core/graph/graph_def_builder.h"
40 #include "tensorflow/core/lib/core/status_test_util.h"
41 #include "tensorflow/core/platform/test.h"
42
43 namespace tensorflow {
44 namespace {
45 REGISTER_OP("FakeNullary").Output("out: int32");
46
47 REGISTER_OP("FakeBinary")
48 .Input("host_in: int32")
49 .Input("device_in: int32")
50 .Output("host_out: int32")
51 .Output("device_out: int32");
52
53 REGISTER_OP("FakeResourceVar").Output("out: resource");
54
55 REGISTER_OP("FakeResourceUpdate")
56 .Input("in: resource")
57 .Output("out: resource")
58 .Output("something_else: int32");
59
60 class FakeBinaryOp : public OpKernel {
61 public:
FakeBinaryOp(OpKernelConstruction * context)62 explicit FakeBinaryOp(OpKernelConstruction* context) : OpKernel(context) {}
63
Compute(OpKernelContext * ctx)64 void Compute(OpKernelContext* ctx) override { CHECK(false); }
65 };
66
67 class FakeResourceUpdateOp : public OpKernel {
68 public:
FakeResourceUpdateOp(OpKernelConstruction * context)69 explicit FakeResourceUpdateOp(OpKernelConstruction* context)
70 : OpKernel(context) {}
71
Compute(OpKernelContext * ctx)72 void Compute(OpKernelContext* ctx) override { CHECK(false); }
73 };
74
75 REGISTER_KERNEL_BUILDER(Name("FakeBinary")
76 .Device(DEVICE_CPU)
77 .HostMemory("host_in")
78 .HostMemory("host_out"),
79 FakeBinaryOp);
80
81 REGISTER_KERNEL_BUILDER(
82 Name("FakeResourceUpdate").Device(DEVICE_CPU).HostMemory("something_else"),
83 FakeResourceUpdateOp);
84
PartiallyDecluster(std::unique_ptr<Graph> * graph)85 Status PartiallyDecluster(std::unique_ptr<Graph>* graph) {
86 FixupSourceAndSinkEdges(graph->get());
87 // Assign all nodes to the CPU device.
88 static const char* kCpuDevice = "/job:localhost/replica:0/task:0/cpu:0";
89 for (Node* n : (*graph)->nodes()) {
90 if (n->assigned_device_name().empty()) {
91 n->set_assigned_device_name(kCpuDevice);
92 }
93 }
94
95 GraphOptimizationPassWrapper wrapper;
96 GraphOptimizationPassOptions opt_options =
97 wrapper.CreateGraphOptimizationPassOptions(graph);
98
99 PartiallyDeclusterPass pass;
100 return pass.Run(opt_options);
101 }
102
FindNodeByName(const Graph & graph,const string & name)103 Node* FindNodeByName(const Graph& graph, const string& name) {
104 for (Node* node : graph.nodes()) {
105 if (node->name() == name) {
106 return node;
107 }
108 }
109 return nullptr;
110 }
111
GetInputsForNode(const Graph & graph,const string & node_name,std::vector<Node * > * inputs)112 bool GetInputsForNode(const Graph& graph, const string& node_name,
113 std::vector<Node*>* inputs) {
114 const Node* node = FindNodeByName(graph, node_name);
115 if (node == nullptr) {
116 return false;
117 }
118 for (const Edge* e : node->in_edges()) {
119 inputs->push_back(e->src());
120 }
121 std::sort(inputs->begin(), inputs->end(), NodeComparatorName());
122 return true;
123 }
124
TEST(PartiallyDeclusterPassTest,ClusteredAndUnclustered)125 TEST(PartiallyDeclusterPassTest, ClusteredAndUnclustered) {
126 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
127 {
128 GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
129 Node* input =
130 ops::SourceOp("FakeNullary", builder.opts().WithName("Input"));
131 Node* clustered_producer =
132 ops::BinaryOp("FakeBinary", input, input,
133 builder.opts().WithName("ClusteredProducer"));
134 ops::BinaryOp("FakeBinary", clustered_producer, input,
135 builder.opts().WithName("UnclusteredConsumer"));
136 Node* clustered_consumer =
137 ops::BinaryOp("FakeBinary", {clustered_producer, 1}, input,
138 builder.opts().WithName("ClusteredConsumer"));
139 clustered_producer->AddAttr(kXlaClusterAttr, "cluster_0");
140 clustered_consumer->AddAttr(kXlaClusterAttr, "cluster_0");
141 TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
142 }
143
144 TF_ASSERT_OK(PartiallyDecluster(&graph));
145 std::vector<Node*> unclustered_consumer_inputs;
146 ASSERT_TRUE(GetInputsForNode(*graph, "UnclusteredConsumer",
147 &unclustered_consumer_inputs));
148 ASSERT_EQ(unclustered_consumer_inputs.size(), 2);
149 EXPECT_EQ(unclustered_consumer_inputs[0]->name(),
150 "ClusteredProducer/declustered");
151 EXPECT_EQ(unclustered_consumer_inputs[1]->name(), "Input");
152
153 std::vector<Node*> clustered_consumer_inputs;
154 ASSERT_TRUE(GetInputsForNode(*graph, "ClusteredConsumer",
155 &clustered_consumer_inputs));
156 ASSERT_EQ(clustered_consumer_inputs.size(), 2);
157 EXPECT_EQ(clustered_consumer_inputs[0]->name(), "ClusteredProducer");
158 EXPECT_EQ(clustered_consumer_inputs[1]->name(), "Input");
159 }
160
TEST(PartiallyDeclusterPassTest,DifferentClusters)161 TEST(PartiallyDeclusterPassTest, DifferentClusters) {
162 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
163 {
164 GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
165 Node* input =
166 ops::SourceOp("FakeNullary", builder.opts().WithName("Input"));
167 Node* clustered_producer =
168 ops::BinaryOp("FakeBinary", input, input,
169 builder.opts().WithName("ClusteredProducer"));
170 Node* consumer_in_different_cluster =
171 ops::BinaryOp("FakeBinary", clustered_producer, input,
172 builder.opts().WithName("ConsumerInDifferentCluster"));
173 Node* clustered_consumer =
174 ops::BinaryOp("FakeBinary", input, {clustered_producer, 1},
175 builder.opts().WithName("ClusteredConsumer"));
176 clustered_producer->AddAttr(kXlaClusterAttr, "cluster_0");
177 clustered_consumer->AddAttr(kXlaClusterAttr, "cluster_0");
178 consumer_in_different_cluster->AddAttr(kXlaClusterAttr, "cluster_1");
179 TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
180 }
181
182 TF_ASSERT_OK(PartiallyDecluster(&graph));
183 std::vector<Node*> inputs;
184 ASSERT_TRUE(GetInputsForNode(*graph, "ConsumerInDifferentCluster", &inputs));
185 ASSERT_EQ(inputs.size(), 2);
186 EXPECT_EQ(inputs[0]->name(), "ClusteredProducer/declustered");
187 EXPECT_EQ(inputs[1]->name(), "Input");
188 }
189
TEST(PartiallyDeclusterPassTest,DontDeclusterIfUserIsDeviceMem)190 TEST(PartiallyDeclusterPassTest, DontDeclusterIfUserIsDeviceMem) {
191 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
192 {
193 GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
194 Node* input =
195 ops::SourceOp("FakeNullary", builder.opts().WithName("Input"));
196 Node* clustered_producer =
197 ops::BinaryOp("FakeBinary", input, input,
198 builder.opts().WithName("ClusteredProducer"));
199 // The first input is hostmem and the second input is devicemem.
200 Node* consumer_in_different_cluster =
201 ops::BinaryOp("FakeBinary", input, clustered_producer,
202 builder.opts().WithName("ConsumerInDifferentCluster"));
203 Node* clustered_consumer =
204 ops::BinaryOp("FakeBinary", input, {clustered_producer, 1},
205 builder.opts().WithName("ClusteredConsumer"));
206 clustered_producer->AddAttr(kXlaClusterAttr, "cluster_0");
207 clustered_consumer->AddAttr(kXlaClusterAttr, "cluster_0");
208 consumer_in_different_cluster->AddAttr(kXlaClusterAttr, "cluster_1");
209 TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
210 }
211
212 TF_ASSERT_OK(PartiallyDecluster(&graph));
213 std::vector<Node*> inputs;
214 ASSERT_TRUE(GetInputsForNode(*graph, "ConsumerInDifferentCluster", &inputs));
215 ASSERT_EQ(inputs.size(), 2);
216 EXPECT_EQ(inputs[0]->name(), "ClusteredProducer");
217 EXPECT_EQ(inputs[1]->name(), "Input");
218 }
219
TEST(PartiallyDeclusterPassTest,DontDuplicateResourceVarOps)220 TEST(PartiallyDeclusterPassTest, DontDuplicateResourceVarOps) {
221 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
222 {
223 GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
224 Node* input =
225 ops::SourceOp("FakeNullary", builder.opts().WithName("Input"));
226 Node* resource_var = ops::SourceOp("FakeResourceVar",
227 builder.opts().WithName("ResourceVar"));
228 Node* clustered_producer =
229 ops::UnaryOp("FakeResourceUpdate", resource_var,
230 builder.opts().WithName("ClusteredProducer"));
231 Node* consumer_in_different_cluster =
232 ops::BinaryOp("FakeBinary", {clustered_producer, 1}, input,
233 builder.opts().WithName("ConsumerInDifferentCluster"));
234 Node* clustered_consumer =
235 ops::BinaryOp("FakeBinary", input, {clustered_producer, 1},
236 builder.opts().WithName("ClusteredConsumer"));
237 clustered_producer->AddAttr(kXlaClusterAttr, "cluster_0");
238 clustered_consumer->AddAttr(kXlaClusterAttr, "cluster_0");
239 consumer_in_different_cluster->AddAttr(kXlaClusterAttr, "cluster_1");
240 TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
241 }
242
243 TF_ASSERT_OK(PartiallyDecluster(&graph));
244 std::vector<Node*> inputs;
245 ASSERT_TRUE(GetInputsForNode(*graph, "ConsumerInDifferentCluster", &inputs));
246 ASSERT_EQ(inputs.size(), 2);
247 EXPECT_EQ(inputs[0]->name(), "ClusteredProducer");
248 EXPECT_EQ(inputs[1]->name(), "Input");
249 }
250
TEST(PartiallyDeclusterPassTest,DeclusterDependentNodes)251 TEST(PartiallyDeclusterPassTest, DeclusterDependentNodes) {
252 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
253 {
254 GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
255 Node* input =
256 ops::SourceOp("FakeNullary", builder.opts().WithName("Input"));
257 Node* clustered_producer_0 =
258 ops::BinaryOp("FakeBinary", input, input,
259 builder.opts().WithName("ClusteredProducer0"));
260 Node* clustered_producer_1 =
261 ops::BinaryOp("FakeBinary", clustered_producer_0, input,
262 builder.opts().WithName("ClusteredProducer1"));
263 ops::BinaryOp("FakeBinary", clustered_producer_1, input,
264 builder.opts().WithName("UnclusteredConsumer"));
265 Node* clustered_consumer =
266 ops::BinaryOp("FakeBinary", {clustered_producer_1, 1}, input,
267 builder.opts().WithName("ClusteredConsumer"));
268 clustered_producer_0->AddAttr(kXlaClusterAttr, "cluster_0");
269 clustered_producer_1->AddAttr(kXlaClusterAttr, "cluster_0");
270 clustered_consumer->AddAttr(kXlaClusterAttr, "cluster_0");
271 TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
272 }
273
274 TF_ASSERT_OK(PartiallyDecluster(&graph));
275 std::vector<Node*> unclustered_consumer_inputs, declustered_producer_1_inputs;
276
277 ASSERT_TRUE(GetInputsForNode(*graph, "UnclusteredConsumer",
278 &unclustered_consumer_inputs));
279 ASSERT_EQ(unclustered_consumer_inputs.size(), 2);
280 EXPECT_EQ(unclustered_consumer_inputs[0]->name(),
281 "ClusteredProducer1/declustered");
282 EXPECT_EQ(unclustered_consumer_inputs[1]->name(), "Input");
283
284 ASSERT_TRUE(GetInputsForNode(*graph, "ClusteredProducer1/declustered",
285 &declustered_producer_1_inputs));
286 ASSERT_EQ(declustered_producer_1_inputs.size(), 2);
287 EXPECT_EQ(declustered_producer_1_inputs[0]->name(),
288 "ClusteredProducer0/declustered");
289 EXPECT_EQ(declustered_producer_1_inputs[1]->name(), "Input");
290 }
291
AddToCluster(absl::Span<Node * const> nodes,absl::string_view cluster_name)292 void AddToCluster(absl::Span<Node* const> nodes,
293 absl::string_view cluster_name) {
294 for (Node* n : nodes) {
295 n->AddAttr(kXlaClusterAttr, string(cluster_name));
296 }
297 }
298
TEST(PartiallyDeclusterPassTest,DeclusterMustBeConstantNodes)299 TEST(PartiallyDeclusterPassTest, DeclusterMustBeConstantNodes) {
300 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
301 Output shape_a = ops::Placeholder(s.WithOpName("shape_a"), DT_INT32,
302 ops::Placeholder::Attrs{});
303 Output shape_b = ops::Placeholder(s.WithOpName("shape_b"), DT_INT32,
304 ops::Placeholder::Attrs{});
305 Output shape = ops::Add(s.WithOpName("shape"), shape_a, shape_b);
306
307 Output reshape_input = ops::Placeholder(s.WithOpName("reshape_input"),
308 DT_FLOAT, ops::Placeholder::Attrs{});
309 Output reshape = ops::Reshape(s.WithOpName("reshape"), reshape_input, shape);
310
311 AddToCluster({shape.node(), reshape.node()}, "cluster_0");
312
313 auto graph = std::make_unique<Graph>(OpRegistry::Global());
314 TF_ASSERT_OK(s.ToGraph(graph.get()));
315 TF_ASSERT_OK(PartiallyDecluster(&graph));
316
317 const Node* n = FindNodeByName(*graph, "shape");
318 ASSERT_NE(n, nullptr);
319
320 EXPECT_EQ(GetXlaClusterForNode(*n), std::nullopt);
321 }
322
TEST(PartiallyDeclusterPassTest,DeclusteringStopsAtMetadataOps)323 TEST(PartiallyDeclusterPassTest, DeclusteringStopsAtMetadataOps) {
324 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
325 Output input_a = ops::Placeholder(s.WithOpName("input_a"), DT_INT32,
326 ops::Placeholder::Attrs{});
327 Output input_b = ops::Placeholder(s.WithOpName("shape_b"), DT_FLOAT,
328 ops::Placeholder::Attrs{});
329 Output mul = ops::Mul(s.WithOpName("mul"), input_b, input_b);
330 Output shape_of_mul = ops::Shape(s.WithOpName("shape_of_mul"), mul);
331
332 Output shape = ops::Add(s.WithOpName("shape"), shape_of_mul, input_a);
333
334 Output reshape_input = ops::Placeholder(s.WithOpName("reshape_input"),
335 DT_FLOAT, ops::Placeholder::Attrs{});
336 Output reshape = ops::Reshape(s.WithOpName("reshape"), reshape_input, shape);
337
338 AddToCluster({mul.node(), shape_of_mul.node(), shape.node(), reshape.node()},
339 "cluster_0");
340
341 auto graph = std::make_unique<Graph>(OpRegistry::Global());
342 TF_ASSERT_OK(s.ToGraph(graph.get()));
343 TF_ASSERT_OK(PartiallyDecluster(&graph));
344
345 const Node* n = FindNodeByName(*graph, "shape");
346 ASSERT_NE(n, nullptr);
347
348 EXPECT_EQ(GetXlaClusterForNode(*n), "cluster_0");
349 }
350
TEST(PartiallyDeclusterPassTest,EdgeAcrossDifferentClusters)351 TEST(PartiallyDeclusterPassTest, EdgeAcrossDifferentClusters) {
352 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
353 Output shape_a = ops::Placeholder(s.WithOpName("shape_a"), DT_INT32,
354 ops::Placeholder::Attrs{});
355 Output shape_b = ops::Placeholder(s.WithOpName("shape_b"), DT_INT32,
356 ops::Placeholder::Attrs{});
357 Output shape = ops::Add(s.WithOpName("shape"), shape_a, shape_b);
358
359 Output reshape_input = ops::Placeholder(s.WithOpName("reshape_input"),
360 DT_FLOAT, ops::Placeholder::Attrs{});
361 Output reshape = ops::Reshape(s.WithOpName("reshape"), reshape_input, shape);
362
363 AddToCluster({reshape.node()}, "cluster_0");
364 AddToCluster({shape.node()}, "cluster_1");
365
366 auto graph = std::make_unique<Graph>(OpRegistry::Global());
367 TF_ASSERT_OK(s.ToGraph(graph.get()));
368 TF_ASSERT_OK(PartiallyDecluster(&graph));
369
370 const Node* n = FindNodeByName(*graph, "shape");
371 ASSERT_NE(n, nullptr);
372
373 EXPECT_EQ(GetXlaClusterForNode(*n), "cluster_1");
374 }
375
TEST(PartiallyDeclusterPassTest,DontDeclusterXlaDeviceOps)376 TEST(PartiallyDeclusterPassTest, DontDeclusterXlaDeviceOps) {
377 tensorflow::Scope s = tensorflow::Scope::NewRootScope();
378 Output shape_a = ops::Placeholder(s.WithOpName("shape_a"), DT_INT32,
379 ops::Placeholder::Attrs{});
380 Output shape_b = ops::Placeholder(s.WithOpName("shape_b"), DT_INT32,
381 ops::Placeholder::Attrs{});
382 Output shape = ops::Add(s.WithOpName("shape"), shape_a, shape_b);
383
384 Output reshape_input = ops::Placeholder(s.WithOpName("reshape_input"),
385 DT_FLOAT, ops::Placeholder::Attrs{});
386 Output reshape = ops::Reshape(s.WithOpName("reshape"), reshape_input, shape);
387
388 AddToCluster({shape.node(), reshape.node()}, "cluster_0");
389
390 auto graph = std::make_unique<Graph>(OpRegistry::Global());
391 TF_ASSERT_OK(s.ToGraph(graph.get()));
392
393 // This is needed to register the XLA_GPU device.
394 std::vector<std::unique_ptr<Device>> devices;
395 TF_ASSERT_OK(DeviceFactory::AddDevices(
396 SessionOptions(), "/job:localhost/replica:0/task:0", &devices));
397
398 // Scope::ToGraph loses the assigned device name since it goes through
399 // GraphDef/NodeDef which does not have a field for the assigned device name.
400 Node* n = FindNodeByName(*graph, "shape");
401 ASSERT_NE(n, nullptr);
402 n->set_assigned_device_name(
403 "/job:localhost/replica:0/task:0/device:XLA_GPU:0");
404
405 TF_ASSERT_OK(PartiallyDecluster(&graph));
406
407 EXPECT_EQ(GetXlaClusterForNode(*n), "cluster_0");
408 }
409
TEST(PartiallyDeclusterPassTest,EliminatedUnusedNodes)410 TEST(PartiallyDeclusterPassTest, EliminatedUnusedNodes) {
411 const char* const kClusteredProducer0Name = "ClusteredProducer0";
412 const char* const kClusteredProducer1Name = "ClusteredProducer1";
413
414 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
415 {
416 GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
417 Node* input =
418 ops::SourceOp("FakeNullary", builder.opts().WithName("Input"));
419 Node* clustered_producer_0 =
420 ops::BinaryOp("FakeBinary", input, input,
421 builder.opts().WithName(kClusteredProducer0Name));
422 Node* clustered_producer_1 =
423 ops::BinaryOp("FakeBinary", clustered_producer_0, input,
424 builder.opts().WithName(kClusteredProducer1Name));
425 ops::BinaryOp("FakeBinary", clustered_producer_1, input,
426 builder.opts().WithName("UnclusteredConsumer"));
427 clustered_producer_0->AddAttr(kXlaClusterAttr, "cluster_0");
428 clustered_producer_1->AddAttr(kXlaClusterAttr, "cluster_0");
429 TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
430 }
431
432 TF_ASSERT_OK(PartiallyDecluster(&graph));
433 EXPECT_EQ(FindNodeByName(*graph, kClusteredProducer0Name), nullptr);
434 EXPECT_EQ(FindNodeByName(*graph, kClusteredProducer1Name), nullptr);
435 }
436
TEST(PartiallyDeclusterPassTest,MetadataOpsDontStartClusters)437 TEST(PartiallyDeclusterPassTest, MetadataOpsDontStartClusters) {
438 tensorflow::Scope root = tensorflow::Scope::NewRootScope();
439 tensorflow::Scope in_cluster_and = root.WithXlaCluster("cluster_0");
440
441 Output a = ops::Placeholder(root.WithOpName("a"), DT_FLOAT);
442 Output b = ops::Shape(in_cluster_and.WithOpName("b"), a);
443 Output c = ops::Rank(in_cluster_and.WithOpName("c"), b);
444 Output d = ops::Size(in_cluster_and.WithOpName("d"), c);
445 (void)ops::Shape(in_cluster_and.WithOpName("e"), d);
446
447 auto graph = std::make_unique<Graph>(OpRegistry::Global());
448 TF_ASSERT_OK(root.ToGraph(graph.get()));
449
450 TF_ASSERT_OK(PartiallyDecluster(&graph));
451
452 Node* n_b = FindNodeByName(*graph, "b");
453 ASSERT_NE(n_b, nullptr);
454 EXPECT_EQ(GetXlaClusterForNode(*n_b), std::nullopt);
455
456 Node* n_c = FindNodeByName(*graph, "c");
457 ASSERT_NE(n_c, nullptr);
458 EXPECT_EQ(GetXlaClusterForNode(*n_c), std::nullopt);
459
460 Node* n_d = FindNodeByName(*graph, "d");
461 ASSERT_NE(n_d, nullptr);
462 EXPECT_EQ(GetXlaClusterForNode(*n_d), std::nullopt);
463
464 Node* n_e = FindNodeByName(*graph, "e");
465 ASSERT_NE(n_e, nullptr);
466 EXPECT_EQ(GetXlaClusterForNode(*n_e), std::nullopt);
467 }
468
TEST(PartiallyDeclusterPassTest,MetaConsumersArentDeclustered)469 TEST(PartiallyDeclusterPassTest, MetaConsumersArentDeclustered) {
470 tensorflow::Scope root = tensorflow::Scope::NewRootScope();
471 tensorflow::Scope in_cluster_and = root.WithXlaCluster("cluster_0");
472 auto graph = std::make_unique<Graph>(OpRegistry::Global());
473 Output a = ops::Placeholder(root.WithOpName("a"), DT_FLOAT);
474 Output b = ops::Add(in_cluster_and.WithOpName("b"), a, a);
475 Output c = ops::Rank(in_cluster_and.WithOpName("c"), b);
476
477 Output e;
478 TF_ASSERT_OK(
479 CreateOutputWithScope("FakeBinary", {c, c}, root.WithOpName("e"), &e));
480
481 TF_ASSERT_OK(root.ToGraph(graph.get()));
482 TF_ASSERT_OK(PartiallyDecluster(&graph));
483
484 Node* n_b = FindNodeByName(*graph, "b");
485 ASSERT_NE(n_b, nullptr);
486 EXPECT_EQ(GetXlaClusterForNode(*n_b), "cluster_0");
487
488 Node* n_c = FindNodeByName(*graph, "c");
489 ASSERT_NE(n_c, nullptr);
490 EXPECT_EQ(GetXlaClusterForNode(*n_c), "cluster_0");
491 }
492
TEST(PartiallyDeclusterPassTest,ConstInputsToSliceArentDeclustered)493 TEST(PartiallyDeclusterPassTest, ConstInputsToSliceArentDeclustered) {
494 tensorflow::Scope root = tensorflow::Scope::NewRootScope();
495 tensorflow::Scope in_cluster_and = root.WithXlaCluster("cluster_0");
496 auto graph = std::make_unique<Graph>(OpRegistry::Global());
497 Output a = ops::Placeholder(root.WithOpName("a"), DT_FLOAT,
498 ops::Placeholder::Attrs{{4}});
499 Output b = ops::Const(in_cluster_and.WithOpName("b"), {1});
500 Output c = ops::Const(in_cluster_and.WithOpName("c"), {2});
501 Output d = ops::Slice(in_cluster_and.WithOpName("d"), a, b, c);
502
503 TF_ASSERT_OK(root.ToGraph(graph.get()));
504 TF_ASSERT_OK(PartiallyDecluster(&graph));
505
506 Node* n_b = FindNodeByName(*graph, "b");
507 ASSERT_NE(n_b, nullptr);
508 EXPECT_EQ(GetXlaClusterForNode(*n_b), "cluster_0");
509
510 Node* n_c = FindNodeByName(*graph, "c");
511 ASSERT_NE(n_c, nullptr);
512 EXPECT_EQ(GetXlaClusterForNode(*n_c), "cluster_0");
513 }
514
TEST(PartiallyDeclusterPassTest,ConstInLoopWithCrossDeviceControlInputsAreDeclustered)515 TEST(PartiallyDeclusterPassTest,
516 ConstInLoopWithCrossDeviceControlInputsAreDeclustered) {
517 // Based on DontClusterTheSpecialIdentityDrivingConstsInLoop in
518 // mark_for_compilation_pass_test.cc
519 tensorflow::Scope root = tensorflow::Scope::NewRootScope();
520 tensorflow::Scope in_cluster_and = root.WithXlaCluster("cluster_0");
521 auto graph = std::make_unique<Graph>(OpRegistry::Global());
522 Output a = ops::Placeholder(root.WithOpName("a"), DT_FLOAT,
523 ops::Placeholder::Attrs{{4}});
524 Output b = ops::Const(in_cluster_and.WithOpName("b"), {1});
525 Output c = ops::Const(in_cluster_and.WithOpName("c"), {2});
526 Output slice = ops::Slice(in_cluster_and.WithOpName("slice"), a, b, c);
527 Output cond = ops::Placeholder(root.WithOpName("cond"), DT_BOOL);
528 Output value = ops::Placeholder(root.WithOpName("value"), DT_FLOAT);
529 Output loop_cond = ops::LoopCond(root.WithOpName("loop_cond"), cond);
530 ops::Switch switch_node(root.WithOpName("switch"), value, loop_cond);
531 Output identity =
532 ops::Identity(root.WithOpName("identity"), switch_node.output_true);
533 root.graph()->AddControlEdge(identity.node(), b.node());
534
535 TF_ASSERT_OK(root.ToGraph(graph.get()));
536
537 // This is needed to register the XLA_GPU device.
538 std::vector<std::unique_ptr<Device>> devices;
539 TF_ASSERT_OK(DeviceFactory::AddDevices(
540 SessionOptions(), "/job:localhost/replica:0/task:0", &devices));
541
542 // Scope::ToGraph loses the assigned device name since it goes through
543 // GraphDef/NodeDef which does not have a field for the assigned device name.
544 Node* identity_node = FindNodeByName(*graph, "identity");
545 ASSERT_NE(identity_node, nullptr);
546 identity_node->set_assigned_device_name(
547 "/job:localhost/replica:0/task:0/device:XLA_GPU:0");
548
549 TF_ASSERT_OK(PartiallyDecluster(&graph));
550
551 Node* n_b = FindNodeByName(*graph, "b");
552 ASSERT_NE(n_b, nullptr);
553 EXPECT_EQ(GetXlaClusterForNode(*n_b), std::nullopt);
554
555 Node* n_c = FindNodeByName(*graph, "c");
556 ASSERT_NE(n_c, nullptr);
557 EXPECT_EQ(GetXlaClusterForNode(*n_c), "cluster_0");
558 }
559
560 } // namespace
561 } // namespace tensorflow
562