• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/tf2xla/xla_compiler.h"
17 
18 #include "absl/strings/match.h"
19 #include "tensorflow/cc/framework/ops.h"
20 #include "tensorflow/cc/ops/const_op.h"
21 #include "tensorflow/cc/ops/data_flow_ops.h"
22 #include "tensorflow/cc/ops/function_ops.h"
23 #include "tensorflow/cc/ops/functional_ops.h"
24 #include "tensorflow/cc/ops/list_ops.h"
25 #include "tensorflow/cc/ops/math_ops.h"
26 #include "tensorflow/cc/ops/resource_variable_ops.h"
27 #include "tensorflow/cc/ops/standard_ops.h"
28 #include "tensorflow/compiler/tf2xla/shape_util.h"
29 #include "tensorflow/compiler/tf2xla/side_effect_util.h"
30 #include "tensorflow/compiler/tf2xla/type_util.h"
31 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
32 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
33 #include "tensorflow/compiler/xla/client/client_library.h"
34 #include "tensorflow/compiler/xla/client/local_client.h"
35 #include "tensorflow/compiler/xla/client/xla_builder.h"
36 #include "tensorflow/compiler/xla/literal.h"
37 #include "tensorflow/compiler/xla/service/hlo.pb.h"
38 #include "tensorflow/compiler/xla/shape_util.h"
39 #include "tensorflow/compiler/xla/status_macros.h"
40 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
41 #include "tensorflow/core/common_runtime/function.h"
42 #include "tensorflow/core/common_runtime/graph_constructor.h"
43 #include "tensorflow/core/framework/common_shape_fns.h"
44 #include "tensorflow/core/framework/function.h"
45 #include "tensorflow/core/framework/function.pb.h"
46 #include "tensorflow/core/framework/function_testlib.h"
47 #include "tensorflow/core/framework/graph_to_functiondef.h"
48 #include "tensorflow/core/framework/node_def_util.h"
49 #include "tensorflow/core/framework/resource_mgr.h"
50 #include "tensorflow/core/framework/tensor.h"
51 #include "tensorflow/core/framework/tensor_testutil.h"
52 #include "tensorflow/core/framework/types.pb.h"
53 #include "tensorflow/core/graph/algorithm.h"
54 #include "tensorflow/core/graph/graph.h"
55 #include "tensorflow/core/lib/core/status_test_util.h"
56 #include "tensorflow/core/platform/test.h"
57 #include "tensorflow/core/public/version.h"
58 
59 namespace tensorflow {
60 
61 class XlaCompilerTest : public ::testing::Test {
62  protected:
SetUp()63   void SetUp() override {
64     client_ = xla::ClientLibrary::LocalClientOrDie();
65 
66     XlaOpRegistry::RegisterCompilationKernels();
67 
68     FunctionDefLibrary flib;
69     flib_def_.reset(new FunctionLibraryDefinition(OpRegistry::Global(), flib));
70   }
71 
DefaultOptions()72   XlaCompiler::Options DefaultOptions() {
73     XlaCompiler::Options options;
74     options.device_type = DeviceType(DEVICE_CPU_XLA_JIT);
75     options.client = client_;
76     options.flib_def = flib_def_.get();
77     return options;
78   }
79 
LocalFlibDef(XlaCompiler * compiler)80   FunctionLibraryDefinition* LocalFlibDef(XlaCompiler* compiler) {
81     return compiler->local_flib_def_.get();
82   }
83 
84   xla::Client* client_;
85   std::unique_ptr<FunctionLibraryDefinition> flib_def_;
86 };
87 
88 namespace {
89 
90 // Helper class to test the ability to pass resources through to XLA
91 // compiled kernels.
92 class DummyResourceForTest : public ResourceBase {
93  public:
DebugString() const94   string DebugString() const override { return "dummy"; }
Increment()95   void Increment() { ++value_; }
Get()96   int Get() { return value_; }
97 
98  private:
99   int value_ = 0;
100 };
101 
102 class DummyReadResourceOp : public XlaOpKernel {
103  public:
DummyReadResourceOp(OpKernelConstruction * ctx)104   explicit DummyReadResourceOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
Compile(XlaOpKernelContext * ctx)105   void Compile(XlaOpKernelContext* ctx) override {
106     ResourceMgr* rm = ctx->op_kernel_context()->resource_manager();
107     OP_REQUIRES(ctx, rm, errors::Internal("No resource manager."));
108     DummyResourceForTest* dummy;
109     OP_REQUIRES_OK(ctx, rm->Lookup<DummyResourceForTest>(
110                             rm->default_container(), "dummy", &dummy));
111     dummy->Increment();
112     dummy->Unref();
113 
114     ctx->SetOutput(0, ctx->Input(0));
115     ctx->SetOutput(1, ctx->Input(0));
116   }
117 };
118 
119 class DummyReadResourceCC {
120  public:
DummyReadResourceCC(const Scope & scope,const Input & value)121   DummyReadResourceCC(const Scope& scope, const Input& value) {
122     if (!scope.ok()) return;
123     auto _value = ops::AsNodeOut(scope, value);
124     if (!scope.ok()) return;
125     Node* ret;
126     const auto unique_name = scope.GetUniqueNameForOp("DummyReadResource");
127     auto builder = NodeBuilder(unique_name, "DummyReadResource").Input(_value);
128     scope.UpdateBuilder(&builder);
129     scope.UpdateStatus(builder.Finalize(scope.graph(), &ret));
130     if (!scope.ok()) return;
131     scope.UpdateStatus(scope.DoShapeInference(ret));
132     if (!scope.ok()) return;
133     this->output1_ = Output(ret, 0);
134     this->output2_ = Output(ret, 1);
135   }
136 
137   Output output1_;
138   Output output2_;
139 };
140 
141 REGISTER_OP("DummyReadResource")
142     .Input("input: int32")
143     .Output("output1: int32")
144     .Output("output2: int32")
145     .SetShapeFn(shape_inference::UnknownShape)
146     .Doc(R"doc(
147 A dummy Op.
148 
149 input: dummy input.
150 output1: dummy output.
151 output2: dummy output.
152 )doc");
153 
154 REGISTER_XLA_OP(Name("DummyReadResource"), DummyReadResourceOp);
155 
156 // DummyDuplicateOp is present purely to test multiple REGISTER_XLA_OP calls
157 // on the same Op name below.
158 class DummyDuplicateOp : public XlaOpKernel {
159  public:
DummyDuplicateOp(OpKernelConstruction * ctx)160   explicit DummyDuplicateOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
Compile(XlaOpKernelContext * ctx)161   void Compile(XlaOpKernelContext* ctx) override {
162     ctx->SetOutput(0, ctx->Input(0));
163   }
164 };
165 
166 REGISTER_OP("DummyDuplicateOp")
167     .Input("input: int32")
168     .Output("output: int32")
169     .Doc(R"doc(
170 A dummy Op.
171 
172 input: dummy input.
173 output: dummy output.
174 )doc");
175 
176 REGISTER_XLA_OP(Name("DummyDuplicateOp").Device(DEVICE_CPU_XLA_JIT),
177                 DummyDuplicateOp);
178 REGISTER_XLA_OP(Name("DummyDuplicateOp").Device(DEVICE_GPU_XLA_JIT),
179                 DummyDuplicateOp);
180 
181 // Tests compilation and execution of an empty graph.
TEST_F(XlaCompilerTest,EmptyReturnValues)182 TEST_F(XlaCompilerTest, EmptyReturnValues) {
183   XlaCompiler compiler(DefaultOptions());
184 
185   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
186   XlaCompiler::CompilationResult result;
187   TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add",
188                                      std::move(graph),
189                                      /*args=*/{}, &result));
190 
191   TF_ASSERT_OK(client_->Execute(*result.computation, {}).status());
192 }
193 
194 // Tests compilation and execution of a graph that adds two tensors.
TEST_F(XlaCompilerTest,Simple)195 TEST_F(XlaCompilerTest, Simple) {
196   // Builds a graph that adds two Tensors.
197   Scope scope = Scope::NewRootScope().ExitOnError();
198   auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0);
199   auto b = ops::_Arg(scope.WithOpName("B"), DT_INT32, 1);
200   auto c = ops::Add(scope.WithOpName("C"), a, b);
201   auto d = ops::_Retval(scope.WithOpName("D"), c, 0);
202   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
203   TF_ASSERT_OK(scope.ToGraph(graph.get()));
204 
205   // Builds a description of the arguments.
206   std::vector<XlaCompiler::Argument> args(2);
207   args[0].kind = XlaCompiler::Argument::kParameter;
208   args[0].type = DT_INT32;
209   args[0].shape = TensorShape({2});
210   args[1].kind = XlaCompiler::Argument::kParameter;
211   args[1].type = DT_INT32;
212   args[1].shape = TensorShape({2});
213 
214   // Compiles the graph.
215   XlaCompiler compiler(DefaultOptions());
216 
217   XlaCompiler::CompilationResult result;
218   TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add",
219                                      std::move(graph), args, &result));
220 
221   // Tests that the generated computation works.
222   xla::Literal param0_literal = xla::LiteralUtil::CreateR1<int32>({7, 42});
223   xla::Literal param1_literal = xla::LiteralUtil::CreateR1<int32>({-3, 101});
224   std::unique_ptr<xla::GlobalData> param0_data =
225       client_->TransferToServer(param0_literal).value();
226   std::unique_ptr<xla::GlobalData> param1_data =
227       client_->TransferToServer(param1_literal).value();
228 
229   std::unique_ptr<xla::GlobalData> actual =
230       client_
231           ->Execute(*result.computation, {param0_data.get(), param1_data.get()})
232           .value();
233   xla::Literal actual_literal = client_->Transfer(*actual).value();
234 
235   xla::Literal expected0 = xla::LiteralUtil::CreateR1<int32>({4, 143});
236   xla::Literal expected_literal = xla::LiteralUtil::MakeTuple({&expected0});
237   EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal));
238 }
239 
240 // Tests compilation of a graph where the _Retval node is not necessarily last
241 // amongst the graph nodes in construction order, and always_return_tuple is
242 // false. Regression test for bug where the wrong value was returned.
TEST_F(XlaCompilerTest,OutOfOrderGraph)243 TEST_F(XlaCompilerTest, OutOfOrderGraph) {
244   Scope scope = Scope::NewRootScope().ExitOnError();
245   auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0);
246   auto b = ops::_Arg(scope.WithOpName("B"), DT_INT32, 1);
247   // The _Retval node is not last in construction order.
248   auto d = ops::_Retval(scope.WithOpName("D"), a, 0);
249   auto c = ops::Add(scope.WithOpName("C"), a, b);
250 
251   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
252   TF_ASSERT_OK(scope.ToGraph(graph.get()));
253 
254   // Builds a description of the arguments.
255   std::vector<XlaCompiler::Argument> args(2);
256   args[0].kind = XlaCompiler::Argument::kParameter;
257   args[0].type = DT_INT32;
258   args[0].shape = TensorShape({2});
259   args[1].kind = XlaCompiler::Argument::kParameter;
260   args[1].type = DT_INT32;
261   args[1].shape = TensorShape({2});
262 
263   // Compiles the graph.
264   XlaCompiler compiler(DefaultOptions());
265 
266   XlaCompiler::CompileOptions compile_options;
267   compile_options.always_return_tuple = false;
268   XlaCompiler::CompilationResult result;
269   TF_ASSERT_OK(compiler.CompileGraph(compile_options, "add", std::move(graph),
270                                      args, &result));
271 
272   // Tests that the generated computation works.
273   xla::Literal param0_literal = xla::LiteralUtil::CreateR1<int32>({7, 42});
274   xla::Literal param1_literal = xla::LiteralUtil::CreateR1<int32>({-3, 101});
275   std::unique_ptr<xla::GlobalData> param0_data =
276       client_->TransferToServer(param0_literal).value();
277   std::unique_ptr<xla::GlobalData> param1_data =
278       client_->TransferToServer(param1_literal).value();
279 
280   std::unique_ptr<xla::GlobalData> actual =
281       client_
282           ->Execute(*result.computation, {param0_data.get(), param1_data.get()})
283           .value();
284   xla::Literal actual_literal = client_->Transfer(*actual).value();
285 
286   EXPECT_TRUE(xla::LiteralTestUtil::Equal(param0_literal, actual_literal));
287 }
288 
289 // Tests that the compiler can correctly propagate the layout assigned by
290 // shape_representation_fn_ to resource returns that have not been written to.
TEST_F(XlaCompilerTest,HonorShapeRepresentationFnForUnwrittenResource)291 TEST_F(XlaCompilerTest, HonorShapeRepresentationFnForUnwrittenResource) {
292   Scope scope = Scope::NewRootScope().ExitOnError();
293   auto var = ops::_Arg(scope.WithOpName("V"), DT_RESOURCE, 0);
294   auto d = ops::_Retval(scope.WithOpName("D"), var, 0);
295   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
296   TF_ASSERT_OK(scope.ToGraph(graph.get()));
297 
298   // Builds a description of the arguments.
299   std::vector<XlaCompiler::Argument> args(1);
300   args[0].kind = XlaCompiler::Argument::kResource;
301   args[0].resource_kind = XlaResource::kVariable;
302   args[0].initialized = true;
303   args[0].type = DT_INT32;
304   args[0].shape = TensorShape({2, 3});
305 
306   auto options = DefaultOptions();
307   XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_fns;
308   shape_determination_fns.shape_representation_fn =
309       [](const TensorShape& shape, DataType dt, bool use_fast_memory,
310          XlaLayoutPreference layout_preference) -> StatusOr<xla::Shape> {
311     xla::Shape xla_shape;
312     TF_RETURN_IF_ERROR(TensorShapeToXLAShape(dt, shape, &xla_shape));
313     *xla_shape.mutable_layout() = xla::LayoutUtil::MakeLayout({0, 1});
314     return xla_shape;
315   };
316   options.shape_determination_fns = shape_determination_fns;
317   // Compiles the graph.
318   XlaCompiler compiler(options);
319 
320   XlaCompiler::CompilationResult result;
321   XlaCompiler::CompileOptions compile_options;
322   compile_options.return_updated_values_for_all_resources = true;
323   TF_ASSERT_OK(compiler.CompileGraph(compile_options, "add", std::move(graph),
324                                      args, &result));
325   xla::Shape transposed =
326       xla::ShapeUtil::MakeShapeWithLayout(xla::S32, {2, 3}, {0, 1});
327   // Check that the return shapes are correctly tranposed.
328   EXPECT_EQ(result.xla_output_shape,
329             xla::ShapeUtil::MakeTupleShape({transposed}));
330 }
331 
332 // Tests that the compiler can correctly propagate fast mem attribute for input
333 // resource variable.
TEST_F(XlaCompilerTest,HonorShapeRepresentationFnForFastMemVar)334 TEST_F(XlaCompilerTest, HonorShapeRepresentationFnForFastMemVar) {
335   Scope scope = Scope::NewRootScope().ExitOnError();
336   auto var = ops::_Arg(scope.WithOpName("V"), DT_RESOURCE, 0);
337   auto d = ops::_Retval(scope.WithOpName("D"), var, 0);
338   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
339   TF_ASSERT_OK(scope.ToGraph(graph.get()));
340 
341   // Builds a description of the arguments.
342   std::vector<XlaCompiler::Argument> args(1);
343   args[0].kind = XlaCompiler::Argument::kResource;
344   args[0].resource_kind = XlaResource::kVariable;
345   args[0].initialized = true;
346   args[0].type = DT_INT32;
347   args[0].shape = TensorShape({2, 3});
348   args[0].fast_mem = true;
349 
350   auto options = DefaultOptions();
351   int fast_mem_arg_count = 0;
352   XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_fns;
353   shape_determination_fns.shape_representation_fn =
354       [&fast_mem_arg_count](
355           const TensorShape& shape, DataType dt, bool use_fast_memory,
356           XlaLayoutPreference layout_preference) -> StatusOr<xla::Shape> {
357     xla::Shape xla_shape;
358     TF_RETURN_IF_ERROR(TensorShapeToXLAShape(dt, shape, &xla_shape));
359     *xla_shape.mutable_layout() = xla::LayoutUtil::MakeLayout({0, 1});
360     if (use_fast_memory) {
361       fast_mem_arg_count++;
362     }
363     return xla_shape;
364   };
365   options.shape_determination_fns = shape_determination_fns;
366   // Compiles the graph.
367   XlaCompiler compiler(options);
368 
369   XlaCompiler::CompilationResult result;
370   XlaCompiler::CompileOptions compile_options;
371   compile_options.return_updated_values_for_all_resources = true;
372   TF_ASSERT_OK(compiler.CompileGraph(compile_options, "add", std::move(graph),
373                                      args, &result));
374   // Count 2: one for argument, one for the return value.
375   EXPECT_EQ(fast_mem_arg_count, 2);
376 }
377 
378 // Tests that the compiler can correctly propagate the layout assigned by
379 // shape_representation_fn_ to return types.
TEST_F(XlaCompilerTest,HonorShapeRepresentationFnForRetVal)380 TEST_F(XlaCompilerTest, HonorShapeRepresentationFnForRetVal) {
381   Scope scope = Scope::NewRootScope().ExitOnError();
382   auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0);
383   auto var = ops::_Arg(scope.WithOpName("V"), DT_RESOURCE, 1);
384   // Adds an identity op around the resource to make sure identity ops propagate
385   // resources correctly.
386   auto identity = ops::Identity(scope.WithOpName("VIdentity"), var);
387   auto write = ops::AssignAddVariableOp(scope, identity, a);
388   auto read = ops::ReadVariableOp(
389       scope.WithControlDependencies(std::vector<Operation>{write}), var,
390       DT_INT32);
391   auto read_plus_one = ops::Add(scope, read, ops::Const<int32>(scope, 1));
392   auto d = ops::_Retval(scope.WithOpName("D"), read_plus_one, 0);
393   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
394   TF_ASSERT_OK(scope.ToGraph(graph.get()));
395 
396   // Builds a description of the arguments.
397   std::vector<XlaCompiler::Argument> args(2);
398   args[0].kind = XlaCompiler::Argument::kParameter;
399   args[0].type = DT_INT32;
400   args[0].shape = TensorShape({2, 3});
401   args[1].kind = XlaCompiler::Argument::kResource;
402   args[1].resource_kind = XlaResource::kVariable;
403   args[1].initialized = true;
404   args[1].type = DT_INT32;
405   args[1].shape = TensorShape({2, 3});
406 
407   auto options = DefaultOptions();
408   XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_fns;
409   shape_determination_fns.shape_representation_fn =
410       [](const TensorShape& shape, DataType dt, bool use_fast_memory,
411          XlaLayoutPreference layout_preference) -> StatusOr<xla::Shape> {
412     xla::Shape xla_shape;
413     TF_RETURN_IF_ERROR(TensorShapeToXLAShape(dt, shape, &xla_shape));
414     *xla_shape.mutable_layout() = xla::LayoutUtil::MakeLayout({0, 1});
415     return xla_shape;
416   };
417   options.shape_determination_fns = shape_determination_fns;
418   // Compiles the graph.
419   XlaCompiler compiler(options);
420 
421   XlaCompiler::CompilationResult result;
422   TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add",
423                                      std::move(graph), args, &result));
424   xla::Shape transposed =
425       xla::ShapeUtil::MakeShapeWithLayout(xla::S32, {2, 3}, {0, 1});
426   // Check that the return shapes are correctly tranposed.
427   EXPECT_EQ(result.xla_output_shape,
428             xla::ShapeUtil::MakeTupleShape({transposed, transposed}));
429   EXPECT_EQ(result.computation->GetProgramShape().value().result(),
430             xla::ShapeUtil::MakeTupleShape({transposed, transposed}));
431 }
432 
433 // The layout of resource variable shouldn't change after transpose
TEST_F(XlaCompilerTest,TransposeVariables)434 TEST_F(XlaCompilerTest, TransposeVariables) {
435   Scope scope = Scope::NewRootScope().ExitOnError();
436   auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0);
437   auto var = ops::_Arg(scope.WithOpName("V"), DT_RESOURCE, 1);
438   // Adds an identity op around the resource to make sure identity ops propagate
439   // resources correctly.
440   auto identity = ops::Identity(scope.WithOpName("VIdentity"), var);
441   auto write = ops::AssignAddVariableOp(scope, identity, a);
442   auto read = ops::ReadVariableOp(
443       scope.WithControlDependencies(std::vector<Operation>{write}), var,
444       DT_INT32);
445   auto transposed_read = ops::Transpose(scope, read, {1, 0});
446   auto reshape = ops::Reshape(scope, transposed_read, {2, 3});
447   auto d = ops::_Retval(scope.WithOpName("D"), reshape, 0);
448   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
449   TF_ASSERT_OK(scope.ToGraph(graph.get()));
450 
451   // Builds a description of the arguments.
452   std::vector<XlaCompiler::Argument> args(2);
453   args[0].kind = XlaCompiler::Argument::kParameter;
454   args[0].type = DT_INT32;
455   args[0].shape = TensorShape({2, 3});
456   args[1].kind = XlaCompiler::Argument::kResource;
457   args[1].resource_kind = XlaResource::kVariable;
458   args[1].initialized = true;
459   args[1].type = DT_INT32;
460   args[1].shape = TensorShape({2, 3});
461   // Compiles the graph.
462   XlaCompiler compiler(DefaultOptions());
463 
464   XlaCompiler::CompilationResult result;
465   TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "transpose",
466                                      std::move(graph), args, &result));
467   xla::Shape transposed =
468       xla::ShapeUtil::MakeShapeWithLayout(xla::S32, {2, 3}, {1, 0});
469   // Check that the return shapes are correctly tranposed.
470   EXPECT_EQ(result.xla_output_shape,
471             xla::ShapeUtil::MakeTupleShape({transposed, transposed}));
472 }
473 
474 // Unranked fake param returns a 0 shaped tensor.
TEST_F(XlaCompilerTest,UnrankedFakeParam)475 TEST_F(XlaCompilerTest, UnrankedFakeParam) {
476   Scope scope = Scope::NewRootScope().ExitOnError();
477   PartialTensorShape shape;
478   auto a = ops::FakeParam(scope, DT_INT32, shape);
479   auto ret = ops::_Retval(scope.WithOpName("D"), a, 0);
480   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
481   TF_ASSERT_OK(scope.ToGraph(graph.get()));
482 
483   // Compiles the graph.
484   XlaCompiler compiler(DefaultOptions());
485 
486   XlaCompiler::CompilationResult result;
487   TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "compile",
488                                      std::move(graph), {}, &result));
489   // Check that the return shapes are correctly tranposed.
490   EXPECT_EQ(result.xla_output_shape,
491             xla::ShapeUtil::MakeTupleShape(
492                 {xla::ShapeUtil::MakeShape(xla::S32, {0})}));
493 }
494 
495 // Tests that the compiler doesn't reorder the parameters.
TEST_F(XlaCompilerTest,MixedOrderArguments)496 TEST_F(XlaCompilerTest, MixedOrderArguments) {
497   for (bool swap_order : {false, true}) {
498     Scope scope = Scope::NewRootScope().ExitOnError();
499     auto var =
500         ops::_Arg(scope.WithOpName("V"), DT_RESOURCE, swap_order ? 0 : 1);
501     auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, swap_order ? 1 : 0);
502     // Adds an identity op around the resource to make sure identity ops
503     // propagate resources correctly.
504     auto identity = ops::Identity(scope.WithOpName("VIdentity"), var);
505     auto write = ops::AssignAddVariableOp(scope, identity, a);
506     auto read = ops::ReadVariableOp(
507         scope.WithControlDependencies(std::vector<Operation>{write}), var,
508         DT_INT32);
509     auto read_plus_one = ops::Add(scope, read, ops::Const<int32>(scope, 1));
510     auto d = ops::_Retval(scope.WithOpName("D"), read_plus_one, 0);
511     std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
512     TF_ASSERT_OK(scope.ToGraph(graph.get()));
513 
514     // Builds a description of the arguments.
515     std::vector<XlaCompiler::Argument> args(2);
516     args[0].kind = XlaCompiler::Argument::kParameter;
517     args[0].type = DT_INT32;
518     args[0].shape = TensorShape({2});
519     args[1].kind = XlaCompiler::Argument::kResource;
520     args[1].resource_kind = XlaResource::kVariable;
521     args[1].initialized = true;
522     args[1].type = DT_INT32;
523     args[1].shape = TensorShape({2});
524 
525     if (swap_order) {
526       // Even after swapping arguments, the compiler should maintain the new
527       // ordering of parameters.
528       std::swap(args[0], args[1]);
529     }
530     // Compiles the graph.
531     XlaCompiler compiler(DefaultOptions());
532 
533     XlaCompiler::CompileOptions compile_options;
534     compile_options.always_return_tuple = false;
535     XlaCompiler::CompilationResult result;
536     TF_ASSERT_OK(compiler.CompileGraph(compile_options, "add", std::move(graph),
537                                        args, &result));
538 
539     EXPECT_THAT(result.input_mapping, ::testing::ElementsAre(0, 1));
540   }
541 }
542 
TEST_F(XlaCompilerTest,HasSaneErrorOnNonCompileTimeConstantInputToReshape)543 TEST_F(XlaCompilerTest, HasSaneErrorOnNonCompileTimeConstantInputToReshape) {
544   // Builds a graph that adds reshapes a tensor, but with the shape not
545   // statically known.
546   Scope scope = Scope::NewRootScope().ExitOnError();
547   auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0);
548   auto b = ops::_Arg(scope.WithOpName("B"), DT_INT32, 1);
549   auto c = ops::Reshape(scope.WithOpName("C"), a, b);
550   auto d = ops::_Retval(scope.WithOpName("D"), c, 0);
551   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
552   TF_ASSERT_OK(scope.ToGraph(graph.get()));
553 
554   // Builds a description of the arguments.
555   std::vector<XlaCompiler::Argument> args(2);
556   args[0].kind = XlaCompiler::Argument::kParameter;
557   args[0].type = DT_INT32;
558   args[0].shape = TensorShape({2});
559   args[1].kind = XlaCompiler::Argument::kParameter;
560   args[1].type = DT_INT32;
561   args[1].shape = TensorShape({2});
562 
563   // Compiles the graph.
564   XlaCompiler compiler(DefaultOptions());
565 
566   XlaCompiler::CompilationResult result;
567   Status status =
568       compiler.CompileGraph(XlaCompiler::CompileOptions(), "reshape",
569                             std::move(graph), args, &result);
570   EXPECT_FALSE(status.ok());
571   EXPECT_TRUE(
572       absl::StrContains(status.error_message(), "depends on a parameter"))
573       << status.error_message();
574   EXPECT_TRUE(absl::StrContains(status.error_message(), "{{node C}}"))
575       << status.error_message();
576   EXPECT_TRUE(absl::StrContains(status.error_message(),
577                                 "must be a compile-time constant"))
578       << status.error_message();
579 }
580 
581 // Tests handling of compile-time constant outputs.
TEST_F(XlaCompilerTest,ConstantOutputs)582 TEST_F(XlaCompilerTest, ConstantOutputs) {
583   // Builds a graph with one compile-time constant output and one data-dependent
584   // output, i.e.,
585   // func(a) { b=7; c=-a; return b, c; }
586   Scope scope = Scope::NewRootScope().ExitOnError();
587   auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0);
588   auto b = ops::Const<int32>(scope.WithOpName("B"), 7);
589   auto c = ops::Neg(scope.WithOpName("C"), a);
590   auto d = ops::_Retval(scope.WithOpName("D"), b, 0);
591   auto e = ops::_Retval(scope.WithOpName("E"), c, 1);
592   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
593   TF_ASSERT_OK(scope.ToGraph(graph.get()));
594 
595   // Builds a description of the arguments.
596   std::vector<XlaCompiler::Argument> args(1);
597   args[0].kind = XlaCompiler::Argument::kParameter;
598   args[0].type = DT_INT32;
599   args[0].shape = TensorShape({2});
600 
601   XlaCompiler::Options options = DefaultOptions();
602   XlaCompiler compiler(options);
603 
604   {
605     std::unique_ptr<Graph> graph_copy(new Graph(OpRegistry::Global()));
606     CopyGraph(*graph, graph_copy.get());
607 
608     XlaCompiler::CompileOptions compile_options;
609     XlaCompiler::CompilationResult result;
610     TF_ASSERT_OK(compiler.CompileGraph(compile_options, "constants",
611                                        std::move(graph_copy), args, &result));
612 
613     ASSERT_EQ(2, result.outputs.size());
614     EXPECT_FALSE(result.outputs[0].is_constant);
615     EXPECT_FALSE(result.outputs[1].is_constant);
616 
617     // Tests that the generated computation works.
618     xla::Literal param0_literal = xla::LiteralUtil::CreateR1<int32>({7, 42});
619     std::unique_ptr<xla::GlobalData> param0_data =
620         client_->TransferToServer(param0_literal).value();
621 
622     std::unique_ptr<xla::GlobalData> actual =
623         client_->Execute(*result.computation, {param0_data.get()}).value();
624     xla::Literal actual_literal = client_->Transfer(*actual).value();
625 
626     xla::Literal expected0 = xla::LiteralUtil::CreateR0<int32>(7);
627     xla::Literal expected1 = xla::LiteralUtil::CreateR1<int32>({-7, -42});
628     xla::Literal expected =
629         xla::LiteralUtil::MakeTuple({&expected0, &expected1});
630     EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected, actual_literal));
631   }
632 }
633 
TEST_F(XlaCompilerTest,ConstantOutputsOfFunctionalNode)634 TEST_F(XlaCompilerTest, ConstantOutputsOfFunctionalNode) {
635   // Define a function with one compile-time constant output and one
636   // data-dependent output.
637   // @function.Defun(noinline=True)
638   // foo(a) {b=7; return b, a; }
639   const Tensor seven = test::AsScalar<int>(7);
640   FunctionDef fdef = FunctionDefHelper::Create(
641       "foo", {"a_0:int32"}, {"const:int32", "a:int32"}, {},
642       {
643           {{"Const"}, "Const", {}, {{"dtype", DT_INT32}, {"value", seven}}},
644       },
645       {{"a", "a_0"}, {"const", "Const:output:0"}});
646   (*fdef.mutable_attr())["_noinline"].set_b(true);
647   FunctionDefLibrary fdef_lib;
648   *(fdef_lib.add_function()) = fdef;
649   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
650   {
651     Scope scope = Scope::NewRootScope().ExitOnError();
652     TF_EXPECT_OK(scope.graph()->AddFunctionLibrary(fdef_lib));
653     auto arg = ops::_Arg(scope.WithOpName("input_arg"), DT_INT32, 0);
654     NodeDef foo;
655     foo.set_name("foo");
656     foo.set_op("foo");
657     *foo.add_input() = "input_arg";
658     Status status;
659     scope.graph()->AddNode(foo, &status);
660     TF_ASSERT_OK(status);
661     NodeDef retval_1;
662     retval_1.set_name("retval_0");
663     retval_1.set_op(FunctionLibraryDefinition::kRetOp);
664     *retval_1.add_input() = "foo";
665     (*retval_1.mutable_attr())["T"].set_type(DT_INT32);
666     (*retval_1.mutable_attr())["index"].set_i(0);
667     scope.graph()->AddNode(retval_1, &status);
668     TF_ASSERT_OK(status);
669     NodeDef retval_2;
670     retval_2.set_name("retval_1");
671     retval_2.set_op(FunctionLibraryDefinition::kRetOp);
672     *retval_2.add_input() = "foo:1";
673     (*retval_2.mutable_attr())["T"].set_type(DT_INT32);
674     (*retval_2.mutable_attr())["index"].set_i(1);
675     scope.graph()->AddNode(retval_2, &status);
676     TF_ASSERT_OK(status);
677     TF_ASSERT_OK(scope.ToGraph(graph.get()));
678   }
679 
680   // Builds a description of the arguments.
681   std::vector<XlaCompiler::Argument> args(1);
682   args[0].kind = XlaCompiler::Argument::kParameter;
683   args[0].type = DT_INT32;
684   args[0].shape = TensorShape({1});
685 
686   XlaCompiler::Options options = DefaultOptions();
687   FunctionLibraryDefinition flib_def(OpRegistry::Global(), fdef_lib);
688   options.flib_def = &flib_def;
689   XlaCompiler compiler(options);
690 
691   XlaCompiler::CompileOptions compile_options;
692   XlaCompiler::CompilationResult result;
693   TF_ASSERT_OK(compiler.CompileGraph(compile_options, "constants",
694                                      std::move(graph), args, &result));
695 
696   ASSERT_EQ(2, result.outputs.size());
697   EXPECT_FALSE(result.outputs[1].is_constant);
698 }
699 
700 // Tests compilation and execution of a graph that adds two tensors.
TEST_F(XlaCompilerTest,ResourceManager)701 TEST_F(XlaCompilerTest, ResourceManager) {
702   // Builds a graph that calls the dummy resource Op.
703   Scope scope = Scope::NewRootScope().ExitOnError();
704   auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0);
705   auto b = DummyReadResourceCC(scope.WithOpName("B"), a);
706   auto c = ops::Add(scope.WithOpName("C"), b.output2_, b.output1_);
707   auto d = ops::_Retval(scope.WithOpName("D"), c, 0);
708   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
709   TF_ASSERT_OK(scope.ToGraph(graph.get()));
710 
711   // Builds a description of the argument.
712   std::vector<XlaCompiler::Argument> args(1);
713   args[0].kind = XlaCompiler::Argument::kParameter;
714   args[0].type = DT_INT32;
715   args[0].shape = TensorShape({2});
716 
717   DummyResourceForTest* resource = new DummyResourceForTest();
718 
719   // Compiles the graph.
720   auto options = DefaultOptions();
721   std::function<Status(ResourceMgr*)> populate_function =
722       [resource](ResourceMgr* rm) {
723         resource->Ref();
724         return rm->Create(rm->default_container(), "dummy", resource);
725       };
726   options.populate_resource_manager = &populate_function;
727   XlaCompiler compiler(options);
728 
729   EXPECT_EQ(0, resource->Get());
730 
731   XlaCompiler::CompilationResult result;
732   TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "dummy",
733                                      std::move(graph), args, &result));
734 
735   EXPECT_EQ(1, resource->Get());
736 
737   resource->Unref();
738 }
739 
740 // Tests compilation and execution of a graph that adds two tensors.
TEST_F(XlaCompilerTest,DeterministicCompilation)741 TEST_F(XlaCompilerTest, DeterministicCompilation) {
742   // Builds a graph that contains a node with two output edges. The compiler
743   // should always traverse them in the same order.
744   const int64_t test_count = 2;
745 
746   std::vector<XlaCompiler::CompilationResult> results(test_count);
747 
748   for (int64_t i = 0; i < test_count; ++i) {
749     Scope scope = Scope::NewRootScope().ExitOnError();
750     auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0);
751     auto b = ops::Neg(scope.WithOpName("B"), a);
752     auto c = ops::Neg(scope.WithOpName("C"), a);
753     auto d = ops::Add(scope.WithOpName("D"), b, c);
754     auto e = ops::_Retval(scope.WithOpName("E"), d, 0);
755     std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
756     TF_ASSERT_OK(scope.ToGraph(graph.get()));
757 
758     // Builds a description of the argument.
759     std::vector<XlaCompiler::Argument> args(1);
760     args[0].kind = XlaCompiler::Argument::kParameter;
761     args[0].type = DT_INT32;
762     args[0].shape = TensorShape({2});
763 
764     // Compiles the graph.
765     auto options = DefaultOptions();
766     XlaCompiler compiler(options);
767 
768     TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "dummy",
769                                        std::move(graph), args, &results[i]));
770   }
771 
772   for (int64_t i = 1; i < test_count; ++i) {
773     const auto& m1 = results[i - 1].computation->proto();
774     const auto& m2 = results[i].computation->proto();
775     ASSERT_EQ(m1.computations_size(), m2.computations_size());
776     // Check if every hlo computation is the same.
777     for (int k = 0; k < m1.computations_size(); k++) {
778       const auto& c1 = m1.computations(k);
779       const auto& c2 = m2.computations(k);
780       ASSERT_EQ(c1.instructions_size(), c2.instructions_size());
781       for (int j = 0; j < c1.instructions_size(); j++) {
782         auto instr1 = c1.instructions(j);
783         auto instr2 = c2.instructions(j);
784         instr1.clear_name();
785         instr1.clear_id();
786         instr1.clear_operand_ids();
787         instr2.clear_name();
788         instr2.clear_id();
789         instr2.clear_operand_ids();
790         // The names of instructions were uniquified by the XlaBuilder and the
791         // unique ids may be different, the rest of the fields should be
792         // identical.
793         string str1, str2;
794         LOG(INFO) << "instr1 = " << instr1.DebugString();
795         LOG(INFO) << "instr2 = " << instr2.DebugString();
796         instr1.AppendPartialToString(&str1);
797         instr2.AppendPartialToString(&str2);
798         EXPECT_EQ(str1, str2);
799       }
800     }
801   }
802 }
803 
804 // Tests a computation that receives a TensorArray resource as input and
805 // updates it.
TEST_F(XlaCompilerTest,CanPassTensorArraysToAndFromComputation)806 TEST_F(XlaCompilerTest, CanPassTensorArraysToAndFromComputation) {
807   Scope scope = Scope::NewRootScope().ExitOnError();
808   auto arg = ops::_Arg(scope.WithOpName("arg"), DT_RESOURCE, 0);
809   auto flow = ops::Const<float>(scope, {});
810   auto grad1 = ops::TensorArrayGrad(scope, arg, flow, "grad1");
811   auto grad2 = ops::TensorArrayGrad(scope, arg, grad1.flow_out, "grad2");
812   auto index = ops::Const<int32>(scope, 1);
813   auto write = ops::TensorArrayWrite(scope, grad1.grad_handle, index, index,
814                                      grad2.flow_out);
815   auto read = ops::TensorArrayRead(scope, arg, index, write.flow_out, DT_INT32);
816   auto retval = ops::_Retval(scope.WithOpName("retval"), read, 0);
817   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
818   TF_ASSERT_OK(scope.ToGraph(graph.get()));
819 
820   // Builds a description of the arguments.
821   std::vector<XlaCompiler::Argument> args(1);
822   args[0].kind = XlaCompiler::Argument::kResource;
823   args[0].resource_kind = XlaResource::kTensorArray;
824   args[0].initialized = true;
825   args[0].type = DT_INT32;
826   args[0].shape = TensorShape({});
827   args[0].max_array_size = 2;
828   args[0].tensor_array_gradients = {"grad2"};
829 
830   // Compiles the graph.
831   XlaCompiler compiler(DefaultOptions());
832 
833   XlaCompiler::CompilationResult result;
834   TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add",
835                                      std::move(graph), args, &result));
836 
837   ASSERT_EQ(1, result.resource_updates.size());
838   const XlaCompiler::ResourceUpdate& update = result.resource_updates[0];
839   EXPECT_EQ(0, update.input_index);
840   EXPECT_EQ(DT_INT32, update.type);
841   EXPECT_EQ((std::set<string>{"grad1", "grad2"}),
842             update.tensor_array_gradients_accessed);
843 
844   // Tests that the generated computation works.
845   xla::Literal input_base = xla::LiteralUtil::CreateR1<int32>({7, 42});
846   xla::Literal input_grad2 = xla::LiteralUtil::CreateR1<int32>({-3, 101});
847   xla::Literal input = xla::LiteralUtil::MakeTuple({&input_base, &input_grad2});
848   std::unique_ptr<xla::GlobalData> param0_data =
849       client_->TransferToServer(input).value();
850 
851   std::unique_ptr<xla::GlobalData> actual =
852       client_->Execute(*result.computation, {param0_data.get()}).value();
853   xla::Literal actual_literal = client_->Transfer(*actual).value();
854 
855   xla::Literal output_read = xla::LiteralUtil::CreateR0<int32>(42);
856   xla::Literal output_base = xla::LiteralUtil::CreateR1<int32>({7, 42});
857   xla::Literal output_grad1 = xla::LiteralUtil::CreateR1<int32>({0, 1});
858   xla::Literal output_grad2 = xla::LiteralUtil::CreateR1<int32>({-3, 101});
859   xla::Literal output_resource =
860       xla::LiteralUtil::MakeTuple({&output_base, &output_grad1, &output_grad2});
861   xla::Literal expected_literal =
862       xla::LiteralUtil::MakeTuple({&output_read, &output_resource});
863   EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal));
864 }
865 
866 // Tests compilation and execution of a graph that adds two tensors.
TEST_F(XlaCompilerTest,UnwrittenTensorArrayGradientsAreNotComputationOutputs)867 TEST_F(XlaCompilerTest, UnwrittenTensorArrayGradientsAreNotComputationOutputs) {
868   Scope scope = Scope::NewRootScope().ExitOnError();
869   auto arg = ops::_Arg(scope.WithOpName("arg"), DT_RESOURCE, 0);
870   auto flow = ops::Const<float>(scope, {});
871   auto grad1 = ops::TensorArrayGrad(scope, arg, flow, "grad1");
872   auto index = ops::Const<int32>(scope, 1);
873   auto read = ops::TensorArrayRead(scope, arg, index, grad1.flow_out, DT_INT32);
874   auto retval = ops::_Retval(scope.WithOpName("retval"), read, 0);
875   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
876   TF_ASSERT_OK(scope.ToGraph(graph.get()));
877 
878   // Builds a description of the arguments.
879   std::vector<XlaCompiler::Argument> args(1);
880   args[0].kind = XlaCompiler::Argument::kResource;
881   args[0].resource_kind = XlaResource::kTensorArray;
882   args[0].initialized = true;
883   args[0].type = DT_INT32;
884   args[0].shape = TensorShape({});
885   args[0].max_array_size = 2;
886   args[0].tensor_array_gradients = {"grad1"};
887 
888   // Compiles the graph.
889   XlaCompiler compiler(DefaultOptions());
890 
891   XlaCompiler::CompilationResult result;
892   TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add",
893                                      std::move(graph), args, &result));
894 
895   EXPECT_EQ(0, result.resource_updates.size());
896 }
897 
898 // Tests compilation and execution of a graph that adds two tensors.
TEST_F(XlaCompilerTest,NewTensorArrayGradientsAreComputationOutputs)899 TEST_F(XlaCompilerTest, NewTensorArrayGradientsAreComputationOutputs) {
900   Scope scope = Scope::NewRootScope().ExitOnError();
901   auto arg = ops::_Arg(scope.WithOpName("arg"), DT_RESOURCE, 0);
902   auto flow = ops::Const<float>(scope, {});
903   auto grad1 = ops::TensorArrayGrad(scope, arg, flow, "grad2");
904   auto index = ops::Const<int32>(scope, 1);
905   auto read = ops::TensorArrayRead(scope, arg, index, grad1.flow_out, DT_INT32);
906   auto retval = ops::_Retval(scope.WithOpName("retval"), read, 0);
907   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
908   TF_ASSERT_OK(scope.ToGraph(graph.get()));
909 
910   // Builds a description of the arguments.
911   std::vector<XlaCompiler::Argument> args(1);
912   args[0].kind = XlaCompiler::Argument::kResource;
913   args[0].resource_kind = XlaResource::kTensorArray;
914   args[0].initialized = true;
915   args[0].type = DT_INT32;
916   args[0].shape = TensorShape({});
917   args[0].max_array_size = 2;
918   args[0].tensor_array_gradients = {"grad1"};
919 
920   // Compiles the graph.
921   XlaCompiler compiler(DefaultOptions());
922 
923   XlaCompiler::CompilationResult result;
924   TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add",
925                                      std::move(graph), args, &result));
926 
927   EXPECT_EQ(1, result.resource_updates.size());
928 }
929 
930 // Tests CompileFunction with undefined function fails.
TEST_F(XlaCompilerTest,UndefinedFunctionFails)931 TEST_F(XlaCompilerTest, UndefinedFunctionFails) {
932   XlaCompiler compiler(DefaultOptions());
933 
934   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
935   XlaCompiler::CompilationResult result;
936   NameAttrList name_attr;
937   name_attr.set_name("Function_NotDefined_");
938   Status status =
939       compiler.CompileFunction(XlaCompiler::CompileOptions(), name_attr,
940                                /*args=*/{}, &result);
941   EXPECT_FALSE(status.ok());
942   EXPECT_TRUE(absl::StrContains(status.error_message(), "is not defined."))
943       << status.error_message();
944 }
945 
FillFn()946 FunctionDef FillFn() {
947   return FunctionDefHelper::Define(
948       // Name
949       "FillFn",
950       // Args
951       {"x: T", "dims: int32"},
952       // Return values
953       {"y: T"},
954       // Attr def
955       {"T: {float, double, int32, int64}"},
956       // Nodes
957       {{{"y"}, "Fill", {"dims", "x"}, {{"T", "$T"}}}});
958 }
959 
TEST_F(XlaCompilerTest,FunctionCallWithConstants)960 TEST_F(XlaCompilerTest, FunctionCallWithConstants) {
961   // Certain operations in a function, "Fill" for example, requires the
962   // operator's argument to be a compile-time constant instead of a parameter.
963   // This testcase tests if XlaCompiler can handle such operators inside
964   // function calls.
965   XlaCompiler compiler(DefaultOptions());
966 
967   FunctionDefLibrary flib;
968   *flib.add_function() = FillFn();
969 
970   TF_ASSERT_OK(flib_def_->AddFunctionDef(FillFn()));
971 
972   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
973 
974   Scope scope = Scope::NewRootScope().ExitOnError();
975   auto value = ops::Const<int32>(scope.WithOpName("value"), 1, {});
976   auto shape = ops::Const<int32>(scope.WithOpName("shape"), {5}, {1});
977   TF_EXPECT_OK(scope.graph()->AddFunctionLibrary(flib));
978 
979   NodeDef def;
980   TF_ASSERT_OK(NodeDefBuilder("fill", "FillFn", flib_def_.get())
981                    .Input(value.name(), 0, DT_INT32)
982                    .Input(shape.name(), 1, DT_INT32)
983                    .Finalize(&def));
984   Status status;
985   Node* fill = scope.graph()->AddNode(def, &status);
986   TF_ASSERT_OK(status);
987   TF_ASSERT_OK(scope.DoShapeInference(fill));
988   scope.graph()->AddEdge(value.node(), 0, fill, 0);
989   scope.graph()->AddEdge(shape.node(), 0, fill, 1);
990 
991   auto retval = ops::_Retval(scope.WithOpName("retval"), Output(fill), 0);
992 
993   TF_ASSERT_OK(scope.ToGraph(graph.get()));
994 
995   // Builds a description of the argument.
996   std::vector<XlaCompiler::Argument> args;
997 
998   XlaCompiler::CompilationResult result;
999   TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "fill",
1000                                      std::move(graph), args, &result));
1001 }
1002 
1003 // Tests CompileFunction with a local function lookup failing, fails with
1004 // informative error about both lookups.
TEST_F(XlaCompilerTest,LocalFunctionWithWrongArgumentsFail)1005 TEST_F(XlaCompilerTest, LocalFunctionWithWrongArgumentsFail) {
1006   XlaCompiler compiler(DefaultOptions());
1007 
1008   auto local_flib_def = LocalFlibDef(&compiler);
1009   TF_ASSERT_OK(local_flib_def->AddFunctionDef(test::function::XTimesTwo()));
1010 
1011   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1012   XlaCompiler::CompilationResult result;
1013   NameAttrList name_attr;
1014   name_attr.set_name("XTimesTwo");
1015   Status status =
1016       compiler.CompileFunction(XlaCompiler::CompileOptions(), name_attr,
1017                                /*args=*/{}, &result);
1018 
1019   ASSERT_FALSE(status.ok());
1020   // Flib lookup failure.
1021   EXPECT_TRUE(absl::StrContains(status.error_message(), "is not defined."))
1022       << status.error_message();
1023   // Local flib lookup failure.
1024   EXPECT_TRUE(absl::StrContains(status.error_message(), "Attr T is not found"))
1025       << status.error_message();
1026 }
1027 
SliceFn()1028 FunctionDef SliceFn() {
1029   return FunctionDefHelper::Define(
1030       // Name
1031       "SliceFn",
1032       // Args
1033       {"x: T", "begin: Index", "size: Index"},
1034       // Return values
1035       {"y: T"},
1036       // Attr def
1037       {"T: {float, double, int32, int64}", "Index: {int32,int64}"},
1038       // Nodes
1039       {{{"y"},
1040         "Slice",
1041         {"x", "begin", "size"},
1042         {{"T", "$T"}, {"Index", "$Index"}}}});
1043 }
1044 
TEST_F(XlaCompilerTest,SliceWithDynamicBegins)1045 TEST_F(XlaCompilerTest, SliceWithDynamicBegins) {
1046   // Certain operations in a function, "Slice" for example, support both dynamic
1047   // inputs and static inputs. This test checks that dynamic inputs can also
1048   // be supported in a function call.
1049   XlaCompiler compiler(DefaultOptions());
1050 
1051   FunctionDefLibrary flib;
1052   *flib.add_function() = SliceFn();
1053 
1054   TF_ASSERT_OK(flib_def_->AddFunctionDef(SliceFn()));
1055 
1056   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1057 
1058   Scope scope = Scope::NewRootScope().ExitOnError();
1059   auto value = ops::Const<int32>(scope.WithOpName("shape"), {5}, {1});
1060   auto begin = ops::_Arg(scope.WithOpName("arg"), DT_INT32, 0);
1061   auto size = ops::Const<int32>(scope.WithOpName("value"), {1}, {1});
1062 
1063   TF_EXPECT_OK(scope.graph()->AddFunctionLibrary(flib));
1064 
1065   NodeDef def;
1066   TF_ASSERT_OK(NodeDefBuilder("slice", "SliceFn", flib_def_.get())
1067                    .Input(value.name(), 0, DT_INT32)
1068                    .Input(begin.node()->name(), 1, DT_INT32)
1069                    .Input(size.name(), 2, DT_INT32)
1070                    .Finalize(&def));
1071   Status status;
1072   Node* slice = scope.graph()->AddNode(def, &status);
1073   TF_ASSERT_OK(status);
1074   TF_ASSERT_OK(scope.DoShapeInference(slice));
1075   scope.graph()->AddEdge(value.node(), 0, slice, 0);
1076   scope.graph()->AddEdge(begin.node(), 0, slice, 1);
1077   scope.graph()->AddEdge(size.node(), 0, slice, 2);
1078 
1079   auto retval = ops::_Retval(scope.WithOpName("retval"), Output(slice), 0);
1080 
1081   TF_ASSERT_OK(scope.ToGraph(graph.get()));
1082 
1083   // Builds a description of the argument.
1084   std::vector<XlaCompiler::Argument> args(1);
1085   args[0].kind = XlaCompiler::Argument::kParameter;
1086   args[0].type = DT_INT32;
1087   args[0].shape = TensorShape({1});
1088 
1089   XlaCompiler::CompilationResult result;
1090   TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "slice",
1091                                      std::move(graph), args, &result));
1092 }
1093 
RunAndCheckVariablesComputation(xla::Client * client,const XlaCompiler::CompilationResult & result)1094 void RunAndCheckVariablesComputation(
1095     xla::Client* client, const XlaCompiler::CompilationResult& result) {
1096   xla::Literal param0_literal = xla::LiteralUtil::CreateR1<int32>({7, 42});
1097   xla::Literal param1_literal = xla::LiteralUtil::CreateR1<int32>({-3, 101});
1098   std::unique_ptr<xla::GlobalData> param0_data =
1099       client->TransferToServer(param0_literal).value();
1100   std::unique_ptr<xla::GlobalData> param1_data =
1101       client->TransferToServer(param1_literal).value();
1102 
1103   std::unique_ptr<xla::GlobalData> actual =
1104       client
1105           ->Execute(*result.computation, {param0_data.get(), param1_data.get()})
1106           .value();
1107   xla::Literal actual_literal = client->Transfer(*actual).value();
1108 
1109   xla::Literal expected0 = xla::LiteralUtil::CreateR1<int32>({5, 144});
1110   xla::Literal expected1 = xla::LiteralUtil::CreateR1<int32>({4, 143});
1111   xla::Literal expected_literal =
1112       xla::LiteralUtil::MakeTuple({&expected0, &expected1});
1113   EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal));
1114 }
1115 
1116 // Tests a simple graph that reads and writes a variable.
TEST_F(XlaCompilerTest,Variables)1117 TEST_F(XlaCompilerTest, Variables) {
1118   Scope scope = Scope::NewRootScope().ExitOnError();
1119   auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0);
1120   auto var = ops::_Arg(scope.WithOpName("V"), DT_RESOURCE, 1);
1121   // Adds an identity op around the resource to make sure identity ops propagate
1122   // resources correctly.
1123   auto identity = ops::Identity(scope.WithOpName("VIdentity"), var);
1124   auto write = ops::AssignAddVariableOp(scope, identity, a);
1125   auto read = ops::ReadVariableOp(
1126       scope.WithControlDependencies(std::vector<Operation>{write}), var,
1127       DT_INT32);
1128   auto read_plus_one = ops::Add(scope, read, ops::Const<int32>(scope, 1));
1129   auto d = ops::_Retval(scope.WithOpName("D"), read_plus_one, 0);
1130   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1131   TF_ASSERT_OK(scope.ToGraph(graph.get()));
1132 
1133   // Builds a description of the arguments.
1134   std::vector<XlaCompiler::Argument> args(2);
1135   args[0].kind = XlaCompiler::Argument::kParameter;
1136   args[0].type = DT_INT32;
1137   args[0].shape = TensorShape({2});
1138   args[1].kind = XlaCompiler::Argument::kResource;
1139   args[1].resource_kind = XlaResource::kVariable;
1140   args[1].initialized = true;
1141   args[1].type = DT_INT32;
1142   args[1].shape = TensorShape({2});
1143 
1144   // Compiles the graph.
1145   XlaCompiler compiler(DefaultOptions());
1146 
1147   XlaCompiler::CompilationResult result;
1148   TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add",
1149                                      std::move(graph), args, &result));
1150   RunAndCheckVariablesComputation(client_, result);
1151 }
1152 
TEST_F(XlaCompilerTest,ResultLayoutSingle)1153 TEST_F(XlaCompilerTest, ResultLayoutSingle) {
1154   Scope scope = Scope::NewRootScope().ExitOnError();
1155   auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0);
1156   auto b = ops::_Retval(scope.WithOpName("RET"), a, 0);
1157 
1158   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1159   TF_ASSERT_OK(scope.ToGraph(graph.get()));
1160 
1161   // Builds a description of the arguments.
1162   std::vector<XlaCompiler::Argument> args(1);
1163   args[0].kind = XlaCompiler::Argument::kParameter;
1164   args[0].type = DT_INT32;
1165   args[0].shape = TensorShape({2, 3});
1166 
1167   auto options = DefaultOptions();
1168   // Sets the representation function to return a non-default layout.
1169   XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_fns;
1170   shape_determination_fns.shape_representation_fn =
1171       [](const TensorShape& shape, DataType type, bool use_fast_memory,
1172          XlaLayoutPreference layout_preference) -> StatusOr<xla::Shape> {
1173     xla::Shape xla_shape;
1174     TF_RETURN_IF_ERROR(TensorShapeToXLAShape(type, shape, &xla_shape));
1175     *xla_shape.mutable_layout() = xla::LayoutUtil::MakeLayout({0, 1});
1176     return xla_shape;
1177   };
1178   options.shape_determination_fns = shape_determination_fns;
1179 
1180   // Compiles the graph.
1181   XlaCompiler compiler(options);
1182 
1183   XlaCompiler::CompilationResult result;
1184   auto compile_options = XlaCompiler::CompileOptions();
1185   compile_options.always_return_tuple = false;
1186   TF_ASSERT_OK(compiler.CompileGraph(compile_options, "id", std::move(graph),
1187                                      args, &result));
1188   EXPECT_TRUE(xla::ShapeUtil::Equal(
1189       result.xla_output_shape,
1190       xla::ShapeUtil::MakeShapeWithLayout(xla::S32, {2, 3}, {0, 1})));
1191   EXPECT_EQ(result.computation->GetProgramShape().value().result(),
1192             result.xla_output_shape);
1193 }
1194 
TEST_F(XlaCompilerTest,ResultLayoutMultiple)1195 TEST_F(XlaCompilerTest, ResultLayoutMultiple) {
1196   Scope scope = Scope::NewRootScope().ExitOnError();
1197   auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0);
1198   auto b = ops::_Retval(scope.WithOpName("RET1"), a, 0);
1199   auto c = ops::_Retval(scope.WithOpName("RET2"), a, 1);
1200 
1201   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1202   TF_ASSERT_OK(scope.ToGraph(graph.get()));
1203 
1204   // Builds a description of the arguments.
1205   std::vector<XlaCompiler::Argument> args(1);
1206   args[0].kind = XlaCompiler::Argument::kParameter;
1207   args[0].type = DT_INT32;
1208   args[0].shape = TensorShape({2, 3});
1209 
1210   auto options = DefaultOptions();
1211   // Sets the representation function to return a non-default layout.
1212   XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_fns;
1213   shape_determination_fns.shape_representation_fn =
1214       [](const TensorShape& shape, DataType type, bool use_fast_memory,
1215          XlaLayoutPreference layout_preference) -> StatusOr<xla::Shape> {
1216     xla::Shape xla_shape;
1217     TF_RETURN_IF_ERROR(TensorShapeToXLAShape(type, shape, &xla_shape));
1218     *xla_shape.mutable_layout() = xla::LayoutUtil::MakeLayout({0, 1});
1219     return xla_shape;
1220   };
1221   shape_determination_fns.layout_preference_fn = UseNoPreferenceLayoutFn();
1222   options.shape_determination_fns = shape_determination_fns;
1223 
1224   // Compiles the graph.
1225   XlaCompiler compiler(options);
1226 
1227   XlaCompiler::CompilationResult result;
1228   TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "id",
1229                                      std::move(graph), args, &result));
1230   xla::Shape result_shape =
1231       xla::ShapeUtil::MakeShapeWithLayout(xla::S32, {2, 3}, {0, 1});
1232 
1233   EXPECT_TRUE(xla::ShapeUtil::Equal(
1234       result.xla_output_shape,
1235       xla::ShapeUtil::MakeTupleShape({result_shape, result_shape})));
1236   EXPECT_EQ(result.computation->GetProgramShape().value().result(),
1237             result.xla_output_shape);
1238 }
1239 
1240 // Tests a simple graph that reads and writes a variable.
TEST_F(XlaCompilerTest,ReturnResourceHandleOnly)1241 TEST_F(XlaCompilerTest, ReturnResourceHandleOnly) {
1242   Scope scope = Scope::NewRootScope().ExitOnError();
1243   auto var = ops::_Arg(scope.WithOpName("V"), DT_RESOURCE, 0);
1244   auto d = ops::_Retval(scope.WithOpName("D"), var, 0);
1245   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1246   TF_ASSERT_OK(scope.ToGraph(graph.get()));
1247 
1248   // Builds a description of the arguments.
1249   std::vector<XlaCompiler::Argument> args(1);
1250   args[0].kind = XlaCompiler::Argument::kResource;
1251   args[0].resource_kind = XlaResource::kVariable;
1252   args[0].initialized = true;
1253   args[0].type = DT_INT32;
1254   args[0].shape = TensorShape({2});
1255 
1256   // Compiles the graph.
1257   XlaCompiler compiler(DefaultOptions());
1258 
1259   XlaCompiler::CompilationResult result;
1260   TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add",
1261                                      std::move(graph), args, &result));
1262 
1263   // Tests that the generated computation works.
1264   xla::Literal param1_literal = xla::LiteralUtil::CreateR1<int32>({-3, 101});
1265   std::unique_ptr<xla::GlobalData> param1_data =
1266       client_->TransferToServer(param1_literal).value();
1267 
1268   std::unique_ptr<xla::GlobalData> actual =
1269       client_->Execute(*result.computation, {param1_data.get()}).value();
1270   xla::Literal actual_literal = client_->Transfer(*actual).value();
1271 
1272   xla::Literal expected_literal = xla::LiteralUtil::MakeTuple({});
1273   EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal));
1274 }
1275 
TEST_F(XlaCompilerTest,ReturnResourceHandle)1276 TEST_F(XlaCompilerTest, ReturnResourceHandle) {
1277   Scope scope = Scope::NewRootScope().ExitOnError();
1278   auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0);
1279   auto var = ops::_Arg(scope.WithOpName("V"), DT_RESOURCE, 1);
1280   // Adds an identity op around the resource to make sure identity ops propagate
1281   // resources correctly.
1282   auto identity = ops::Identity(scope.WithOpName("VIdentity"), var);
1283   auto write = ops::AssignAddVariableOp(scope, identity, a);
1284   auto read = ops::ReadVariableOp(
1285       scope.WithControlDependencies(std::vector<Operation>{write}), var,
1286       DT_INT32);
1287   auto read_plus_one = ops::Add(scope, read, ops::Const<int32>(scope, 1));
1288   auto r = ops::_Retval(scope.WithOpName("R"), var, 0);
1289   auto d = ops::_Retval(scope.WithOpName("D"), read_plus_one, 1);
1290 
1291   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1292   TF_ASSERT_OK(scope.ToGraph(graph.get()));
1293 
1294   // Builds a description of the arguments.
1295   std::vector<XlaCompiler::Argument> args(2);
1296   args[0].kind = XlaCompiler::Argument::kParameter;
1297   args[0].type = DT_INT32;
1298   args[0].shape = TensorShape({2});
1299   args[1].kind = XlaCompiler::Argument::kResource;
1300   args[1].resource_kind = XlaResource::kVariable;
1301   args[1].initialized = true;
1302   args[1].type = DT_INT32;
1303   args[1].shape = TensorShape({2});
1304 
1305   // Compiles the graph.
1306   XlaCompiler compiler(DefaultOptions());
1307 
1308   XlaCompiler::CompilationResult result;
1309   TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add",
1310                                      std::move(graph), args, &result));
1311   RunAndCheckVariablesComputation(client_, result);
1312 }
1313 
BuildTestGraph()1314 StatusOr<std::unique_ptr<Graph>> BuildTestGraph() {
1315   Scope scope = Scope::NewRootScope().ExitOnError();
1316   auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0);
1317   auto var = ops::_Arg(scope.WithOpName("V"), DT_RESOURCE, 1);
1318   auto write = ops::AssignAddVariableOp(scope, var, a);
1319   auto read = ops::ReadVariableOp(
1320       scope.WithControlDependencies(std::vector<Operation>{write}), var,
1321       DT_INT32);
1322   auto read_plus_one = ops::Add(scope, read, ops::Const<int32>(scope, 1));
1323   auto d = ops::_Retval(scope.WithOpName("D"), read_plus_one, 0);
1324   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1325   TF_RETURN_IF_ERROR(scope.ToGraph(graph.get()));
1326   return std::move(graph);
1327 }
1328 
1329 // Tests a simple graph that reads and writes a variable, with a
1330 // shape_representation_fn passed to the compiler that flattens all
1331 // variable tensors to vectors.
TEST_F(XlaCompilerTest,VariableRepresentationShapeFunction)1332 TEST_F(XlaCompilerTest, VariableRepresentationShapeFunction) {
1333   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Graph> graph, BuildTestGraph());
1334 
1335   // Builds a description of the arguments.
1336   std::vector<XlaCompiler::Argument> args(2);
1337   args[0].kind = XlaCompiler::Argument::kParameter;
1338   args[0].type = DT_INT32;
1339   args[0].shape = TensorShape({2, 2});
1340   args[1].kind = XlaCompiler::Argument::kResource;
1341   args[1].resource_kind = XlaResource::kVariable;
1342   args[1].initialized = true;
1343   args[1].type = DT_INT32;
1344   args[1].shape = TensorShape({2, 2});
1345 
1346   // Compiles the graph.
1347   XlaCompiler::Options options = DefaultOptions();
1348   XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_fns;
1349   shape_determination_fns.shape_representation_fn =
1350       [](const TensorShape& shape, DataType type, bool use_fast_memory,
1351          XlaLayoutPreference layout_preference) -> StatusOr<xla::Shape> {
1352     xla::PrimitiveType ptype;
1353     TF_RETURN_IF_ERROR(DataTypeToPrimitiveType(type, &ptype));
1354     return xla::ShapeUtil::MakeShape(ptype, {shape.num_elements()});
1355   };
1356   options.shape_determination_fns = shape_determination_fns;
1357   XlaCompiler compiler(options);
1358 
1359   XlaCompiler::CompileOptions compile_options;
1360   compile_options.is_entry_computation = false;  // Only reshape variables.
1361 
1362   XlaCompiler::CompilationResult result;
1363   TF_ASSERT_OK(compiler.CompileGraph(compile_options, "add", std::move(graph),
1364                                      args, &result));
1365 
1366   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::ProgramShape> program_shape,
1367                           client_->GetComputationShape(*result.computation));
1368 
1369   ASSERT_EQ(program_shape->parameters_size(), 2);
1370   EXPECT_TRUE(
1371       xla::ShapeUtil::Compatible(program_shape->parameters(0),
1372                                  xla::ShapeUtil::MakeShape(xla::S32, {2, 2})));
1373   EXPECT_TRUE(xla::ShapeUtil::Compatible(
1374       program_shape->parameters(1), xla::ShapeUtil::MakeShape(xla::S32, {4})));
1375   EXPECT_TRUE(xla::ShapeUtil::Compatible(
1376       program_shape->result(),
1377       xla::ShapeUtil::MakeTupleShape(
1378           {xla::ShapeUtil::MakeShape(xla::S32, {2, 2}),
1379            xla::ShapeUtil::MakeShape(xla::S32, {4})})));
1380 
1381   // Tests that the generated computation works.
1382   xla::Literal param0_literal =
1383       xla::LiteralUtil::CreateR2<int32>({{4, 55}, {1, -3}});
1384   xla::Literal param1_literal =
1385       xla::LiteralUtil::CreateR1<int32>({22, 11, 33, 404});
1386   std::unique_ptr<xla::GlobalData> param0_data =
1387       client_->TransferToServer(param0_literal).value();
1388   std::unique_ptr<xla::GlobalData> param1_data =
1389       client_->TransferToServer(param1_literal).value();
1390 
1391   std::unique_ptr<xla::GlobalData> actual =
1392       client_
1393           ->Execute(*result.computation, {param0_data.get(), param1_data.get()})
1394           .value();
1395   xla::Literal actual_literal = client_->Transfer(*actual).value();
1396 
1397   xla::Literal expected0 =
1398       xla::LiteralUtil::CreateR2<int32>({{27, 67}, {35, 402}});
1399   xla::Literal expected1 = xla::LiteralUtil::CreateR1<int32>({26, 66, 34, 401});
1400   xla::Literal expected_literal =
1401       xla::LiteralUtil::MakeTuple({&expected0, &expected1});
1402   EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal));
1403 }
1404 
TEST_F(XlaCompilerTest,ArgRetvalShapeRepresentationFunction)1405 TEST_F(XlaCompilerTest, ArgRetvalShapeRepresentationFunction) {
1406   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Graph> graph, BuildTestGraph());
1407 
1408   // Builds a description of the arguments.
1409   std::vector<XlaCompiler::Argument> args(2);
1410   args[0].kind = XlaCompiler::Argument::kParameter;
1411   args[0].type = DT_INT32;
1412   args[0].shape = TensorShape({2, 2});
1413   args[1].kind = XlaCompiler::Argument::kResource;
1414   args[1].resource_kind = XlaResource::kVariable;
1415   args[1].initialized = true;
1416   args[1].type = DT_INT32;
1417   args[1].shape = TensorShape({2, 2});
1418 
1419   // Compiles the graph.
1420   XlaCompiler::Options options = DefaultOptions();
1421   XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_fns;
1422   shape_determination_fns.shape_representation_fn =
1423       [](const TensorShape& shape, DataType type, bool use_fast_memory,
1424          XlaLayoutPreference layout_preference) -> StatusOr<xla::Shape> {
1425     xla::PrimitiveType ptype;
1426     TF_RETURN_IF_ERROR(DataTypeToPrimitiveType(type, &ptype));
1427     return xla::ShapeUtil::MakeShape(ptype, {shape.num_elements()});
1428   };
1429   options.shape_determination_fns = shape_determination_fns;
1430   XlaCompiler compiler(options);
1431 
1432   XlaCompiler::CompileOptions compile_options;
1433   compile_options.is_entry_computation = true;  // Reshape args and retvals.
1434 
1435   XlaCompiler::CompilationResult result;
1436   TF_ASSERT_OK(compiler.CompileGraph(compile_options, "add", std::move(graph),
1437                                      args, &result));
1438 
1439   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::ProgramShape> program_shape,
1440                           client_->GetComputationShape(*result.computation));
1441 
1442   ASSERT_EQ(program_shape->parameters_size(), 2);
1443   EXPECT_TRUE(xla::ShapeUtil::Compatible(
1444       program_shape->parameters(0), xla::ShapeUtil::MakeShape(xla::S32, {4})));
1445   EXPECT_TRUE(xla::ShapeUtil::Compatible(
1446       program_shape->parameters(1), xla::ShapeUtil::MakeShape(xla::S32, {4})));
1447   EXPECT_TRUE(xla::ShapeUtil::Compatible(
1448       program_shape->result(),
1449       xla::ShapeUtil::MakeTupleShape(
1450           {xla::ShapeUtil::MakeShape(xla::S32, {4}),
1451            xla::ShapeUtil::MakeShape(xla::S32, {4})})));
1452 
1453   // Tests that the generated computation works.
1454   xla::Literal param0_literal =
1455       xla::LiteralUtil::CreateR1<int32>({4, 55, 1, -3});
1456   xla::Literal param1_literal =
1457       xla::LiteralUtil::CreateR1<int32>({22, 11, 33, 404});
1458   std::unique_ptr<xla::GlobalData> param0_data =
1459       client_->TransferToServer(param0_literal).value();
1460   std::unique_ptr<xla::GlobalData> param1_data =
1461       client_->TransferToServer(param1_literal).value();
1462 
1463   std::unique_ptr<xla::GlobalData> actual =
1464       client_
1465           ->Execute(*result.computation, {param0_data.get(), param1_data.get()})
1466           .value();
1467   xla::Literal actual_literal = client_->Transfer(*actual).value();
1468 
1469   xla::Literal expected0 = xla::LiteralUtil::CreateR1<int32>({27, 67, 35, 402});
1470   xla::Literal expected1 = xla::LiteralUtil::CreateR1<int32>({26, 66, 34, 401});
1471   xla::Literal expected_literal =
1472       xla::LiteralUtil::MakeTuple({&expected0, &expected1});
1473   EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal));
1474 }
1475 
1476 // Tests a graph which has a function with an invalid op.
TEST_F(XlaCompilerTest,FunctionWithInvalidOp)1477 TEST_F(XlaCompilerTest, FunctionWithInvalidOp) {
1478   XlaCompiler compiler(DefaultOptions());
1479 
1480   FunctionDefLibrary flib;
1481   FunctionDef fn = FillFn();
1482   NodeDef* node = fn.add_node_def();
1483   node->set_name("Invalid");
1484   node->set_op("InvalidOp"); /* unsupported op */
1485   node = fn.add_node_def();
1486   node->set_name("Switch");
1487   node->set_op("Switch"); /* control flow node */
1488   *flib.add_function() = fn;
1489 
1490   TF_ASSERT_OK(flib_def_->AddFunctionDef(fn));
1491 
1492   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1493 
1494   Scope scope = Scope::NewRootScope().ExitOnError();
1495   auto value = ops::Const<int32>(scope.WithOpName("value"), 1, {});
1496   auto shape = ops::Const<int32>(scope.WithOpName("shape"), {5}, {1});
1497   TF_ASSERT_OK(scope.graph()->AddFunctionLibrary(flib));
1498 
1499   NodeDef def;
1500   TF_ASSERT_OK(NodeDefBuilder("fill_fn", "FillFn", flib_def_.get())
1501                    .Input(value.name(), 0, DT_INT32)
1502                    .Input(shape.name(), 1, DT_INT32)
1503                    .Finalize(&def));
1504   Status status;
1505   Node* fill = scope.graph()->AddNode(def, &status);
1506   TF_ASSERT_OK(status);
1507   TF_ASSERT_OK(scope.DoShapeInference(fill));
1508   scope.graph()->AddEdge(value.node(), 0, fill, 0);
1509   scope.graph()->AddEdge(shape.node(), 0, fill, 1);
1510 
1511   auto retval = ops::_Retval(scope.WithOpName("retval"), Output(fill), 0);
1512 
1513   TF_ASSERT_OK(scope.ToGraph(graph.get()));
1514 
1515   std::vector<XlaCompiler::Argument> args;
1516   XlaCompiler::CompilationResult result;
1517   status = compiler.CompileGraph(XlaCompiler::CompileOptions(), "fill",
1518                                  std::move(graph), args, &result);
1519   ASSERT_FALSE(status.ok());
1520   EXPECT_TRUE(absl::StrContains(status.error_message(), "InvalidOp"))
1521       << status.error_message();
1522   EXPECT_TRUE(absl::StrContains(status.error_message(), "{{node fill_fn}}"))
1523       << status.error_message();
1524 }
1525 
1526 // Tests a graph which has a node with invalid data type.
TEST_F(XlaCompilerTest,NodeWithInvalidDataType)1527 TEST_F(XlaCompilerTest, NodeWithInvalidDataType) {
1528   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1529   NodeDef shape;
1530   shape.set_name("Shape");
1531   shape.set_op("Shape");
1532   (*shape.mutable_attr())["T"].set_type(DT_INT32);
1533   (*shape.mutable_attr())["out_type"].set_type(DT_BOOL); /* invalid type */
1534   Status status;
1535   Node* shape_node = graph->AddNode(shape, &status);
1536   TF_ASSERT_OK(status);
1537   graph->AddControlEdge(graph->source_node(), shape_node);
1538 
1539   std::vector<XlaCompiler::Argument> args;
1540   XlaCompiler::CompilationResult result;
1541   XlaCompiler compiler(DefaultOptions());
1542   status = compiler.CompileGraph(XlaCompiler::CompileOptions(), "invalid_type",
1543                                  std::move(graph), args, &result);
1544   ASSERT_FALSE(status.ok());
1545   EXPECT_TRUE(absl::StrContains(status.error_message(),
1546                                 "is not in the list of allowed values"))
1547       << status.error_message();
1548   EXPECT_TRUE(absl::StrContains(status.error_message(), "{{node Shape}}"))
1549       << status.error_message();
1550 }
1551 
TEST_F(XlaCompilerTest,SingleOpWithoutInputs)1552 TEST_F(XlaCompilerTest, SingleOpWithoutInputs) {
1553   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1554   NodeDef no_op;
1555   no_op.set_name("NoOp");
1556   no_op.set_op("NoOp");
1557   Status status;
1558   graph->AddNode(no_op, &status);
1559   TF_ASSERT_OK(status);
1560 
1561   std::vector<XlaCompiler::Argument> args;
1562   XlaCompiler compiler(DefaultOptions());
1563   // No control edge linking NoOp with source/sink.
1564   {
1565     std::unique_ptr<Graph> graph_copy(new Graph(OpRegistry::Global()));
1566     CopyGraph(*graph, graph_copy.get());
1567     XlaCompiler::CompilationResult result;
1568     TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "NoOp",
1569                                        std::move(graph_copy), args, &result));
1570   }
1571 }
1572 
1573 class DummySideEffectingOp : public XlaOpKernel {
1574  public:
DummySideEffectingOp(OpKernelConstruction * ctx)1575   explicit DummySideEffectingOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
Compile(XlaOpKernelContext * ctx)1576   void Compile(XlaOpKernelContext* ctx) override {
1577     OP_REQUIRES_OK(ctx, ctx->compiler()->SetNodeToken(
1578                             name(), xla::CreateToken(ctx->builder())));
1579   }
1580 };
1581 
1582 REGISTER_OP("DummySideEffectingOp");
1583 
1584 REGISTER_XLA_OP(Name("DummySideEffectingOp"), DummySideEffectingOp);
1585 
TEST_F(XlaCompilerTest,TokenInputAndOutput)1586 TEST_F(XlaCompilerTest, TokenInputAndOutput) {
1587   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1588   NodeDef side_effecting_op;
1589   side_effecting_op.set_name("DummySideEffectingOp");
1590   side_effecting_op.set_op("DummySideEffectingOp");
1591   AddNodeAttr(kXlaTokenInputNodesAttrName,
1592               std::vector<string>{kXlaTokenArgNodeName}, &side_effecting_op);
1593   AddNodeAttr(kXlaOriginalOutsideCompilationNodeName, side_effecting_op.name(),
1594               &side_effecting_op);
1595   Status status;
1596   graph->AddNode(side_effecting_op, &status);
1597   TF_ASSERT_OK(status);
1598   EXPECT_TRUE(FixupSourceAndSinkEdges(graph.get()));
1599 
1600   std::vector<XlaCompiler::Argument> args(1);
1601   args[0].kind = XlaCompiler::Argument::kResource;
1602   args[0].resource_kind = XlaResource::kVariable;
1603   args[0].initialized = true;
1604   args[0].type = DT_INT32;
1605   args[0].shape = TensorShape({2, 2});
1606 
1607   {
1608     // The case for entry computation: we don't add token input/output. Instead,
1609     // we use CreateToken HLO to create the entry token.
1610     XlaCompiler::CompileOptions options;
1611     options.is_entry_computation = true;
1612     options.add_token_input_output = false;
1613     options.return_updated_values_for_all_resources = true;
1614     XlaCompiler compiler(DefaultOptions());
1615 
1616     std::unique_ptr<Graph> graph_copy(new Graph(OpRegistry::Global()));
1617     CopyGraph(*graph, graph_copy.get());
1618     XlaCompiler::CompilationResult result;
1619     TF_ASSERT_OK(compiler.CompileGraph(options, "NoOp", std::move(graph_copy),
1620                                        args, &result));
1621     EXPECT_EQ(result.xla_input_shapes.size(), 1);
1622     EXPECT_TRUE(result.xla_output_shape.IsTuple());
1623     EXPECT_EQ(xla::ShapeUtil::TupleElementCount(result.xla_output_shape), 1);
1624   }
1625   {
1626     // The case for non-entry computation (e.g. while loop body). We add token
1627     // input/output.
1628     XlaCompiler::CompileOptions options;
1629     options.is_entry_computation = false;
1630     options.add_token_input_output = true;
1631     options.return_updated_values_for_all_resources = true;
1632     XlaCompiler compiler(DefaultOptions());
1633 
1634     std::unique_ptr<Graph> graph_copy(new Graph(OpRegistry::Global()));
1635     CopyGraph(*graph, graph_copy.get());
1636     XlaCompiler::CompilationResult result;
1637     TF_ASSERT_OK(compiler.CompileGraph(options, "NoOp", std::move(graph_copy),
1638                                        args, &result));
1639     EXPECT_EQ(result.xla_input_shapes.size(), 2);
1640     EXPECT_TRUE(result.xla_input_shapes[1].IsToken());
1641     EXPECT_TRUE(result.xla_output_shape.IsTuple());
1642     EXPECT_EQ(xla::ShapeUtil::TupleElementCount(result.xla_output_shape), 2);
1643     EXPECT_TRUE(xla::ShapeUtil::GetTupleElementShape(result.xla_output_shape, 1)
1644                     .IsToken());
1645   }
1646 }
1647 
TEST_F(XlaCompilerTest,OpsWithTensorListInput)1648 TEST_F(XlaCompilerTest, OpsWithTensorListInput) {
1649   FunctionDefLibrary fdef_lib;
1650   FunctionLibraryDefinition flib_def(OpRegistry::Global(), fdef_lib);
1651   // Build cond fn for While.
1652   {
1653     Scope scope = Scope::NewRootScope().ExitOnError();
1654     std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1655     ops::_Arg(scope.WithOpName("arg"), DT_VARIANT, 0);
1656     auto result = ops::Const<bool>(scope, {true}, {});
1657     ops::_Retval(scope.WithOpName("ret"), result, 0);
1658     TF_ASSERT_OK(scope.ToGraph(graph.get()));
1659     FunctionDef fdef;
1660     TF_ASSERT_OK(GraphToFunctionDef(*graph, "cond", &fdef));
1661     TF_ASSERT_OK(flib_def.AddFunctionDef(fdef));
1662   }
1663   // Build body fn for While.
1664   {
1665     Scope scope = Scope::NewRootScope().ExitOnError();
1666     std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1667     auto arg = ops::_Arg(scope.WithOpName("arg"), DT_VARIANT, 0);
1668     ops::_Retval(scope.WithOpName("ret"), arg, 0);
1669     TF_ASSERT_OK(scope.ToGraph(graph.get()));
1670     FunctionDef fdef;
1671     TF_ASSERT_OK(GraphToFunctionDef(*graph, "body", &fdef));
1672     TF_ASSERT_OK(flib_def.AddFunctionDef(fdef));
1673   }
1674 
1675   Scope scope = Scope::NewRootScope().ExitOnError();
1676   auto element_shape = ops::Const<int32>(scope, {1}, {1});
1677   auto max_elements = ops::Const<int32>(scope, {10}, {});
1678   auto arg = ops::_Arg(scope.WithOpName("arg"), DT_VARIANT, 0);
1679   std::initializer_list<Output> out = {arg, arg};
1680   auto add_n = ops::AddN(scope, out);
1681   NameAttrList cond_fn, body_fn;
1682   cond_fn.set_name("cond");
1683   body_fn.set_name("body");
1684   auto while_op =
1685       ops::While(scope, std::initializer_list<Input>{arg}, cond_fn, body_fn);
1686   auto ret0 = ops::_Retval(scope.WithOpName("ret0"), add_n, 0);
1687   auto ret1 = ops::_Retval(scope.WithOpName("ret1"), while_op.output[0], 1);
1688   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1689   TF_ASSERT_OK(scope.ToGraph(graph.get()));
1690 
1691   // Builds a description of the arguments.
1692   std::vector<XlaCompiler::Argument> args(1);
1693   args[0].kind = XlaCompiler::Argument::kTensorList;
1694   xla::Shape tensor_list_element_shape;
1695   TF_ASSERT_OK(TensorShapeToXLAShape(DT_INT32, TensorShape{1},
1696                                      &tensor_list_element_shape));
1697   xla::Shape index_shape;
1698   TF_ASSERT_OK(TensorShapeToXLAShape(DT_INT32, TensorShape{}, &index_shape));
1699   std::vector<xla::Shape> shapes{tensor_list_element_shape, index_shape};
1700   xla::Shape arg_shape = xla::ShapeUtil::MakeTupleShape(shapes);
1701   args[0].shape = arg_shape;
1702 
1703   // Compiles the graph.
1704   XlaCompiler::Options options = DefaultOptions();
1705   options.flib_def = &flib_def;
1706   XlaCompiler compiler(options);
1707 
1708   XlaCompiler::CompilationResult result;
1709   TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add",
1710                                      std::move(graph), args, &result));
1711   ASSERT_EQ(result.outputs.size(), 2);
1712   const XlaCompiler::OutputDescription& output0 = result.outputs[0];
1713   ASSERT_TRUE(output0.is_tensor_list);
1714   const XlaCompiler::OutputDescription& output1 = result.outputs[1];
1715   ASSERT_TRUE(output1.is_tensor_list);
1716 }
1717 
1718 // Test the compiler supports WhileOp with a loop body where DT_RESOURCE
1719 // variables are both inputs and outputs.
TEST_F(XlaCompilerTest,WhileWithResources)1720 TEST_F(XlaCompilerTest, WhileWithResources) {
1721   FunctionDefLibrary fdef_lib;
1722   FunctionLibraryDefinition flib_def(OpRegistry::Global(), fdef_lib);
1723   // Build cond fn for While.
1724   {
1725     Scope scope = Scope::NewRootScope().ExitOnError();
1726     std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1727     auto arg0 = ops::_Arg(scope.WithOpName("arg0"), DT_INT32, 0);
1728     auto arg1 = ops::_Arg(scope.WithOpName("arg1"), DT_RESOURCE, 1);
1729     auto arg2 = ops::_Arg(scope.WithOpName("arg2"), DT_RESOURCE, 2);
1730     auto less = ops::Less(scope, arg0, ops::Const<int32>(scope, 10));
1731     (void)ops::_Retval(scope.WithOpName("ret"), less, 0);
1732     TF_ASSERT_OK(scope.ToGraph(graph.get()));
1733     FunctionDef fdef;
1734     TF_ASSERT_OK(GraphToFunctionDef(*graph, "cond", &fdef));
1735     TF_ASSERT_OK(flib_def.AddFunctionDef(fdef));
1736   }
1737   // Build body fn for While.
1738   {
1739     Scope scope = Scope::NewRootScope().ExitOnError();
1740     std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1741     auto arg0 = ops::_Arg(scope.WithOpName("arg0"), DT_INT32, 0);
1742     auto arg1 = ops::_Arg(scope.WithOpName("arg1"), DT_RESOURCE, 1);
1743     auto arg2 = ops::_Arg(scope.WithOpName("arg2"), DT_RESOURCE, 2);
1744     auto read1 = ops::ReadVariableOp(scope.WithOpName("read1"), arg1, DT_INT32);
1745     auto plus_read1 = ops::Add(scope, arg0, read1);
1746     auto read2 = ops::ReadVariableOp(scope.WithOpName("read2"), arg2, DT_INT32);
1747     auto minus_read2 = ops::Sub(scope, plus_read1, read2);
1748     (void)ops::_Retval(scope.WithOpName("ret0"), minus_read2, 0);
1749     (void)ops::_Retval(scope.WithOpName("ret1"), arg1, 1);
1750     (void)ops::_Retval(scope.WithOpName("ret2"), arg2, 2);
1751     TF_ASSERT_OK(scope.ToGraph(graph.get()));
1752     FunctionDef fdef;
1753     TF_ASSERT_OK(GraphToFunctionDef(*graph, "body", &fdef));
1754     TF_ASSERT_OK(flib_def.AddFunctionDef(fdef));
1755   }
1756 
1757   Scope scope = Scope::NewRootScope().ExitOnError();
1758   auto arg0 = ops::_Arg(scope.WithOpName("arg0"), DT_INT32, 0);
1759   auto arg1 = ops::_Arg(scope.WithOpName("arg1"), DT_RESOURCE, 1);
1760   auto arg2 = ops::_Arg(scope.WithOpName("arg2"), DT_RESOURCE, 2);
1761 
1762   NameAttrList cond_fn, body_fn;
1763   cond_fn.set_name("cond");
1764   body_fn.set_name("body");
1765   auto while_op = ops::While(
1766       scope, std::initializer_list<Input>{arg0, arg1, arg2}, cond_fn, body_fn);
1767 
1768   (void)ops::_Retval(scope.WithOpName("ret0"), while_op.output[0], 0);
1769   (void)ops::_Retval(scope.WithOpName("ret1"), while_op.output[1], 1);
1770   (void)ops::_Retval(scope.WithOpName("ret2"), while_op.output[2], 2);
1771   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1772   TF_ASSERT_OK(scope.ToGraph(graph.get()));
1773 
1774   // Builds a description of the arguments.
1775   std::vector<XlaCompiler::Argument> args(3);
1776   args[0].kind = XlaCompiler::Argument::kParameter;
1777   args[0].type = DT_INT32;
1778   args[0].shape = TensorShape({});
1779   args[1].kind = XlaCompiler::Argument::kResource;
1780   args[1].resource_kind = XlaResource::kVariable;
1781   args[1].initialized = true;
1782   args[1].type = DT_INT32;
1783   args[1].shape = TensorShape({});
1784   args[2].kind = XlaCompiler::Argument::kResource;
1785   args[2].resource_kind = XlaResource::kVariable;
1786   args[2].initialized = true;
1787   args[2].type = DT_INT32;
1788   args[2].shape = TensorShape({});
1789 
1790   // Compiles the graph.
1791   XlaCompiler::Options options = DefaultOptions();
1792   options.flib_def = &flib_def;
1793   XlaCompiler compiler(options);
1794 
1795   XlaCompiler::CompileOptions compile_options = XlaCompiler::CompileOptions();
1796   compile_options.return_updated_values_for_all_resources = true;
1797   XlaCompiler::CompilationResult result;
1798   TF_ASSERT_OK(compiler.CompileGraph(compile_options, "tested_while_with_vars",
1799                                      std::move(graph), args, &result));
1800   ASSERT_EQ(result.outputs.size(), 3);
1801   const XlaCompiler::OutputDescription& output1 = result.outputs[1];
1802   ASSERT_EQ(output1.input_index, 1);
1803   const XlaCompiler::OutputDescription& output2 = result.outputs[2];
1804   ASSERT_EQ(output2.input_index, 2);
1805 
1806   // Tests that the generated computation works.
1807   xla::Literal literal0 = xla::LiteralUtil::CreateR0<int32>(0);
1808   xla::Literal literal1 = xla::LiteralUtil::CreateR0<int32>(2);
1809   xla::Literal literal2 = xla::LiteralUtil::CreateR0<int32>(1);
1810   std::unique_ptr<xla::GlobalData> data0 =
1811       client_->TransferToServer(literal0).value();
1812   std::unique_ptr<xla::GlobalData> data1 =
1813       client_->TransferToServer(literal1).value();
1814   std::unique_ptr<xla::GlobalData> data2 =
1815       client_->TransferToServer(literal2).value();
1816 
1817   std::unique_ptr<xla::GlobalData> actual =
1818       client_
1819           ->Execute(*result.computation,
1820                     {data0.get(), data1.get(), data2.get()})
1821           .value();
1822   xla::Literal actual_literal = client_->Transfer(*actual).value();
1823 
1824   xla::Literal expected0 = xla::LiteralUtil::CreateR0<int32>(10);
1825   xla::Literal expected1 = xla::LiteralUtil::CreateR0<int32>(2);
1826   xla::Literal expected2 = xla::LiteralUtil::CreateR0<int32>(1);
1827   xla::Literal expected_literal =
1828       xla::LiteralUtil::MakeTuple({&expected0, &expected1, &expected2});
1829   EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal));
1830 }
1831 
TEST_F(XlaCompilerTest,SetShardingForReturnedTuple)1832 TEST_F(XlaCompilerTest, SetShardingForReturnedTuple) {
1833   // Builds a graph that returns its only argument.
1834   Scope scope = Scope::NewRootScope().ExitOnError();
1835   auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0);
1836   auto b = ops::_Retval(scope.WithOpName("B"), a, 0);
1837   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1838   TF_ASSERT_OK(scope.ToGraph(graph.get()));
1839 
1840   // Sets _XlaSharding attribute for the _Retval node.
1841   auto node_name_index = graph->BuildNodeNameIndex();
1842   Node* ret_node = node_name_index["B"];
1843   ASSERT_NE(ret_node, nullptr);
1844   xla::Array<int64_t> tile_assignment({2});
1845   tile_assignment.FillIota(0);
1846   xla::HloSharding sharding = xla::HloSharding::Tile(tile_assignment);
1847   ret_node->AddAttr("_XlaSharding", sharding.ToProto().SerializeAsString());
1848 
1849   // Builds a description of the arguments.
1850   std::vector<XlaCompiler::Argument> args(1);
1851   args[0].kind = XlaCompiler::Argument::kParameter;
1852   args[0].type = DT_INT32;
1853   args[0].shape = TensorShape({2});
1854 
1855   // Compiles the graph.
1856   XlaCompiler compiler(DefaultOptions());
1857 
1858   XlaCompiler::CompilationResult result;
1859   TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "test",
1860                                      std::move(graph), args, &result));
1861 
1862   // Tests that we set sharding on the root TUPLE instruction.
1863   const auto& hlo_module_proto = result.computation->proto();
1864   ASSERT_EQ(hlo_module_proto.computations_size(), 1);
1865   const auto& hlo_computation_proto = hlo_module_proto.computations(0);
1866   std::optional<xla::HloInstructionProto> root_instruction_proto;
1867   for (const auto& inst : hlo_computation_proto.instructions()) {
1868     if (inst.id() == hlo_computation_proto.root_id()) {
1869       root_instruction_proto = inst;
1870       break;
1871     }
1872   }
1873   ASSERT_TRUE(root_instruction_proto);
1874   xla::Shape tuple_shape = xla::ShapeUtil::MakeTupleShape(
1875       {xla::ShapeUtil::MakeShape(xla::S32, {2})});
1876   xla::HloSharding tuple_sharding = xla::HloSharding::Tuple(
1877       tuple_shape, std::vector<xla::HloSharding>{sharding});
1878   EXPECT_EQ(root_instruction_proto->sharding().SerializeAsString(),
1879             tuple_sharding.ToProto().SerializeAsString());
1880 }
1881 
TEST_F(XlaCompilerTest,AliasResourceUpdates)1882 TEST_F(XlaCompilerTest, AliasResourceUpdates) {
1883   Scope scope = Scope::NewRootScope().ExitOnError();
1884   auto a = ops::Const<int32>(scope.WithOpName("A"), {1, 2});
1885   auto var = ops::_Arg(scope.WithOpName("V"), DT_RESOURCE, 1);
1886   auto write = ops::AssignAddVariableOp(scope, var, a);
1887   auto read = ops::ReadVariableOp(
1888       scope.WithControlDependencies(std::vector<Operation>{write}), var,
1889       DT_INT32);
1890   auto d = ops::_Retval(scope.WithOpName("D"), read, 0);
1891   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1892   TF_ASSERT_OK(scope.ToGraph(graph.get()));
1893 
1894   // Builds a description of the arguments.
1895   std::vector<XlaCompiler::Argument> args(2);
1896   args[0].kind = XlaCompiler::Argument::kConstant;
1897   args[0].type = DT_INT32;
1898   args[0].shape = TensorShape({2});
1899   args[0].constant_value = Tensor(DT_INT32, {1, 1});
1900   args[0].initialized = true;
1901 
1902   args[1].kind = XlaCompiler::Argument::kResource;
1903   args[1].resource_kind = XlaResource::kVariable;
1904   args[1].initialized = true;
1905   args[1].type = DT_INT32;
1906   args[1].shape = TensorShape({2});
1907 
1908   XlaCompiler compiler(DefaultOptions());
1909 
1910   XlaCompiler::CompileOptions compile_options;
1911   compile_options.alias_resource_update = true;
1912 
1913   XlaCompiler::CompilationResult result;
1914   TF_ASSERT_OK(compiler.CompileGraph(compile_options, "add", std::move(graph),
1915                                      args, &result));
1916 
1917   const xla::HloInputOutputAliasProto& alias =
1918       result.computation->proto().input_output_alias();
1919   EXPECT_EQ(alias.entries_size(), 1);
1920   EXPECT_EQ(alias.entries(0).parameter_number(), 0);
1921 }
1922 
1923 // Tests that passing in an exact duplicate input to SetDeviceToHostMeatadata
1924 // is not an error.
TEST_F(XlaCompilerTest,SetDeviceToHostMetadataExactDuplicate)1925 TEST_F(XlaCompilerTest, SetDeviceToHostMetadataExactDuplicate) {
1926   XlaCompiler compiler(DefaultOptions());
1927 
1928   const string& key = "comm_key";
1929   std::vector<DataType> types{DT_INT32};
1930   std::vector<TensorShape> shapes{TensorShape({2})};
1931 
1932   TF_ASSERT_OK(compiler.SetDeviceToHostMetadata(key, types, shapes));
1933   TF_ASSERT_OK(compiler.SetDeviceToHostMetadata(key, types, shapes));
1934 }
1935 
1936 // Tests that passing in a mismatched duplicate input to
1937 // SetDeviceToHostMeatadata is not an error.
TEST_F(XlaCompilerTest,SetDeviceToHostMetadataMismatchedDuplicate)1938 TEST_F(XlaCompilerTest, SetDeviceToHostMetadataMismatchedDuplicate) {
1939   XlaCompiler compiler(DefaultOptions());
1940 
1941   const string& key = "comm_key";
1942   std::vector<DataType> types{DT_INT32};
1943   std::vector<TensorShape> shapes{TensorShape({2})};
1944   std::vector<DataType> types2{DT_FLOAT};
1945   std::vector<TensorShape> shapes2{TensorShape({1})};
1946 
1947   TF_ASSERT_OK(compiler.SetDeviceToHostMetadata(key, types, shapes));
1948   Status status = compiler.SetDeviceToHostMetadata(key, types2, shapes2);
1949   EXPECT_EQ(status.code(), error::Code::INVALID_ARGUMENT);
1950 }
1951 
1952 // Tests that passing in an exact duplicate input to SetHostToDeviceMeatadata
1953 // is not an error.
TEST_F(XlaCompilerTest,SetHostToDeviceMetadataExactDuplicate)1954 TEST_F(XlaCompilerTest, SetHostToDeviceMetadataExactDuplicate) {
1955   XlaCompiler compiler(DefaultOptions());
1956 
1957   const string& key = "comm_key";
1958   std::vector<DataType> types{DT_INT32};
1959   std::vector<TensorShape> shapes{TensorShape({2})};
1960 
1961   TF_ASSERT_OK(compiler.SetHostToDeviceMetadata(key, types, shapes));
1962   TF_ASSERT_OK(compiler.SetHostToDeviceMetadata(key, types, shapes));
1963 }
1964 
1965 // Tests that passing in a mismatched duplicate input to
1966 // SetHostToDeviceMeatadata is not an error.
TEST_F(XlaCompilerTest,SetHostToDeviceMetadataMismatchedDuplicate)1967 TEST_F(XlaCompilerTest, SetHostToDeviceMetadataMismatchedDuplicate) {
1968   XlaCompiler compiler(DefaultOptions());
1969 
1970   const string& key = "comm_key";
1971   std::vector<DataType> types{DT_INT32};
1972   std::vector<TensorShape> shapes{TensorShape({2})};
1973   std::vector<DataType> types2{DT_FLOAT};
1974   std::vector<TensorShape> shapes2{TensorShape({1})};
1975 
1976   TF_ASSERT_OK(compiler.SetHostToDeviceMetadata(key, types, shapes));
1977   Status status = compiler.SetHostToDeviceMetadata(key, types2, shapes2);
1978   EXPECT_EQ(status.code(), error::Code::INVALID_ARGUMENT);
1979 }
1980 
1981 }  // namespace
1982 }  // namespace tensorflow
1983