• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <initializer_list>
17 #include <memory>
18 #include <vector>
19 
20 #include "tensorflow/compiler/xla/client/client_library.h"
21 #include "tensorflow/compiler/xla/client/local_client.h"
22 #include "tensorflow/compiler/xla/client/xla_builder.h"
23 #include "tensorflow/compiler/xla/layout_util.h"
24 #include "tensorflow/compiler/xla/literal.h"
25 #include "tensorflow/compiler/xla/service/device_memory_allocator.h"
26 #include "tensorflow/compiler/xla/service/local_service.h"
27 #include "tensorflow/compiler/xla/service/platform_util.h"
28 #include "tensorflow/compiler/xla/service/shaped_buffer.h"
29 #include "tensorflow/compiler/xla/service/transfer_manager.h"
30 #include "tensorflow/compiler/xla/shape_util.h"
31 #include "tensorflow/compiler/xla/statusor.h"
32 #include "tensorflow/compiler/xla/test.h"
33 #include "tensorflow/compiler/xla/test_helpers.h"
34 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
35 #include "tensorflow/compiler/xla/tests/local_client_test_base.h"
36 #include "tensorflow/compiler/xla/tests/test_macros.h"
37 #include "tensorflow/compiler/xla/tests/test_utils.h"
38 #include "tensorflow/compiler/xla/xla_data.pb.h"
39 #include "tensorflow/core/platform/env.h"
40 #include "tensorflow/core/platform/logging.h"
41 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
42 #include "tensorflow/core/platform/test.h"
43 #include "tensorflow/core/platform/test_benchmark.h"
44 
45 namespace xla {
46 namespace {
47 
48 using ::testing::ContainsRegex;
49 
50 class LocalClientExecuteTest : public LocalClientTestBase {
51  protected:
52   ErrorSpec error_spec_{0.0001};
53 };
54 
XLA_TEST_F(LocalClientExecuteTest,Constant)55 XLA_TEST_F(LocalClientExecuteTest, Constant) {
56   XlaBuilder builder(TestName());
57   ConstantR0<float>(&builder, 123.0f);
58 
59   ScopedShapedBuffer result =
60       ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {});
61   LiteralTestUtil::ExpectR0Near<float>(123.f, ShapedBufferToLiteral(result),
62                                        error_spec_);
63 }
64 
XLA_TEST_F(LocalClientExecuteTest,AddScalars)65 XLA_TEST_F(LocalClientExecuteTest, AddScalars) {
66   XlaBuilder builder(TestName());
67   auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {}), "x");
68   auto y = ConstantR0<float>(&builder, 123.0f);
69   Add(x, y);
70 
71   auto x_value = LiteralToShapedBuffer(LiteralUtil::CreateR0<float>(42.0f));
72   ScopedShapedBuffer result =
73       ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {&x_value});
74   LiteralTestUtil::ExpectR0Near<float>(165.f, ShapedBufferToLiteral(result),
75                                        error_spec_);
76 }
77 
XLA_TEST_F(LocalClientExecuteTest,AddZeroElementVectors)78 XLA_TEST_F(LocalClientExecuteTest, AddZeroElementVectors) {
79   XlaBuilder builder(TestName());
80   auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {0}), "x");
81   auto y = ConstantR1<float>(&builder, {});
82   Add(x, y);
83 
84   auto x_array = LiteralToShapedBuffer(LiteralUtil::CreateR1<float>({}));
85   ScopedShapedBuffer result =
86       ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {&x_array});
87   LiteralTestUtil::ExpectR1Near<float>({}, ShapedBufferToLiteral(result),
88                                        error_spec_);
89 }
90 
XLA_TEST_F(LocalClientExecuteTest,AddVectors)91 XLA_TEST_F(LocalClientExecuteTest, AddVectors) {
92   XlaBuilder builder(TestName());
93   auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {3}), "x");
94   auto y = ConstantR1<float>(&builder, {2.0f, 3.0f, 4.0f});
95   Add(x, y);
96 
97   auto x_array =
98       LiteralToShapedBuffer(LiteralUtil::CreateR1<float>({0.0f, 1.0f, 2.0f}));
99   ScopedShapedBuffer result =
100       ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {&x_array});
101   LiteralTestUtil::ExpectR1Near<float>(
102       {2.0f, 4.0f, 6.0f}, ShapedBufferToLiteral(result), error_spec_);
103 }
104 
XLA_TEST_F(LocalClientExecuteTest,AddVectorsWithProfile)105 XLA_TEST_F(LocalClientExecuteTest, AddVectorsWithProfile) {
106   XlaBuilder builder(TestName());
107   auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {3}), "x");
108   auto y = ConstantR1<float>(&builder, {2.0f, 3.0f, 4.0f});
109   Add(x, y);
110 
111   auto x_array =
112       LiteralToShapedBuffer(LiteralUtil::CreateR1<float>({0.0f, 1.0f, 2.0f}));
113   ExecutionProfile profile;
114   ScopedShapedBuffer result = ExecuteLocallyOrDie(
115       builder.Build().ValueOrDie(), {&x_array}, DefaultExecutableBuildOptions(),
116       DefaultExecutableRunOptions().set_execution_profile(&profile));
117 
118   LiteralTestUtil::ExpectR1Near<float>(
119       {2.0f, 4.0f, 6.0f}, ShapedBufferToLiteral(result), error_spec_);
120   EXPECT_GT(profile.compute_and_transfer_time_ns(), 0);
121 }
122 
XLA_TEST_F(LocalClientExecuteTest,AddArraysWithDifferentInputLayouts)123 XLA_TEST_F(LocalClientExecuteTest, AddArraysWithDifferentInputLayouts) {
124   XlaBuilder builder(TestName());
125   auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2, 2}), "x");
126   auto y = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {2, 2}), "y");
127   Add(x, y);
128   auto computation = builder.Build().ConsumeValueOrDie();
129 
130   // Create x as a col-major array.
131   auto x_array = LiteralToShapedBuffer(LiteralUtil::CreateR2WithLayout(
132       {{1.0f, 2.0f}, {3.0f, 4.0f}}, LayoutUtil::MakeLayout({0, 1})));
133   EXPECT_TRUE(LayoutUtil::Equal(x_array.on_device_shape().layout(),
134                                 LayoutUtil::MakeLayout({0, 1})));
135 
136   // Create y as a row-major array.
137   auto y_array = LiteralToShapedBuffer(LiteralUtil::CreateR2WithLayout(
138       {{10.0f, 20.0f}, {30.0f, 40.0f}}, LayoutUtil::MakeLayout({1, 0})));
139   EXPECT_TRUE(LayoutUtil::Equal(y_array.on_device_shape().layout(),
140                                 LayoutUtil::MakeLayout({1, 0})));
141 
142   ScopedShapedBuffer result_colmaj =
143       ExecuteLocallyOrDie(computation, {&x_array, &y_array});
144   LiteralTestUtil::ExpectR2Near<float>({{11.0f, 22.0f}, {33.0f, 44.0f}},
145                                        ShapedBufferToLiteral(result_colmaj),
146                                        error_spec_);
147 
148   // Run with the parameter values in a different order.
149   ScopedShapedBuffer result_param_swap =
150       ExecuteLocallyOrDie(computation, {&y_array, &x_array});
151   LiteralTestUtil::ExpectR2Near<float>({{11.0f, 22.0f}, {33.0f, 44.0f}},
152                                        ShapedBufferToLiteral(result_param_swap),
153                                        error_spec_);
154 }
155 
XLA_TEST_F(LocalClientExecuteTest,AddArraysWithDifferentOutputLayouts)156 XLA_TEST_F(LocalClientExecuteTest, AddArraysWithDifferentOutputLayouts) {
157   XlaBuilder builder(TestName());
158   auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2, 2}), "x");
159   auto y = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {2, 2}), "y");
160   Add(x, y);
161   auto computation = builder.Build().ConsumeValueOrDie();
162 
163   auto x_array = LiteralToShapedBuffer(
164       LiteralUtil::CreateR2<float>({{1.0f, 2.0f}, {3.0f, 4.0f}}));
165   auto y_array = LiteralToShapedBuffer(
166       LiteralUtil::CreateR2<float>({{10.0f, 20.0f}, {30.0f, 40.0f}}));
167 
168   // Run with col-major result layout.
169   ScopedShapedBuffer result_colmaj = ExecuteLocallyOrDie(
170       computation, {&x_array, &y_array},
171       DefaultExecutableBuildOptions().set_result_layout(
172           ShapeUtil::MakeShapeWithLayout(F32, /*dimensions=*/{2, 2}, {0, 1})),
173       DefaultExecutableRunOptions());
174   EXPECT_TRUE(LayoutUtil::Equal(result_colmaj.on_device_shape().layout(),
175                                 LayoutUtil::MakeLayout({0, 1})));
176   LiteralTestUtil::ExpectR2Near<float>({{11.0f, 22.0f}, {33.0f, 44.0f}},
177                                        ShapedBufferToLiteral(result_colmaj),
178                                        error_spec_);
179 
180   // Run with row-major result layout.
181   ScopedShapedBuffer result_rowmaj = ExecuteLocallyOrDie(
182       computation, {&x_array, &y_array},
183       DefaultExecutableBuildOptions().set_result_layout(
184           ShapeUtil::MakeShapeWithLayout(F32, /*dimensions=*/{2, 2}, {1, 0})),
185       DefaultExecutableRunOptions());
186   EXPECT_TRUE(LayoutUtil::Equal(result_rowmaj.on_device_shape().layout(),
187                                 LayoutUtil::MakeLayout({1, 0})));
188   LiteralTestUtil::ExpectR2Near<float>({{11.0f, 22.0f}, {33.0f, 44.0f}},
189                                        ShapedBufferToLiteral(result_rowmaj),
190                                        error_spec_);
191 }
192 
XLA_TEST_F(LocalClientExecuteTest,TupleResult)193 XLA_TEST_F(LocalClientExecuteTest, TupleResult) {
194   XlaBuilder builder(TestName());
195   auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2, 2}), "x");
196   auto y = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {2, 2}), "y");
197   Tuple(&builder, {x, y, x});
198   auto computation = builder.Build().ConsumeValueOrDie();
199 
200   auto x_array = LiteralToShapedBuffer(
201       LiteralUtil::CreateR2<float>({{1.0f, 2.0f}, {3.0f, 4.0f}}));
202   auto y_array = LiteralToShapedBuffer(
203       LiteralUtil::CreateR2<float>({{10.0f, 20.0f}, {30.0f, 40.0f}}));
204 
205   ScopedShapedBuffer result =
206       ExecuteLocallyOrDie(computation, {&x_array, &y_array});
207 
208   EXPECT_TRUE(result.on_host_shape().IsTuple());
209   EXPECT_EQ(3, ShapeUtil::TupleElementCount(result.on_host_shape()));
210 
211   Literal result_literal = ShapedBufferToLiteral(result);
212   LiteralTestUtil::ExpectR2Equal<float>({{1.0f, 2.0f}, {3.0f, 4.0f}},
213                                         LiteralSlice(result_literal, {0}));
214   LiteralTestUtil::ExpectR2Equal<float>({{10.0f, 20.0f}, {30.0f, 40.0f}},
215                                         LiteralSlice(result_literal, {1}));
216   LiteralTestUtil::ExpectR2Equal<float>({{1.0f, 2.0f}, {3.0f, 4.0f}},
217                                         LiteralSlice(result_literal, {2}));
218 }
219 
XLA_TEST_F(LocalClientExecuteTest,NestedTupleResult)220 XLA_TEST_F(LocalClientExecuteTest, NestedTupleResult) {
221   XlaBuilder builder(TestName());
222   auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2, 2}), "x");
223   auto y = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {2, 2}), "y");
224   auto inner_tuple = Tuple(&builder, {x, y, x});
225   Tuple(&builder, {inner_tuple, x});
226   auto computation = builder.Build().ConsumeValueOrDie();
227 
228   auto x_array = LiteralToShapedBuffer(
229       LiteralUtil::CreateR2<float>({{1.0f, 2.0f}, {3.0f, 4.0f}}));
230   auto y_array = LiteralToShapedBuffer(
231       LiteralUtil::CreateR2<float>({{10.0f, 20.0f}, {30.0f, 40.0f}}));
232 
233   ScopedShapedBuffer result =
234       ExecuteLocallyOrDie(computation, {&x_array, &y_array});
235 
236   EXPECT_TRUE(result.on_host_shape().IsTuple());
237   EXPECT_EQ(2, ShapeUtil::TupleElementCount(result.on_host_shape()));
238 
239   Literal result_literal = ShapedBufferToLiteral(result);
240   LiteralTestUtil::ExpectR2Equal<float>({{1.0f, 2.0f}, {3.0f, 4.0f}},
241                                         LiteralSlice(result_literal, {1}));
242   LiteralTestUtil::ExpectR2Equal<float>({{1.0f, 2.0f}, {3.0f, 4.0f}},
243                                         LiteralSlice(result_literal, {0, 0}));
244   LiteralTestUtil::ExpectR2Equal<float>({{10.0f, 20.0f}, {30.0f, 40.0f}},
245                                         LiteralSlice(result_literal, {0, 1}));
246   LiteralTestUtil::ExpectR2Equal<float>({{1.0f, 2.0f}, {3.0f, 4.0f}},
247                                         LiteralSlice(result_literal, {0, 2}));
248 }
249 
XLA_TEST_F(LocalClientExecuteTest,TupleResultWithLayout)250 XLA_TEST_F(LocalClientExecuteTest, TupleResultWithLayout) {
251   // Verify setting the result layout of a computation with a tuple output.
252   XlaBuilder builder(TestName());
253   auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2, 2}), "x");
254   auto y = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {2, 2}), "y");
255   Tuple(&builder, {x, y});
256 
257   auto array = LiteralToShapedBuffer(
258       LiteralUtil::CreateR2<float>({{1.0f, 2.0f}, {3.0f, 4.0f}}));
259 
260   ExecutableBuildOptions options = DefaultExecutableBuildOptions();
261   Shape shape_with_layout = ShapeUtil::MakeTupleShape(
262       {ShapeUtil::MakeShapeWithLayout(F32, /*dimensions=*/{2, 2},
263                                       /*minor_to_major=*/{0, 1}),
264        ShapeUtil::MakeShapeWithLayout(F32, /*dimensions=*/{2, 2},
265                                       /*minor_to_major=*/{1, 0})});
266   options.set_result_layout(shape_with_layout);
267   ScopedShapedBuffer result =
268       ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {&array, &array},
269                           options, DefaultExecutableRunOptions());
270 
271   Literal result_literal = ShapedBufferToLiteral(result);
272   LiteralTestUtil::ExpectR2Equal<float>({{1.0f, 2.0f}, {3.0f, 4.0f}},
273                                         LiteralSlice(result_literal, {0}));
274   LiteralTestUtil::ExpectR2Equal<float>({{1.0f, 2.0f}, {3.0f, 4.0f}},
275                                         LiteralSlice(result_literal, {1}));
276 }
277 
XLA_TEST_F(LocalClientExecuteTest,TupleArguments)278 XLA_TEST_F(LocalClientExecuteTest, TupleArguments) {
279   const Shape array_shape = ShapeUtil::MakeShape(F32, {2, 2});
280   const Shape vector_shape = ShapeUtil::MakeShape(F32, {3});
281 
282   const Shape tuple_shape0 =
283       ShapeUtil::MakeTupleShape({array_shape, vector_shape});
284   const Shape tuple_shape1 =
285       ShapeUtil::MakeTupleShape({vector_shape, array_shape});
286 
287   // Computation adds the respective array and vector elements from each tuple
288   // argument and returns the results as a tuple.
289   XlaBuilder builder(TestName());
290   auto x = Parameter(&builder, 0, tuple_shape0, "x");
291   auto y = Parameter(&builder, 1, tuple_shape1, "y");
292   auto x_0 = GetTupleElement(x, 0);
293   auto x_1 = GetTupleElement(x, 1);
294   auto y_0 = GetTupleElement(y, 0);
295   auto y_1 = GetTupleElement(y, 1);
296   auto array_sum = Add(x_0, y_1);
297   auto vector_diff = Sub(x_1, y_0);
298   Tuple(&builder, {array_sum, vector_diff});
299   auto computation = builder.Build().ConsumeValueOrDie();
300 
301   auto x_literal = LiteralUtil::MakeTupleFromSlices(
302       {LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}),
303        LiteralUtil::CreateR1<float>({42.0, 75.0, 123.0})});
304   auto y_literal = LiteralUtil::MakeTupleFromSlices(
305       {LiteralUtil::CreateR1<float>({2.0, 4.0, 6.0}),
306        LiteralUtil::CreateR2<float>({{55.0, 44.0}, {33.0, 22.0}})});
307 
308   auto x_buffer = LiteralToShapedBuffer(x_literal);
309   auto y_buffer = LiteralToShapedBuffer(y_literal);
310 
311   ScopedShapedBuffer result =
312       ExecuteLocallyOrDie(computation, {&x_buffer, &y_buffer});
313 
314   EXPECT_TRUE(result.on_host_shape().IsTuple());
315   EXPECT_EQ(2, ShapeUtil::TupleElementCount(result.on_host_shape()));
316 
317   Literal result_literal = ShapedBufferToLiteral(result);
318   LiteralTestUtil::ExpectR2Equal<float>({{56.0f, 46.0f}, {36.0f, 26.0f}},
319                                         LiteralSlice(result_literal, {0}));
320   LiteralTestUtil::ExpectR1Equal<float>({40.0f, 71.0f, 117.0f},
321                                         LiteralSlice(result_literal, {1}));
322 }
323 
XLA_TEST_F(LocalClientExecuteTest,NestedTupleArgument)324 XLA_TEST_F(LocalClientExecuteTest, NestedTupleArgument) {
325   const Shape array_shape = ShapeUtil::MakeShape(F32, {2, 2});
326   const Shape vector_shape = ShapeUtil::MakeShape(F32, {3});
327 
328   const Shape inner_tuple_shape =
329       ShapeUtil::MakeTupleShape({array_shape, vector_shape});
330   const Shape nested_tuple_shape =
331       ShapeUtil::MakeTupleShape({inner_tuple_shape, vector_shape});
332 
333   // Computation negates the array element and sums the two vector elements in
334   // the nested tuple. The resulting array and vector are returned as a tuple.
335   XlaBuilder builder(TestName());
336   auto param = Parameter(&builder, 0, nested_tuple_shape, "param");
337   auto inner_tuple = GetTupleElement(param, 0);
338   auto inner_array = GetTupleElement(inner_tuple, 0);
339   auto inner_vector = GetTupleElement(inner_tuple, 1);
340   auto outer_vector = GetTupleElement(param, 1);
341 
342   auto negate_array = Neg(inner_array);
343   auto vector_sum = Add(inner_vector, outer_vector);
344   Tuple(&builder, {negate_array, vector_sum});
345   auto computation = builder.Build().ConsumeValueOrDie();
346 
347   auto arg_literal = LiteralUtil::MakeTupleFromSlices(
348       {LiteralUtil::MakeTupleFromSlices(
349            {LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}),
350             LiteralUtil::CreateR1<float>({42.0, 75.0, 123.0})}),
351        LiteralUtil::CreateR1<float>({222.0, -2.0, 10.0})});
352   auto arg_buffer = LiteralToShapedBuffer(arg_literal);
353 
354   ScopedShapedBuffer result = ExecuteLocallyOrDie(computation, {&arg_buffer});
355 
356   Literal result_literal = ShapedBufferToLiteral(result);
357   LiteralTestUtil::ExpectR2Equal<float>({{-1.0, -2.0}, {-3.0, -4}},
358                                         LiteralSlice(result_literal, {0}));
359   LiteralTestUtil::ExpectR1Equal<float>({264.0, 73.0, 133.0},
360                                         LiteralSlice(result_literal, {1}));
361 }
362 
XLA_TEST_F(LocalClientExecuteTest,PassingTupleResultBackIntoComputation)363 XLA_TEST_F(LocalClientExecuteTest, PassingTupleResultBackIntoComputation) {
364   // Construct a computation which takes and returns the same shape (a
365   // tuple). Feed the result of the computation back into the input. This
366   // provides additional verification that the returned tuple is properly
367   // constructed.
368   const Shape array_shape = ShapeUtil::MakeShape(F32, {2, 2});
369   const Shape tuple_shape =
370       ShapeUtil::MakeTupleShape({array_shape, array_shape});
371 
372   XlaBuilder builder(TestName());
373   auto param = Parameter(&builder, 0, tuple_shape, "param");
374   auto element_0 = GetTupleElement(param, 0);
375   auto element_1 = GetTupleElement(param, 1);
376   Tuple(&builder, {Neg(element_0), Add(element_1, element_1)});
377   auto computation = builder.Build().ConsumeValueOrDie();
378 
379   auto arg_literal = LiteralUtil::MakeTupleFromSlices(
380       {LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}),
381        LiteralUtil::CreateR2<float>({{11.0, 3.0}, {4.0, 5.0}})});
382   auto arg_buffer = LiteralToShapedBuffer(arg_literal);
383 
384   ScopedShapedBuffer result_0 = ExecuteLocallyOrDie(computation, {&arg_buffer});
385   Literal result_0_literal = ShapedBufferToLiteral(result_0);
386   LiteralTestUtil::ExpectR2Equal<float>({{-1.0, -2.0}, {-3.0, -4.0}},
387                                         LiteralSlice(result_0_literal, {0}));
388   LiteralTestUtil::ExpectR2Equal<float>({{22.0, 6.0}, {8.0, 10}},
389                                         LiteralSlice(result_0_literal, {1}));
390 
391   ScopedShapedBuffer result_1 = ExecuteLocallyOrDie(computation, {&result_0});
392   Literal result_1_literal = ShapedBufferToLiteral(result_1);
393   LiteralTestUtil::ExpectR2Equal<float>({{1.0, 2.0}, {3.0, 4.0}},
394                                         LiteralSlice(result_1_literal, {0}));
395   LiteralTestUtil::ExpectR2Equal<float>({{44.0, 12.0}, {16.0, 20}},
396                                         LiteralSlice(result_1_literal, {1}));
397 }
398 
XLA_TEST_F(LocalClientExecuteTest,LargeTuple)399 XLA_TEST_F(LocalClientExecuteTest, LargeTuple) {
400   // Construct a computation which takes a tuple parameter with a very large
401   // number of elements.
402 
403   // A larger number of elements would make for a better, more strenuous test,
404   // but:
405   // TODO(b/66959878): On cpu a large number of elements results in long
406   //   compilation time.
407   // TODO(b/66954197): On gpu a large number of elements OOMs.
408   const int kElementCount = 100;
409 
410   // Each element is a 2-element vector.
411   const Shape element_shape = ShapeUtil::MakeShape(F32, {2});
412   std::vector<Shape> element_shapes(kElementCount, element_shape);
413   const Shape tuple_shape = ShapeUtil::MakeTupleShape(element_shapes);
414 
415   XlaBuilder builder(TestName());
416   auto param = Parameter(&builder, 0, tuple_shape, "param");
417 
418   // Add each element's tuple index value to every element.
419   std::vector<XlaOp> result_elements;
420   for (int i = 0; i < kElementCount; ++i) {
421     auto element = GetTupleElement(param, i);
422     result_elements.push_back(Add(element, ConstantR0<float>(&builder, i)));
423   }
424   Tuple(&builder, result_elements);
425   auto computation = builder.Build().ConsumeValueOrDie();
426 
427   // Feed in a tuple where each two-element vector element is {tuple_index,
428   // -tuple_index}.
429   std::vector<Literal> arg_elements;
430   for (int i = 0; i < kElementCount; ++i) {
431     arg_elements.push_back(LiteralUtil::CreateR1<float>({1.0f * i, -1.0f * i}));
432   }
433   Literal arg_literal = LiteralUtil::MakeTupleOwned(std::move(arg_elements));
434   auto arg_buffer = LiteralToShapedBuffer(arg_literal);
435 
436   ScopedShapedBuffer result = ExecuteLocallyOrDie(computation, {&arg_buffer});
437   Literal result_literal = ShapedBufferToLiteral(result);
438 
439   for (int i = 0; i < kElementCount; ++i) {
440     LiteralTestUtil::ExpectR1Near<float>(
441         {2.0f * i, 0.0f}, LiteralSlice(result_literal, {i}), error_spec_);
442   }
443 }
444 
XLA_TEST_F(LocalClientExecuteTest,LargeNestedTuple)445 XLA_TEST_F(LocalClientExecuteTest, LargeNestedTuple) {
446   // Construct and run a computation which takes a two-level nested tuple
447   // parameter with a large fanout.
448   const int kFanout = 40;
449 
450   // Tuple shape is full two-level tree with the given fanout.
451   const Shape element_shape = ShapeUtil::MakeShape(F32, {});
452   std::vector<Shape> element_shapes(kFanout, element_shape);
453   const Shape inner_tuple_shape = ShapeUtil::MakeTupleShape(element_shapes);
454   std::vector<Shape> inner_tuple_shapes(kFanout, inner_tuple_shape);
455   const Shape tuple_shape = ShapeUtil::MakeTupleShape(inner_tuple_shapes);
456 
457   XlaBuilder builder(TestName());
458   auto param = Parameter(&builder, 0, tuple_shape, "param");
459 
460   // The computation increments each leaf value by an amount equal to the leaf's
461   // ordinal position in a traversal of the tuple.
462   std::vector<XlaOp> result_elements;
463   for (int i = 0; i < kFanout; ++i) {
464     auto outer_element = GetTupleElement(param, i);
465     std::vector<XlaOp> inner_result_elements;
466     for (int j = 0; j < kFanout; ++j) {
467       auto inner_element = GetTupleElement(outer_element, j);
468       inner_result_elements.push_back(
469           Add(inner_element, ConstantR0<float>(&builder, i * kFanout + j)));
470     }
471     result_elements.push_back(Tuple(&builder, inner_result_elements));
472   }
473   Tuple(&builder, result_elements);
474   auto computation = builder.Build().ConsumeValueOrDie();
475 
476   // Construct the argument to pass to the computation.
477   std::vector<Literal> outer_tuple_elements;
478   for (int i = 0; i < kFanout; ++i) {
479     std::vector<Literal> inner_tuple_elements;
480     for (int j = 0; j < kFanout; ++j) {
481       inner_tuple_elements.push_back(LiteralUtil::CreateR0<float>(i + j));
482     }
483     outer_tuple_elements.push_back(
484         LiteralUtil::MakeTupleOwned(std::move(inner_tuple_elements)));
485   }
486   auto arg_literal =
487       LiteralUtil::MakeTupleOwned(std::move(outer_tuple_elements));
488   auto arg_buffer = LiteralToShapedBuffer(arg_literal);
489 
490   ScopedShapedBuffer result = ExecuteLocallyOrDie(computation, {&arg_buffer});
491   Literal result_literal = ShapedBufferToLiteral(result);
492 
493   for (int i = 0; i < kFanout; ++i) {
494     for (int j = 0; j < kFanout; ++j) {
495       LiteralTestUtil::ExpectR0Near<float>(i + j + i * kFanout + j,
496                                            LiteralSlice(result_literal, {i, j}),
497                                            error_spec_);
498     }
499   }
500 }
501 
XLA_TEST_F(LocalClientExecuteTest,DeepTuple)502 XLA_TEST_F(LocalClientExecuteTest, DeepTuple) {
503   // Construct and run a computation which takes a very deep tuple. The tuple
504   // has no fan out and a single scalar element at the bottom.
505   const int kTupleDepth = 100;
506 
507   // Tuple shape is full two-level tree with the given fanout.
508   Shape shape = ShapeUtil::MakeShape(F32, {});
509   for (int i = 0; i < kTupleDepth; ++i) {
510     shape = ShapeUtil::MakeTupleShape({shape});
511   }
512 
513   XlaBuilder builder(TestName());
514   auto element = Parameter(&builder, 0, shape, "param");
515   for (int i = 0; i < kTupleDepth; ++i) {
516     element = GetTupleElement(element, 0);
517   }
518 
519   auto output = Add(element, ConstantR0<float>(&builder, 42.0));
520   for (int i = 0; i < kTupleDepth; ++i) {
521     output = Tuple(&builder, {output});
522   }
523   auto computation = builder.Build().ConsumeValueOrDie();
524 
525   // Construct the argument to pass to the computation.
526   Literal arg_literal = LiteralUtil::CreateR0<float>(123.0);
527   for (int i = 0; i < kTupleDepth; ++i) {
528     std::vector<Literal> arg_vector;
529     arg_vector.push_back(std::move(arg_literal));
530     arg_literal = LiteralUtil::MakeTupleOwned(std::move(arg_vector));
531   }
532   auto arg_buffer = LiteralToShapedBuffer(arg_literal);
533 
534   ScopedShapedBuffer result = ExecuteLocallyOrDie(computation, {&arg_buffer});
535   Literal result_literal = ShapedBufferToLiteral(result);
536 
537   ShapeIndex index;
538   for (int i = 0; i < kTupleDepth; ++i) {
539     index.push_back(0);
540   }
541   LiteralTestUtil::ExpectR0Equal<float>(165.0,
542                                         LiteralSlice(result_literal, index));
543 }
544 
XLA_TEST_F(LocalClientExecuteTest,InvalidNumberOfArguments)545 XLA_TEST_F(LocalClientExecuteTest, InvalidNumberOfArguments) {
546   // Test passing in an invalid number of arguments.
547   XlaBuilder builder(TestName());
548   auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {3}), "x");
549   auto y = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {3}), "y");
550   Add(x, y);
551 
552   auto x_array =
553       LiteralToShapedBuffer(LiteralUtil::CreateR1<float>({1.0f, 2.0f, 3.0f}));
554   auto execute_status =
555       ExecuteLocally(builder.Build().ValueOrDie(), {&x_array});
556 
557   EXPECT_FALSE(execute_status.ok());
558   EXPECT_THAT(execute_status.status().error_message(),
559               ContainsRegex("Invalid number of arguments"));
560 }
561 
XLA_TEST_F(LocalClientExecuteTest,IncorrectArgumentShape)562 XLA_TEST_F(LocalClientExecuteTest, IncorrectArgumentShape) {
563   // Test passing in an argument with the wrong shape.
564   XlaBuilder builder(TestName());
565   auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {3}), "x");
566   Neg(x);
567 
568   auto x_array = LiteralToShapedBuffer(
569       LiteralUtil::CreateR2<float>({{0.0f, 1.0f}, {2.0f, 3.0f}}));
570   auto execute_status =
571       ExecuteLocally(builder.Build().ValueOrDie(), {&x_array});
572 
573   EXPECT_FALSE(execute_status.ok());
574   EXPECT_THAT(execute_status.status().error_message(),
575               ContainsRegex("Invalid argument shape"))
576       << execute_status.status();
577 }
578 
XLA_TEST_F(LocalClientExecuteTest,InvalidResultLayout)579 XLA_TEST_F(LocalClientExecuteTest, InvalidResultLayout) {
580   // Test passing in an invalid result layout parameter.
581   XlaBuilder builder(TestName());
582   auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2, 2}), "x");
583   Neg(x);
584 
585   auto x_array = LiteralToShapedBuffer(
586       LiteralUtil::CreateR2<float>({{0.0f, 1.0f}, {2.0f, 3.0f}}));
587   auto execute_status = ExecuteLocally(
588       builder.Build().ValueOrDie(), {&x_array},
589       DefaultExecutableBuildOptions().set_result_layout(
590           ShapeUtil::MakeShapeWithLayout(F32,
591                                          /*dimensions=*/{1, 2, 3, 4},
592                                          /*minor_to_major=*/{0, 1, 2, 3})),
593       DefaultExecutableRunOptions());
594 
595   EXPECT_FALSE(execute_status.ok());
596   EXPECT_THAT(execute_status.status().error_message(),
597               ContainsRegex("not compatible with result shape"))
598       << execute_status.status();
599 }
600 
XLA_TEST_F(LocalClientExecuteTest,RunOnAllDeviceOrdinals)601 XLA_TEST_F(LocalClientExecuteTest, RunOnAllDeviceOrdinals) {
602   // Try to run a trivial computation on every device on the system. If a
603   // specific device is not supported, check that the right error is returned.
604   XlaBuilder builder(TestName());
605   ConstantR0<float>(&builder, 42.0f);
606   auto computation = builder.Build().ConsumeValueOrDie();
607   for (int d = 0; d < local_client_->device_count(); ++d) {
608     if (!local_client_->device_ordinal_supported(d)) {
609       auto execute_status =
610           ExecuteLocally(computation, {},
611                          DefaultExecutableBuildOptions().set_device_ordinal(d),
612                          DefaultExecutableRunOptions().set_device_ordinal(d));
613       EXPECT_FALSE(execute_status.ok());
614       EXPECT_THAT(execute_status.status().error_message(),
615                   ContainsRegex("device .* not supported"));
616     } else {
617       auto result = ExecuteLocallyOrDie(
618           computation, {},
619           DefaultExecutableBuildOptions().set_device_ordinal(d),
620           DefaultExecutableRunOptions().set_device_ordinal(d));
621       EXPECT_EQ(d, result.device_ordinal());
622       LiteralTestUtil::ExpectR0Equal<float>(42.0f,
623                                             ShapedBufferToLiteral(result));
624     }
625   }
626 }
627 
XLA_TEST_F(LocalClientExecuteTest,InvalidDeviceOrdinalValues)628 XLA_TEST_F(LocalClientExecuteTest, InvalidDeviceOrdinalValues) {
629   // Try running computations on devices with device ordinal values which do not
630   // exist.
631   XlaBuilder builder(TestName());
632   ConstantR0<float>(&builder, 42.0f);
633   auto computation = builder.Build().ConsumeValueOrDie();
634 
635   auto execute_status =
636       ExecuteLocally(computation, {},
637                      DefaultExecutableBuildOptions().set_device_ordinal(
638                          local_client_->device_count()),
639                      DefaultExecutableRunOptions().set_device_ordinal(
640                          local_client_->device_count()));
641   EXPECT_FALSE(execute_status.ok());
642   EXPECT_THAT(execute_status.status().error_message(),
643               ContainsRegex("Invalid device ordinal value"));
644 }
645 
XLA_TEST_F(LocalClientExecuteTest,RunOnStream)646 XLA_TEST_F(LocalClientExecuteTest, RunOnStream) {
647   // Run a computation on a specific stream on each device on the system.
648   XlaBuilder builder(TestName());
649   ConstantR0<float>(&builder, 42.0f);
650   auto computation = builder.Build().ConsumeValueOrDie();
651 
652   for (int d = 0; d < local_client_->device_count(); ++d) {
653     if (!local_client_->device_ordinal_supported(d)) {
654       continue;
655     }
656     se::StreamExecutor* executor =
657         local_client_->platform()->ExecutorForDevice(d).ValueOrDie();
658     se::Stream stream(executor);
659     stream.Init();
660 
661     auto result =
662         ExecuteLocallyOrDie(computation, {}, DefaultExecutableBuildOptions(),
663                             DefaultExecutableRunOptions().set_stream(&stream));
664     // As a check to verify that the computation ran of the device associated
665     // with the stream. This is a weak check, but stronger verification is hard.
666     EXPECT_EQ(d, result.device_ordinal());
667     LiteralTestUtil::ExpectR0Equal<float>(42.0f, ShapedBufferToLiteral(result));
668   }
669 }
670 
671 // Disable this test on CPU because we're using the CPU as the platform
672 // which does not match the service platform.
XLA_TEST_F(LocalClientExecuteTest,DISABLED_ON_CPU (RunOnStreamForWrongPlatform))673 XLA_TEST_F(LocalClientExecuteTest,
674            DISABLED_ON_CPU(RunOnStreamForWrongPlatform)) {
675   // Try to run a computation on a stream for a platform (CPU) which does not
676   // match the platform of the service (!= CPU).
677   se::Platform* wrong_platform =
678       se::MultiPlatformManager::PlatformWithId(se::host::kHostPlatformId)
679           .ValueOrDie();
680   se::Stream wrong_stream(wrong_platform->ExecutorForDevice(0).ValueOrDie());
681   wrong_stream.Init();
682 
683   XlaBuilder builder(TestName());
684   ConstantR0<float>(&builder, 42.0f);
685   auto execute_status = ExecuteLocally(
686       builder.Build().ValueOrDie(), {}, DefaultExecutableBuildOptions(),
687       DefaultExecutableRunOptions().set_stream(&wrong_stream));
688   EXPECT_FALSE(execute_status.ok());
689   EXPECT_THAT(execute_status.status().error_message(),
690               ContainsRegex("stream is for platform .*, but service targets"));
691 }
692 
XLA_TEST_F(LocalClientExecuteTest,DISABLED_ON_CPU (AllocatorDoesNotMatchPlatform))693 XLA_TEST_F(LocalClientExecuteTest,
694            DISABLED_ON_CPU(AllocatorDoesNotMatchPlatform)) {
695   se::Platform* wrong_platform =
696       se::MultiPlatformManager::PlatformWithId(se::host::kHostPlatformId)
697           .ValueOrDie();
698   TestAllocator allocator(wrong_platform);
699 
700   XlaBuilder builder(TestName());
701   ConstantR0<float>(&builder, 123.0f);
702 
703   auto execute_status = ExecuteLocally(
704       builder.Build().ValueOrDie(), {}, DefaultExecutableBuildOptions(),
705       DefaultExecutableRunOptions().set_allocator(&allocator));
706   EXPECT_FALSE(execute_status.ok());
707   EXPECT_THAT(execute_status.status().error_message(),
708               ContainsRegex("allocator platform .* does not match service"));
709 }
710 
XLA_TEST_F(LocalClientExecuteTest,RunOnUninitializedStream)711 XLA_TEST_F(LocalClientExecuteTest, RunOnUninitializedStream) {
712   // Try to run a computation on a stream that has not been initialized.
713   XlaBuilder builder(TestName());
714   ConstantR0<float>(&builder, 42.0f);
715 
716   LOG(INFO) << "default device = " << local_client_->default_device_ordinal();
717   se::StreamExecutor* executor =
718       local_client_->platform()
719           ->ExecutorForDevice(local_client_->default_device_ordinal())
720           .ValueOrDie();
721   se::Stream stream(executor);
722   // Don't call stream.Init().
723 
724   auto execute_status = ExecuteLocally(
725       builder.Build().ValueOrDie(), {}, DefaultExecutableBuildOptions(),
726       DefaultExecutableRunOptions().set_stream(&stream));
727   EXPECT_FALSE(execute_status.ok());
728   EXPECT_THAT(execute_status.status().error_message(),
729               ContainsRegex("stream is uninitialized or in an error state"));
730 }
731 
XLA_TEST_F(LocalClientExecuteTest,SelectBetweenTuples)732 XLA_TEST_F(LocalClientExecuteTest, SelectBetweenTuples) {
733   XlaBuilder builder(TestName());
734 
735   std::initializer_list<float> vec1 = {1.f, 2.f, 3.f};
736   std::initializer_list<float> vec2 = {2.f, 4.f, 6.f};
737   auto tuple12 = Tuple(&builder, {ConstantR1<float>(&builder, vec1),
738                                   ConstantR1<float>(&builder, vec2)});
739   auto tuple21 = Tuple(&builder, {ConstantR1<float>(&builder, vec2),
740                                   ConstantR1<float>(&builder, vec1)});
741   Select(ConstantR0<bool>(&builder, false), tuple12, tuple21);
742 
743   ScopedShapedBuffer result =
744       ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {});
745   Literal tuple_literal = ShapedBufferToLiteral(result);
746   LiteralTestUtil::ExpectR1Equal<float>({2.0f, 4.0f, 6.0f},
747                                         LiteralSlice(tuple_literal, {0}));
748   LiteralTestUtil::ExpectR1Equal<float>({1.0f, 2.0f, 3.0f},
749                                         LiteralSlice(tuple_literal, {1}));
750 }
751 
XLA_TEST_F(LocalClientExecuteTest,CompileExecutable)752 XLA_TEST_F(LocalClientExecuteTest, CompileExecutable) {
753   XlaBuilder builder(TestName());
754   auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {3}), "x");
755   auto y = ConstantR1<float>(&builder, {2.0f, 3.0f, 4.0f});
756   Add(x, y);
757 
758   Shape argument_layout =
759       ShapeUtil::MakeShapeWithLayout(F32, /*dimensions=*/{3}, {0});
760   auto executable_status =
761       local_client_->Compile(builder.Build().ValueOrDie(), {&argument_layout},
762                              ExecutableBuildOptions());
763   ASSERT_IS_OK(executable_status);
764   std::unique_ptr<LocalExecutable> executable =
765       executable_status.ConsumeValueOrDie();
766 
767   auto x_array =
768       LiteralToShapedBuffer(LiteralUtil::CreateR1<float>({0.0f, 1.0f, 2.0f}));
769   ScopedShapedBuffer result =
770       executable->Run({&x_array}, DefaultExecutableRunOptions())
771           .ConsumeValueOrDie();
772   ASSERT_IS_OK(local_client_->mutable_backend()
773                    ->BorrowStream(0)
774                    .ValueOrDie()
775                    ->BlockHostUntilDone());
776 
777   LiteralTestUtil::ExpectR1Near<float>(
778       {2.0f, 4.0f, 6.0f}, ShapedBufferToLiteral(result), error_spec_);
779 }
780 
XLA_TEST_F(LocalClientExecuteTest,ShapeBufferToLiteralConversion)781 XLA_TEST_F(LocalClientExecuteTest, ShapeBufferToLiteralConversion) {
782   // Test copying Literals to the device as ShapedBuffers, then copying them
783   // back again to Literals.
784   auto test_to_device_and_back = [this](const Literal& literal) {
785     TF_ASSERT_OK_AND_ASSIGN(
786         auto shaped_buffer,
787         local_client_->LiteralToShapedBuffer(
788             literal, local_client_->default_device_ordinal(), allocator_));
789     TF_ASSERT_OK_AND_ASSIGN(
790         auto transferred_literal,
791         local_client_->ShapedBufferToLiteral(shaped_buffer));
792     EXPECT_EQ(literal, transferred_literal);
793   };
794 
795   // Array shapes.
796   test_to_device_and_back(LiteralUtil::CreateR0<float>(42.0));
797   test_to_device_and_back(LiteralUtil::CreateR0<bool>(true));
798   test_to_device_and_back(LiteralUtil::CreateR1<float>({1.0, 42.0, 744.4}));
799   test_to_device_and_back(
800       LiteralUtil::CreateR2<float>({{1.0, 2.0, 3.0}, {44.0, 0.1, -3}}));
801   test_to_device_and_back(LiteralUtil::CreateR2<int32>({{2, 1}, {4444, 56}}));
802 
803   // Null shape (empty tuple).
804   test_to_device_and_back(LiteralUtil::MakeTuple({}));
805 
806   // Non-nested tuples.
807   test_to_device_and_back(LiteralUtil::MakeTupleFromSlices(
808       {LiteralUtil::CreateR0<float>(12223.0)}));
809   test_to_device_and_back(LiteralUtil::MakeTupleFromSlices(
810       {LiteralUtil::CreateR1<float>({1.0, -42.0}),
811        LiteralUtil::CreateR0<float>(123456.0)}));
812 
813   // Nested tuple.
814   test_to_device_and_back(LiteralUtil::MakeTupleFromSlices(
815       {LiteralUtil::MakeTupleFromSlices(
816            {LiteralUtil::CreateR1<float>({1.0, -42.0}),
817             LiteralUtil::CreateR0<float>(123456.0)}),
818        LiteralUtil::CreateR0<bool>(false)}));
819 }
820 
XLA_TEST_F(LocalClientExecuteTest,ShapeBufferToLiteralConversion64bit)821 XLA_TEST_F(LocalClientExecuteTest, ShapeBufferToLiteralConversion64bit) {
822   // Test copying Literals to the device as ShapedBuffers, then copying them
823   // back again to Literals for 64-bit values.
824   auto test_to_device_and_back = [this](const Literal& literal) {
825     TF_ASSERT_OK_AND_ASSIGN(
826         auto shaped_buffer,
827         local_client_->LiteralToShapedBuffer(
828             literal, local_client_->default_device_ordinal(), allocator_));
829     TF_ASSERT_OK_AND_ASSIGN(
830         auto transferred_literal,
831         local_client_->ShapedBufferToLiteral(shaped_buffer));
832     EXPECT_EQ(literal, transferred_literal);
833   };
834 
835   test_to_device_and_back(
836       LiteralUtil::CreateR2<double>({{1.0, 2.0, 3.0}, {44.0, 0.1, -3}}));
837   test_to_device_and_back(LiteralUtil::CreateR2<int64>({{2, 1}, {4444, 56}}));
838   test_to_device_and_back(
839       LiteralUtil::CreateR2<uint64>({{20000000000ULL, 1}, {4444, 56}}));
840   test_to_device_and_back(LiteralUtil::MakeTupleFromSlices(
841       {LiteralUtil::CreateR1<double>({1.0, -42.0}),
842        LiteralUtil::CreateR0<int64>(123456789000LL)}));
843 }
844 
845 // Disabled on interpreter backend since infeed HLO is unsupported.
XLA_TEST_F(LocalClientExecuteTest,DISABLED_ON_INTERPRETER (InfeedTest))846 XLA_TEST_F(LocalClientExecuteTest, DISABLED_ON_INTERPRETER(InfeedTest)) {
847   XlaBuilder builder(TestName());
848   const Shape shape = ShapeUtil::MakeShape(F32, {3});
849   auto in = Infeed(&builder, shape);
850   auto constant = ConstantR1<float>(&builder, {1.0f, 2.0f, 3.0f});
851   Add(in, constant);
852 
853   Literal result;
854   std::unique_ptr<tensorflow::Thread> thread(
855       tensorflow::Env::Default()->StartThread(
856           tensorflow::ThreadOptions(), "execute_thread", [&] {
857             result = ShapedBufferToLiteral(ExecuteLocallyOrDie(
858                 builder.Build().ValueOrDie(), /*arguments=*/{}));
859           }));
860 
861   ASSERT_IS_OK(local_client_->TransferToInfeedLocal(
862       LiteralUtil::CreateR1<float>({-5.0, 123.0, 42.0}),
863       local_client_->default_device_ordinal()));
864 
865   // Join the thread.
866   thread.reset();
867 
868   LiteralTestUtil::ExpectR1Equal<float>({-4.0, 125.0, 45.0}, result);
869 }
870 
871 // Disabled on interpreter backend since infeed/outfeed HLOs are unsupported.
XLA_TEST_F(LocalClientExecuteTest,DISABLED_ON_INTERPRETER (InfeedOutfeedTest))872 XLA_TEST_F(LocalClientExecuteTest, DISABLED_ON_INTERPRETER(InfeedOutfeedTest)) {
873   XlaBuilder builder(TestName());
874   const Shape shape = ShapeUtil::MakeShape(F32, {3});
875   auto in = Infeed(&builder, shape);
876   auto constant = ConstantR1<float>(&builder, {1.0f, 2.0f, 3.0f});
877   auto sum = Add(in, constant);
878   Outfeed(sum, shape, /*outfeed_config=*/"");
879 
880   std::unique_ptr<tensorflow::Thread> thread(
881       tensorflow::Env::Default()->StartThread(
882           tensorflow::ThreadOptions(), "execute_thread",
883           [&] { ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {}); }));
884 
885   ASSERT_IS_OK(local_client_->TransferToInfeedLocal(
886       LiteralUtil::CreateR1<float>({-5.0, 123.0, 42.0}),
887       local_client_->default_device_ordinal()));
888 
889   TF_ASSERT_OK_AND_ASSIGN(Literal result,
890                           local_client_->TransferFromOutfeedLocal(
891                               shape, local_client_->default_device_ordinal()));
892 
893   LiteralTestUtil::ExpectR1Equal<float>({-4.0, 125.0, 45.0}, result);
894 }
895 
896 // Benchmark that measures the overhead of the LocalClient API when running a
897 // trivial computation
BM_LocalClientOverhead(int num_iters)898 void BM_LocalClientOverhead(int num_iters) {
899   tensorflow::testing::StopTiming();
900 
901   se::Platform* platform = PlatformUtil::GetDefaultPlatform().ValueOrDie();
902   auto executors = PlatformUtil::GetStreamExecutors(platform).ValueOrDie();
903   StreamExecutorMemoryAllocator allocator(platform, executors);
904   LocalClient* client =
905       ClientLibrary::GetOrCreateLocalClient(platform).ValueOrDie();
906   auto* transfer_manager =
907       TransferManager::GetForPlatform(platform).ValueOrDie();
908   int device_ordinal = client->default_device_ordinal();
909 
910   // Use a tiny add operation as the computation.
911   XlaBuilder builder("Add");
912   auto shape = ShapeUtil::MakeShape(F32, {2, 3});
913   auto x = Parameter(&builder, 0, shape, "x");
914   Add(x, x);
915   auto computation = builder.Build().ConsumeValueOrDie();
916 
917   auto buffer =
918       transfer_manager
919           ->AllocateScopedShapedBuffer(shape, &allocator, /*device_ordinal=*/0)
920           .ConsumeValueOrDie();
921   auto literal = LiteralUtil::CreateR2<float>({{0, 0, 0}, {0, 0, 0}});
922   auto stream =
923       client->mutable_backend()->BorrowStream(device_ordinal).ValueOrDie();
924   ASSERT_IS_OK(
925       transfer_manager->TransferLiteralToDevice(stream.get(), literal, buffer));
926 
927   const int kWarmups = 2;
928 
929   auto executable_status = client->Compile(
930       computation, {&buffer.on_host_shape()}, ExecutableBuildOptions());
931   ASSERT_IS_OK(executable_status);
932   std::unique_ptr<LocalExecutable> executable =
933       executable_status.ConsumeValueOrDie();
934 
935   ExecutableRunOptions run_options;
936   run_options.set_allocator(&allocator).set_stream(stream.get());
937 
938   for (int i = 0; i < kWarmups; ++i) {
939     auto result = executable->Run({&buffer}, run_options);
940     ASSERT_IS_OK(result);
941   }
942 
943   tensorflow::testing::StartTiming();
944   for (int i = 0; i < num_iters; ++i) {
945     auto result = executable->Run({&buffer}, run_options);
946     ASSERT_IS_OK(result);
947   }
948 }
949 
950 BENCHMARK(BM_LocalClientOverhead);
951 
952 }  // namespace
953 }  // namespace xla
954