1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/tf2xla/xla_compiler.h"
17
18 #include "absl/strings/match.h"
19 #include "tensorflow/cc/framework/ops.h"
20 #include "tensorflow/cc/ops/const_op.h"
21 #include "tensorflow/cc/ops/data_flow_ops.h"
22 #include "tensorflow/cc/ops/function_ops.h"
23 #include "tensorflow/cc/ops/functional_ops.h"
24 #include "tensorflow/cc/ops/list_ops.h"
25 #include "tensorflow/cc/ops/math_ops.h"
26 #include "tensorflow/cc/ops/resource_variable_ops.h"
27 #include "tensorflow/cc/ops/standard_ops.h"
28 #include "tensorflow/compiler/tf2xla/shape_util.h"
29 #include "tensorflow/compiler/tf2xla/side_effect_util.h"
30 #include "tensorflow/compiler/tf2xla/type_util.h"
31 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
32 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
33 #include "tensorflow/compiler/xla/client/client_library.h"
34 #include "tensorflow/compiler/xla/client/local_client.h"
35 #include "tensorflow/compiler/xla/client/xla_builder.h"
36 #include "tensorflow/compiler/xla/literal.h"
37 #include "tensorflow/compiler/xla/service/hlo.pb.h"
38 #include "tensorflow/compiler/xla/shape_util.h"
39 #include "tensorflow/compiler/xla/status_macros.h"
40 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
41 #include "tensorflow/core/common_runtime/function.h"
42 #include "tensorflow/core/common_runtime/graph_constructor.h"
43 #include "tensorflow/core/framework/common_shape_fns.h"
44 #include "tensorflow/core/framework/function.h"
45 #include "tensorflow/core/framework/function.pb.h"
46 #include "tensorflow/core/framework/function_testlib.h"
47 #include "tensorflow/core/framework/graph_to_functiondef.h"
48 #include "tensorflow/core/framework/node_def_util.h"
49 #include "tensorflow/core/framework/resource_mgr.h"
50 #include "tensorflow/core/framework/tensor.h"
51 #include "tensorflow/core/framework/tensor_testutil.h"
52 #include "tensorflow/core/framework/types.pb.h"
53 #include "tensorflow/core/graph/algorithm.h"
54 #include "tensorflow/core/graph/graph.h"
55 #include "tensorflow/core/lib/core/status_test_util.h"
56 #include "tensorflow/core/platform/test.h"
57 #include "tensorflow/core/public/version.h"
58
59 namespace tensorflow {
60
61 class XlaCompilerTest : public ::testing::Test {
62 protected:
SetUp()63 void SetUp() override {
64 client_ = xla::ClientLibrary::LocalClientOrDie();
65
66 XlaOpRegistry::RegisterCompilationKernels();
67
68 FunctionDefLibrary flib;
69 flib_def_.reset(new FunctionLibraryDefinition(OpRegistry::Global(), flib));
70 }
71
DefaultOptions()72 XlaCompiler::Options DefaultOptions() {
73 XlaCompiler::Options options;
74 options.device_type = DeviceType(DEVICE_CPU_XLA_JIT);
75 options.client = client_;
76 options.flib_def = flib_def_.get();
77 return options;
78 }
79
LocalFlibDef(XlaCompiler * compiler)80 FunctionLibraryDefinition* LocalFlibDef(XlaCompiler* compiler) {
81 return compiler->local_flib_def_.get();
82 }
83
84 xla::Client* client_;
85 std::unique_ptr<FunctionLibraryDefinition> flib_def_;
86 };
87
88 namespace {
89
90 // Helper class to test the ability to pass resources through to XLA
91 // compiled kernels.
92 class DummyResourceForTest : public ResourceBase {
93 public:
DebugString() const94 string DebugString() const override { return "dummy"; }
Increment()95 void Increment() { ++value_; }
Get()96 int Get() { return value_; }
97
98 private:
99 int value_ = 0;
100 };
101
102 class DummyReadResourceOp : public XlaOpKernel {
103 public:
DummyReadResourceOp(OpKernelConstruction * ctx)104 explicit DummyReadResourceOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
Compile(XlaOpKernelContext * ctx)105 void Compile(XlaOpKernelContext* ctx) override {
106 ResourceMgr* rm = ctx->op_kernel_context()->resource_manager();
107 OP_REQUIRES(ctx, rm, errors::Internal("No resource manager."));
108 DummyResourceForTest* dummy;
109 OP_REQUIRES_OK(ctx, rm->Lookup<DummyResourceForTest>(
110 rm->default_container(), "dummy", &dummy));
111 dummy->Increment();
112 dummy->Unref();
113
114 ctx->SetOutput(0, ctx->Input(0));
115 ctx->SetOutput(1, ctx->Input(0));
116 }
117 };
118
119 class DummyReadResourceCC {
120 public:
DummyReadResourceCC(const Scope & scope,const Input & value)121 DummyReadResourceCC(const Scope& scope, const Input& value) {
122 if (!scope.ok()) return;
123 auto _value = ops::AsNodeOut(scope, value);
124 if (!scope.ok()) return;
125 Node* ret;
126 const auto unique_name = scope.GetUniqueNameForOp("DummyReadResource");
127 auto builder = NodeBuilder(unique_name, "DummyReadResource").Input(_value);
128 scope.UpdateBuilder(&builder);
129 scope.UpdateStatus(builder.Finalize(scope.graph(), &ret));
130 if (!scope.ok()) return;
131 scope.UpdateStatus(scope.DoShapeInference(ret));
132 if (!scope.ok()) return;
133 this->output1_ = Output(ret, 0);
134 this->output2_ = Output(ret, 1);
135 }
136
137 Output output1_;
138 Output output2_;
139 };
140
141 REGISTER_OP("DummyReadResource")
142 .Input("input: int32")
143 .Output("output1: int32")
144 .Output("output2: int32")
145 .SetShapeFn(shape_inference::UnknownShape)
146 .Doc(R"doc(
147 A dummy Op.
148
149 input: dummy input.
150 output1: dummy output.
151 output2: dummy output.
152 )doc");
153
154 REGISTER_XLA_OP(Name("DummyReadResource"), DummyReadResourceOp);
155
156 // DummyDuplicateOp is present purely to test multiple REGISTER_XLA_OP calls
157 // on the same Op name below.
158 class DummyDuplicateOp : public XlaOpKernel {
159 public:
DummyDuplicateOp(OpKernelConstruction * ctx)160 explicit DummyDuplicateOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
Compile(XlaOpKernelContext * ctx)161 void Compile(XlaOpKernelContext* ctx) override {
162 ctx->SetOutput(0, ctx->Input(0));
163 }
164 };
165
166 REGISTER_OP("DummyDuplicateOp")
167 .Input("input: int32")
168 .Output("output: int32")
169 .Doc(R"doc(
170 A dummy Op.
171
172 input: dummy input.
173 output: dummy output.
174 )doc");
175
176 REGISTER_XLA_OP(Name("DummyDuplicateOp").Device(DEVICE_CPU_XLA_JIT),
177 DummyDuplicateOp);
178 REGISTER_XLA_OP(Name("DummyDuplicateOp").Device(DEVICE_GPU_XLA_JIT),
179 DummyDuplicateOp);
180
181 // Tests compilation and execution of an empty graph.
TEST_F(XlaCompilerTest,EmptyReturnValues)182 TEST_F(XlaCompilerTest, EmptyReturnValues) {
183 XlaCompiler compiler(DefaultOptions());
184
185 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
186 XlaCompiler::CompilationResult result;
187 TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add",
188 std::move(graph),
189 /*args=*/{}, &result));
190
191 TF_ASSERT_OK(client_->Execute(*result.computation, {}).status());
192 }
193
194 // Tests compilation and execution of a graph that adds two tensors.
TEST_F(XlaCompilerTest,Simple)195 TEST_F(XlaCompilerTest, Simple) {
196 // Builds a graph that adds two Tensors.
197 Scope scope = Scope::NewRootScope().ExitOnError();
198 auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0);
199 auto b = ops::_Arg(scope.WithOpName("B"), DT_INT32, 1);
200 auto c = ops::Add(scope.WithOpName("C"), a, b);
201 auto d = ops::_Retval(scope.WithOpName("D"), c, 0);
202 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
203 TF_ASSERT_OK(scope.ToGraph(graph.get()));
204
205 // Builds a description of the arguments.
206 std::vector<XlaCompiler::Argument> args(2);
207 args[0].kind = XlaCompiler::Argument::kParameter;
208 args[0].type = DT_INT32;
209 args[0].shape = TensorShape({2});
210 args[1].kind = XlaCompiler::Argument::kParameter;
211 args[1].type = DT_INT32;
212 args[1].shape = TensorShape({2});
213
214 // Compiles the graph.
215 XlaCompiler compiler(DefaultOptions());
216
217 XlaCompiler::CompilationResult result;
218 TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add",
219 std::move(graph), args, &result));
220
221 // Tests that the generated computation works.
222 xla::Literal param0_literal = xla::LiteralUtil::CreateR1<int32>({7, 42});
223 xla::Literal param1_literal = xla::LiteralUtil::CreateR1<int32>({-3, 101});
224 std::unique_ptr<xla::GlobalData> param0_data =
225 client_->TransferToServer(param0_literal).value();
226 std::unique_ptr<xla::GlobalData> param1_data =
227 client_->TransferToServer(param1_literal).value();
228
229 std::unique_ptr<xla::GlobalData> actual =
230 client_
231 ->Execute(*result.computation, {param0_data.get(), param1_data.get()})
232 .value();
233 xla::Literal actual_literal = client_->Transfer(*actual).value();
234
235 xla::Literal expected0 = xla::LiteralUtil::CreateR1<int32>({4, 143});
236 xla::Literal expected_literal = xla::LiteralUtil::MakeTuple({&expected0});
237 EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal));
238 }
239
240 // Tests compilation of a graph where the _Retval node is not necessarily last
241 // amongst the graph nodes in construction order, and always_return_tuple is
242 // false. Regression test for bug where the wrong value was returned.
TEST_F(XlaCompilerTest,OutOfOrderGraph)243 TEST_F(XlaCompilerTest, OutOfOrderGraph) {
244 Scope scope = Scope::NewRootScope().ExitOnError();
245 auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0);
246 auto b = ops::_Arg(scope.WithOpName("B"), DT_INT32, 1);
247 // The _Retval node is not last in construction order.
248 auto d = ops::_Retval(scope.WithOpName("D"), a, 0);
249 auto c = ops::Add(scope.WithOpName("C"), a, b);
250
251 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
252 TF_ASSERT_OK(scope.ToGraph(graph.get()));
253
254 // Builds a description of the arguments.
255 std::vector<XlaCompiler::Argument> args(2);
256 args[0].kind = XlaCompiler::Argument::kParameter;
257 args[0].type = DT_INT32;
258 args[0].shape = TensorShape({2});
259 args[1].kind = XlaCompiler::Argument::kParameter;
260 args[1].type = DT_INT32;
261 args[1].shape = TensorShape({2});
262
263 // Compiles the graph.
264 XlaCompiler compiler(DefaultOptions());
265
266 XlaCompiler::CompileOptions compile_options;
267 compile_options.always_return_tuple = false;
268 XlaCompiler::CompilationResult result;
269 TF_ASSERT_OK(compiler.CompileGraph(compile_options, "add", std::move(graph),
270 args, &result));
271
272 // Tests that the generated computation works.
273 xla::Literal param0_literal = xla::LiteralUtil::CreateR1<int32>({7, 42});
274 xla::Literal param1_literal = xla::LiteralUtil::CreateR1<int32>({-3, 101});
275 std::unique_ptr<xla::GlobalData> param0_data =
276 client_->TransferToServer(param0_literal).value();
277 std::unique_ptr<xla::GlobalData> param1_data =
278 client_->TransferToServer(param1_literal).value();
279
280 std::unique_ptr<xla::GlobalData> actual =
281 client_
282 ->Execute(*result.computation, {param0_data.get(), param1_data.get()})
283 .value();
284 xla::Literal actual_literal = client_->Transfer(*actual).value();
285
286 EXPECT_TRUE(xla::LiteralTestUtil::Equal(param0_literal, actual_literal));
287 }
288
289 // Tests that the compiler can correctly propagate the layout assigned by
290 // shape_representation_fn_ to resource returns that have not been written to.
TEST_F(XlaCompilerTest,HonorShapeRepresentationFnForUnwrittenResource)291 TEST_F(XlaCompilerTest, HonorShapeRepresentationFnForUnwrittenResource) {
292 Scope scope = Scope::NewRootScope().ExitOnError();
293 auto var = ops::_Arg(scope.WithOpName("V"), DT_RESOURCE, 0);
294 auto d = ops::_Retval(scope.WithOpName("D"), var, 0);
295 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
296 TF_ASSERT_OK(scope.ToGraph(graph.get()));
297
298 // Builds a description of the arguments.
299 std::vector<XlaCompiler::Argument> args(1);
300 args[0].kind = XlaCompiler::Argument::kResource;
301 args[0].resource_kind = XlaResource::kVariable;
302 args[0].initialized = true;
303 args[0].type = DT_INT32;
304 args[0].shape = TensorShape({2, 3});
305
306 auto options = DefaultOptions();
307 XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_fns;
308 shape_determination_fns.shape_representation_fn =
309 [](const TensorShape& shape, DataType dt, bool use_fast_memory,
310 XlaLayoutPreference layout_preference) -> StatusOr<xla::Shape> {
311 xla::Shape xla_shape;
312 TF_RETURN_IF_ERROR(TensorShapeToXLAShape(dt, shape, &xla_shape));
313 *xla_shape.mutable_layout() = xla::LayoutUtil::MakeLayout({0, 1});
314 return xla_shape;
315 };
316 options.shape_determination_fns = shape_determination_fns;
317 // Compiles the graph.
318 XlaCompiler compiler(options);
319
320 XlaCompiler::CompilationResult result;
321 XlaCompiler::CompileOptions compile_options;
322 compile_options.return_updated_values_for_all_resources = true;
323 TF_ASSERT_OK(compiler.CompileGraph(compile_options, "add", std::move(graph),
324 args, &result));
325 xla::Shape transposed =
326 xla::ShapeUtil::MakeShapeWithLayout(xla::S32, {2, 3}, {0, 1});
327 // Check that the return shapes are correctly tranposed.
328 EXPECT_EQ(result.xla_output_shape,
329 xla::ShapeUtil::MakeTupleShape({transposed}));
330 }
331
332 // Tests that the compiler can correctly propagate fast mem attribute for input
333 // resource variable.
TEST_F(XlaCompilerTest,HonorShapeRepresentationFnForFastMemVar)334 TEST_F(XlaCompilerTest, HonorShapeRepresentationFnForFastMemVar) {
335 Scope scope = Scope::NewRootScope().ExitOnError();
336 auto var = ops::_Arg(scope.WithOpName("V"), DT_RESOURCE, 0);
337 auto d = ops::_Retval(scope.WithOpName("D"), var, 0);
338 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
339 TF_ASSERT_OK(scope.ToGraph(graph.get()));
340
341 // Builds a description of the arguments.
342 std::vector<XlaCompiler::Argument> args(1);
343 args[0].kind = XlaCompiler::Argument::kResource;
344 args[0].resource_kind = XlaResource::kVariable;
345 args[0].initialized = true;
346 args[0].type = DT_INT32;
347 args[0].shape = TensorShape({2, 3});
348 args[0].fast_mem = true;
349
350 auto options = DefaultOptions();
351 int fast_mem_arg_count = 0;
352 XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_fns;
353 shape_determination_fns.shape_representation_fn =
354 [&fast_mem_arg_count](
355 const TensorShape& shape, DataType dt, bool use_fast_memory,
356 XlaLayoutPreference layout_preference) -> StatusOr<xla::Shape> {
357 xla::Shape xla_shape;
358 TF_RETURN_IF_ERROR(TensorShapeToXLAShape(dt, shape, &xla_shape));
359 *xla_shape.mutable_layout() = xla::LayoutUtil::MakeLayout({0, 1});
360 if (use_fast_memory) {
361 fast_mem_arg_count++;
362 }
363 return xla_shape;
364 };
365 options.shape_determination_fns = shape_determination_fns;
366 // Compiles the graph.
367 XlaCompiler compiler(options);
368
369 XlaCompiler::CompilationResult result;
370 XlaCompiler::CompileOptions compile_options;
371 compile_options.return_updated_values_for_all_resources = true;
372 TF_ASSERT_OK(compiler.CompileGraph(compile_options, "add", std::move(graph),
373 args, &result));
374 // Count 2: one for argument, one for the return value.
375 EXPECT_EQ(fast_mem_arg_count, 2);
376 }
377
378 // Tests that the compiler can correctly propagate the layout assigned by
379 // shape_representation_fn_ to return types.
TEST_F(XlaCompilerTest,HonorShapeRepresentationFnForRetVal)380 TEST_F(XlaCompilerTest, HonorShapeRepresentationFnForRetVal) {
381 Scope scope = Scope::NewRootScope().ExitOnError();
382 auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0);
383 auto var = ops::_Arg(scope.WithOpName("V"), DT_RESOURCE, 1);
384 // Adds an identity op around the resource to make sure identity ops propagate
385 // resources correctly.
386 auto identity = ops::Identity(scope.WithOpName("VIdentity"), var);
387 auto write = ops::AssignAddVariableOp(scope, identity, a);
388 auto read = ops::ReadVariableOp(
389 scope.WithControlDependencies(std::vector<Operation>{write}), var,
390 DT_INT32);
391 auto read_plus_one = ops::Add(scope, read, ops::Const<int32>(scope, 1));
392 auto d = ops::_Retval(scope.WithOpName("D"), read_plus_one, 0);
393 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
394 TF_ASSERT_OK(scope.ToGraph(graph.get()));
395
396 // Builds a description of the arguments.
397 std::vector<XlaCompiler::Argument> args(2);
398 args[0].kind = XlaCompiler::Argument::kParameter;
399 args[0].type = DT_INT32;
400 args[0].shape = TensorShape({2, 3});
401 args[1].kind = XlaCompiler::Argument::kResource;
402 args[1].resource_kind = XlaResource::kVariable;
403 args[1].initialized = true;
404 args[1].type = DT_INT32;
405 args[1].shape = TensorShape({2, 3});
406
407 auto options = DefaultOptions();
408 XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_fns;
409 shape_determination_fns.shape_representation_fn =
410 [](const TensorShape& shape, DataType dt, bool use_fast_memory,
411 XlaLayoutPreference layout_preference) -> StatusOr<xla::Shape> {
412 xla::Shape xla_shape;
413 TF_RETURN_IF_ERROR(TensorShapeToXLAShape(dt, shape, &xla_shape));
414 *xla_shape.mutable_layout() = xla::LayoutUtil::MakeLayout({0, 1});
415 return xla_shape;
416 };
417 options.shape_determination_fns = shape_determination_fns;
418 // Compiles the graph.
419 XlaCompiler compiler(options);
420
421 XlaCompiler::CompilationResult result;
422 TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add",
423 std::move(graph), args, &result));
424 xla::Shape transposed =
425 xla::ShapeUtil::MakeShapeWithLayout(xla::S32, {2, 3}, {0, 1});
426 // Check that the return shapes are correctly tranposed.
427 EXPECT_EQ(result.xla_output_shape,
428 xla::ShapeUtil::MakeTupleShape({transposed, transposed}));
429 EXPECT_EQ(result.computation->GetProgramShape().value().result(),
430 xla::ShapeUtil::MakeTupleShape({transposed, transposed}));
431 }
432
433 // The layout of resource variable shouldn't change after transpose
TEST_F(XlaCompilerTest,TransposeVariables)434 TEST_F(XlaCompilerTest, TransposeVariables) {
435 Scope scope = Scope::NewRootScope().ExitOnError();
436 auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0);
437 auto var = ops::_Arg(scope.WithOpName("V"), DT_RESOURCE, 1);
438 // Adds an identity op around the resource to make sure identity ops propagate
439 // resources correctly.
440 auto identity = ops::Identity(scope.WithOpName("VIdentity"), var);
441 auto write = ops::AssignAddVariableOp(scope, identity, a);
442 auto read = ops::ReadVariableOp(
443 scope.WithControlDependencies(std::vector<Operation>{write}), var,
444 DT_INT32);
445 auto transposed_read = ops::Transpose(scope, read, {1, 0});
446 auto reshape = ops::Reshape(scope, transposed_read, {2, 3});
447 auto d = ops::_Retval(scope.WithOpName("D"), reshape, 0);
448 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
449 TF_ASSERT_OK(scope.ToGraph(graph.get()));
450
451 // Builds a description of the arguments.
452 std::vector<XlaCompiler::Argument> args(2);
453 args[0].kind = XlaCompiler::Argument::kParameter;
454 args[0].type = DT_INT32;
455 args[0].shape = TensorShape({2, 3});
456 args[1].kind = XlaCompiler::Argument::kResource;
457 args[1].resource_kind = XlaResource::kVariable;
458 args[1].initialized = true;
459 args[1].type = DT_INT32;
460 args[1].shape = TensorShape({2, 3});
461 // Compiles the graph.
462 XlaCompiler compiler(DefaultOptions());
463
464 XlaCompiler::CompilationResult result;
465 TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "transpose",
466 std::move(graph), args, &result));
467 xla::Shape transposed =
468 xla::ShapeUtil::MakeShapeWithLayout(xla::S32, {2, 3}, {1, 0});
469 // Check that the return shapes are correctly tranposed.
470 EXPECT_EQ(result.xla_output_shape,
471 xla::ShapeUtil::MakeTupleShape({transposed, transposed}));
472 }
473
474 // Unranked fake param returns a 0 shaped tensor.
TEST_F(XlaCompilerTest,UnrankedFakeParam)475 TEST_F(XlaCompilerTest, UnrankedFakeParam) {
476 Scope scope = Scope::NewRootScope().ExitOnError();
477 PartialTensorShape shape;
478 auto a = ops::FakeParam(scope, DT_INT32, shape);
479 auto ret = ops::_Retval(scope.WithOpName("D"), a, 0);
480 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
481 TF_ASSERT_OK(scope.ToGraph(graph.get()));
482
483 // Compiles the graph.
484 XlaCompiler compiler(DefaultOptions());
485
486 XlaCompiler::CompilationResult result;
487 TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "compile",
488 std::move(graph), {}, &result));
489 // Check that the return shapes are correctly tranposed.
490 EXPECT_EQ(result.xla_output_shape,
491 xla::ShapeUtil::MakeTupleShape(
492 {xla::ShapeUtil::MakeShape(xla::S32, {0})}));
493 }
494
495 // Tests that the compiler doesn't reorder the parameters.
TEST_F(XlaCompilerTest,MixedOrderArguments)496 TEST_F(XlaCompilerTest, MixedOrderArguments) {
497 for (bool swap_order : {false, true}) {
498 Scope scope = Scope::NewRootScope().ExitOnError();
499 auto var =
500 ops::_Arg(scope.WithOpName("V"), DT_RESOURCE, swap_order ? 0 : 1);
501 auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, swap_order ? 1 : 0);
502 // Adds an identity op around the resource to make sure identity ops
503 // propagate resources correctly.
504 auto identity = ops::Identity(scope.WithOpName("VIdentity"), var);
505 auto write = ops::AssignAddVariableOp(scope, identity, a);
506 auto read = ops::ReadVariableOp(
507 scope.WithControlDependencies(std::vector<Operation>{write}), var,
508 DT_INT32);
509 auto read_plus_one = ops::Add(scope, read, ops::Const<int32>(scope, 1));
510 auto d = ops::_Retval(scope.WithOpName("D"), read_plus_one, 0);
511 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
512 TF_ASSERT_OK(scope.ToGraph(graph.get()));
513
514 // Builds a description of the arguments.
515 std::vector<XlaCompiler::Argument> args(2);
516 args[0].kind = XlaCompiler::Argument::kParameter;
517 args[0].type = DT_INT32;
518 args[0].shape = TensorShape({2});
519 args[1].kind = XlaCompiler::Argument::kResource;
520 args[1].resource_kind = XlaResource::kVariable;
521 args[1].initialized = true;
522 args[1].type = DT_INT32;
523 args[1].shape = TensorShape({2});
524
525 if (swap_order) {
526 // Even after swapping arguments, the compiler should maintain the new
527 // ordering of parameters.
528 std::swap(args[0], args[1]);
529 }
530 // Compiles the graph.
531 XlaCompiler compiler(DefaultOptions());
532
533 XlaCompiler::CompileOptions compile_options;
534 compile_options.always_return_tuple = false;
535 XlaCompiler::CompilationResult result;
536 TF_ASSERT_OK(compiler.CompileGraph(compile_options, "add", std::move(graph),
537 args, &result));
538
539 EXPECT_THAT(result.input_mapping, ::testing::ElementsAre(0, 1));
540 }
541 }
542
TEST_F(XlaCompilerTest,HasSaneErrorOnNonCompileTimeConstantInputToReshape)543 TEST_F(XlaCompilerTest, HasSaneErrorOnNonCompileTimeConstantInputToReshape) {
544 // Builds a graph that adds reshapes a tensor, but with the shape not
545 // statically known.
546 Scope scope = Scope::NewRootScope().ExitOnError();
547 auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0);
548 auto b = ops::_Arg(scope.WithOpName("B"), DT_INT32, 1);
549 auto c = ops::Reshape(scope.WithOpName("C"), a, b);
550 auto d = ops::_Retval(scope.WithOpName("D"), c, 0);
551 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
552 TF_ASSERT_OK(scope.ToGraph(graph.get()));
553
554 // Builds a description of the arguments.
555 std::vector<XlaCompiler::Argument> args(2);
556 args[0].kind = XlaCompiler::Argument::kParameter;
557 args[0].type = DT_INT32;
558 args[0].shape = TensorShape({2});
559 args[1].kind = XlaCompiler::Argument::kParameter;
560 args[1].type = DT_INT32;
561 args[1].shape = TensorShape({2});
562
563 // Compiles the graph.
564 XlaCompiler compiler(DefaultOptions());
565
566 XlaCompiler::CompilationResult result;
567 Status status =
568 compiler.CompileGraph(XlaCompiler::CompileOptions(), "reshape",
569 std::move(graph), args, &result);
570 EXPECT_FALSE(status.ok());
571 EXPECT_TRUE(
572 absl::StrContains(status.error_message(), "depends on a parameter"))
573 << status.error_message();
574 EXPECT_TRUE(absl::StrContains(status.error_message(), "{{node C}}"))
575 << status.error_message();
576 EXPECT_TRUE(absl::StrContains(status.error_message(),
577 "must be a compile-time constant"))
578 << status.error_message();
579 }
580
581 // Tests handling of compile-time constant outputs.
TEST_F(XlaCompilerTest,ConstantOutputs)582 TEST_F(XlaCompilerTest, ConstantOutputs) {
583 // Builds a graph with one compile-time constant output and one data-dependent
584 // output, i.e.,
585 // func(a) { b=7; c=-a; return b, c; }
586 Scope scope = Scope::NewRootScope().ExitOnError();
587 auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0);
588 auto b = ops::Const<int32>(scope.WithOpName("B"), 7);
589 auto c = ops::Neg(scope.WithOpName("C"), a);
590 auto d = ops::_Retval(scope.WithOpName("D"), b, 0);
591 auto e = ops::_Retval(scope.WithOpName("E"), c, 1);
592 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
593 TF_ASSERT_OK(scope.ToGraph(graph.get()));
594
595 // Builds a description of the arguments.
596 std::vector<XlaCompiler::Argument> args(1);
597 args[0].kind = XlaCompiler::Argument::kParameter;
598 args[0].type = DT_INT32;
599 args[0].shape = TensorShape({2});
600
601 XlaCompiler::Options options = DefaultOptions();
602 XlaCompiler compiler(options);
603
604 {
605 std::unique_ptr<Graph> graph_copy(new Graph(OpRegistry::Global()));
606 CopyGraph(*graph, graph_copy.get());
607
608 XlaCompiler::CompileOptions compile_options;
609 XlaCompiler::CompilationResult result;
610 TF_ASSERT_OK(compiler.CompileGraph(compile_options, "constants",
611 std::move(graph_copy), args, &result));
612
613 ASSERT_EQ(2, result.outputs.size());
614 EXPECT_FALSE(result.outputs[0].is_constant);
615 EXPECT_FALSE(result.outputs[1].is_constant);
616
617 // Tests that the generated computation works.
618 xla::Literal param0_literal = xla::LiteralUtil::CreateR1<int32>({7, 42});
619 std::unique_ptr<xla::GlobalData> param0_data =
620 client_->TransferToServer(param0_literal).value();
621
622 std::unique_ptr<xla::GlobalData> actual =
623 client_->Execute(*result.computation, {param0_data.get()}).value();
624 xla::Literal actual_literal = client_->Transfer(*actual).value();
625
626 xla::Literal expected0 = xla::LiteralUtil::CreateR0<int32>(7);
627 xla::Literal expected1 = xla::LiteralUtil::CreateR1<int32>({-7, -42});
628 xla::Literal expected =
629 xla::LiteralUtil::MakeTuple({&expected0, &expected1});
630 EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected, actual_literal));
631 }
632 }
633
TEST_F(XlaCompilerTest,ConstantOutputsOfFunctionalNode)634 TEST_F(XlaCompilerTest, ConstantOutputsOfFunctionalNode) {
635 // Define a function with one compile-time constant output and one
636 // data-dependent output.
637 // @function.Defun(noinline=True)
638 // foo(a) {b=7; return b, a; }
639 const Tensor seven = test::AsScalar<int>(7);
640 FunctionDef fdef = FunctionDefHelper::Create(
641 "foo", {"a_0:int32"}, {"const:int32", "a:int32"}, {},
642 {
643 {{"Const"}, "Const", {}, {{"dtype", DT_INT32}, {"value", seven}}},
644 },
645 {{"a", "a_0"}, {"const", "Const:output:0"}});
646 (*fdef.mutable_attr())["_noinline"].set_b(true);
647 FunctionDefLibrary fdef_lib;
648 *(fdef_lib.add_function()) = fdef;
649 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
650 {
651 Scope scope = Scope::NewRootScope().ExitOnError();
652 TF_EXPECT_OK(scope.graph()->AddFunctionLibrary(fdef_lib));
653 auto arg = ops::_Arg(scope.WithOpName("input_arg"), DT_INT32, 0);
654 NodeDef foo;
655 foo.set_name("foo");
656 foo.set_op("foo");
657 *foo.add_input() = "input_arg";
658 Status status;
659 scope.graph()->AddNode(foo, &status);
660 TF_ASSERT_OK(status);
661 NodeDef retval_1;
662 retval_1.set_name("retval_0");
663 retval_1.set_op(FunctionLibraryDefinition::kRetOp);
664 *retval_1.add_input() = "foo";
665 (*retval_1.mutable_attr())["T"].set_type(DT_INT32);
666 (*retval_1.mutable_attr())["index"].set_i(0);
667 scope.graph()->AddNode(retval_1, &status);
668 TF_ASSERT_OK(status);
669 NodeDef retval_2;
670 retval_2.set_name("retval_1");
671 retval_2.set_op(FunctionLibraryDefinition::kRetOp);
672 *retval_2.add_input() = "foo:1";
673 (*retval_2.mutable_attr())["T"].set_type(DT_INT32);
674 (*retval_2.mutable_attr())["index"].set_i(1);
675 scope.graph()->AddNode(retval_2, &status);
676 TF_ASSERT_OK(status);
677 TF_ASSERT_OK(scope.ToGraph(graph.get()));
678 }
679
680 // Builds a description of the arguments.
681 std::vector<XlaCompiler::Argument> args(1);
682 args[0].kind = XlaCompiler::Argument::kParameter;
683 args[0].type = DT_INT32;
684 args[0].shape = TensorShape({1});
685
686 XlaCompiler::Options options = DefaultOptions();
687 FunctionLibraryDefinition flib_def(OpRegistry::Global(), fdef_lib);
688 options.flib_def = &flib_def;
689 XlaCompiler compiler(options);
690
691 XlaCompiler::CompileOptions compile_options;
692 XlaCompiler::CompilationResult result;
693 TF_ASSERT_OK(compiler.CompileGraph(compile_options, "constants",
694 std::move(graph), args, &result));
695
696 ASSERT_EQ(2, result.outputs.size());
697 EXPECT_FALSE(result.outputs[1].is_constant);
698 }
699
700 // Tests compilation and execution of a graph that adds two tensors.
TEST_F(XlaCompilerTest,ResourceManager)701 TEST_F(XlaCompilerTest, ResourceManager) {
702 // Builds a graph that calls the dummy resource Op.
703 Scope scope = Scope::NewRootScope().ExitOnError();
704 auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0);
705 auto b = DummyReadResourceCC(scope.WithOpName("B"), a);
706 auto c = ops::Add(scope.WithOpName("C"), b.output2_, b.output1_);
707 auto d = ops::_Retval(scope.WithOpName("D"), c, 0);
708 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
709 TF_ASSERT_OK(scope.ToGraph(graph.get()));
710
711 // Builds a description of the argument.
712 std::vector<XlaCompiler::Argument> args(1);
713 args[0].kind = XlaCompiler::Argument::kParameter;
714 args[0].type = DT_INT32;
715 args[0].shape = TensorShape({2});
716
717 DummyResourceForTest* resource = new DummyResourceForTest();
718
719 // Compiles the graph.
720 auto options = DefaultOptions();
721 std::function<Status(ResourceMgr*)> populate_function =
722 [resource](ResourceMgr* rm) {
723 resource->Ref();
724 return rm->Create(rm->default_container(), "dummy", resource);
725 };
726 options.populate_resource_manager = &populate_function;
727 XlaCompiler compiler(options);
728
729 EXPECT_EQ(0, resource->Get());
730
731 XlaCompiler::CompilationResult result;
732 TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "dummy",
733 std::move(graph), args, &result));
734
735 EXPECT_EQ(1, resource->Get());
736
737 resource->Unref();
738 }
739
740 // Tests compilation and execution of a graph that adds two tensors.
TEST_F(XlaCompilerTest,DeterministicCompilation)741 TEST_F(XlaCompilerTest, DeterministicCompilation) {
742 // Builds a graph that contains a node with two output edges. The compiler
743 // should always traverse them in the same order.
744 const int64_t test_count = 2;
745
746 std::vector<XlaCompiler::CompilationResult> results(test_count);
747
748 for (int64_t i = 0; i < test_count; ++i) {
749 Scope scope = Scope::NewRootScope().ExitOnError();
750 auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0);
751 auto b = ops::Neg(scope.WithOpName("B"), a);
752 auto c = ops::Neg(scope.WithOpName("C"), a);
753 auto d = ops::Add(scope.WithOpName("D"), b, c);
754 auto e = ops::_Retval(scope.WithOpName("E"), d, 0);
755 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
756 TF_ASSERT_OK(scope.ToGraph(graph.get()));
757
758 // Builds a description of the argument.
759 std::vector<XlaCompiler::Argument> args(1);
760 args[0].kind = XlaCompiler::Argument::kParameter;
761 args[0].type = DT_INT32;
762 args[0].shape = TensorShape({2});
763
764 // Compiles the graph.
765 auto options = DefaultOptions();
766 XlaCompiler compiler(options);
767
768 TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "dummy",
769 std::move(graph), args, &results[i]));
770 }
771
772 for (int64_t i = 1; i < test_count; ++i) {
773 const auto& m1 = results[i - 1].computation->proto();
774 const auto& m2 = results[i].computation->proto();
775 ASSERT_EQ(m1.computations_size(), m2.computations_size());
776 // Check if every hlo computation is the same.
777 for (int k = 0; k < m1.computations_size(); k++) {
778 const auto& c1 = m1.computations(k);
779 const auto& c2 = m2.computations(k);
780 ASSERT_EQ(c1.instructions_size(), c2.instructions_size());
781 for (int j = 0; j < c1.instructions_size(); j++) {
782 auto instr1 = c1.instructions(j);
783 auto instr2 = c2.instructions(j);
784 instr1.clear_name();
785 instr1.clear_id();
786 instr1.clear_operand_ids();
787 instr2.clear_name();
788 instr2.clear_id();
789 instr2.clear_operand_ids();
790 // The names of instructions were uniquified by the XlaBuilder and the
791 // unique ids may be different, the rest of the fields should be
792 // identical.
793 string str1, str2;
794 LOG(INFO) << "instr1 = " << instr1.DebugString();
795 LOG(INFO) << "instr2 = " << instr2.DebugString();
796 instr1.AppendPartialToString(&str1);
797 instr2.AppendPartialToString(&str2);
798 EXPECT_EQ(str1, str2);
799 }
800 }
801 }
802 }
803
804 // Tests a computation that receives a TensorArray resource as input and
805 // updates it.
TEST_F(XlaCompilerTest,CanPassTensorArraysToAndFromComputation)806 TEST_F(XlaCompilerTest, CanPassTensorArraysToAndFromComputation) {
807 Scope scope = Scope::NewRootScope().ExitOnError();
808 auto arg = ops::_Arg(scope.WithOpName("arg"), DT_RESOURCE, 0);
809 auto flow = ops::Const<float>(scope, {});
810 auto grad1 = ops::TensorArrayGrad(scope, arg, flow, "grad1");
811 auto grad2 = ops::TensorArrayGrad(scope, arg, grad1.flow_out, "grad2");
812 auto index = ops::Const<int32>(scope, 1);
813 auto write = ops::TensorArrayWrite(scope, grad1.grad_handle, index, index,
814 grad2.flow_out);
815 auto read = ops::TensorArrayRead(scope, arg, index, write.flow_out, DT_INT32);
816 auto retval = ops::_Retval(scope.WithOpName("retval"), read, 0);
817 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
818 TF_ASSERT_OK(scope.ToGraph(graph.get()));
819
820 // Builds a description of the arguments.
821 std::vector<XlaCompiler::Argument> args(1);
822 args[0].kind = XlaCompiler::Argument::kResource;
823 args[0].resource_kind = XlaResource::kTensorArray;
824 args[0].initialized = true;
825 args[0].type = DT_INT32;
826 args[0].shape = TensorShape({});
827 args[0].max_array_size = 2;
828 args[0].tensor_array_gradients = {"grad2"};
829
830 // Compiles the graph.
831 XlaCompiler compiler(DefaultOptions());
832
833 XlaCompiler::CompilationResult result;
834 TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add",
835 std::move(graph), args, &result));
836
837 ASSERT_EQ(1, result.resource_updates.size());
838 const XlaCompiler::ResourceUpdate& update = result.resource_updates[0];
839 EXPECT_EQ(0, update.input_index);
840 EXPECT_EQ(DT_INT32, update.type);
841 EXPECT_EQ((std::set<string>{"grad1", "grad2"}),
842 update.tensor_array_gradients_accessed);
843
844 // Tests that the generated computation works.
845 xla::Literal input_base = xla::LiteralUtil::CreateR1<int32>({7, 42});
846 xla::Literal input_grad2 = xla::LiteralUtil::CreateR1<int32>({-3, 101});
847 xla::Literal input = xla::LiteralUtil::MakeTuple({&input_base, &input_grad2});
848 std::unique_ptr<xla::GlobalData> param0_data =
849 client_->TransferToServer(input).value();
850
851 std::unique_ptr<xla::GlobalData> actual =
852 client_->Execute(*result.computation, {param0_data.get()}).value();
853 xla::Literal actual_literal = client_->Transfer(*actual).value();
854
855 xla::Literal output_read = xla::LiteralUtil::CreateR0<int32>(42);
856 xla::Literal output_base = xla::LiteralUtil::CreateR1<int32>({7, 42});
857 xla::Literal output_grad1 = xla::LiteralUtil::CreateR1<int32>({0, 1});
858 xla::Literal output_grad2 = xla::LiteralUtil::CreateR1<int32>({-3, 101});
859 xla::Literal output_resource =
860 xla::LiteralUtil::MakeTuple({&output_base, &output_grad1, &output_grad2});
861 xla::Literal expected_literal =
862 xla::LiteralUtil::MakeTuple({&output_read, &output_resource});
863 EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal));
864 }
865
866 // Tests compilation and execution of a graph that adds two tensors.
TEST_F(XlaCompilerTest,UnwrittenTensorArrayGradientsAreNotComputationOutputs)867 TEST_F(XlaCompilerTest, UnwrittenTensorArrayGradientsAreNotComputationOutputs) {
868 Scope scope = Scope::NewRootScope().ExitOnError();
869 auto arg = ops::_Arg(scope.WithOpName("arg"), DT_RESOURCE, 0);
870 auto flow = ops::Const<float>(scope, {});
871 auto grad1 = ops::TensorArrayGrad(scope, arg, flow, "grad1");
872 auto index = ops::Const<int32>(scope, 1);
873 auto read = ops::TensorArrayRead(scope, arg, index, grad1.flow_out, DT_INT32);
874 auto retval = ops::_Retval(scope.WithOpName("retval"), read, 0);
875 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
876 TF_ASSERT_OK(scope.ToGraph(graph.get()));
877
878 // Builds a description of the arguments.
879 std::vector<XlaCompiler::Argument> args(1);
880 args[0].kind = XlaCompiler::Argument::kResource;
881 args[0].resource_kind = XlaResource::kTensorArray;
882 args[0].initialized = true;
883 args[0].type = DT_INT32;
884 args[0].shape = TensorShape({});
885 args[0].max_array_size = 2;
886 args[0].tensor_array_gradients = {"grad1"};
887
888 // Compiles the graph.
889 XlaCompiler compiler(DefaultOptions());
890
891 XlaCompiler::CompilationResult result;
892 TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add",
893 std::move(graph), args, &result));
894
895 EXPECT_EQ(0, result.resource_updates.size());
896 }
897
898 // Tests compilation and execution of a graph that adds two tensors.
TEST_F(XlaCompilerTest,NewTensorArrayGradientsAreComputationOutputs)899 TEST_F(XlaCompilerTest, NewTensorArrayGradientsAreComputationOutputs) {
900 Scope scope = Scope::NewRootScope().ExitOnError();
901 auto arg = ops::_Arg(scope.WithOpName("arg"), DT_RESOURCE, 0);
902 auto flow = ops::Const<float>(scope, {});
903 auto grad1 = ops::TensorArrayGrad(scope, arg, flow, "grad2");
904 auto index = ops::Const<int32>(scope, 1);
905 auto read = ops::TensorArrayRead(scope, arg, index, grad1.flow_out, DT_INT32);
906 auto retval = ops::_Retval(scope.WithOpName("retval"), read, 0);
907 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
908 TF_ASSERT_OK(scope.ToGraph(graph.get()));
909
910 // Builds a description of the arguments.
911 std::vector<XlaCompiler::Argument> args(1);
912 args[0].kind = XlaCompiler::Argument::kResource;
913 args[0].resource_kind = XlaResource::kTensorArray;
914 args[0].initialized = true;
915 args[0].type = DT_INT32;
916 args[0].shape = TensorShape({});
917 args[0].max_array_size = 2;
918 args[0].tensor_array_gradients = {"grad1"};
919
920 // Compiles the graph.
921 XlaCompiler compiler(DefaultOptions());
922
923 XlaCompiler::CompilationResult result;
924 TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add",
925 std::move(graph), args, &result));
926
927 EXPECT_EQ(1, result.resource_updates.size());
928 }
929
930 // Tests CompileFunction with undefined function fails.
TEST_F(XlaCompilerTest,UndefinedFunctionFails)931 TEST_F(XlaCompilerTest, UndefinedFunctionFails) {
932 XlaCompiler compiler(DefaultOptions());
933
934 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
935 XlaCompiler::CompilationResult result;
936 NameAttrList name_attr;
937 name_attr.set_name("Function_NotDefined_");
938 Status status =
939 compiler.CompileFunction(XlaCompiler::CompileOptions(), name_attr,
940 /*args=*/{}, &result);
941 EXPECT_FALSE(status.ok());
942 EXPECT_TRUE(absl::StrContains(status.error_message(), "is not defined."))
943 << status.error_message();
944 }
945
FillFn()946 FunctionDef FillFn() {
947 return FunctionDefHelper::Define(
948 // Name
949 "FillFn",
950 // Args
951 {"x: T", "dims: int32"},
952 // Return values
953 {"y: T"},
954 // Attr def
955 {"T: {float, double, int32, int64}"},
956 // Nodes
957 {{{"y"}, "Fill", {"dims", "x"}, {{"T", "$T"}}}});
958 }
959
TEST_F(XlaCompilerTest,FunctionCallWithConstants)960 TEST_F(XlaCompilerTest, FunctionCallWithConstants) {
961 // Certain operations in a function, "Fill" for example, requires the
962 // operator's argument to be a compile-time constant instead of a parameter.
963 // This testcase tests if XlaCompiler can handle such operators inside
964 // function calls.
965 XlaCompiler compiler(DefaultOptions());
966
967 FunctionDefLibrary flib;
968 *flib.add_function() = FillFn();
969
970 TF_ASSERT_OK(flib_def_->AddFunctionDef(FillFn()));
971
972 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
973
974 Scope scope = Scope::NewRootScope().ExitOnError();
975 auto value = ops::Const<int32>(scope.WithOpName("value"), 1, {});
976 auto shape = ops::Const<int32>(scope.WithOpName("shape"), {5}, {1});
977 TF_EXPECT_OK(scope.graph()->AddFunctionLibrary(flib));
978
979 NodeDef def;
980 TF_ASSERT_OK(NodeDefBuilder("fill", "FillFn", flib_def_.get())
981 .Input(value.name(), 0, DT_INT32)
982 .Input(shape.name(), 1, DT_INT32)
983 .Finalize(&def));
984 Status status;
985 Node* fill = scope.graph()->AddNode(def, &status);
986 TF_ASSERT_OK(status);
987 TF_ASSERT_OK(scope.DoShapeInference(fill));
988 scope.graph()->AddEdge(value.node(), 0, fill, 0);
989 scope.graph()->AddEdge(shape.node(), 0, fill, 1);
990
991 auto retval = ops::_Retval(scope.WithOpName("retval"), Output(fill), 0);
992
993 TF_ASSERT_OK(scope.ToGraph(graph.get()));
994
995 // Builds a description of the argument.
996 std::vector<XlaCompiler::Argument> args;
997
998 XlaCompiler::CompilationResult result;
999 TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "fill",
1000 std::move(graph), args, &result));
1001 }
1002
1003 // Tests CompileFunction with a local function lookup failing, fails with
1004 // informative error about both lookups.
TEST_F(XlaCompilerTest,LocalFunctionWithWrongArgumentsFail)1005 TEST_F(XlaCompilerTest, LocalFunctionWithWrongArgumentsFail) {
1006 XlaCompiler compiler(DefaultOptions());
1007
1008 auto local_flib_def = LocalFlibDef(&compiler);
1009 TF_ASSERT_OK(local_flib_def->AddFunctionDef(test::function::XTimesTwo()));
1010
1011 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1012 XlaCompiler::CompilationResult result;
1013 NameAttrList name_attr;
1014 name_attr.set_name("XTimesTwo");
1015 Status status =
1016 compiler.CompileFunction(XlaCompiler::CompileOptions(), name_attr,
1017 /*args=*/{}, &result);
1018
1019 ASSERT_FALSE(status.ok());
1020 // Flib lookup failure.
1021 EXPECT_TRUE(absl::StrContains(status.error_message(), "is not defined."))
1022 << status.error_message();
1023 // Local flib lookup failure.
1024 EXPECT_TRUE(absl::StrContains(status.error_message(), "Attr T is not found"))
1025 << status.error_message();
1026 }
1027
SliceFn()1028 FunctionDef SliceFn() {
1029 return FunctionDefHelper::Define(
1030 // Name
1031 "SliceFn",
1032 // Args
1033 {"x: T", "begin: Index", "size: Index"},
1034 // Return values
1035 {"y: T"},
1036 // Attr def
1037 {"T: {float, double, int32, int64}", "Index: {int32,int64}"},
1038 // Nodes
1039 {{{"y"},
1040 "Slice",
1041 {"x", "begin", "size"},
1042 {{"T", "$T"}, {"Index", "$Index"}}}});
1043 }
1044
TEST_F(XlaCompilerTest,SliceWithDynamicBegins)1045 TEST_F(XlaCompilerTest, SliceWithDynamicBegins) {
1046 // Certain operations in a function, "Slice" for example, support both dynamic
1047 // inputs and static inputs. This test checks that dynamic inputs can also
1048 // be supported in a function call.
1049 XlaCompiler compiler(DefaultOptions());
1050
1051 FunctionDefLibrary flib;
1052 *flib.add_function() = SliceFn();
1053
1054 TF_ASSERT_OK(flib_def_->AddFunctionDef(SliceFn()));
1055
1056 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1057
1058 Scope scope = Scope::NewRootScope().ExitOnError();
1059 auto value = ops::Const<int32>(scope.WithOpName("shape"), {5}, {1});
1060 auto begin = ops::_Arg(scope.WithOpName("arg"), DT_INT32, 0);
1061 auto size = ops::Const<int32>(scope.WithOpName("value"), {1}, {1});
1062
1063 TF_EXPECT_OK(scope.graph()->AddFunctionLibrary(flib));
1064
1065 NodeDef def;
1066 TF_ASSERT_OK(NodeDefBuilder("slice", "SliceFn", flib_def_.get())
1067 .Input(value.name(), 0, DT_INT32)
1068 .Input(begin.node()->name(), 1, DT_INT32)
1069 .Input(size.name(), 2, DT_INT32)
1070 .Finalize(&def));
1071 Status status;
1072 Node* slice = scope.graph()->AddNode(def, &status);
1073 TF_ASSERT_OK(status);
1074 TF_ASSERT_OK(scope.DoShapeInference(slice));
1075 scope.graph()->AddEdge(value.node(), 0, slice, 0);
1076 scope.graph()->AddEdge(begin.node(), 0, slice, 1);
1077 scope.graph()->AddEdge(size.node(), 0, slice, 2);
1078
1079 auto retval = ops::_Retval(scope.WithOpName("retval"), Output(slice), 0);
1080
1081 TF_ASSERT_OK(scope.ToGraph(graph.get()));
1082
1083 // Builds a description of the argument.
1084 std::vector<XlaCompiler::Argument> args(1);
1085 args[0].kind = XlaCompiler::Argument::kParameter;
1086 args[0].type = DT_INT32;
1087 args[0].shape = TensorShape({1});
1088
1089 XlaCompiler::CompilationResult result;
1090 TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "slice",
1091 std::move(graph), args, &result));
1092 }
1093
RunAndCheckVariablesComputation(xla::Client * client,const XlaCompiler::CompilationResult & result)1094 void RunAndCheckVariablesComputation(
1095 xla::Client* client, const XlaCompiler::CompilationResult& result) {
1096 xla::Literal param0_literal = xla::LiteralUtil::CreateR1<int32>({7, 42});
1097 xla::Literal param1_literal = xla::LiteralUtil::CreateR1<int32>({-3, 101});
1098 std::unique_ptr<xla::GlobalData> param0_data =
1099 client->TransferToServer(param0_literal).value();
1100 std::unique_ptr<xla::GlobalData> param1_data =
1101 client->TransferToServer(param1_literal).value();
1102
1103 std::unique_ptr<xla::GlobalData> actual =
1104 client
1105 ->Execute(*result.computation, {param0_data.get(), param1_data.get()})
1106 .value();
1107 xla::Literal actual_literal = client->Transfer(*actual).value();
1108
1109 xla::Literal expected0 = xla::LiteralUtil::CreateR1<int32>({5, 144});
1110 xla::Literal expected1 = xla::LiteralUtil::CreateR1<int32>({4, 143});
1111 xla::Literal expected_literal =
1112 xla::LiteralUtil::MakeTuple({&expected0, &expected1});
1113 EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal));
1114 }
1115
1116 // Tests a simple graph that reads and writes a variable.
TEST_F(XlaCompilerTest,Variables)1117 TEST_F(XlaCompilerTest, Variables) {
1118 Scope scope = Scope::NewRootScope().ExitOnError();
1119 auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0);
1120 auto var = ops::_Arg(scope.WithOpName("V"), DT_RESOURCE, 1);
1121 // Adds an identity op around the resource to make sure identity ops propagate
1122 // resources correctly.
1123 auto identity = ops::Identity(scope.WithOpName("VIdentity"), var);
1124 auto write = ops::AssignAddVariableOp(scope, identity, a);
1125 auto read = ops::ReadVariableOp(
1126 scope.WithControlDependencies(std::vector<Operation>{write}), var,
1127 DT_INT32);
1128 auto read_plus_one = ops::Add(scope, read, ops::Const<int32>(scope, 1));
1129 auto d = ops::_Retval(scope.WithOpName("D"), read_plus_one, 0);
1130 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1131 TF_ASSERT_OK(scope.ToGraph(graph.get()));
1132
1133 // Builds a description of the arguments.
1134 std::vector<XlaCompiler::Argument> args(2);
1135 args[0].kind = XlaCompiler::Argument::kParameter;
1136 args[0].type = DT_INT32;
1137 args[0].shape = TensorShape({2});
1138 args[1].kind = XlaCompiler::Argument::kResource;
1139 args[1].resource_kind = XlaResource::kVariable;
1140 args[1].initialized = true;
1141 args[1].type = DT_INT32;
1142 args[1].shape = TensorShape({2});
1143
1144 // Compiles the graph.
1145 XlaCompiler compiler(DefaultOptions());
1146
1147 XlaCompiler::CompilationResult result;
1148 TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add",
1149 std::move(graph), args, &result));
1150 RunAndCheckVariablesComputation(client_, result);
1151 }
1152
TEST_F(XlaCompilerTest,ResultLayoutSingle)1153 TEST_F(XlaCompilerTest, ResultLayoutSingle) {
1154 Scope scope = Scope::NewRootScope().ExitOnError();
1155 auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0);
1156 auto b = ops::_Retval(scope.WithOpName("RET"), a, 0);
1157
1158 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1159 TF_ASSERT_OK(scope.ToGraph(graph.get()));
1160
1161 // Builds a description of the arguments.
1162 std::vector<XlaCompiler::Argument> args(1);
1163 args[0].kind = XlaCompiler::Argument::kParameter;
1164 args[0].type = DT_INT32;
1165 args[0].shape = TensorShape({2, 3});
1166
1167 auto options = DefaultOptions();
1168 // Sets the representation function to return a non-default layout.
1169 XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_fns;
1170 shape_determination_fns.shape_representation_fn =
1171 [](const TensorShape& shape, DataType type, bool use_fast_memory,
1172 XlaLayoutPreference layout_preference) -> StatusOr<xla::Shape> {
1173 xla::Shape xla_shape;
1174 TF_RETURN_IF_ERROR(TensorShapeToXLAShape(type, shape, &xla_shape));
1175 *xla_shape.mutable_layout() = xla::LayoutUtil::MakeLayout({0, 1});
1176 return xla_shape;
1177 };
1178 options.shape_determination_fns = shape_determination_fns;
1179
1180 // Compiles the graph.
1181 XlaCompiler compiler(options);
1182
1183 XlaCompiler::CompilationResult result;
1184 auto compile_options = XlaCompiler::CompileOptions();
1185 compile_options.always_return_tuple = false;
1186 TF_ASSERT_OK(compiler.CompileGraph(compile_options, "id", std::move(graph),
1187 args, &result));
1188 EXPECT_TRUE(xla::ShapeUtil::Equal(
1189 result.xla_output_shape,
1190 xla::ShapeUtil::MakeShapeWithLayout(xla::S32, {2, 3}, {0, 1})));
1191 EXPECT_EQ(result.computation->GetProgramShape().value().result(),
1192 result.xla_output_shape);
1193 }
1194
TEST_F(XlaCompilerTest,ResultLayoutMultiple)1195 TEST_F(XlaCompilerTest, ResultLayoutMultiple) {
1196 Scope scope = Scope::NewRootScope().ExitOnError();
1197 auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0);
1198 auto b = ops::_Retval(scope.WithOpName("RET1"), a, 0);
1199 auto c = ops::_Retval(scope.WithOpName("RET2"), a, 1);
1200
1201 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1202 TF_ASSERT_OK(scope.ToGraph(graph.get()));
1203
1204 // Builds a description of the arguments.
1205 std::vector<XlaCompiler::Argument> args(1);
1206 args[0].kind = XlaCompiler::Argument::kParameter;
1207 args[0].type = DT_INT32;
1208 args[0].shape = TensorShape({2, 3});
1209
1210 auto options = DefaultOptions();
1211 // Sets the representation function to return a non-default layout.
1212 XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_fns;
1213 shape_determination_fns.shape_representation_fn =
1214 [](const TensorShape& shape, DataType type, bool use_fast_memory,
1215 XlaLayoutPreference layout_preference) -> StatusOr<xla::Shape> {
1216 xla::Shape xla_shape;
1217 TF_RETURN_IF_ERROR(TensorShapeToXLAShape(type, shape, &xla_shape));
1218 *xla_shape.mutable_layout() = xla::LayoutUtil::MakeLayout({0, 1});
1219 return xla_shape;
1220 };
1221 shape_determination_fns.layout_preference_fn = UseNoPreferenceLayoutFn();
1222 options.shape_determination_fns = shape_determination_fns;
1223
1224 // Compiles the graph.
1225 XlaCompiler compiler(options);
1226
1227 XlaCompiler::CompilationResult result;
1228 TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "id",
1229 std::move(graph), args, &result));
1230 xla::Shape result_shape =
1231 xla::ShapeUtil::MakeShapeWithLayout(xla::S32, {2, 3}, {0, 1});
1232
1233 EXPECT_TRUE(xla::ShapeUtil::Equal(
1234 result.xla_output_shape,
1235 xla::ShapeUtil::MakeTupleShape({result_shape, result_shape})));
1236 EXPECT_EQ(result.computation->GetProgramShape().value().result(),
1237 result.xla_output_shape);
1238 }
1239
1240 // Tests a simple graph that reads and writes a variable.
TEST_F(XlaCompilerTest,ReturnResourceHandleOnly)1241 TEST_F(XlaCompilerTest, ReturnResourceHandleOnly) {
1242 Scope scope = Scope::NewRootScope().ExitOnError();
1243 auto var = ops::_Arg(scope.WithOpName("V"), DT_RESOURCE, 0);
1244 auto d = ops::_Retval(scope.WithOpName("D"), var, 0);
1245 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1246 TF_ASSERT_OK(scope.ToGraph(graph.get()));
1247
1248 // Builds a description of the arguments.
1249 std::vector<XlaCompiler::Argument> args(1);
1250 args[0].kind = XlaCompiler::Argument::kResource;
1251 args[0].resource_kind = XlaResource::kVariable;
1252 args[0].initialized = true;
1253 args[0].type = DT_INT32;
1254 args[0].shape = TensorShape({2});
1255
1256 // Compiles the graph.
1257 XlaCompiler compiler(DefaultOptions());
1258
1259 XlaCompiler::CompilationResult result;
1260 TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add",
1261 std::move(graph), args, &result));
1262
1263 // Tests that the generated computation works.
1264 xla::Literal param1_literal = xla::LiteralUtil::CreateR1<int32>({-3, 101});
1265 std::unique_ptr<xla::GlobalData> param1_data =
1266 client_->TransferToServer(param1_literal).value();
1267
1268 std::unique_ptr<xla::GlobalData> actual =
1269 client_->Execute(*result.computation, {param1_data.get()}).value();
1270 xla::Literal actual_literal = client_->Transfer(*actual).value();
1271
1272 xla::Literal expected_literal = xla::LiteralUtil::MakeTuple({});
1273 EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal));
1274 }
1275
TEST_F(XlaCompilerTest,ReturnResourceHandle)1276 TEST_F(XlaCompilerTest, ReturnResourceHandle) {
1277 Scope scope = Scope::NewRootScope().ExitOnError();
1278 auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0);
1279 auto var = ops::_Arg(scope.WithOpName("V"), DT_RESOURCE, 1);
1280 // Adds an identity op around the resource to make sure identity ops propagate
1281 // resources correctly.
1282 auto identity = ops::Identity(scope.WithOpName("VIdentity"), var);
1283 auto write = ops::AssignAddVariableOp(scope, identity, a);
1284 auto read = ops::ReadVariableOp(
1285 scope.WithControlDependencies(std::vector<Operation>{write}), var,
1286 DT_INT32);
1287 auto read_plus_one = ops::Add(scope, read, ops::Const<int32>(scope, 1));
1288 auto r = ops::_Retval(scope.WithOpName("R"), var, 0);
1289 auto d = ops::_Retval(scope.WithOpName("D"), read_plus_one, 1);
1290
1291 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1292 TF_ASSERT_OK(scope.ToGraph(graph.get()));
1293
1294 // Builds a description of the arguments.
1295 std::vector<XlaCompiler::Argument> args(2);
1296 args[0].kind = XlaCompiler::Argument::kParameter;
1297 args[0].type = DT_INT32;
1298 args[0].shape = TensorShape({2});
1299 args[1].kind = XlaCompiler::Argument::kResource;
1300 args[1].resource_kind = XlaResource::kVariable;
1301 args[1].initialized = true;
1302 args[1].type = DT_INT32;
1303 args[1].shape = TensorShape({2});
1304
1305 // Compiles the graph.
1306 XlaCompiler compiler(DefaultOptions());
1307
1308 XlaCompiler::CompilationResult result;
1309 TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add",
1310 std::move(graph), args, &result));
1311 RunAndCheckVariablesComputation(client_, result);
1312 }
1313
BuildTestGraph()1314 StatusOr<std::unique_ptr<Graph>> BuildTestGraph() {
1315 Scope scope = Scope::NewRootScope().ExitOnError();
1316 auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0);
1317 auto var = ops::_Arg(scope.WithOpName("V"), DT_RESOURCE, 1);
1318 auto write = ops::AssignAddVariableOp(scope, var, a);
1319 auto read = ops::ReadVariableOp(
1320 scope.WithControlDependencies(std::vector<Operation>{write}), var,
1321 DT_INT32);
1322 auto read_plus_one = ops::Add(scope, read, ops::Const<int32>(scope, 1));
1323 auto d = ops::_Retval(scope.WithOpName("D"), read_plus_one, 0);
1324 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1325 TF_RETURN_IF_ERROR(scope.ToGraph(graph.get()));
1326 return std::move(graph);
1327 }
1328
1329 // Tests a simple graph that reads and writes a variable, with a
1330 // shape_representation_fn passed to the compiler that flattens all
1331 // variable tensors to vectors.
TEST_F(XlaCompilerTest,VariableRepresentationShapeFunction)1332 TEST_F(XlaCompilerTest, VariableRepresentationShapeFunction) {
1333 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Graph> graph, BuildTestGraph());
1334
1335 // Builds a description of the arguments.
1336 std::vector<XlaCompiler::Argument> args(2);
1337 args[0].kind = XlaCompiler::Argument::kParameter;
1338 args[0].type = DT_INT32;
1339 args[0].shape = TensorShape({2, 2});
1340 args[1].kind = XlaCompiler::Argument::kResource;
1341 args[1].resource_kind = XlaResource::kVariable;
1342 args[1].initialized = true;
1343 args[1].type = DT_INT32;
1344 args[1].shape = TensorShape({2, 2});
1345
1346 // Compiles the graph.
1347 XlaCompiler::Options options = DefaultOptions();
1348 XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_fns;
1349 shape_determination_fns.shape_representation_fn =
1350 [](const TensorShape& shape, DataType type, bool use_fast_memory,
1351 XlaLayoutPreference layout_preference) -> StatusOr<xla::Shape> {
1352 xla::PrimitiveType ptype;
1353 TF_RETURN_IF_ERROR(DataTypeToPrimitiveType(type, &ptype));
1354 return xla::ShapeUtil::MakeShape(ptype, {shape.num_elements()});
1355 };
1356 options.shape_determination_fns = shape_determination_fns;
1357 XlaCompiler compiler(options);
1358
1359 XlaCompiler::CompileOptions compile_options;
1360 compile_options.is_entry_computation = false; // Only reshape variables.
1361
1362 XlaCompiler::CompilationResult result;
1363 TF_ASSERT_OK(compiler.CompileGraph(compile_options, "add", std::move(graph),
1364 args, &result));
1365
1366 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::ProgramShape> program_shape,
1367 client_->GetComputationShape(*result.computation));
1368
1369 ASSERT_EQ(program_shape->parameters_size(), 2);
1370 EXPECT_TRUE(
1371 xla::ShapeUtil::Compatible(program_shape->parameters(0),
1372 xla::ShapeUtil::MakeShape(xla::S32, {2, 2})));
1373 EXPECT_TRUE(xla::ShapeUtil::Compatible(
1374 program_shape->parameters(1), xla::ShapeUtil::MakeShape(xla::S32, {4})));
1375 EXPECT_TRUE(xla::ShapeUtil::Compatible(
1376 program_shape->result(),
1377 xla::ShapeUtil::MakeTupleShape(
1378 {xla::ShapeUtil::MakeShape(xla::S32, {2, 2}),
1379 xla::ShapeUtil::MakeShape(xla::S32, {4})})));
1380
1381 // Tests that the generated computation works.
1382 xla::Literal param0_literal =
1383 xla::LiteralUtil::CreateR2<int32>({{4, 55}, {1, -3}});
1384 xla::Literal param1_literal =
1385 xla::LiteralUtil::CreateR1<int32>({22, 11, 33, 404});
1386 std::unique_ptr<xla::GlobalData> param0_data =
1387 client_->TransferToServer(param0_literal).value();
1388 std::unique_ptr<xla::GlobalData> param1_data =
1389 client_->TransferToServer(param1_literal).value();
1390
1391 std::unique_ptr<xla::GlobalData> actual =
1392 client_
1393 ->Execute(*result.computation, {param0_data.get(), param1_data.get()})
1394 .value();
1395 xla::Literal actual_literal = client_->Transfer(*actual).value();
1396
1397 xla::Literal expected0 =
1398 xla::LiteralUtil::CreateR2<int32>({{27, 67}, {35, 402}});
1399 xla::Literal expected1 = xla::LiteralUtil::CreateR1<int32>({26, 66, 34, 401});
1400 xla::Literal expected_literal =
1401 xla::LiteralUtil::MakeTuple({&expected0, &expected1});
1402 EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal));
1403 }
1404
TEST_F(XlaCompilerTest,ArgRetvalShapeRepresentationFunction)1405 TEST_F(XlaCompilerTest, ArgRetvalShapeRepresentationFunction) {
1406 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Graph> graph, BuildTestGraph());
1407
1408 // Builds a description of the arguments.
1409 std::vector<XlaCompiler::Argument> args(2);
1410 args[0].kind = XlaCompiler::Argument::kParameter;
1411 args[0].type = DT_INT32;
1412 args[0].shape = TensorShape({2, 2});
1413 args[1].kind = XlaCompiler::Argument::kResource;
1414 args[1].resource_kind = XlaResource::kVariable;
1415 args[1].initialized = true;
1416 args[1].type = DT_INT32;
1417 args[1].shape = TensorShape({2, 2});
1418
1419 // Compiles the graph.
1420 XlaCompiler::Options options = DefaultOptions();
1421 XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_fns;
1422 shape_determination_fns.shape_representation_fn =
1423 [](const TensorShape& shape, DataType type, bool use_fast_memory,
1424 XlaLayoutPreference layout_preference) -> StatusOr<xla::Shape> {
1425 xla::PrimitiveType ptype;
1426 TF_RETURN_IF_ERROR(DataTypeToPrimitiveType(type, &ptype));
1427 return xla::ShapeUtil::MakeShape(ptype, {shape.num_elements()});
1428 };
1429 options.shape_determination_fns = shape_determination_fns;
1430 XlaCompiler compiler(options);
1431
1432 XlaCompiler::CompileOptions compile_options;
1433 compile_options.is_entry_computation = true; // Reshape args and retvals.
1434
1435 XlaCompiler::CompilationResult result;
1436 TF_ASSERT_OK(compiler.CompileGraph(compile_options, "add", std::move(graph),
1437 args, &result));
1438
1439 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::ProgramShape> program_shape,
1440 client_->GetComputationShape(*result.computation));
1441
1442 ASSERT_EQ(program_shape->parameters_size(), 2);
1443 EXPECT_TRUE(xla::ShapeUtil::Compatible(
1444 program_shape->parameters(0), xla::ShapeUtil::MakeShape(xla::S32, {4})));
1445 EXPECT_TRUE(xla::ShapeUtil::Compatible(
1446 program_shape->parameters(1), xla::ShapeUtil::MakeShape(xla::S32, {4})));
1447 EXPECT_TRUE(xla::ShapeUtil::Compatible(
1448 program_shape->result(),
1449 xla::ShapeUtil::MakeTupleShape(
1450 {xla::ShapeUtil::MakeShape(xla::S32, {4}),
1451 xla::ShapeUtil::MakeShape(xla::S32, {4})})));
1452
1453 // Tests that the generated computation works.
1454 xla::Literal param0_literal =
1455 xla::LiteralUtil::CreateR1<int32>({4, 55, 1, -3});
1456 xla::Literal param1_literal =
1457 xla::LiteralUtil::CreateR1<int32>({22, 11, 33, 404});
1458 std::unique_ptr<xla::GlobalData> param0_data =
1459 client_->TransferToServer(param0_literal).value();
1460 std::unique_ptr<xla::GlobalData> param1_data =
1461 client_->TransferToServer(param1_literal).value();
1462
1463 std::unique_ptr<xla::GlobalData> actual =
1464 client_
1465 ->Execute(*result.computation, {param0_data.get(), param1_data.get()})
1466 .value();
1467 xla::Literal actual_literal = client_->Transfer(*actual).value();
1468
1469 xla::Literal expected0 = xla::LiteralUtil::CreateR1<int32>({27, 67, 35, 402});
1470 xla::Literal expected1 = xla::LiteralUtil::CreateR1<int32>({26, 66, 34, 401});
1471 xla::Literal expected_literal =
1472 xla::LiteralUtil::MakeTuple({&expected0, &expected1});
1473 EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal));
1474 }
1475
1476 // Tests a graph which has a function with an invalid op.
TEST_F(XlaCompilerTest,FunctionWithInvalidOp)1477 TEST_F(XlaCompilerTest, FunctionWithInvalidOp) {
1478 XlaCompiler compiler(DefaultOptions());
1479
1480 FunctionDefLibrary flib;
1481 FunctionDef fn = FillFn();
1482 NodeDef* node = fn.add_node_def();
1483 node->set_name("Invalid");
1484 node->set_op("InvalidOp"); /* unsupported op */
1485 node = fn.add_node_def();
1486 node->set_name("Switch");
1487 node->set_op("Switch"); /* control flow node */
1488 *flib.add_function() = fn;
1489
1490 TF_ASSERT_OK(flib_def_->AddFunctionDef(fn));
1491
1492 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1493
1494 Scope scope = Scope::NewRootScope().ExitOnError();
1495 auto value = ops::Const<int32>(scope.WithOpName("value"), 1, {});
1496 auto shape = ops::Const<int32>(scope.WithOpName("shape"), {5}, {1});
1497 TF_ASSERT_OK(scope.graph()->AddFunctionLibrary(flib));
1498
1499 NodeDef def;
1500 TF_ASSERT_OK(NodeDefBuilder("fill_fn", "FillFn", flib_def_.get())
1501 .Input(value.name(), 0, DT_INT32)
1502 .Input(shape.name(), 1, DT_INT32)
1503 .Finalize(&def));
1504 Status status;
1505 Node* fill = scope.graph()->AddNode(def, &status);
1506 TF_ASSERT_OK(status);
1507 TF_ASSERT_OK(scope.DoShapeInference(fill));
1508 scope.graph()->AddEdge(value.node(), 0, fill, 0);
1509 scope.graph()->AddEdge(shape.node(), 0, fill, 1);
1510
1511 auto retval = ops::_Retval(scope.WithOpName("retval"), Output(fill), 0);
1512
1513 TF_ASSERT_OK(scope.ToGraph(graph.get()));
1514
1515 std::vector<XlaCompiler::Argument> args;
1516 XlaCompiler::CompilationResult result;
1517 status = compiler.CompileGraph(XlaCompiler::CompileOptions(), "fill",
1518 std::move(graph), args, &result);
1519 ASSERT_FALSE(status.ok());
1520 EXPECT_TRUE(absl::StrContains(status.error_message(), "InvalidOp"))
1521 << status.error_message();
1522 EXPECT_TRUE(absl::StrContains(status.error_message(), "{{node fill_fn}}"))
1523 << status.error_message();
1524 }
1525
1526 // Tests a graph which has a node with invalid data type.
TEST_F(XlaCompilerTest,NodeWithInvalidDataType)1527 TEST_F(XlaCompilerTest, NodeWithInvalidDataType) {
1528 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1529 NodeDef shape;
1530 shape.set_name("Shape");
1531 shape.set_op("Shape");
1532 (*shape.mutable_attr())["T"].set_type(DT_INT32);
1533 (*shape.mutable_attr())["out_type"].set_type(DT_BOOL); /* invalid type */
1534 Status status;
1535 Node* shape_node = graph->AddNode(shape, &status);
1536 TF_ASSERT_OK(status);
1537 graph->AddControlEdge(graph->source_node(), shape_node);
1538
1539 std::vector<XlaCompiler::Argument> args;
1540 XlaCompiler::CompilationResult result;
1541 XlaCompiler compiler(DefaultOptions());
1542 status = compiler.CompileGraph(XlaCompiler::CompileOptions(), "invalid_type",
1543 std::move(graph), args, &result);
1544 ASSERT_FALSE(status.ok());
1545 EXPECT_TRUE(absl::StrContains(status.error_message(),
1546 "is not in the list of allowed values"))
1547 << status.error_message();
1548 EXPECT_TRUE(absl::StrContains(status.error_message(), "{{node Shape}}"))
1549 << status.error_message();
1550 }
1551
TEST_F(XlaCompilerTest,SingleOpWithoutInputs)1552 TEST_F(XlaCompilerTest, SingleOpWithoutInputs) {
1553 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1554 NodeDef no_op;
1555 no_op.set_name("NoOp");
1556 no_op.set_op("NoOp");
1557 Status status;
1558 graph->AddNode(no_op, &status);
1559 TF_ASSERT_OK(status);
1560
1561 std::vector<XlaCompiler::Argument> args;
1562 XlaCompiler compiler(DefaultOptions());
1563 // No control edge linking NoOp with source/sink.
1564 {
1565 std::unique_ptr<Graph> graph_copy(new Graph(OpRegistry::Global()));
1566 CopyGraph(*graph, graph_copy.get());
1567 XlaCompiler::CompilationResult result;
1568 TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "NoOp",
1569 std::move(graph_copy), args, &result));
1570 }
1571 }
1572
1573 class DummySideEffectingOp : public XlaOpKernel {
1574 public:
DummySideEffectingOp(OpKernelConstruction * ctx)1575 explicit DummySideEffectingOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
Compile(XlaOpKernelContext * ctx)1576 void Compile(XlaOpKernelContext* ctx) override {
1577 OP_REQUIRES_OK(ctx, ctx->compiler()->SetNodeToken(
1578 name(), xla::CreateToken(ctx->builder())));
1579 }
1580 };
1581
1582 REGISTER_OP("DummySideEffectingOp");
1583
1584 REGISTER_XLA_OP(Name("DummySideEffectingOp"), DummySideEffectingOp);
1585
TEST_F(XlaCompilerTest,TokenInputAndOutput)1586 TEST_F(XlaCompilerTest, TokenInputAndOutput) {
1587 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1588 NodeDef side_effecting_op;
1589 side_effecting_op.set_name("DummySideEffectingOp");
1590 side_effecting_op.set_op("DummySideEffectingOp");
1591 AddNodeAttr(kXlaTokenInputNodesAttrName,
1592 std::vector<string>{kXlaTokenArgNodeName}, &side_effecting_op);
1593 AddNodeAttr(kXlaOriginalOutsideCompilationNodeName, side_effecting_op.name(),
1594 &side_effecting_op);
1595 Status status;
1596 graph->AddNode(side_effecting_op, &status);
1597 TF_ASSERT_OK(status);
1598 EXPECT_TRUE(FixupSourceAndSinkEdges(graph.get()));
1599
1600 std::vector<XlaCompiler::Argument> args(1);
1601 args[0].kind = XlaCompiler::Argument::kResource;
1602 args[0].resource_kind = XlaResource::kVariable;
1603 args[0].initialized = true;
1604 args[0].type = DT_INT32;
1605 args[0].shape = TensorShape({2, 2});
1606
1607 {
1608 // The case for entry computation: we don't add token input/output. Instead,
1609 // we use CreateToken HLO to create the entry token.
1610 XlaCompiler::CompileOptions options;
1611 options.is_entry_computation = true;
1612 options.add_token_input_output = false;
1613 options.return_updated_values_for_all_resources = true;
1614 XlaCompiler compiler(DefaultOptions());
1615
1616 std::unique_ptr<Graph> graph_copy(new Graph(OpRegistry::Global()));
1617 CopyGraph(*graph, graph_copy.get());
1618 XlaCompiler::CompilationResult result;
1619 TF_ASSERT_OK(compiler.CompileGraph(options, "NoOp", std::move(graph_copy),
1620 args, &result));
1621 EXPECT_EQ(result.xla_input_shapes.size(), 1);
1622 EXPECT_TRUE(result.xla_output_shape.IsTuple());
1623 EXPECT_EQ(xla::ShapeUtil::TupleElementCount(result.xla_output_shape), 1);
1624 }
1625 {
1626 // The case for non-entry computation (e.g. while loop body). We add token
1627 // input/output.
1628 XlaCompiler::CompileOptions options;
1629 options.is_entry_computation = false;
1630 options.add_token_input_output = true;
1631 options.return_updated_values_for_all_resources = true;
1632 XlaCompiler compiler(DefaultOptions());
1633
1634 std::unique_ptr<Graph> graph_copy(new Graph(OpRegistry::Global()));
1635 CopyGraph(*graph, graph_copy.get());
1636 XlaCompiler::CompilationResult result;
1637 TF_ASSERT_OK(compiler.CompileGraph(options, "NoOp", std::move(graph_copy),
1638 args, &result));
1639 EXPECT_EQ(result.xla_input_shapes.size(), 2);
1640 EXPECT_TRUE(result.xla_input_shapes[1].IsToken());
1641 EXPECT_TRUE(result.xla_output_shape.IsTuple());
1642 EXPECT_EQ(xla::ShapeUtil::TupleElementCount(result.xla_output_shape), 2);
1643 EXPECT_TRUE(xla::ShapeUtil::GetTupleElementShape(result.xla_output_shape, 1)
1644 .IsToken());
1645 }
1646 }
1647
TEST_F(XlaCompilerTest,OpsWithTensorListInput)1648 TEST_F(XlaCompilerTest, OpsWithTensorListInput) {
1649 FunctionDefLibrary fdef_lib;
1650 FunctionLibraryDefinition flib_def(OpRegistry::Global(), fdef_lib);
1651 // Build cond fn for While.
1652 {
1653 Scope scope = Scope::NewRootScope().ExitOnError();
1654 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1655 ops::_Arg(scope.WithOpName("arg"), DT_VARIANT, 0);
1656 auto result = ops::Const<bool>(scope, {true}, {});
1657 ops::_Retval(scope.WithOpName("ret"), result, 0);
1658 TF_ASSERT_OK(scope.ToGraph(graph.get()));
1659 FunctionDef fdef;
1660 TF_ASSERT_OK(GraphToFunctionDef(*graph, "cond", &fdef));
1661 TF_ASSERT_OK(flib_def.AddFunctionDef(fdef));
1662 }
1663 // Build body fn for While.
1664 {
1665 Scope scope = Scope::NewRootScope().ExitOnError();
1666 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1667 auto arg = ops::_Arg(scope.WithOpName("arg"), DT_VARIANT, 0);
1668 ops::_Retval(scope.WithOpName("ret"), arg, 0);
1669 TF_ASSERT_OK(scope.ToGraph(graph.get()));
1670 FunctionDef fdef;
1671 TF_ASSERT_OK(GraphToFunctionDef(*graph, "body", &fdef));
1672 TF_ASSERT_OK(flib_def.AddFunctionDef(fdef));
1673 }
1674
1675 Scope scope = Scope::NewRootScope().ExitOnError();
1676 auto element_shape = ops::Const<int32>(scope, {1}, {1});
1677 auto max_elements = ops::Const<int32>(scope, {10}, {});
1678 auto arg = ops::_Arg(scope.WithOpName("arg"), DT_VARIANT, 0);
1679 std::initializer_list<Output> out = {arg, arg};
1680 auto add_n = ops::AddN(scope, out);
1681 NameAttrList cond_fn, body_fn;
1682 cond_fn.set_name("cond");
1683 body_fn.set_name("body");
1684 auto while_op =
1685 ops::While(scope, std::initializer_list<Input>{arg}, cond_fn, body_fn);
1686 auto ret0 = ops::_Retval(scope.WithOpName("ret0"), add_n, 0);
1687 auto ret1 = ops::_Retval(scope.WithOpName("ret1"), while_op.output[0], 1);
1688 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1689 TF_ASSERT_OK(scope.ToGraph(graph.get()));
1690
1691 // Builds a description of the arguments.
1692 std::vector<XlaCompiler::Argument> args(1);
1693 args[0].kind = XlaCompiler::Argument::kTensorList;
1694 xla::Shape tensor_list_element_shape;
1695 TF_ASSERT_OK(TensorShapeToXLAShape(DT_INT32, TensorShape{1},
1696 &tensor_list_element_shape));
1697 xla::Shape index_shape;
1698 TF_ASSERT_OK(TensorShapeToXLAShape(DT_INT32, TensorShape{}, &index_shape));
1699 std::vector<xla::Shape> shapes{tensor_list_element_shape, index_shape};
1700 xla::Shape arg_shape = xla::ShapeUtil::MakeTupleShape(shapes);
1701 args[0].shape = arg_shape;
1702
1703 // Compiles the graph.
1704 XlaCompiler::Options options = DefaultOptions();
1705 options.flib_def = &flib_def;
1706 XlaCompiler compiler(options);
1707
1708 XlaCompiler::CompilationResult result;
1709 TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add",
1710 std::move(graph), args, &result));
1711 ASSERT_EQ(result.outputs.size(), 2);
1712 const XlaCompiler::OutputDescription& output0 = result.outputs[0];
1713 ASSERT_TRUE(output0.is_tensor_list);
1714 const XlaCompiler::OutputDescription& output1 = result.outputs[1];
1715 ASSERT_TRUE(output1.is_tensor_list);
1716 }
1717
1718 // Test the compiler supports WhileOp with a loop body where DT_RESOURCE
1719 // variables are both inputs and outputs.
TEST_F(XlaCompilerTest,WhileWithResources)1720 TEST_F(XlaCompilerTest, WhileWithResources) {
1721 FunctionDefLibrary fdef_lib;
1722 FunctionLibraryDefinition flib_def(OpRegistry::Global(), fdef_lib);
1723 // Build cond fn for While.
1724 {
1725 Scope scope = Scope::NewRootScope().ExitOnError();
1726 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1727 auto arg0 = ops::_Arg(scope.WithOpName("arg0"), DT_INT32, 0);
1728 auto arg1 = ops::_Arg(scope.WithOpName("arg1"), DT_RESOURCE, 1);
1729 auto arg2 = ops::_Arg(scope.WithOpName("arg2"), DT_RESOURCE, 2);
1730 auto less = ops::Less(scope, arg0, ops::Const<int32>(scope, 10));
1731 (void)ops::_Retval(scope.WithOpName("ret"), less, 0);
1732 TF_ASSERT_OK(scope.ToGraph(graph.get()));
1733 FunctionDef fdef;
1734 TF_ASSERT_OK(GraphToFunctionDef(*graph, "cond", &fdef));
1735 TF_ASSERT_OK(flib_def.AddFunctionDef(fdef));
1736 }
1737 // Build body fn for While.
1738 {
1739 Scope scope = Scope::NewRootScope().ExitOnError();
1740 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1741 auto arg0 = ops::_Arg(scope.WithOpName("arg0"), DT_INT32, 0);
1742 auto arg1 = ops::_Arg(scope.WithOpName("arg1"), DT_RESOURCE, 1);
1743 auto arg2 = ops::_Arg(scope.WithOpName("arg2"), DT_RESOURCE, 2);
1744 auto read1 = ops::ReadVariableOp(scope.WithOpName("read1"), arg1, DT_INT32);
1745 auto plus_read1 = ops::Add(scope, arg0, read1);
1746 auto read2 = ops::ReadVariableOp(scope.WithOpName("read2"), arg2, DT_INT32);
1747 auto minus_read2 = ops::Sub(scope, plus_read1, read2);
1748 (void)ops::_Retval(scope.WithOpName("ret0"), minus_read2, 0);
1749 (void)ops::_Retval(scope.WithOpName("ret1"), arg1, 1);
1750 (void)ops::_Retval(scope.WithOpName("ret2"), arg2, 2);
1751 TF_ASSERT_OK(scope.ToGraph(graph.get()));
1752 FunctionDef fdef;
1753 TF_ASSERT_OK(GraphToFunctionDef(*graph, "body", &fdef));
1754 TF_ASSERT_OK(flib_def.AddFunctionDef(fdef));
1755 }
1756
1757 Scope scope = Scope::NewRootScope().ExitOnError();
1758 auto arg0 = ops::_Arg(scope.WithOpName("arg0"), DT_INT32, 0);
1759 auto arg1 = ops::_Arg(scope.WithOpName("arg1"), DT_RESOURCE, 1);
1760 auto arg2 = ops::_Arg(scope.WithOpName("arg2"), DT_RESOURCE, 2);
1761
1762 NameAttrList cond_fn, body_fn;
1763 cond_fn.set_name("cond");
1764 body_fn.set_name("body");
1765 auto while_op = ops::While(
1766 scope, std::initializer_list<Input>{arg0, arg1, arg2}, cond_fn, body_fn);
1767
1768 (void)ops::_Retval(scope.WithOpName("ret0"), while_op.output[0], 0);
1769 (void)ops::_Retval(scope.WithOpName("ret1"), while_op.output[1], 1);
1770 (void)ops::_Retval(scope.WithOpName("ret2"), while_op.output[2], 2);
1771 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1772 TF_ASSERT_OK(scope.ToGraph(graph.get()));
1773
1774 // Builds a description of the arguments.
1775 std::vector<XlaCompiler::Argument> args(3);
1776 args[0].kind = XlaCompiler::Argument::kParameter;
1777 args[0].type = DT_INT32;
1778 args[0].shape = TensorShape({});
1779 args[1].kind = XlaCompiler::Argument::kResource;
1780 args[1].resource_kind = XlaResource::kVariable;
1781 args[1].initialized = true;
1782 args[1].type = DT_INT32;
1783 args[1].shape = TensorShape({});
1784 args[2].kind = XlaCompiler::Argument::kResource;
1785 args[2].resource_kind = XlaResource::kVariable;
1786 args[2].initialized = true;
1787 args[2].type = DT_INT32;
1788 args[2].shape = TensorShape({});
1789
1790 // Compiles the graph.
1791 XlaCompiler::Options options = DefaultOptions();
1792 options.flib_def = &flib_def;
1793 XlaCompiler compiler(options);
1794
1795 XlaCompiler::CompileOptions compile_options = XlaCompiler::CompileOptions();
1796 compile_options.return_updated_values_for_all_resources = true;
1797 XlaCompiler::CompilationResult result;
1798 TF_ASSERT_OK(compiler.CompileGraph(compile_options, "tested_while_with_vars",
1799 std::move(graph), args, &result));
1800 ASSERT_EQ(result.outputs.size(), 3);
1801 const XlaCompiler::OutputDescription& output1 = result.outputs[1];
1802 ASSERT_EQ(output1.input_index, 1);
1803 const XlaCompiler::OutputDescription& output2 = result.outputs[2];
1804 ASSERT_EQ(output2.input_index, 2);
1805
1806 // Tests that the generated computation works.
1807 xla::Literal literal0 = xla::LiteralUtil::CreateR0<int32>(0);
1808 xla::Literal literal1 = xla::LiteralUtil::CreateR0<int32>(2);
1809 xla::Literal literal2 = xla::LiteralUtil::CreateR0<int32>(1);
1810 std::unique_ptr<xla::GlobalData> data0 =
1811 client_->TransferToServer(literal0).value();
1812 std::unique_ptr<xla::GlobalData> data1 =
1813 client_->TransferToServer(literal1).value();
1814 std::unique_ptr<xla::GlobalData> data2 =
1815 client_->TransferToServer(literal2).value();
1816
1817 std::unique_ptr<xla::GlobalData> actual =
1818 client_
1819 ->Execute(*result.computation,
1820 {data0.get(), data1.get(), data2.get()})
1821 .value();
1822 xla::Literal actual_literal = client_->Transfer(*actual).value();
1823
1824 xla::Literal expected0 = xla::LiteralUtil::CreateR0<int32>(10);
1825 xla::Literal expected1 = xla::LiteralUtil::CreateR0<int32>(2);
1826 xla::Literal expected2 = xla::LiteralUtil::CreateR0<int32>(1);
1827 xla::Literal expected_literal =
1828 xla::LiteralUtil::MakeTuple({&expected0, &expected1, &expected2});
1829 EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal));
1830 }
1831
TEST_F(XlaCompilerTest,SetShardingForReturnedTuple)1832 TEST_F(XlaCompilerTest, SetShardingForReturnedTuple) {
1833 // Builds a graph that returns its only argument.
1834 Scope scope = Scope::NewRootScope().ExitOnError();
1835 auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0);
1836 auto b = ops::_Retval(scope.WithOpName("B"), a, 0);
1837 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1838 TF_ASSERT_OK(scope.ToGraph(graph.get()));
1839
1840 // Sets _XlaSharding attribute for the _Retval node.
1841 auto node_name_index = graph->BuildNodeNameIndex();
1842 Node* ret_node = node_name_index["B"];
1843 ASSERT_NE(ret_node, nullptr);
1844 xla::Array<int64_t> tile_assignment({2});
1845 tile_assignment.FillIota(0);
1846 xla::HloSharding sharding = xla::HloSharding::Tile(tile_assignment);
1847 ret_node->AddAttr("_XlaSharding", sharding.ToProto().SerializeAsString());
1848
1849 // Builds a description of the arguments.
1850 std::vector<XlaCompiler::Argument> args(1);
1851 args[0].kind = XlaCompiler::Argument::kParameter;
1852 args[0].type = DT_INT32;
1853 args[0].shape = TensorShape({2});
1854
1855 // Compiles the graph.
1856 XlaCompiler compiler(DefaultOptions());
1857
1858 XlaCompiler::CompilationResult result;
1859 TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "test",
1860 std::move(graph), args, &result));
1861
1862 // Tests that we set sharding on the root TUPLE instruction.
1863 const auto& hlo_module_proto = result.computation->proto();
1864 ASSERT_EQ(hlo_module_proto.computations_size(), 1);
1865 const auto& hlo_computation_proto = hlo_module_proto.computations(0);
1866 std::optional<xla::HloInstructionProto> root_instruction_proto;
1867 for (const auto& inst : hlo_computation_proto.instructions()) {
1868 if (inst.id() == hlo_computation_proto.root_id()) {
1869 root_instruction_proto = inst;
1870 break;
1871 }
1872 }
1873 ASSERT_TRUE(root_instruction_proto);
1874 xla::Shape tuple_shape = xla::ShapeUtil::MakeTupleShape(
1875 {xla::ShapeUtil::MakeShape(xla::S32, {2})});
1876 xla::HloSharding tuple_sharding = xla::HloSharding::Tuple(
1877 tuple_shape, std::vector<xla::HloSharding>{sharding});
1878 EXPECT_EQ(root_instruction_proto->sharding().SerializeAsString(),
1879 tuple_sharding.ToProto().SerializeAsString());
1880 }
1881
TEST_F(XlaCompilerTest,AliasResourceUpdates)1882 TEST_F(XlaCompilerTest, AliasResourceUpdates) {
1883 Scope scope = Scope::NewRootScope().ExitOnError();
1884 auto a = ops::Const<int32>(scope.WithOpName("A"), {1, 2});
1885 auto var = ops::_Arg(scope.WithOpName("V"), DT_RESOURCE, 1);
1886 auto write = ops::AssignAddVariableOp(scope, var, a);
1887 auto read = ops::ReadVariableOp(
1888 scope.WithControlDependencies(std::vector<Operation>{write}), var,
1889 DT_INT32);
1890 auto d = ops::_Retval(scope.WithOpName("D"), read, 0);
1891 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1892 TF_ASSERT_OK(scope.ToGraph(graph.get()));
1893
1894 // Builds a description of the arguments.
1895 std::vector<XlaCompiler::Argument> args(2);
1896 args[0].kind = XlaCompiler::Argument::kConstant;
1897 args[0].type = DT_INT32;
1898 args[0].shape = TensorShape({2});
1899 args[0].constant_value = Tensor(DT_INT32, {1, 1});
1900 args[0].initialized = true;
1901
1902 args[1].kind = XlaCompiler::Argument::kResource;
1903 args[1].resource_kind = XlaResource::kVariable;
1904 args[1].initialized = true;
1905 args[1].type = DT_INT32;
1906 args[1].shape = TensorShape({2});
1907
1908 XlaCompiler compiler(DefaultOptions());
1909
1910 XlaCompiler::CompileOptions compile_options;
1911 compile_options.alias_resource_update = true;
1912
1913 XlaCompiler::CompilationResult result;
1914 TF_ASSERT_OK(compiler.CompileGraph(compile_options, "add", std::move(graph),
1915 args, &result));
1916
1917 const xla::HloInputOutputAliasProto& alias =
1918 result.computation->proto().input_output_alias();
1919 EXPECT_EQ(alias.entries_size(), 1);
1920 EXPECT_EQ(alias.entries(0).parameter_number(), 0);
1921 }
1922
1923 // Tests that passing in an exact duplicate input to SetDeviceToHostMeatadata
1924 // is not an error.
TEST_F(XlaCompilerTest,SetDeviceToHostMetadataExactDuplicate)1925 TEST_F(XlaCompilerTest, SetDeviceToHostMetadataExactDuplicate) {
1926 XlaCompiler compiler(DefaultOptions());
1927
1928 const string& key = "comm_key";
1929 std::vector<DataType> types{DT_INT32};
1930 std::vector<TensorShape> shapes{TensorShape({2})};
1931
1932 TF_ASSERT_OK(compiler.SetDeviceToHostMetadata(key, types, shapes));
1933 TF_ASSERT_OK(compiler.SetDeviceToHostMetadata(key, types, shapes));
1934 }
1935
1936 // Tests that passing in a mismatched duplicate input to
1937 // SetDeviceToHostMeatadata is not an error.
TEST_F(XlaCompilerTest,SetDeviceToHostMetadataMismatchedDuplicate)1938 TEST_F(XlaCompilerTest, SetDeviceToHostMetadataMismatchedDuplicate) {
1939 XlaCompiler compiler(DefaultOptions());
1940
1941 const string& key = "comm_key";
1942 std::vector<DataType> types{DT_INT32};
1943 std::vector<TensorShape> shapes{TensorShape({2})};
1944 std::vector<DataType> types2{DT_FLOAT};
1945 std::vector<TensorShape> shapes2{TensorShape({1})};
1946
1947 TF_ASSERT_OK(compiler.SetDeviceToHostMetadata(key, types, shapes));
1948 Status status = compiler.SetDeviceToHostMetadata(key, types2, shapes2);
1949 EXPECT_EQ(status.code(), error::Code::INVALID_ARGUMENT);
1950 }
1951
1952 // Tests that passing in an exact duplicate input to SetHostToDeviceMeatadata
1953 // is not an error.
TEST_F(XlaCompilerTest,SetHostToDeviceMetadataExactDuplicate)1954 TEST_F(XlaCompilerTest, SetHostToDeviceMetadataExactDuplicate) {
1955 XlaCompiler compiler(DefaultOptions());
1956
1957 const string& key = "comm_key";
1958 std::vector<DataType> types{DT_INT32};
1959 std::vector<TensorShape> shapes{TensorShape({2})};
1960
1961 TF_ASSERT_OK(compiler.SetHostToDeviceMetadata(key, types, shapes));
1962 TF_ASSERT_OK(compiler.SetHostToDeviceMetadata(key, types, shapes));
1963 }
1964
1965 // Tests that passing in a mismatched duplicate input to
1966 // SetHostToDeviceMeatadata is not an error.
TEST_F(XlaCompilerTest,SetHostToDeviceMetadataMismatchedDuplicate)1967 TEST_F(XlaCompilerTest, SetHostToDeviceMetadataMismatchedDuplicate) {
1968 XlaCompiler compiler(DefaultOptions());
1969
1970 const string& key = "comm_key";
1971 std::vector<DataType> types{DT_INT32};
1972 std::vector<TensorShape> shapes{TensorShape({2})};
1973 std::vector<DataType> types2{DT_FLOAT};
1974 std::vector<TensorShape> shapes2{TensorShape({1})};
1975
1976 TF_ASSERT_OK(compiler.SetHostToDeviceMetadata(key, types, shapes));
1977 Status status = compiler.SetHostToDeviceMetadata(key, types2, shapes2);
1978 EXPECT_EQ(status.code(), error::Code::INVALID_ARGUMENT);
1979 }
1980
1981 } // namespace
1982 } // namespace tensorflow
1983