• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <initializer_list>
17 #include <memory>
18 #include <vector>
19 
20 #include "tensorflow/compiler/xla/client/client_library.h"
21 #include "tensorflow/compiler/xla/client/local_client.h"
22 #include "tensorflow/compiler/xla/client/sharding_builder.h"
23 #include "tensorflow/compiler/xla/client/xla_builder.h"
24 #include "tensorflow/compiler/xla/layout_util.h"
25 #include "tensorflow/compiler/xla/literal.h"
26 #include "tensorflow/compiler/xla/service/local_service.h"
27 #include "tensorflow/compiler/xla/service/platform_util.h"
28 #include "tensorflow/compiler/xla/service/shaped_buffer.h"
29 #include "tensorflow/compiler/xla/service/transfer_manager.h"
30 #include "tensorflow/compiler/xla/shape_util.h"
31 #include "tensorflow/compiler/xla/statusor.h"
32 #include "tensorflow/compiler/xla/test.h"
33 #include "tensorflow/compiler/xla/test_helpers.h"
34 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
35 #include "tensorflow/compiler/xla/tests/local_client_test_base.h"
36 #include "tensorflow/compiler/xla/tests/test_macros.h"
37 #include "tensorflow/compiler/xla/tests/test_utils.h"
38 #include "tensorflow/compiler/xla/xla_data.pb.h"
39 #include "tensorflow/core/platform/env.h"
40 #include "tensorflow/core/platform/logging.h"
41 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
42 #include "tensorflow/core/platform/test.h"
43 #include "tensorflow/core/platform/test_benchmark.h"
44 #include "tensorflow/stream_executor/device_memory_allocator.h"
45 
46 namespace xla {
47 namespace {
48 
49 using ::testing::ContainsRegex;
50 
51 class LocalClientExecuteTest : public LocalClientTestBase {
52  protected:
53   ErrorSpec error_spec_{0.0001};
54 };
55 
XLA_TEST_F(LocalClientExecuteTest,Constant)56 XLA_TEST_F(LocalClientExecuteTest, Constant) {
57   XlaBuilder builder(TestName());
58   ConstantR0<float>(&builder, 123.0f);
59 
60   ScopedShapedBuffer result =
61       ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {});
62   LiteralTestUtil::ExpectR0Near<float>(123.f, ShapedBufferToLiteral(result),
63                                        error_spec_);
64 }
65 
XLA_TEST_F(LocalClientExecuteTest,AddScalars)66 XLA_TEST_F(LocalClientExecuteTest, AddScalars) {
67   XlaBuilder builder(TestName());
68   auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {}), "x");
69   auto y = ConstantR0<float>(&builder, 123.0f);
70   Add(x, y);
71 
72   auto x_value = LiteralToShapedBuffer(LiteralUtil::CreateR0<float>(42.0f));
73   ScopedShapedBuffer result =
74       ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {&x_value});
75   LiteralTestUtil::ExpectR0Near<float>(165.f, ShapedBufferToLiteral(result),
76                                        error_spec_);
77 }
78 
XLA_TEST_F(LocalClientExecuteTest,AddZeroElementVectors)79 XLA_TEST_F(LocalClientExecuteTest, AddZeroElementVectors) {
80   XlaBuilder builder(TestName());
81   auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {0}), "x");
82   auto y = ConstantR1<float>(&builder, {});
83   Add(x, y);
84 
85   auto x_array = LiteralToShapedBuffer(LiteralUtil::CreateR1<float>({}));
86   ScopedShapedBuffer result =
87       ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {&x_array});
88   LiteralTestUtil::ExpectR1Near<float>({}, ShapedBufferToLiteral(result),
89                                        error_spec_);
90 }
91 
XLA_TEST_F(LocalClientExecuteTest,AddVectors)92 XLA_TEST_F(LocalClientExecuteTest, AddVectors) {
93   XlaBuilder builder(TestName());
94   auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {3}), "x");
95   auto y = ConstantR1<float>(&builder, {2.0f, 3.0f, 4.0f});
96   Add(x, y);
97 
98   auto x_array =
99       LiteralToShapedBuffer(LiteralUtil::CreateR1<float>({0.0f, 1.0f, 2.0f}));
100   ScopedShapedBuffer result =
101       ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {&x_array});
102   LiteralTestUtil::ExpectR1Near<float>(
103       {2.0f, 4.0f, 6.0f}, ShapedBufferToLiteral(result), error_spec_);
104 }
105 
XLA_TEST_F(LocalClientExecuteTest,AddVectorsWithProfile)106 XLA_TEST_F(LocalClientExecuteTest, AddVectorsWithProfile) {
107   XlaBuilder builder(TestName());
108   auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {3}), "x");
109   auto y = ConstantR1<float>(&builder, {2.0f, 3.0f, 4.0f});
110   Add(x, y);
111 
112   auto x_array =
113       LiteralToShapedBuffer(LiteralUtil::CreateR1<float>({0.0f, 1.0f, 2.0f}));
114   ExecutionProfile profile;
115   ScopedShapedBuffer result = ExecuteLocallyOrDie(
116       builder.Build().ValueOrDie(), {&x_array}, DefaultExecutableBuildOptions(),
117       DefaultExecutableRunOptions().set_execution_profile(&profile));
118 
119   LiteralTestUtil::ExpectR1Near<float>(
120       {2.0f, 4.0f, 6.0f}, ShapedBufferToLiteral(result), error_spec_);
121   EXPECT_GT(profile.compute_and_transfer_time_ns(), 0);
122 }
123 
XLA_TEST_F(LocalClientExecuteTest,AddArraysWithDifferentInputLayouts)124 XLA_TEST_F(LocalClientExecuteTest, AddArraysWithDifferentInputLayouts) {
125   XlaBuilder builder(TestName());
126   auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2, 2}), "x");
127   auto y = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {2, 2}), "y");
128   Add(x, y);
129   auto computation = builder.Build().ConsumeValueOrDie();
130 
131   // Create x as a col-major array.
132   auto x_array = LiteralToShapedBuffer(LiteralUtil::CreateR2WithLayout(
133       {{1.0f, 2.0f}, {3.0f, 4.0f}}, LayoutUtil::MakeLayout({0, 1})));
134   EXPECT_TRUE(Layout::Equal().MinorToMajorOnly()(
135       x_array.on_device_shape().layout(), LayoutUtil::MakeLayout({0, 1})));
136 
137   // Create y as a row-major array.
138   auto y_array = LiteralToShapedBuffer(LiteralUtil::CreateR2WithLayout(
139       {{10.0f, 20.0f}, {30.0f, 40.0f}}, LayoutUtil::MakeLayout({1, 0})));
140   EXPECT_TRUE(Layout::Equal().MinorToMajorOnly()(
141       y_array.on_device_shape().layout(), LayoutUtil::MakeLayout({1, 0})));
142 
143   ScopedShapedBuffer result_colmaj =
144       ExecuteLocallyOrDie(computation, {&x_array, &y_array});
145   LiteralTestUtil::ExpectR2Near<float>({{11.0f, 22.0f}, {33.0f, 44.0f}},
146                                        ShapedBufferToLiteral(result_colmaj),
147                                        error_spec_);
148 
149   // Run with the parameter values in a different order.
150   ScopedShapedBuffer result_param_swap =
151       ExecuteLocallyOrDie(computation, {&y_array, &x_array});
152   LiteralTestUtil::ExpectR2Near<float>({{11.0f, 22.0f}, {33.0f, 44.0f}},
153                                        ShapedBufferToLiteral(result_param_swap),
154                                        error_spec_);
155 }
156 
XLA_TEST_F(LocalClientExecuteTest,AddArraysWithDifferentOutputLayouts)157 XLA_TEST_F(LocalClientExecuteTest, AddArraysWithDifferentOutputLayouts) {
158   XlaBuilder builder(TestName());
159   auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2, 2}), "x");
160   auto y = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {2, 2}), "y");
161   Add(x, y);
162   auto computation = builder.Build().ConsumeValueOrDie();
163 
164   auto x_array = LiteralToShapedBuffer(
165       LiteralUtil::CreateR2<float>({{1.0f, 2.0f}, {3.0f, 4.0f}}));
166   auto y_array = LiteralToShapedBuffer(
167       LiteralUtil::CreateR2<float>({{10.0f, 20.0f}, {30.0f, 40.0f}}));
168 
169   // Run with col-major result layout.
170   ScopedShapedBuffer result_colmaj = ExecuteLocallyOrDie(
171       computation, {&x_array, &y_array},
172       DefaultExecutableBuildOptions().set_result_layout(
173           ShapeUtil::MakeShapeWithLayout(F32, /*dimensions=*/{2, 2}, {0, 1})),
174       DefaultExecutableRunOptions());
175   EXPECT_TRUE(Layout::Equal().MinorToMajorOnly()(
176       result_colmaj.on_device_shape().layout(),
177       LayoutUtil::MakeLayout({0, 1})));
178   LiteralTestUtil::ExpectR2Near<float>({{11.0f, 22.0f}, {33.0f, 44.0f}},
179                                        ShapedBufferToLiteral(result_colmaj),
180                                        error_spec_);
181 
182   // Run with row-major result layout.
183   ScopedShapedBuffer result_rowmaj = ExecuteLocallyOrDie(
184       computation, {&x_array, &y_array},
185       DefaultExecutableBuildOptions().set_result_layout(
186           ShapeUtil::MakeShapeWithLayout(F32, /*dimensions=*/{2, 2}, {1, 0})),
187       DefaultExecutableRunOptions());
188   EXPECT_TRUE(Layout::Equal().MinorToMajorOnly()(
189       result_rowmaj.on_device_shape().layout(),
190       LayoutUtil::MakeLayout({1, 0})));
191   LiteralTestUtil::ExpectR2Near<float>({{11.0f, 22.0f}, {33.0f, 44.0f}},
192                                        ShapedBufferToLiteral(result_rowmaj),
193                                        error_spec_);
194 }
195 
XLA_TEST_F(LocalClientExecuteTest,TupleResult)196 XLA_TEST_F(LocalClientExecuteTest, TupleResult) {
197   XlaBuilder builder(TestName());
198   auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2, 2}), "x");
199   auto y = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {2, 2}), "y");
200   Tuple(&builder, {x, y, x});
201   auto computation = builder.Build().ConsumeValueOrDie();
202 
203   auto x_array = LiteralToShapedBuffer(
204       LiteralUtil::CreateR2<float>({{1.0f, 2.0f}, {3.0f, 4.0f}}));
205   auto y_array = LiteralToShapedBuffer(
206       LiteralUtil::CreateR2<float>({{10.0f, 20.0f}, {30.0f, 40.0f}}));
207 
208   ScopedShapedBuffer result =
209       ExecuteLocallyOrDie(computation, {&x_array, &y_array});
210 
211   EXPECT_TRUE(result.on_host_shape().IsTuple());
212   EXPECT_EQ(3, ShapeUtil::TupleElementCount(result.on_host_shape()));
213 
214   Literal result_literal = ShapedBufferToLiteral(result);
215   LiteralTestUtil::ExpectR2Equal<float>({{1.0f, 2.0f}, {3.0f, 4.0f}},
216                                         LiteralSlice(result_literal, {0}));
217   LiteralTestUtil::ExpectR2Equal<float>({{10.0f, 20.0f}, {30.0f, 40.0f}},
218                                         LiteralSlice(result_literal, {1}));
219   LiteralTestUtil::ExpectR2Equal<float>({{1.0f, 2.0f}, {3.0f, 4.0f}},
220                                         LiteralSlice(result_literal, {2}));
221 }
222 
XLA_TEST_F(LocalClientExecuteTest,NestedTupleResult)223 XLA_TEST_F(LocalClientExecuteTest, NestedTupleResult) {
224   XlaBuilder builder(TestName());
225   auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2, 2}), "x");
226   auto y = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {2, 2}), "y");
227   auto inner_tuple = Tuple(&builder, {x, y, x});
228   Tuple(&builder, {inner_tuple, x});
229   auto computation = builder.Build().ConsumeValueOrDie();
230 
231   auto x_array = LiteralToShapedBuffer(
232       LiteralUtil::CreateR2<float>({{1.0f, 2.0f}, {3.0f, 4.0f}}));
233   auto y_array = LiteralToShapedBuffer(
234       LiteralUtil::CreateR2<float>({{10.0f, 20.0f}, {30.0f, 40.0f}}));
235 
236   ScopedShapedBuffer result =
237       ExecuteLocallyOrDie(computation, {&x_array, &y_array});
238 
239   EXPECT_TRUE(result.on_host_shape().IsTuple());
240   EXPECT_EQ(2, ShapeUtil::TupleElementCount(result.on_host_shape()));
241 
242   Literal result_literal = ShapedBufferToLiteral(result);
243   LiteralTestUtil::ExpectR2Equal<float>({{1.0f, 2.0f}, {3.0f, 4.0f}},
244                                         LiteralSlice(result_literal, {1}));
245   LiteralTestUtil::ExpectR2Equal<float>({{1.0f, 2.0f}, {3.0f, 4.0f}},
246                                         LiteralSlice(result_literal, {0, 0}));
247   LiteralTestUtil::ExpectR2Equal<float>({{10.0f, 20.0f}, {30.0f, 40.0f}},
248                                         LiteralSlice(result_literal, {0, 1}));
249   LiteralTestUtil::ExpectR2Equal<float>({{1.0f, 2.0f}, {3.0f, 4.0f}},
250                                         LiteralSlice(result_literal, {0, 2}));
251 }
252 
XLA_TEST_F(LocalClientExecuteTest,TupleResultWithLayout)253 XLA_TEST_F(LocalClientExecuteTest, TupleResultWithLayout) {
254   // Verify setting the result layout of a computation with a tuple output.
255   XlaBuilder builder(TestName());
256   auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2, 2}), "x");
257   auto y = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {2, 2}), "y");
258   Tuple(&builder, {x, y});
259 
260   auto array = LiteralToShapedBuffer(
261       LiteralUtil::CreateR2<float>({{1.0f, 2.0f}, {3.0f, 4.0f}}));
262 
263   ExecutableBuildOptions options = DefaultExecutableBuildOptions();
264   Shape shape_with_layout = ShapeUtil::MakeTupleShape(
265       {ShapeUtil::MakeShapeWithLayout(F32, /*dimensions=*/{2, 2},
266                                       /*minor_to_major=*/{0, 1}),
267        ShapeUtil::MakeShapeWithLayout(F32, /*dimensions=*/{2, 2},
268                                       /*minor_to_major=*/{1, 0})});
269   options.set_result_layout(shape_with_layout);
270   ScopedShapedBuffer result =
271       ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {&array, &array},
272                           options, DefaultExecutableRunOptions());
273 
274   Literal result_literal = ShapedBufferToLiteral(result);
275   LiteralTestUtil::ExpectR2Equal<float>({{1.0f, 2.0f}, {3.0f, 4.0f}},
276                                         LiteralSlice(result_literal, {0}));
277   LiteralTestUtil::ExpectR2Equal<float>({{1.0f, 2.0f}, {3.0f, 4.0f}},
278                                         LiteralSlice(result_literal, {1}));
279 }
280 
XLA_TEST_F(LocalClientExecuteTest,TupleArguments)281 XLA_TEST_F(LocalClientExecuteTest, TupleArguments) {
282   const Shape array_shape = ShapeUtil::MakeShape(F32, {2, 2});
283   const Shape vector_shape = ShapeUtil::MakeShape(F32, {3});
284 
285   const Shape tuple_shape0 =
286       ShapeUtil::MakeTupleShape({array_shape, vector_shape});
287   const Shape tuple_shape1 =
288       ShapeUtil::MakeTupleShape({vector_shape, array_shape});
289 
290   // Computation adds the respective array and vector elements from each tuple
291   // argument and returns the results as a tuple.
292   XlaBuilder builder(TestName());
293   auto x = Parameter(&builder, 0, tuple_shape0, "x");
294   auto y = Parameter(&builder, 1, tuple_shape1, "y");
295   auto x_0 = GetTupleElement(x, 0);
296   auto x_1 = GetTupleElement(x, 1);
297   auto y_0 = GetTupleElement(y, 0);
298   auto y_1 = GetTupleElement(y, 1);
299   auto array_sum = Add(x_0, y_1);
300   auto vector_diff = Sub(x_1, y_0);
301   Tuple(&builder, {array_sum, vector_diff});
302   auto computation = builder.Build().ConsumeValueOrDie();
303 
304   auto x_literal = LiteralUtil::MakeTupleFromSlices(
305       {LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}),
306        LiteralUtil::CreateR1<float>({42.0, 75.0, 123.0})});
307   auto y_literal = LiteralUtil::MakeTupleFromSlices(
308       {LiteralUtil::CreateR1<float>({2.0, 4.0, 6.0}),
309        LiteralUtil::CreateR2<float>({{55.0, 44.0}, {33.0, 22.0}})});
310 
311   auto x_buffer = LiteralToShapedBuffer(x_literal);
312   auto y_buffer = LiteralToShapedBuffer(y_literal);
313 
314   ScopedShapedBuffer result =
315       ExecuteLocallyOrDie(computation, {&x_buffer, &y_buffer});
316 
317   EXPECT_TRUE(result.on_host_shape().IsTuple());
318   EXPECT_EQ(2, ShapeUtil::TupleElementCount(result.on_host_shape()));
319 
320   Literal result_literal = ShapedBufferToLiteral(result);
321   LiteralTestUtil::ExpectR2Equal<float>({{56.0f, 46.0f}, {36.0f, 26.0f}},
322                                         LiteralSlice(result_literal, {0}));
323   LiteralTestUtil::ExpectR1Equal<float>({40.0f, 71.0f, 117.0f},
324                                         LiteralSlice(result_literal, {1}));
325 }
326 
XLA_TEST_F(LocalClientExecuteTest,NestedTupleArgument)327 XLA_TEST_F(LocalClientExecuteTest, NestedTupleArgument) {
328   const Shape array_shape = ShapeUtil::MakeShape(F32, {2, 2});
329   const Shape vector_shape = ShapeUtil::MakeShape(F32, {3});
330 
331   const Shape inner_tuple_shape =
332       ShapeUtil::MakeTupleShape({array_shape, vector_shape});
333   const Shape nested_tuple_shape =
334       ShapeUtil::MakeTupleShape({inner_tuple_shape, vector_shape});
335 
336   // Computation negates the array element and sums the two vector elements in
337   // the nested tuple. The resulting array and vector are returned as a tuple.
338   XlaBuilder builder(TestName());
339   auto param = Parameter(&builder, 0, nested_tuple_shape, "param");
340   auto inner_tuple = GetTupleElement(param, 0);
341   auto inner_array = GetTupleElement(inner_tuple, 0);
342   auto inner_vector = GetTupleElement(inner_tuple, 1);
343   auto outer_vector = GetTupleElement(param, 1);
344 
345   auto negate_array = Neg(inner_array);
346   auto vector_sum = Add(inner_vector, outer_vector);
347   Tuple(&builder, {negate_array, vector_sum});
348   auto computation = builder.Build().ConsumeValueOrDie();
349 
350   auto arg_literal = LiteralUtil::MakeTupleFromSlices(
351       {LiteralUtil::MakeTupleFromSlices(
352            {LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}),
353             LiteralUtil::CreateR1<float>({42.0, 75.0, 123.0})}),
354        LiteralUtil::CreateR1<float>({222.0, -2.0, 10.0})});
355   auto arg_buffer = LiteralToShapedBuffer(arg_literal);
356 
357   ScopedShapedBuffer result = ExecuteLocallyOrDie(computation, {&arg_buffer});
358 
359   Literal result_literal = ShapedBufferToLiteral(result);
360   LiteralTestUtil::ExpectR2Equal<float>({{-1.0, -2.0}, {-3.0, -4}},
361                                         LiteralSlice(result_literal, {0}));
362   LiteralTestUtil::ExpectR1Equal<float>({264.0, 73.0, 133.0},
363                                         LiteralSlice(result_literal, {1}));
364 }
365 
XLA_TEST_F(LocalClientExecuteTest,PassingTupleResultBackIntoComputation)366 XLA_TEST_F(LocalClientExecuteTest, PassingTupleResultBackIntoComputation) {
367   // Construct a computation which takes and returns the same shape (a
368   // tuple). Feed the result of the computation back into the input. This
369   // provides additional verification that the returned tuple is properly
370   // constructed.
371   const Shape array_shape = ShapeUtil::MakeShape(F32, {2, 2});
372   const Shape tuple_shape =
373       ShapeUtil::MakeTupleShape({array_shape, array_shape});
374 
375   XlaBuilder builder(TestName());
376   auto param = Parameter(&builder, 0, tuple_shape, "param");
377   auto element_0 = GetTupleElement(param, 0);
378   auto element_1 = GetTupleElement(param, 1);
379   Tuple(&builder, {Neg(element_0), Add(element_1, element_1)});
380   auto computation = builder.Build().ConsumeValueOrDie();
381 
382   auto arg_literal = LiteralUtil::MakeTupleFromSlices(
383       {LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}),
384        LiteralUtil::CreateR2<float>({{11.0, 3.0}, {4.0, 5.0}})});
385   auto arg_buffer = LiteralToShapedBuffer(arg_literal);
386 
387   ScopedShapedBuffer result_0 = ExecuteLocallyOrDie(computation, {&arg_buffer});
388   Literal result_0_literal = ShapedBufferToLiteral(result_0);
389   LiteralTestUtil::ExpectR2Equal<float>({{-1.0, -2.0}, {-3.0, -4.0}},
390                                         LiteralSlice(result_0_literal, {0}));
391   LiteralTestUtil::ExpectR2Equal<float>({{22.0, 6.0}, {8.0, 10}},
392                                         LiteralSlice(result_0_literal, {1}));
393 
394   ScopedShapedBuffer result_1 = ExecuteLocallyOrDie(computation, {&result_0});
395   Literal result_1_literal = ShapedBufferToLiteral(result_1);
396   LiteralTestUtil::ExpectR2Equal<float>({{1.0, 2.0}, {3.0, 4.0}},
397                                         LiteralSlice(result_1_literal, {0}));
398   LiteralTestUtil::ExpectR2Equal<float>({{44.0, 12.0}, {16.0, 20}},
399                                         LiteralSlice(result_1_literal, {1}));
400 }
401 
XLA_TEST_F(LocalClientExecuteTest,LargeTuple)402 XLA_TEST_F(LocalClientExecuteTest, LargeTuple) {
403   // Construct a computation which takes a tuple parameter with a very large
404   // number of elements.
405 
406   // A larger number of elements would make for a better, more strenuous test,
407   // but:
408   // TODO(b/66959878): On cpu a large number of elements results in long
409   //   compilation time.
410   // TODO(b/66954197): On gpu a large number of elements OOMs.
411   const int kElementCount = 100;
412 
413   // Each element is a 2-element vector.
414   const Shape element_shape = ShapeUtil::MakeShape(F32, {2});
415   std::vector<Shape> element_shapes(kElementCount, element_shape);
416   const Shape tuple_shape = ShapeUtil::MakeTupleShape(element_shapes);
417 
418   XlaBuilder builder(TestName());
419   auto param = Parameter(&builder, 0, tuple_shape, "param");
420 
421   // Add each element's tuple index value to every element.
422   std::vector<XlaOp> result_elements;
423   for (int i = 0; i < kElementCount; ++i) {
424     auto element = GetTupleElement(param, i);
425     result_elements.push_back(Add(element, ConstantR0<float>(&builder, i)));
426   }
427   Tuple(&builder, result_elements);
428   auto computation = builder.Build().ConsumeValueOrDie();
429 
430   // Feed in a tuple where each two-element vector element is {tuple_index,
431   // -tuple_index}.
432   std::vector<Literal> arg_elements;
433   for (int i = 0; i < kElementCount; ++i) {
434     arg_elements.push_back(LiteralUtil::CreateR1<float>({1.0f * i, -1.0f * i}));
435   }
436   Literal arg_literal = LiteralUtil::MakeTupleOwned(std::move(arg_elements));
437   auto arg_buffer = LiteralToShapedBuffer(arg_literal);
438 
439   ScopedShapedBuffer result = ExecuteLocallyOrDie(computation, {&arg_buffer});
440   Literal result_literal = ShapedBufferToLiteral(result);
441 
442   for (int i = 0; i < kElementCount; ++i) {
443     LiteralTestUtil::ExpectR1Near<float>(
444         {2.0f * i, 0.0f}, LiteralSlice(result_literal, {i}), error_spec_);
445   }
446 }
447 
XLA_TEST_F(LocalClientExecuteTest,LargeNestedTuple)448 XLA_TEST_F(LocalClientExecuteTest, LargeNestedTuple) {
449   // Construct and run a computation which takes a two-level nested tuple
450   // parameter with a large fanout.
451   const int kFanout = 40;
452 
453   // Tuple shape is full two-level tree with the given fanout.
454   const Shape element_shape = ShapeUtil::MakeShape(F32, {});
455   std::vector<Shape> element_shapes(kFanout, element_shape);
456   const Shape inner_tuple_shape = ShapeUtil::MakeTupleShape(element_shapes);
457   std::vector<Shape> inner_tuple_shapes(kFanout, inner_tuple_shape);
458   const Shape tuple_shape = ShapeUtil::MakeTupleShape(inner_tuple_shapes);
459 
460   XlaBuilder builder(TestName());
461   auto param = Parameter(&builder, 0, tuple_shape, "param");
462 
463   // The computation increments each leaf value by an amount equal to the leaf's
464   // ordinal position in a traversal of the tuple.
465   std::vector<XlaOp> result_elements;
466   for (int i = 0; i < kFanout; ++i) {
467     auto outer_element = GetTupleElement(param, i);
468     std::vector<XlaOp> inner_result_elements;
469     for (int j = 0; j < kFanout; ++j) {
470       auto inner_element = GetTupleElement(outer_element, j);
471       inner_result_elements.push_back(
472           Add(inner_element, ConstantR0<float>(&builder, i * kFanout + j)));
473     }
474     result_elements.push_back(Tuple(&builder, inner_result_elements));
475   }
476   Tuple(&builder, result_elements);
477   auto computation = builder.Build().ConsumeValueOrDie();
478 
479   // Construct the argument to pass to the computation.
480   std::vector<Literal> outer_tuple_elements;
481   for (int i = 0; i < kFanout; ++i) {
482     std::vector<Literal> inner_tuple_elements;
483     for (int j = 0; j < kFanout; ++j) {
484       inner_tuple_elements.push_back(LiteralUtil::CreateR0<float>(i + j));
485     }
486     outer_tuple_elements.push_back(
487         LiteralUtil::MakeTupleOwned(std::move(inner_tuple_elements)));
488   }
489   auto arg_literal =
490       LiteralUtil::MakeTupleOwned(std::move(outer_tuple_elements));
491   auto arg_buffer = LiteralToShapedBuffer(arg_literal);
492 
493   ScopedShapedBuffer result = ExecuteLocallyOrDie(computation, {&arg_buffer});
494   Literal result_literal = ShapedBufferToLiteral(result);
495 
496   for (int i = 0; i < kFanout; ++i) {
497     for (int j = 0; j < kFanout; ++j) {
498       LiteralTestUtil::ExpectR0Near<float>(i + j + i * kFanout + j,
499                                            LiteralSlice(result_literal, {i, j}),
500                                            error_spec_);
501     }
502   }
503 }
504 
XLA_TEST_F(LocalClientExecuteTest,DeepTuple)505 XLA_TEST_F(LocalClientExecuteTest, DeepTuple) {
506   // Construct and run a computation which takes a very deep tuple. The tuple
507   // has no fan out and a single scalar element at the bottom.
508   const int kTupleDepth = 100;
509 
510   // Tuple shape is full two-level tree with the given fanout.
511   Shape shape = ShapeUtil::MakeShape(F32, {});
512   for (int i = 0; i < kTupleDepth; ++i) {
513     shape = ShapeUtil::MakeTupleShape({shape});
514   }
515 
516   XlaBuilder builder(TestName());
517   auto element = Parameter(&builder, 0, shape, "param");
518   for (int i = 0; i < kTupleDepth; ++i) {
519     element = GetTupleElement(element, 0);
520   }
521 
522   auto output = Add(element, ConstantR0<float>(&builder, 42.0));
523   for (int i = 0; i < kTupleDepth; ++i) {
524     output = Tuple(&builder, {output});
525   }
526   auto computation = builder.Build().ConsumeValueOrDie();
527 
528   // Construct the argument to pass to the computation.
529   Literal arg_literal = LiteralUtil::CreateR0<float>(123.0);
530   for (int i = 0; i < kTupleDepth; ++i) {
531     std::vector<Literal> arg_vector;
532     arg_vector.push_back(std::move(arg_literal));
533     arg_literal = LiteralUtil::MakeTupleOwned(std::move(arg_vector));
534   }
535   auto arg_buffer = LiteralToShapedBuffer(arg_literal);
536 
537   ScopedShapedBuffer result = ExecuteLocallyOrDie(computation, {&arg_buffer});
538   Literal result_literal = ShapedBufferToLiteral(result);
539 
540   ShapeIndex index;
541   for (int i = 0; i < kTupleDepth; ++i) {
542     index.push_back(0);
543   }
544   LiteralTestUtil::ExpectR0Equal<float>(165.0,
545                                         LiteralSlice(result_literal, index));
546 }
547 
XLA_TEST_F(LocalClientExecuteTest,InvalidNumberOfArguments)548 XLA_TEST_F(LocalClientExecuteTest, InvalidNumberOfArguments) {
549   // Test passing in an invalid number of arguments.
550   XlaBuilder builder(TestName());
551   auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {3}), "x");
552   auto y = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {3}), "y");
553   Add(x, y);
554 
555   auto x_array =
556       LiteralToShapedBuffer(LiteralUtil::CreateR1<float>({1.0f, 2.0f, 3.0f}));
557   auto execute_status =
558       ExecuteLocally(builder.Build().ValueOrDie(), {&x_array});
559 
560   EXPECT_FALSE(execute_status.ok());
561   EXPECT_THAT(execute_status.status().error_message(),
562               ContainsRegex("Invalid number of arguments"));
563 }
564 
XLA_TEST_F(LocalClientExecuteTest,IncorrectArgumentShape)565 XLA_TEST_F(LocalClientExecuteTest, IncorrectArgumentShape) {
566   // Test passing in an argument with the wrong shape.
567   XlaBuilder builder(TestName());
568   auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {3}), "x");
569   Neg(x);
570 
571   auto x_array = LiteralToShapedBuffer(
572       LiteralUtil::CreateR2<float>({{0.0f, 1.0f}, {2.0f, 3.0f}}));
573   auto execute_status =
574       ExecuteLocally(builder.Build().ValueOrDie(), {&x_array});
575 
576   EXPECT_FALSE(execute_status.ok());
577   EXPECT_THAT(execute_status.status().error_message(),
578               ContainsRegex("Invalid argument shape"))
579       << execute_status.status();
580 }
581 
XLA_TEST_F(LocalClientExecuteTest,InvalidResultLayout)582 XLA_TEST_F(LocalClientExecuteTest, InvalidResultLayout) {
583   // Test passing in an invalid result layout parameter.
584   XlaBuilder builder(TestName());
585   auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2, 2}), "x");
586   Neg(x);
587 
588   auto x_array = LiteralToShapedBuffer(
589       LiteralUtil::CreateR2<float>({{0.0f, 1.0f}, {2.0f, 3.0f}}));
590   auto execute_status = ExecuteLocally(
591       builder.Build().ValueOrDie(), {&x_array},
592       DefaultExecutableBuildOptions().set_result_layout(
593           ShapeUtil::MakeShapeWithLayout(F32,
594                                          /*dimensions=*/{1, 2, 3, 4},
595                                          /*minor_to_major=*/{0, 1, 2, 3})),
596       DefaultExecutableRunOptions());
597 
598   EXPECT_FALSE(execute_status.ok());
599   EXPECT_THAT(execute_status.status().error_message(),
600               ContainsRegex("not compatible with result shape"))
601       << execute_status.status();
602 }
603 
XLA_TEST_F(LocalClientExecuteTest,RunOnAllDeviceOrdinals)604 XLA_TEST_F(LocalClientExecuteTest, RunOnAllDeviceOrdinals) {
605   // Try to run a trivial computation on every device on the system. If a
606   // specific device is not supported, check that the right error is returned.
607   XlaBuilder builder(TestName());
608   ConstantR0<float>(&builder, 42.0f);
609   auto computation = builder.Build().ConsumeValueOrDie();
610   for (int d = 0; d < local_client_->device_count(); ++d) {
611     if (!local_client_->device_ordinal_supported(d)) {
612       auto execute_status =
613           ExecuteLocally(computation, {},
614                          DefaultExecutableBuildOptions().set_device_ordinal(d),
615                          DefaultExecutableRunOptions().set_device_ordinal(d));
616       EXPECT_FALSE(execute_status.ok());
617       EXPECT_THAT(execute_status.status().error_message(),
618                   ContainsRegex("device .* not supported"));
619     } else {
620       auto result = ExecuteLocallyOrDie(
621           computation, {},
622           DefaultExecutableBuildOptions().set_device_ordinal(d),
623           DefaultExecutableRunOptions().set_device_ordinal(d));
624       EXPECT_EQ(d, result.device_ordinal());
625       LiteralTestUtil::ExpectR0Equal<float>(42.0f,
626                                             ShapedBufferToLiteral(result));
627     }
628   }
629 }
630 
XLA_TEST_F(LocalClientExecuteTest,InvalidDeviceOrdinalValues)631 XLA_TEST_F(LocalClientExecuteTest, InvalidDeviceOrdinalValues) {
632   // Try running computations on devices with device ordinal values which do not
633   // exist.
634   XlaBuilder builder(TestName());
635   ConstantR0<float>(&builder, 42.0f);
636   auto computation = builder.Build().ConsumeValueOrDie();
637 
638   auto execute_status =
639       ExecuteLocally(computation, {},
640                      DefaultExecutableBuildOptions().set_device_ordinal(
641                          local_client_->device_count()),
642                      DefaultExecutableRunOptions().set_device_ordinal(
643                          local_client_->device_count()));
644   EXPECT_FALSE(execute_status.ok());
645   EXPECT_THAT(execute_status.status().error_message(),
646               ContainsRegex("Invalid device ordinal value"));
647 }
648 
XLA_TEST_F(LocalClientExecuteTest,RunOnStream)649 XLA_TEST_F(LocalClientExecuteTest, RunOnStream) {
650   // Run a computation on a specific stream on each device on the system.
651   XlaBuilder builder(TestName());
652   ConstantR0<float>(&builder, 42.0f);
653   auto computation = builder.Build().ConsumeValueOrDie();
654 
655   for (int d = 0; d < local_client_->device_count(); ++d) {
656     if (!local_client_->device_ordinal_supported(d)) {
657       continue;
658     }
659     se::StreamExecutor* executor =
660         local_client_->platform()->ExecutorForDevice(d).ValueOrDie();
661     se::Stream stream(executor);
662     stream.Init();
663 
664     auto result =
665         ExecuteLocallyOrDie(computation, {}, DefaultExecutableBuildOptions(),
666                             DefaultExecutableRunOptions().set_stream(&stream));
667     // As a check to verify that the computation ran of the device associated
668     // with the stream. This is a weak check, but stronger verification is hard.
669     EXPECT_EQ(d, result.device_ordinal());
670     LiteralTestUtil::ExpectR0Equal<float>(42.0f, ShapedBufferToLiteral(result));
671   }
672 }
673 
674 // Disable this test on CPU because we're using the CPU as the platform
675 // which does not match the service platform.
XLA_TEST_F(LocalClientExecuteTest,DISABLED_ON_CPU (RunOnStreamForWrongPlatform))676 XLA_TEST_F(LocalClientExecuteTest,
677            DISABLED_ON_CPU(RunOnStreamForWrongPlatform)) {
678   // Try to run a computation on a stream for a platform (CPU) which does not
679   // match the platform of the service (!= CPU).
680   se::Platform* wrong_platform =
681       se::MultiPlatformManager::PlatformWithId(se::host::kHostPlatformId)
682           .ValueOrDie();
683   se::Stream wrong_stream(wrong_platform->ExecutorForDevice(0).ValueOrDie());
684   wrong_stream.Init();
685 
686   XlaBuilder builder(TestName());
687   ConstantR0<float>(&builder, 42.0f);
688   auto execute_status = ExecuteLocally(
689       builder.Build().ValueOrDie(), {}, DefaultExecutableBuildOptions(),
690       DefaultExecutableRunOptions().set_stream(&wrong_stream));
691   EXPECT_FALSE(execute_status.ok());
692   EXPECT_THAT(execute_status.status().error_message(),
693               ContainsRegex("stream is for platform .*, but service targets"));
694 }
695 
XLA_TEST_F(LocalClientExecuteTest,DISABLED_ON_CPU (AllocatorDoesNotMatchPlatform))696 XLA_TEST_F(LocalClientExecuteTest,
697            DISABLED_ON_CPU(AllocatorDoesNotMatchPlatform)) {
698   se::Platform* wrong_platform =
699       se::MultiPlatformManager::PlatformWithId(se::host::kHostPlatformId)
700           .ValueOrDie();
701   TestAllocator allocator(wrong_platform);
702 
703   XlaBuilder builder(TestName());
704   ConstantR0<float>(&builder, 123.0f);
705 
706   auto execute_status = ExecuteLocally(
707       builder.Build().ValueOrDie(), {}, DefaultExecutableBuildOptions(),
708       DefaultExecutableRunOptions().set_allocator(&allocator));
709   EXPECT_FALSE(execute_status.ok());
710   EXPECT_THAT(execute_status.status().error_message(),
711               ContainsRegex("allocator platform .* does not match service"));
712 }
713 
XLA_TEST_F(LocalClientExecuteTest,RunOnUninitializedStream)714 XLA_TEST_F(LocalClientExecuteTest, RunOnUninitializedStream) {
715   // Try to run a computation on a stream that has not been initialized.
716   XlaBuilder builder(TestName());
717   ConstantR0<float>(&builder, 42.0f);
718 
719   LOG(INFO) << "default device = " << local_client_->default_device_ordinal();
720   se::StreamExecutor* executor =
721       local_client_->platform()
722           ->ExecutorForDevice(local_client_->default_device_ordinal())
723           .ValueOrDie();
724   se::Stream stream(executor);
725   // Don't call stream.Init().
726 
727   auto execute_status = ExecuteLocally(
728       builder.Build().ValueOrDie(), {}, DefaultExecutableBuildOptions(),
729       DefaultExecutableRunOptions().set_stream(&stream));
730   EXPECT_FALSE(execute_status.ok());
731   EXPECT_THAT(execute_status.status().error_message(),
732               ContainsRegex("stream is uninitialized or in an error state"));
733 }
734 
XLA_TEST_F(LocalClientExecuteTest,DISABLED_ON_GPU (SelectBetweenTuples))735 XLA_TEST_F(LocalClientExecuteTest, DISABLED_ON_GPU(SelectBetweenTuples)) {
736   XlaBuilder builder(TestName());
737 
738   std::initializer_list<float> vec1 = {1.f, 2.f, 3.f};
739   std::initializer_list<float> vec2 = {2.f, 4.f, 6.f};
740   auto tuple12 = Tuple(&builder, {ConstantR1<float>(&builder, vec1),
741                                   ConstantR1<float>(&builder, vec2)});
742   auto tuple21 = Tuple(&builder, {ConstantR1<float>(&builder, vec2),
743                                   ConstantR1<float>(&builder, vec1)});
744   Select(ConstantR0<bool>(&builder, false), tuple12, tuple21);
745 
746   ScopedShapedBuffer result =
747       ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {});
748   Literal tuple_literal = ShapedBufferToLiteral(result);
749   LiteralTestUtil::ExpectR1Equal<float>({2.0f, 4.0f, 6.0f},
750                                         LiteralSlice(tuple_literal, {0}));
751   LiteralTestUtil::ExpectR1Equal<float>({1.0f, 2.0f, 3.0f},
752                                         LiteralSlice(tuple_literal, {1}));
753 }
754 
XLA_TEST_F(LocalClientExecuteTest,CompileExecutable)755 XLA_TEST_F(LocalClientExecuteTest, CompileExecutable) {
756   XlaBuilder builder(TestName());
757   auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {3}), "x");
758   auto y = ConstantR1<float>(&builder, {2.0f, 3.0f, 4.0f});
759   Add(x, y);
760 
761   Shape argument_layout =
762       ShapeUtil::MakeShapeWithLayout(F32, /*dimensions=*/{3}, {0});
763   TF_ASSERT_OK_AND_ASSIGN(
764       auto executables,
765       local_client_->Compile(builder.Build().ValueOrDie(), {&argument_layout},
766                              ExecutableBuildOptions()));
767   EXPECT_EQ(1, executables.size());
768 
769   auto x_array =
770       LiteralToShapedBuffer(LiteralUtil::CreateR1<float>({0.0f, 1.0f, 2.0f}));
771   ScopedShapedBuffer result =
772       executables[0]
773           ->Run({&x_array}, DefaultExecutableRunOptions())
774           .ConsumeValueOrDie();
775   ASSERT_IS_OK(local_client_->mutable_backend()
776                    ->BorrowStream(0)
777                    .ValueOrDie()
778                    ->BlockHostUntilDone());
779 
780   LiteralTestUtil::ExpectR1Near<float>(
781       {2.0f, 4.0f, 6.0f}, ShapedBufferToLiteral(result), error_spec_);
782 }
783 
XLA_TEST_F(LocalClientExecuteTest,CompilePartitionedExecutable)784 XLA_TEST_F(LocalClientExecuteTest, CompilePartitionedExecutable) {
785   if (local_client_->device_count() < 2) {
786     GTEST_SKIP_("requires two devices");
787   }
788 
789   XlaBuilder builder(TestName());
790   auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {3}), "x");
791   auto y = ConstantR1<float>(&builder, {2.0f, 3.0f, 4.0f});
792   auto z = ConstantR1<float>(&builder, {5.0f, 6.0f, 7.0f});
793   auto r = Add(x, y);
794   builder.SetSharding(sharding_builder::AssignDevice(1));
795   Add(r, z);
796   builder.ClearSharding();
797 
798   Shape argument_layout =
799       ShapeUtil::MakeShapeWithLayout(F32, /*dimensions=*/{3}, {0});
800   ExecutableBuildOptions build_options;
801   build_options.set_num_partitions(2);
802   TF_ASSERT_OK_AND_ASSIGN(
803       auto executables,
804       local_client_->Compile(builder.Build().ValueOrDie(), {&argument_layout},
805                              build_options));
806   EXPECT_EQ(2, executables.size());
807 }
808 
XLA_TEST_F(LocalClientExecuteTest,DISABLED_ON_INTERPRETER (SizeOfGeneratedCodeInBytes))809 XLA_TEST_F(LocalClientExecuteTest,
810            DISABLED_ON_INTERPRETER(SizeOfGeneratedCodeInBytes)) {
811   XlaBuilder builder(TestName());
812   auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {}), "x");
813   constexpr int size = 100000;
814   TF_ASSERT_OK_AND_ASSIGN(auto literal,
815                           LiteralUtil::CreateRandomLiteral<F32>(
816                               ShapeUtil::MakeShape(F32, {size}), 0.0, 1.0));
817   auto y = ConstantLiteral(&builder, literal);
818   Add(x, y);
819 
820   Shape argument_layout =
821       ShapeUtil::MakeShapeWithLayout(F32, /*dimensions=*/{}, {});
822   TF_ASSERT_OK_AND_ASSIGN(
823       auto executables,
824       local_client_->Compile(builder.Build().ValueOrDie(), {&argument_layout},
825                              ExecutableBuildOptions()));
826   EXPECT_EQ(1, executables.size());
827   // The executable should be at least as large as the constant it contains.
828   EXPECT_GT(executables.front()->executable()->SizeOfGeneratedCodeInBytes(),
829             int64{sizeof(float) * size});
830 }
831 
XLA_TEST_F(LocalClientExecuteTest,ShapeBufferToLiteralConversion)832 XLA_TEST_F(LocalClientExecuteTest, ShapeBufferToLiteralConversion) {
833   // Test copying Literals to the device as ShapedBuffers, then copying them
834   // back again to Literals.
835   auto test_to_device_and_back = [this](const Literal& literal) {
836     TF_ASSERT_OK_AND_ASSIGN(
837         auto shaped_buffer,
838         local_client_->LiteralToShapedBuffer(
839             literal, local_client_->default_device_ordinal(), allocator_));
840     TF_ASSERT_OK_AND_ASSIGN(
841         auto transferred_literal,
842         local_client_->ShapedBufferToLiteral(shaped_buffer));
843     EXPECT_EQ(literal, transferred_literal);
844   };
845 
846   // Array shapes.
847   test_to_device_and_back(LiteralUtil::CreateR0<float>(42.0));
848   test_to_device_and_back(LiteralUtil::CreateR0<bool>(true));
849   test_to_device_and_back(LiteralUtil::CreateR1<float>({1.0, 42.0, 744.4}));
850   test_to_device_and_back(
851       LiteralUtil::CreateR2<float>({{1.0, 2.0, 3.0}, {44.0, 0.1, -3}}));
852   test_to_device_and_back(LiteralUtil::CreateR2<int32>({{2, 1}, {4444, 56}}));
853 
854   // Null shape (empty tuple).
855   test_to_device_and_back(LiteralUtil::MakeTuple({}));
856 
857   // Non-nested tuples.
858   test_to_device_and_back(LiteralUtil::MakeTupleFromSlices(
859       {LiteralUtil::CreateR0<float>(12223.0)}));
860   test_to_device_and_back(LiteralUtil::MakeTupleFromSlices(
861       {LiteralUtil::CreateR1<float>({1.0, -42.0}),
862        LiteralUtil::CreateR0<float>(123456.0)}));
863 
864   // Nested tuple.
865   test_to_device_and_back(LiteralUtil::MakeTupleFromSlices(
866       {LiteralUtil::MakeTupleFromSlices(
867            {LiteralUtil::CreateR1<float>({1.0, -42.0}),
868             LiteralUtil::CreateR0<float>(123456.0)}),
869        LiteralUtil::CreateR0<bool>(false)}));
870 }
871 
XLA_TEST_F(LocalClientExecuteTest,ShapeBufferToLiteralConversion64bit)872 XLA_TEST_F(LocalClientExecuteTest, ShapeBufferToLiteralConversion64bit) {
873   // Test copying Literals to the device as ShapedBuffers, then copying them
874   // back again to Literals for 64-bit values.
875   auto test_to_device_and_back = [this](const Literal& literal) {
876     TF_ASSERT_OK_AND_ASSIGN(
877         auto shaped_buffer,
878         local_client_->LiteralToShapedBuffer(
879             literal, local_client_->default_device_ordinal(), allocator_));
880     TF_ASSERT_OK_AND_ASSIGN(
881         auto transferred_literal,
882         local_client_->ShapedBufferToLiteral(shaped_buffer));
883     EXPECT_EQ(literal, transferred_literal);
884   };
885 
886   test_to_device_and_back(LiteralUtil::CreateR2<double>(
887       {{1.0, 2.0, 3.0}, {44.0, 0.099999999999999978, -3}}));
888   test_to_device_and_back(LiteralUtil::CreateR2<int64>({{2, 1}, {4444, 56}}));
889   test_to_device_and_back(
890       LiteralUtil::CreateR2<uint64>({{20000000000ULL, 1}, {4444, 56}}));
891   test_to_device_and_back(LiteralUtil::MakeTupleFromSlices(
892       {LiteralUtil::CreateR1<double>({1.0, -42.0}),
893        LiteralUtil::CreateR0<int64>(123456789000LL)}));
894 }
895 
896 // Disabled on interpreter backend since infeed HLO is unsupported.
XLA_TEST_F(LocalClientExecuteTest,DISABLED_ON_INTERPRETER (InfeedTest))897 XLA_TEST_F(LocalClientExecuteTest, DISABLED_ON_INTERPRETER(InfeedTest)) {
898   XlaBuilder builder(TestName());
899   const Shape shape = ShapeUtil::MakeShape(F32, {3});
900   auto in = Infeed(&builder, shape);
901   auto constant = ConstantR1<float>(&builder, {1.0f, 2.0f, 3.0f});
902   Add(in, constant);
903 
904   Literal result;
905   std::unique_ptr<tensorflow::Thread> thread(
906       tensorflow::Env::Default()->StartThread(
907           tensorflow::ThreadOptions(), "execute_thread", [&] {
908             result = ShapedBufferToLiteral(ExecuteLocallyOrDie(
909                 builder.Build().ValueOrDie(), /*arguments=*/{}));
910           }));
911 
912   ASSERT_IS_OK(local_client_->TransferToInfeedLocal(
913       LiteralUtil::CreateR1<float>({-5.0, 123.0, 42.0}),
914       local_client_->default_device_ordinal()));
915 
916   // Join the thread.
917   thread.reset();
918 
919   LiteralTestUtil::ExpectR1Equal<float>({-4.0, 125.0, 45.0}, result);
920 }
921 
922 // Disabled on interpreter backend since infeed/outfeed HLOs are unsupported.
XLA_TEST_F(LocalClientExecuteTest,DISABLED_ON_INTERPRETER (InfeedOutfeedTest))923 XLA_TEST_F(LocalClientExecuteTest, DISABLED_ON_INTERPRETER(InfeedOutfeedTest)) {
924   XlaBuilder builder(TestName());
925   const Shape shape = ShapeUtil::MakeShape(F32, {3});
926   auto in = Infeed(&builder, shape);
927   auto constant = ConstantR1<float>(&builder, {1.0f, 2.0f, 3.0f});
928   auto sum = Add(in, constant);
929   Outfeed(sum, shape, /*outfeed_config=*/"");
930 
931   std::unique_ptr<tensorflow::Thread> thread(
932       tensorflow::Env::Default()->StartThread(
933           tensorflow::ThreadOptions(), "execute_thread",
934           [&] { ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {}); }));
935 
936   ASSERT_IS_OK(local_client_->TransferToInfeedLocal(
937       LiteralUtil::CreateR1<float>({-5.0, 123.0, 42.0}),
938       local_client_->default_device_ordinal()));
939 
940   Literal result(shape);
941   ASSERT_IS_OK(local_client_->TransferFromOutfeedLocal(
942       local_client_->default_device_ordinal(), &result));
943 
944   LiteralTestUtil::ExpectR1Equal<float>({-4.0, 125.0, 45.0}, result);
945 }
946 
947 // Benchmark that measures the overhead of the LocalClient API when running a
948 // trivial computation
BM_LocalClientOverhead(::testing::benchmark::State & state)949 void BM_LocalClientOverhead(::testing::benchmark::State& state) {
950   se::Platform* platform = PlatformUtil::GetDefaultPlatform().ValueOrDie();
951   auto executors = PlatformUtil::GetStreamExecutors(platform).ValueOrDie();
952   se::StreamExecutorMemoryAllocator allocator(platform, executors);
953   LocalClient* client =
954       ClientLibrary::GetOrCreateLocalClient(platform).ValueOrDie();
955   auto* transfer_manager =
956       TransferManager::GetForPlatform(platform).ValueOrDie();
957   int device_ordinal = client->default_device_ordinal();
958 
959   // Use a tiny add operation as the computation.
960   XlaBuilder builder("Add");
961   auto shape = ShapeUtil::MakeShape(F32, {2, 3});
962   auto x = Parameter(&builder, 0, shape, "x");
963   Add(x, x);
964   auto computation = builder.Build().ConsumeValueOrDie();
965 
966   auto buffer =
967       transfer_manager
968           ->AllocateScopedShapedBuffer(shape, &allocator, /*device_ordinal=*/0)
969           .ConsumeValueOrDie();
970   auto literal = LiteralUtil::CreateR2<float>({{0, 0, 0}, {0, 0, 0}});
971   auto stream =
972       client->mutable_backend()->BorrowStream(device_ordinal).ValueOrDie();
973   ASSERT_IS_OK(
974       transfer_manager->TransferLiteralToDevice(stream.get(), literal, buffer));
975 
976   const int kWarmups = 2;
977 
978   TF_ASSERT_OK_AND_ASSIGN(
979       auto executables, client->Compile(computation, {&buffer.on_host_shape()},
980                                         ExecutableBuildOptions()));
981   std::unique_ptr<LocalExecutable> executable = std::move(executables[0]);
982 
983   ExecutableRunOptions run_options;
984   run_options.set_allocator(&allocator).set_stream(stream.get());
985 
986   for (int i = 0; i < kWarmups; ++i) {
987     auto result = executable->Run({&buffer}, run_options);
988     ASSERT_IS_OK(result);
989   }
990 
991   for (auto s : state) {
992     auto result = executable->Run({&buffer}, run_options);
993     ASSERT_IS_OK(result);
994   }
995 }
996 
997 BENCHMARK(BM_LocalClientOverhead);
998 
999 }  // namespace
1000 }  // namespace xla
1001