• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/jit/partially_decluster_pass.h"
17 
18 #include "absl/memory/memory.h"
19 #include "tensorflow/cc/framework/ops.h"
20 #include "tensorflow/cc/ops/array_ops.h"
21 #include "tensorflow/cc/ops/const_op.h"
22 #include "tensorflow/cc/ops/control_flow_ops_internal.h"
23 #include "tensorflow/cc/ops/function_ops.h"
24 #include "tensorflow/cc/ops/sendrecv_ops.h"
25 #include "tensorflow/cc/ops/standard_ops.h"
26 #include "tensorflow/compiler/jit/defs.h"
27 #include "tensorflow/compiler/jit/test_util.h"
28 #include "tensorflow/compiler/jit/xla_cluster_util.h"
29 #include "tensorflow/compiler/tf2xla/cc/ops/xla_ops.h"
30 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
31 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
32 #include "tensorflow/core/common_runtime/graph_constructor.h"
33 #include "tensorflow/core/common_runtime/graph_def_builder_util.h"
34 #include "tensorflow/core/framework/function.h"
35 #include "tensorflow/core/framework/function.pb.h"
36 #include "tensorflow/core/framework/node_def_util.h"
37 #include "tensorflow/core/framework/op.h"
38 #include "tensorflow/core/graph/algorithm.h"
39 #include "tensorflow/core/graph/graph_def_builder.h"
40 #include "tensorflow/core/lib/core/status_test_util.h"
41 #include "tensorflow/core/platform/test.h"
42 
43 namespace tensorflow {
44 namespace {
45 REGISTER_OP("FakeNullary").Output("out: int32");
46 
47 REGISTER_OP("FakeBinary")
48     .Input("host_in: int32")
49     .Input("device_in: int32")
50     .Output("host_out: int32")
51     .Output("device_out: int32");
52 
53 REGISTER_OP("FakeResourceVar").Output("out: resource");
54 
55 REGISTER_OP("FakeResourceUpdate")
56     .Input("in: resource")
57     .Output("out: resource")
58     .Output("something_else: int32");
59 
60 class FakeBinaryOp : public OpKernel {
61  public:
FakeBinaryOp(OpKernelConstruction * context)62   explicit FakeBinaryOp(OpKernelConstruction* context) : OpKernel(context) {}
63 
Compute(OpKernelContext * ctx)64   void Compute(OpKernelContext* ctx) override { CHECK(false); }
65 };
66 
67 class FakeResourceUpdateOp : public OpKernel {
68  public:
FakeResourceUpdateOp(OpKernelConstruction * context)69   explicit FakeResourceUpdateOp(OpKernelConstruction* context)
70       : OpKernel(context) {}
71 
Compute(OpKernelContext * ctx)72   void Compute(OpKernelContext* ctx) override { CHECK(false); }
73 };
74 
75 REGISTER_KERNEL_BUILDER(Name("FakeBinary")
76                             .Device(DEVICE_CPU)
77                             .HostMemory("host_in")
78                             .HostMemory("host_out"),
79                         FakeBinaryOp);
80 
81 REGISTER_KERNEL_BUILDER(
82     Name("FakeResourceUpdate").Device(DEVICE_CPU).HostMemory("something_else"),
83     FakeResourceUpdateOp);
84 
PartiallyDecluster(std::unique_ptr<Graph> * graph)85 Status PartiallyDecluster(std::unique_ptr<Graph>* graph) {
86   FixupSourceAndSinkEdges(graph->get());
87   // Assign all nodes to the CPU device.
88   static const char* kCpuDevice = "/job:localhost/replica:0/task:0/cpu:0";
89   for (Node* n : (*graph)->nodes()) {
90     if (n->assigned_device_name().empty()) {
91       n->set_assigned_device_name(kCpuDevice);
92     }
93   }
94 
95   GraphOptimizationPassWrapper wrapper;
96   GraphOptimizationPassOptions opt_options =
97       wrapper.CreateGraphOptimizationPassOptions(graph);
98 
99   PartiallyDeclusterPass pass;
100   return pass.Run(opt_options);
101 }
102 
FindNodeByName(const Graph & graph,const string & name)103 Node* FindNodeByName(const Graph& graph, const string& name) {
104   for (Node* node : graph.nodes()) {
105     if (node->name() == name) {
106       return node;
107     }
108   }
109   return nullptr;
110 }
111 
GetInputsForNode(const Graph & graph,const string & node_name,std::vector<Node * > * inputs)112 bool GetInputsForNode(const Graph& graph, const string& node_name,
113                       std::vector<Node*>* inputs) {
114   const Node* node = FindNodeByName(graph, node_name);
115   if (node == nullptr) {
116     return false;
117   }
118   for (const Edge* e : node->in_edges()) {
119     inputs->push_back(e->src());
120   }
121   std::sort(inputs->begin(), inputs->end(), NodeComparatorName());
122   return true;
123 }
124 
TEST(PartiallyDeclusterPassTest,ClusteredAndUnclustered)125 TEST(PartiallyDeclusterPassTest, ClusteredAndUnclustered) {
126   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
127   {
128     GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
129     Node* input =
130         ops::SourceOp("FakeNullary", builder.opts().WithName("Input"));
131     Node* clustered_producer =
132         ops::BinaryOp("FakeBinary", input, input,
133                       builder.opts().WithName("ClusteredProducer"));
134     ops::BinaryOp("FakeBinary", clustered_producer, input,
135                   builder.opts().WithName("UnclusteredConsumer"));
136     Node* clustered_consumer =
137         ops::BinaryOp("FakeBinary", {clustered_producer, 1}, input,
138                       builder.opts().WithName("ClusteredConsumer"));
139     clustered_producer->AddAttr(kXlaClusterAttr, "cluster_0");
140     clustered_consumer->AddAttr(kXlaClusterAttr, "cluster_0");
141     TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
142   }
143 
144   TF_ASSERT_OK(PartiallyDecluster(&graph));
145   std::vector<Node*> unclustered_consumer_inputs;
146   ASSERT_TRUE(GetInputsForNode(*graph, "UnclusteredConsumer",
147                                &unclustered_consumer_inputs));
148   ASSERT_EQ(unclustered_consumer_inputs.size(), 2);
149   EXPECT_EQ(unclustered_consumer_inputs[0]->name(),
150             "ClusteredProducer/declustered");
151   EXPECT_EQ(unclustered_consumer_inputs[1]->name(), "Input");
152 
153   std::vector<Node*> clustered_consumer_inputs;
154   ASSERT_TRUE(GetInputsForNode(*graph, "ClusteredConsumer",
155                                &clustered_consumer_inputs));
156   ASSERT_EQ(clustered_consumer_inputs.size(), 2);
157   EXPECT_EQ(clustered_consumer_inputs[0]->name(), "ClusteredProducer");
158   EXPECT_EQ(clustered_consumer_inputs[1]->name(), "Input");
159 }
160 
TEST(PartiallyDeclusterPassTest,DifferentClusters)161 TEST(PartiallyDeclusterPassTest, DifferentClusters) {
162   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
163   {
164     GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
165     Node* input =
166         ops::SourceOp("FakeNullary", builder.opts().WithName("Input"));
167     Node* clustered_producer =
168         ops::BinaryOp("FakeBinary", input, input,
169                       builder.opts().WithName("ClusteredProducer"));
170     Node* consumer_in_different_cluster =
171         ops::BinaryOp("FakeBinary", clustered_producer, input,
172                       builder.opts().WithName("ConsumerInDifferentCluster"));
173     Node* clustered_consumer =
174         ops::BinaryOp("FakeBinary", input, {clustered_producer, 1},
175                       builder.opts().WithName("ClusteredConsumer"));
176     clustered_producer->AddAttr(kXlaClusterAttr, "cluster_0");
177     clustered_consumer->AddAttr(kXlaClusterAttr, "cluster_0");
178     consumer_in_different_cluster->AddAttr(kXlaClusterAttr, "cluster_1");
179     TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
180   }
181 
182   TF_ASSERT_OK(PartiallyDecluster(&graph));
183   std::vector<Node*> inputs;
184   ASSERT_TRUE(GetInputsForNode(*graph, "ConsumerInDifferentCluster", &inputs));
185   ASSERT_EQ(inputs.size(), 2);
186   EXPECT_EQ(inputs[0]->name(), "ClusteredProducer/declustered");
187   EXPECT_EQ(inputs[1]->name(), "Input");
188 }
189 
TEST(PartiallyDeclusterPassTest,DontDeclusterIfUserIsDeviceMem)190 TEST(PartiallyDeclusterPassTest, DontDeclusterIfUserIsDeviceMem) {
191   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
192   {
193     GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
194     Node* input =
195         ops::SourceOp("FakeNullary", builder.opts().WithName("Input"));
196     Node* clustered_producer =
197         ops::BinaryOp("FakeBinary", input, input,
198                       builder.opts().WithName("ClusteredProducer"));
199     // The first input is hostmem and the second input is devicemem.
200     Node* consumer_in_different_cluster =
201         ops::BinaryOp("FakeBinary", input, clustered_producer,
202                       builder.opts().WithName("ConsumerInDifferentCluster"));
203     Node* clustered_consumer =
204         ops::BinaryOp("FakeBinary", input, {clustered_producer, 1},
205                       builder.opts().WithName("ClusteredConsumer"));
206     clustered_producer->AddAttr(kXlaClusterAttr, "cluster_0");
207     clustered_consumer->AddAttr(kXlaClusterAttr, "cluster_0");
208     consumer_in_different_cluster->AddAttr(kXlaClusterAttr, "cluster_1");
209     TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
210   }
211 
212   TF_ASSERT_OK(PartiallyDecluster(&graph));
213   std::vector<Node*> inputs;
214   ASSERT_TRUE(GetInputsForNode(*graph, "ConsumerInDifferentCluster", &inputs));
215   ASSERT_EQ(inputs.size(), 2);
216   EXPECT_EQ(inputs[0]->name(), "ClusteredProducer");
217   EXPECT_EQ(inputs[1]->name(), "Input");
218 }
219 
TEST(PartiallyDeclusterPassTest,DontDuplicateResourceVarOps)220 TEST(PartiallyDeclusterPassTest, DontDuplicateResourceVarOps) {
221   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
222   {
223     GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
224     Node* input =
225         ops::SourceOp("FakeNullary", builder.opts().WithName("Input"));
226     Node* resource_var = ops::SourceOp("FakeResourceVar",
227                                        builder.opts().WithName("ResourceVar"));
228     Node* clustered_producer =
229         ops::UnaryOp("FakeResourceUpdate", resource_var,
230                      builder.opts().WithName("ClusteredProducer"));
231     Node* consumer_in_different_cluster =
232         ops::BinaryOp("FakeBinary", {clustered_producer, 1}, input,
233                       builder.opts().WithName("ConsumerInDifferentCluster"));
234     Node* clustered_consumer =
235         ops::BinaryOp("FakeBinary", input, {clustered_producer, 1},
236                       builder.opts().WithName("ClusteredConsumer"));
237     clustered_producer->AddAttr(kXlaClusterAttr, "cluster_0");
238     clustered_consumer->AddAttr(kXlaClusterAttr, "cluster_0");
239     consumer_in_different_cluster->AddAttr(kXlaClusterAttr, "cluster_1");
240     TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
241   }
242 
243   TF_ASSERT_OK(PartiallyDecluster(&graph));
244   std::vector<Node*> inputs;
245   ASSERT_TRUE(GetInputsForNode(*graph, "ConsumerInDifferentCluster", &inputs));
246   ASSERT_EQ(inputs.size(), 2);
247   EXPECT_EQ(inputs[0]->name(), "ClusteredProducer");
248   EXPECT_EQ(inputs[1]->name(), "Input");
249 }
250 
TEST(PartiallyDeclusterPassTest,DeclusterDependentNodes)251 TEST(PartiallyDeclusterPassTest, DeclusterDependentNodes) {
252   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
253   {
254     GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
255     Node* input =
256         ops::SourceOp("FakeNullary", builder.opts().WithName("Input"));
257     Node* clustered_producer_0 =
258         ops::BinaryOp("FakeBinary", input, input,
259                       builder.opts().WithName("ClusteredProducer0"));
260     Node* clustered_producer_1 =
261         ops::BinaryOp("FakeBinary", clustered_producer_0, input,
262                       builder.opts().WithName("ClusteredProducer1"));
263     ops::BinaryOp("FakeBinary", clustered_producer_1, input,
264                   builder.opts().WithName("UnclusteredConsumer"));
265     Node* clustered_consumer =
266         ops::BinaryOp("FakeBinary", {clustered_producer_1, 1}, input,
267                       builder.opts().WithName("ClusteredConsumer"));
268     clustered_producer_0->AddAttr(kXlaClusterAttr, "cluster_0");
269     clustered_producer_1->AddAttr(kXlaClusterAttr, "cluster_0");
270     clustered_consumer->AddAttr(kXlaClusterAttr, "cluster_0");
271     TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
272   }
273 
274   TF_ASSERT_OK(PartiallyDecluster(&graph));
275   std::vector<Node*> unclustered_consumer_inputs, declustered_producer_1_inputs;
276 
277   ASSERT_TRUE(GetInputsForNode(*graph, "UnclusteredConsumer",
278                                &unclustered_consumer_inputs));
279   ASSERT_EQ(unclustered_consumer_inputs.size(), 2);
280   EXPECT_EQ(unclustered_consumer_inputs[0]->name(),
281             "ClusteredProducer1/declustered");
282   EXPECT_EQ(unclustered_consumer_inputs[1]->name(), "Input");
283 
284   ASSERT_TRUE(GetInputsForNode(*graph, "ClusteredProducer1/declustered",
285                                &declustered_producer_1_inputs));
286   ASSERT_EQ(declustered_producer_1_inputs.size(), 2);
287   EXPECT_EQ(declustered_producer_1_inputs[0]->name(),
288             "ClusteredProducer0/declustered");
289   EXPECT_EQ(declustered_producer_1_inputs[1]->name(), "Input");
290 }
291 
AddToCluster(absl::Span<Node * const> nodes,absl::string_view cluster_name)292 void AddToCluster(absl::Span<Node* const> nodes,
293                   absl::string_view cluster_name) {
294   for (Node* n : nodes) {
295     n->AddAttr(kXlaClusterAttr, string(cluster_name));
296   }
297 }
298 
TEST(PartiallyDeclusterPassTest,DeclusterMustBeConstantNodes)299 TEST(PartiallyDeclusterPassTest, DeclusterMustBeConstantNodes) {
300   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
301   Output shape_a = ops::Placeholder(s.WithOpName("shape_a"), DT_INT32,
302                                     ops::Placeholder::Attrs{});
303   Output shape_b = ops::Placeholder(s.WithOpName("shape_b"), DT_INT32,
304                                     ops::Placeholder::Attrs{});
305   Output shape = ops::Add(s.WithOpName("shape"), shape_a, shape_b);
306 
307   Output reshape_input = ops::Placeholder(s.WithOpName("reshape_input"),
308                                           DT_FLOAT, ops::Placeholder::Attrs{});
309   Output reshape = ops::Reshape(s.WithOpName("reshape"), reshape_input, shape);
310 
311   AddToCluster({shape.node(), reshape.node()}, "cluster_0");
312 
313   auto graph = std::make_unique<Graph>(OpRegistry::Global());
314   TF_ASSERT_OK(s.ToGraph(graph.get()));
315   TF_ASSERT_OK(PartiallyDecluster(&graph));
316 
317   const Node* n = FindNodeByName(*graph, "shape");
318   ASSERT_NE(n, nullptr);
319 
320   EXPECT_EQ(GetXlaClusterForNode(*n), std::nullopt);
321 }
322 
TEST(PartiallyDeclusterPassTest,DeclusteringStopsAtMetadataOps)323 TEST(PartiallyDeclusterPassTest, DeclusteringStopsAtMetadataOps) {
324   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
325   Output input_a = ops::Placeholder(s.WithOpName("input_a"), DT_INT32,
326                                     ops::Placeholder::Attrs{});
327   Output input_b = ops::Placeholder(s.WithOpName("shape_b"), DT_FLOAT,
328                                     ops::Placeholder::Attrs{});
329   Output mul = ops::Mul(s.WithOpName("mul"), input_b, input_b);
330   Output shape_of_mul = ops::Shape(s.WithOpName("shape_of_mul"), mul);
331 
332   Output shape = ops::Add(s.WithOpName("shape"), shape_of_mul, input_a);
333 
334   Output reshape_input = ops::Placeholder(s.WithOpName("reshape_input"),
335                                           DT_FLOAT, ops::Placeholder::Attrs{});
336   Output reshape = ops::Reshape(s.WithOpName("reshape"), reshape_input, shape);
337 
338   AddToCluster({mul.node(), shape_of_mul.node(), shape.node(), reshape.node()},
339                "cluster_0");
340 
341   auto graph = std::make_unique<Graph>(OpRegistry::Global());
342   TF_ASSERT_OK(s.ToGraph(graph.get()));
343   TF_ASSERT_OK(PartiallyDecluster(&graph));
344 
345   const Node* n = FindNodeByName(*graph, "shape");
346   ASSERT_NE(n, nullptr);
347 
348   EXPECT_EQ(GetXlaClusterForNode(*n), "cluster_0");
349 }
350 
TEST(PartiallyDeclusterPassTest,EdgeAcrossDifferentClusters)351 TEST(PartiallyDeclusterPassTest, EdgeAcrossDifferentClusters) {
352   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
353   Output shape_a = ops::Placeholder(s.WithOpName("shape_a"), DT_INT32,
354                                     ops::Placeholder::Attrs{});
355   Output shape_b = ops::Placeholder(s.WithOpName("shape_b"), DT_INT32,
356                                     ops::Placeholder::Attrs{});
357   Output shape = ops::Add(s.WithOpName("shape"), shape_a, shape_b);
358 
359   Output reshape_input = ops::Placeholder(s.WithOpName("reshape_input"),
360                                           DT_FLOAT, ops::Placeholder::Attrs{});
361   Output reshape = ops::Reshape(s.WithOpName("reshape"), reshape_input, shape);
362 
363   AddToCluster({reshape.node()}, "cluster_0");
364   AddToCluster({shape.node()}, "cluster_1");
365 
366   auto graph = std::make_unique<Graph>(OpRegistry::Global());
367   TF_ASSERT_OK(s.ToGraph(graph.get()));
368   TF_ASSERT_OK(PartiallyDecluster(&graph));
369 
370   const Node* n = FindNodeByName(*graph, "shape");
371   ASSERT_NE(n, nullptr);
372 
373   EXPECT_EQ(GetXlaClusterForNode(*n), "cluster_1");
374 }
375 
TEST(PartiallyDeclusterPassTest,DontDeclusterXlaDeviceOps)376 TEST(PartiallyDeclusterPassTest, DontDeclusterXlaDeviceOps) {
377   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
378   Output shape_a = ops::Placeholder(s.WithOpName("shape_a"), DT_INT32,
379                                     ops::Placeholder::Attrs{});
380   Output shape_b = ops::Placeholder(s.WithOpName("shape_b"), DT_INT32,
381                                     ops::Placeholder::Attrs{});
382   Output shape = ops::Add(s.WithOpName("shape"), shape_a, shape_b);
383 
384   Output reshape_input = ops::Placeholder(s.WithOpName("reshape_input"),
385                                           DT_FLOAT, ops::Placeholder::Attrs{});
386   Output reshape = ops::Reshape(s.WithOpName("reshape"), reshape_input, shape);
387 
388   AddToCluster({shape.node(), reshape.node()}, "cluster_0");
389 
390   auto graph = std::make_unique<Graph>(OpRegistry::Global());
391   TF_ASSERT_OK(s.ToGraph(graph.get()));
392 
393   // This is needed to register the XLA_GPU device.
394   std::vector<std::unique_ptr<Device>> devices;
395   TF_ASSERT_OK(DeviceFactory::AddDevices(
396       SessionOptions(), "/job:localhost/replica:0/task:0", &devices));
397 
398   // Scope::ToGraph loses the assigned device name since it goes through
399   // GraphDef/NodeDef which does not have a field for the assigned device name.
400   Node* n = FindNodeByName(*graph, "shape");
401   ASSERT_NE(n, nullptr);
402   n->set_assigned_device_name(
403       "/job:localhost/replica:0/task:0/device:XLA_GPU:0");
404 
405   TF_ASSERT_OK(PartiallyDecluster(&graph));
406 
407   EXPECT_EQ(GetXlaClusterForNode(*n), "cluster_0");
408 }
409 
TEST(PartiallyDeclusterPassTest,EliminatedUnusedNodes)410 TEST(PartiallyDeclusterPassTest, EliminatedUnusedNodes) {
411   const char* const kClusteredProducer0Name = "ClusteredProducer0";
412   const char* const kClusteredProducer1Name = "ClusteredProducer1";
413 
414   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
415   {
416     GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
417     Node* input =
418         ops::SourceOp("FakeNullary", builder.opts().WithName("Input"));
419     Node* clustered_producer_0 =
420         ops::BinaryOp("FakeBinary", input, input,
421                       builder.opts().WithName(kClusteredProducer0Name));
422     Node* clustered_producer_1 =
423         ops::BinaryOp("FakeBinary", clustered_producer_0, input,
424                       builder.opts().WithName(kClusteredProducer1Name));
425     ops::BinaryOp("FakeBinary", clustered_producer_1, input,
426                   builder.opts().WithName("UnclusteredConsumer"));
427     clustered_producer_0->AddAttr(kXlaClusterAttr, "cluster_0");
428     clustered_producer_1->AddAttr(kXlaClusterAttr, "cluster_0");
429     TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
430   }
431 
432   TF_ASSERT_OK(PartiallyDecluster(&graph));
433   EXPECT_EQ(FindNodeByName(*graph, kClusteredProducer0Name), nullptr);
434   EXPECT_EQ(FindNodeByName(*graph, kClusteredProducer1Name), nullptr);
435 }
436 
TEST(PartiallyDeclusterPassTest,MetadataOpsDontStartClusters)437 TEST(PartiallyDeclusterPassTest, MetadataOpsDontStartClusters) {
438   tensorflow::Scope root = tensorflow::Scope::NewRootScope();
439   tensorflow::Scope in_cluster_and = root.WithXlaCluster("cluster_0");
440 
441   Output a = ops::Placeholder(root.WithOpName("a"), DT_FLOAT);
442   Output b = ops::Shape(in_cluster_and.WithOpName("b"), a);
443   Output c = ops::Rank(in_cluster_and.WithOpName("c"), b);
444   Output d = ops::Size(in_cluster_and.WithOpName("d"), c);
445   (void)ops::Shape(in_cluster_and.WithOpName("e"), d);
446 
447   auto graph = std::make_unique<Graph>(OpRegistry::Global());
448   TF_ASSERT_OK(root.ToGraph(graph.get()));
449 
450   TF_ASSERT_OK(PartiallyDecluster(&graph));
451 
452   Node* n_b = FindNodeByName(*graph, "b");
453   ASSERT_NE(n_b, nullptr);
454   EXPECT_EQ(GetXlaClusterForNode(*n_b), std::nullopt);
455 
456   Node* n_c = FindNodeByName(*graph, "c");
457   ASSERT_NE(n_c, nullptr);
458   EXPECT_EQ(GetXlaClusterForNode(*n_c), std::nullopt);
459 
460   Node* n_d = FindNodeByName(*graph, "d");
461   ASSERT_NE(n_d, nullptr);
462   EXPECT_EQ(GetXlaClusterForNode(*n_d), std::nullopt);
463 
464   Node* n_e = FindNodeByName(*graph, "e");
465   ASSERT_NE(n_e, nullptr);
466   EXPECT_EQ(GetXlaClusterForNode(*n_e), std::nullopt);
467 }
468 
TEST(PartiallyDeclusterPassTest,MetaConsumersArentDeclustered)469 TEST(PartiallyDeclusterPassTest, MetaConsumersArentDeclustered) {
470   tensorflow::Scope root = tensorflow::Scope::NewRootScope();
471   tensorflow::Scope in_cluster_and = root.WithXlaCluster("cluster_0");
472   auto graph = std::make_unique<Graph>(OpRegistry::Global());
473   Output a = ops::Placeholder(root.WithOpName("a"), DT_FLOAT);
474   Output b = ops::Add(in_cluster_and.WithOpName("b"), a, a);
475   Output c = ops::Rank(in_cluster_and.WithOpName("c"), b);
476 
477   Output e;
478   TF_ASSERT_OK(
479       CreateOutputWithScope("FakeBinary", {c, c}, root.WithOpName("e"), &e));
480 
481   TF_ASSERT_OK(root.ToGraph(graph.get()));
482   TF_ASSERT_OK(PartiallyDecluster(&graph));
483 
484   Node* n_b = FindNodeByName(*graph, "b");
485   ASSERT_NE(n_b, nullptr);
486   EXPECT_EQ(GetXlaClusterForNode(*n_b), "cluster_0");
487 
488   Node* n_c = FindNodeByName(*graph, "c");
489   ASSERT_NE(n_c, nullptr);
490   EXPECT_EQ(GetXlaClusterForNode(*n_c), "cluster_0");
491 }
492 
TEST(PartiallyDeclusterPassTest,ConstInputsToSliceArentDeclustered)493 TEST(PartiallyDeclusterPassTest, ConstInputsToSliceArentDeclustered) {
494   tensorflow::Scope root = tensorflow::Scope::NewRootScope();
495   tensorflow::Scope in_cluster_and = root.WithXlaCluster("cluster_0");
496   auto graph = std::make_unique<Graph>(OpRegistry::Global());
497   Output a = ops::Placeholder(root.WithOpName("a"), DT_FLOAT,
498                               ops::Placeholder::Attrs{{4}});
499   Output b = ops::Const(in_cluster_and.WithOpName("b"), {1});
500   Output c = ops::Const(in_cluster_and.WithOpName("c"), {2});
501   Output d = ops::Slice(in_cluster_and.WithOpName("d"), a, b, c);
502 
503   TF_ASSERT_OK(root.ToGraph(graph.get()));
504   TF_ASSERT_OK(PartiallyDecluster(&graph));
505 
506   Node* n_b = FindNodeByName(*graph, "b");
507   ASSERT_NE(n_b, nullptr);
508   EXPECT_EQ(GetXlaClusterForNode(*n_b), "cluster_0");
509 
510   Node* n_c = FindNodeByName(*graph, "c");
511   ASSERT_NE(n_c, nullptr);
512   EXPECT_EQ(GetXlaClusterForNode(*n_c), "cluster_0");
513 }
514 
TEST(PartiallyDeclusterPassTest,ConstInLoopWithCrossDeviceControlInputsAreDeclustered)515 TEST(PartiallyDeclusterPassTest,
516      ConstInLoopWithCrossDeviceControlInputsAreDeclustered) {
517   // Based on DontClusterTheSpecialIdentityDrivingConstsInLoop in
518   // mark_for_compilation_pass_test.cc
519   tensorflow::Scope root = tensorflow::Scope::NewRootScope();
520   tensorflow::Scope in_cluster_and = root.WithXlaCluster("cluster_0");
521   auto graph = std::make_unique<Graph>(OpRegistry::Global());
522   Output a = ops::Placeholder(root.WithOpName("a"), DT_FLOAT,
523                               ops::Placeholder::Attrs{{4}});
524   Output b = ops::Const(in_cluster_and.WithOpName("b"), {1});
525   Output c = ops::Const(in_cluster_and.WithOpName("c"), {2});
526   Output slice = ops::Slice(in_cluster_and.WithOpName("slice"), a, b, c);
527   Output cond = ops::Placeholder(root.WithOpName("cond"), DT_BOOL);
528   Output value = ops::Placeholder(root.WithOpName("value"), DT_FLOAT);
529   Output loop_cond = ops::LoopCond(root.WithOpName("loop_cond"), cond);
530   ops::Switch switch_node(root.WithOpName("switch"), value, loop_cond);
531   Output identity =
532       ops::Identity(root.WithOpName("identity"), switch_node.output_true);
533   root.graph()->AddControlEdge(identity.node(), b.node());
534 
535   TF_ASSERT_OK(root.ToGraph(graph.get()));
536 
537   // This is needed to register the XLA_GPU device.
538   std::vector<std::unique_ptr<Device>> devices;
539   TF_ASSERT_OK(DeviceFactory::AddDevices(
540       SessionOptions(), "/job:localhost/replica:0/task:0", &devices));
541 
542   // Scope::ToGraph loses the assigned device name since it goes through
543   // GraphDef/NodeDef which does not have a field for the assigned device name.
544   Node* identity_node = FindNodeByName(*graph, "identity");
545   ASSERT_NE(identity_node, nullptr);
546   identity_node->set_assigned_device_name(
547       "/job:localhost/replica:0/task:0/device:XLA_GPU:0");
548 
549   TF_ASSERT_OK(PartiallyDecluster(&graph));
550 
551   Node* n_b = FindNodeByName(*graph, "b");
552   ASSERT_NE(n_b, nullptr);
553   EXPECT_EQ(GetXlaClusterForNode(*n_b), std::nullopt);
554 
555   Node* n_c = FindNodeByName(*graph, "c");
556   ASSERT_NE(n_c, nullptr);
557   EXPECT_EQ(GetXlaClusterForNode(*n_c), "cluster_0");
558 }
559 
560 }  // namespace
561 }  // namespace tensorflow
562