1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <initializer_list>
17 #include <memory>
18 #include <vector>
19
20 #include "tensorflow/compiler/xla/client/client_library.h"
21 #include "tensorflow/compiler/xla/client/local_client.h"
22 #include "tensorflow/compiler/xla/client/xla_builder.h"
23 #include "tensorflow/compiler/xla/layout_util.h"
24 #include "tensorflow/compiler/xla/literal.h"
25 #include "tensorflow/compiler/xla/service/device_memory_allocator.h"
26 #include "tensorflow/compiler/xla/service/local_service.h"
27 #include "tensorflow/compiler/xla/service/platform_util.h"
28 #include "tensorflow/compiler/xla/service/shaped_buffer.h"
29 #include "tensorflow/compiler/xla/service/transfer_manager.h"
30 #include "tensorflow/compiler/xla/shape_util.h"
31 #include "tensorflow/compiler/xla/statusor.h"
32 #include "tensorflow/compiler/xla/test.h"
33 #include "tensorflow/compiler/xla/test_helpers.h"
34 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
35 #include "tensorflow/compiler/xla/tests/local_client_test_base.h"
36 #include "tensorflow/compiler/xla/tests/test_macros.h"
37 #include "tensorflow/compiler/xla/tests/test_utils.h"
38 #include "tensorflow/compiler/xla/xla_data.pb.h"
39 #include "tensorflow/core/platform/env.h"
40 #include "tensorflow/core/platform/logging.h"
41 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
42 #include "tensorflow/core/platform/test.h"
43 #include "tensorflow/core/platform/test_benchmark.h"
44
45 namespace xla {
46 namespace {
47
48 using ::testing::ContainsRegex;
49
50 class LocalClientExecuteTest : public LocalClientTestBase {
51 protected:
52 ErrorSpec error_spec_{0.0001};
53 };
54
XLA_TEST_F(LocalClientExecuteTest,Constant)55 XLA_TEST_F(LocalClientExecuteTest, Constant) {
56 XlaBuilder builder(TestName());
57 ConstantR0<float>(&builder, 123.0f);
58
59 ScopedShapedBuffer result =
60 ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {});
61 LiteralTestUtil::ExpectR0Near<float>(123.f, ShapedBufferToLiteral(result),
62 error_spec_);
63 }
64
XLA_TEST_F(LocalClientExecuteTest,AddScalars)65 XLA_TEST_F(LocalClientExecuteTest, AddScalars) {
66 XlaBuilder builder(TestName());
67 auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {}), "x");
68 auto y = ConstantR0<float>(&builder, 123.0f);
69 Add(x, y);
70
71 auto x_value = LiteralToShapedBuffer(LiteralUtil::CreateR0<float>(42.0f));
72 ScopedShapedBuffer result =
73 ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {&x_value});
74 LiteralTestUtil::ExpectR0Near<float>(165.f, ShapedBufferToLiteral(result),
75 error_spec_);
76 }
77
XLA_TEST_F(LocalClientExecuteTest,AddZeroElementVectors)78 XLA_TEST_F(LocalClientExecuteTest, AddZeroElementVectors) {
79 XlaBuilder builder(TestName());
80 auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {0}), "x");
81 auto y = ConstantR1<float>(&builder, {});
82 Add(x, y);
83
84 auto x_array = LiteralToShapedBuffer(LiteralUtil::CreateR1<float>({}));
85 ScopedShapedBuffer result =
86 ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {&x_array});
87 LiteralTestUtil::ExpectR1Near<float>({}, ShapedBufferToLiteral(result),
88 error_spec_);
89 }
90
XLA_TEST_F(LocalClientExecuteTest,AddVectors)91 XLA_TEST_F(LocalClientExecuteTest, AddVectors) {
92 XlaBuilder builder(TestName());
93 auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {3}), "x");
94 auto y = ConstantR1<float>(&builder, {2.0f, 3.0f, 4.0f});
95 Add(x, y);
96
97 auto x_array =
98 LiteralToShapedBuffer(LiteralUtil::CreateR1<float>({0.0f, 1.0f, 2.0f}));
99 ScopedShapedBuffer result =
100 ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {&x_array});
101 LiteralTestUtil::ExpectR1Near<float>(
102 {2.0f, 4.0f, 6.0f}, ShapedBufferToLiteral(result), error_spec_);
103 }
104
XLA_TEST_F(LocalClientExecuteTest,AddVectorsWithProfile)105 XLA_TEST_F(LocalClientExecuteTest, AddVectorsWithProfile) {
106 XlaBuilder builder(TestName());
107 auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {3}), "x");
108 auto y = ConstantR1<float>(&builder, {2.0f, 3.0f, 4.0f});
109 Add(x, y);
110
111 auto x_array =
112 LiteralToShapedBuffer(LiteralUtil::CreateR1<float>({0.0f, 1.0f, 2.0f}));
113 ExecutionProfile profile;
114 ScopedShapedBuffer result = ExecuteLocallyOrDie(
115 builder.Build().ValueOrDie(), {&x_array}, DefaultExecutableBuildOptions(),
116 DefaultExecutableRunOptions().set_execution_profile(&profile));
117
118 LiteralTestUtil::ExpectR1Near<float>(
119 {2.0f, 4.0f, 6.0f}, ShapedBufferToLiteral(result), error_spec_);
120 EXPECT_GT(profile.compute_and_transfer_time_ns(), 0);
121 }
122
XLA_TEST_F(LocalClientExecuteTest,AddArraysWithDifferentInputLayouts)123 XLA_TEST_F(LocalClientExecuteTest, AddArraysWithDifferentInputLayouts) {
124 XlaBuilder builder(TestName());
125 auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2, 2}), "x");
126 auto y = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {2, 2}), "y");
127 Add(x, y);
128 auto computation = builder.Build().ConsumeValueOrDie();
129
130 // Create x as a col-major array.
131 auto x_array = LiteralToShapedBuffer(LiteralUtil::CreateR2WithLayout(
132 {{1.0f, 2.0f}, {3.0f, 4.0f}}, LayoutUtil::MakeLayout({0, 1})));
133 EXPECT_TRUE(LayoutUtil::Equal(x_array.on_device_shape().layout(),
134 LayoutUtil::MakeLayout({0, 1})));
135
136 // Create y as a row-major array.
137 auto y_array = LiteralToShapedBuffer(LiteralUtil::CreateR2WithLayout(
138 {{10.0f, 20.0f}, {30.0f, 40.0f}}, LayoutUtil::MakeLayout({1, 0})));
139 EXPECT_TRUE(LayoutUtil::Equal(y_array.on_device_shape().layout(),
140 LayoutUtil::MakeLayout({1, 0})));
141
142 ScopedShapedBuffer result_colmaj =
143 ExecuteLocallyOrDie(computation, {&x_array, &y_array});
144 LiteralTestUtil::ExpectR2Near<float>({{11.0f, 22.0f}, {33.0f, 44.0f}},
145 ShapedBufferToLiteral(result_colmaj),
146 error_spec_);
147
148 // Run with the parameter values in a different order.
149 ScopedShapedBuffer result_param_swap =
150 ExecuteLocallyOrDie(computation, {&y_array, &x_array});
151 LiteralTestUtil::ExpectR2Near<float>({{11.0f, 22.0f}, {33.0f, 44.0f}},
152 ShapedBufferToLiteral(result_param_swap),
153 error_spec_);
154 }
155
XLA_TEST_F(LocalClientExecuteTest,AddArraysWithDifferentOutputLayouts)156 XLA_TEST_F(LocalClientExecuteTest, AddArraysWithDifferentOutputLayouts) {
157 XlaBuilder builder(TestName());
158 auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2, 2}), "x");
159 auto y = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {2, 2}), "y");
160 Add(x, y);
161 auto computation = builder.Build().ConsumeValueOrDie();
162
163 auto x_array = LiteralToShapedBuffer(
164 LiteralUtil::CreateR2<float>({{1.0f, 2.0f}, {3.0f, 4.0f}}));
165 auto y_array = LiteralToShapedBuffer(
166 LiteralUtil::CreateR2<float>({{10.0f, 20.0f}, {30.0f, 40.0f}}));
167
168 // Run with col-major result layout.
169 ScopedShapedBuffer result_colmaj = ExecuteLocallyOrDie(
170 computation, {&x_array, &y_array},
171 DefaultExecutableBuildOptions().set_result_layout(
172 ShapeUtil::MakeShapeWithLayout(F32, /*dimensions=*/{2, 2}, {0, 1})),
173 DefaultExecutableRunOptions());
174 EXPECT_TRUE(LayoutUtil::Equal(result_colmaj.on_device_shape().layout(),
175 LayoutUtil::MakeLayout({0, 1})));
176 LiteralTestUtil::ExpectR2Near<float>({{11.0f, 22.0f}, {33.0f, 44.0f}},
177 ShapedBufferToLiteral(result_colmaj),
178 error_spec_);
179
180 // Run with row-major result layout.
181 ScopedShapedBuffer result_rowmaj = ExecuteLocallyOrDie(
182 computation, {&x_array, &y_array},
183 DefaultExecutableBuildOptions().set_result_layout(
184 ShapeUtil::MakeShapeWithLayout(F32, /*dimensions=*/{2, 2}, {1, 0})),
185 DefaultExecutableRunOptions());
186 EXPECT_TRUE(LayoutUtil::Equal(result_rowmaj.on_device_shape().layout(),
187 LayoutUtil::MakeLayout({1, 0})));
188 LiteralTestUtil::ExpectR2Near<float>({{11.0f, 22.0f}, {33.0f, 44.0f}},
189 ShapedBufferToLiteral(result_rowmaj),
190 error_spec_);
191 }
192
XLA_TEST_F(LocalClientExecuteTest,TupleResult)193 XLA_TEST_F(LocalClientExecuteTest, TupleResult) {
194 XlaBuilder builder(TestName());
195 auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2, 2}), "x");
196 auto y = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {2, 2}), "y");
197 Tuple(&builder, {x, y, x});
198 auto computation = builder.Build().ConsumeValueOrDie();
199
200 auto x_array = LiteralToShapedBuffer(
201 LiteralUtil::CreateR2<float>({{1.0f, 2.0f}, {3.0f, 4.0f}}));
202 auto y_array = LiteralToShapedBuffer(
203 LiteralUtil::CreateR2<float>({{10.0f, 20.0f}, {30.0f, 40.0f}}));
204
205 ScopedShapedBuffer result =
206 ExecuteLocallyOrDie(computation, {&x_array, &y_array});
207
208 EXPECT_TRUE(result.on_host_shape().IsTuple());
209 EXPECT_EQ(3, ShapeUtil::TupleElementCount(result.on_host_shape()));
210
211 Literal result_literal = ShapedBufferToLiteral(result);
212 LiteralTestUtil::ExpectR2Equal<float>({{1.0f, 2.0f}, {3.0f, 4.0f}},
213 LiteralSlice(result_literal, {0}));
214 LiteralTestUtil::ExpectR2Equal<float>({{10.0f, 20.0f}, {30.0f, 40.0f}},
215 LiteralSlice(result_literal, {1}));
216 LiteralTestUtil::ExpectR2Equal<float>({{1.0f, 2.0f}, {3.0f, 4.0f}},
217 LiteralSlice(result_literal, {2}));
218 }
219
XLA_TEST_F(LocalClientExecuteTest,NestedTupleResult)220 XLA_TEST_F(LocalClientExecuteTest, NestedTupleResult) {
221 XlaBuilder builder(TestName());
222 auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2, 2}), "x");
223 auto y = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {2, 2}), "y");
224 auto inner_tuple = Tuple(&builder, {x, y, x});
225 Tuple(&builder, {inner_tuple, x});
226 auto computation = builder.Build().ConsumeValueOrDie();
227
228 auto x_array = LiteralToShapedBuffer(
229 LiteralUtil::CreateR2<float>({{1.0f, 2.0f}, {3.0f, 4.0f}}));
230 auto y_array = LiteralToShapedBuffer(
231 LiteralUtil::CreateR2<float>({{10.0f, 20.0f}, {30.0f, 40.0f}}));
232
233 ScopedShapedBuffer result =
234 ExecuteLocallyOrDie(computation, {&x_array, &y_array});
235
236 EXPECT_TRUE(result.on_host_shape().IsTuple());
237 EXPECT_EQ(2, ShapeUtil::TupleElementCount(result.on_host_shape()));
238
239 Literal result_literal = ShapedBufferToLiteral(result);
240 LiteralTestUtil::ExpectR2Equal<float>({{1.0f, 2.0f}, {3.0f, 4.0f}},
241 LiteralSlice(result_literal, {1}));
242 LiteralTestUtil::ExpectR2Equal<float>({{1.0f, 2.0f}, {3.0f, 4.0f}},
243 LiteralSlice(result_literal, {0, 0}));
244 LiteralTestUtil::ExpectR2Equal<float>({{10.0f, 20.0f}, {30.0f, 40.0f}},
245 LiteralSlice(result_literal, {0, 1}));
246 LiteralTestUtil::ExpectR2Equal<float>({{1.0f, 2.0f}, {3.0f, 4.0f}},
247 LiteralSlice(result_literal, {0, 2}));
248 }
249
XLA_TEST_F(LocalClientExecuteTest,TupleResultWithLayout)250 XLA_TEST_F(LocalClientExecuteTest, TupleResultWithLayout) {
251 // Verify setting the result layout of a computation with a tuple output.
252 XlaBuilder builder(TestName());
253 auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2, 2}), "x");
254 auto y = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {2, 2}), "y");
255 Tuple(&builder, {x, y});
256
257 auto array = LiteralToShapedBuffer(
258 LiteralUtil::CreateR2<float>({{1.0f, 2.0f}, {3.0f, 4.0f}}));
259
260 ExecutableBuildOptions options = DefaultExecutableBuildOptions();
261 Shape shape_with_layout = ShapeUtil::MakeTupleShape(
262 {ShapeUtil::MakeShapeWithLayout(F32, /*dimensions=*/{2, 2},
263 /*minor_to_major=*/{0, 1}),
264 ShapeUtil::MakeShapeWithLayout(F32, /*dimensions=*/{2, 2},
265 /*minor_to_major=*/{1, 0})});
266 options.set_result_layout(shape_with_layout);
267 ScopedShapedBuffer result =
268 ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {&array, &array},
269 options, DefaultExecutableRunOptions());
270
271 Literal result_literal = ShapedBufferToLiteral(result);
272 LiteralTestUtil::ExpectR2Equal<float>({{1.0f, 2.0f}, {3.0f, 4.0f}},
273 LiteralSlice(result_literal, {0}));
274 LiteralTestUtil::ExpectR2Equal<float>({{1.0f, 2.0f}, {3.0f, 4.0f}},
275 LiteralSlice(result_literal, {1}));
276 }
277
XLA_TEST_F(LocalClientExecuteTest,TupleArguments)278 XLA_TEST_F(LocalClientExecuteTest, TupleArguments) {
279 const Shape array_shape = ShapeUtil::MakeShape(F32, {2, 2});
280 const Shape vector_shape = ShapeUtil::MakeShape(F32, {3});
281
282 const Shape tuple_shape0 =
283 ShapeUtil::MakeTupleShape({array_shape, vector_shape});
284 const Shape tuple_shape1 =
285 ShapeUtil::MakeTupleShape({vector_shape, array_shape});
286
287 // Computation adds the respective array and vector elements from each tuple
288 // argument and returns the results as a tuple.
289 XlaBuilder builder(TestName());
290 auto x = Parameter(&builder, 0, tuple_shape0, "x");
291 auto y = Parameter(&builder, 1, tuple_shape1, "y");
292 auto x_0 = GetTupleElement(x, 0);
293 auto x_1 = GetTupleElement(x, 1);
294 auto y_0 = GetTupleElement(y, 0);
295 auto y_1 = GetTupleElement(y, 1);
296 auto array_sum = Add(x_0, y_1);
297 auto vector_diff = Sub(x_1, y_0);
298 Tuple(&builder, {array_sum, vector_diff});
299 auto computation = builder.Build().ConsumeValueOrDie();
300
301 auto x_literal = LiteralUtil::MakeTupleFromSlices(
302 {LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}),
303 LiteralUtil::CreateR1<float>({42.0, 75.0, 123.0})});
304 auto y_literal = LiteralUtil::MakeTupleFromSlices(
305 {LiteralUtil::CreateR1<float>({2.0, 4.0, 6.0}),
306 LiteralUtil::CreateR2<float>({{55.0, 44.0}, {33.0, 22.0}})});
307
308 auto x_buffer = LiteralToShapedBuffer(x_literal);
309 auto y_buffer = LiteralToShapedBuffer(y_literal);
310
311 ScopedShapedBuffer result =
312 ExecuteLocallyOrDie(computation, {&x_buffer, &y_buffer});
313
314 EXPECT_TRUE(result.on_host_shape().IsTuple());
315 EXPECT_EQ(2, ShapeUtil::TupleElementCount(result.on_host_shape()));
316
317 Literal result_literal = ShapedBufferToLiteral(result);
318 LiteralTestUtil::ExpectR2Equal<float>({{56.0f, 46.0f}, {36.0f, 26.0f}},
319 LiteralSlice(result_literal, {0}));
320 LiteralTestUtil::ExpectR1Equal<float>({40.0f, 71.0f, 117.0f},
321 LiteralSlice(result_literal, {1}));
322 }
323
XLA_TEST_F(LocalClientExecuteTest,NestedTupleArgument)324 XLA_TEST_F(LocalClientExecuteTest, NestedTupleArgument) {
325 const Shape array_shape = ShapeUtil::MakeShape(F32, {2, 2});
326 const Shape vector_shape = ShapeUtil::MakeShape(F32, {3});
327
328 const Shape inner_tuple_shape =
329 ShapeUtil::MakeTupleShape({array_shape, vector_shape});
330 const Shape nested_tuple_shape =
331 ShapeUtil::MakeTupleShape({inner_tuple_shape, vector_shape});
332
333 // Computation negates the array element and sums the two vector elements in
334 // the nested tuple. The resulting array and vector are returned as a tuple.
335 XlaBuilder builder(TestName());
336 auto param = Parameter(&builder, 0, nested_tuple_shape, "param");
337 auto inner_tuple = GetTupleElement(param, 0);
338 auto inner_array = GetTupleElement(inner_tuple, 0);
339 auto inner_vector = GetTupleElement(inner_tuple, 1);
340 auto outer_vector = GetTupleElement(param, 1);
341
342 auto negate_array = Neg(inner_array);
343 auto vector_sum = Add(inner_vector, outer_vector);
344 Tuple(&builder, {negate_array, vector_sum});
345 auto computation = builder.Build().ConsumeValueOrDie();
346
347 auto arg_literal = LiteralUtil::MakeTupleFromSlices(
348 {LiteralUtil::MakeTupleFromSlices(
349 {LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}),
350 LiteralUtil::CreateR1<float>({42.0, 75.0, 123.0})}),
351 LiteralUtil::CreateR1<float>({222.0, -2.0, 10.0})});
352 auto arg_buffer = LiteralToShapedBuffer(arg_literal);
353
354 ScopedShapedBuffer result = ExecuteLocallyOrDie(computation, {&arg_buffer});
355
356 Literal result_literal = ShapedBufferToLiteral(result);
357 LiteralTestUtil::ExpectR2Equal<float>({{-1.0, -2.0}, {-3.0, -4}},
358 LiteralSlice(result_literal, {0}));
359 LiteralTestUtil::ExpectR1Equal<float>({264.0, 73.0, 133.0},
360 LiteralSlice(result_literal, {1}));
361 }
362
XLA_TEST_F(LocalClientExecuteTest,PassingTupleResultBackIntoComputation)363 XLA_TEST_F(LocalClientExecuteTest, PassingTupleResultBackIntoComputation) {
364 // Construct a computation which takes and returns the same shape (a
365 // tuple). Feed the result of the computation back into the input. This
366 // provides additional verification that the returned tuple is properly
367 // constructed.
368 const Shape array_shape = ShapeUtil::MakeShape(F32, {2, 2});
369 const Shape tuple_shape =
370 ShapeUtil::MakeTupleShape({array_shape, array_shape});
371
372 XlaBuilder builder(TestName());
373 auto param = Parameter(&builder, 0, tuple_shape, "param");
374 auto element_0 = GetTupleElement(param, 0);
375 auto element_1 = GetTupleElement(param, 1);
376 Tuple(&builder, {Neg(element_0), Add(element_1, element_1)});
377 auto computation = builder.Build().ConsumeValueOrDie();
378
379 auto arg_literal = LiteralUtil::MakeTupleFromSlices(
380 {LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}),
381 LiteralUtil::CreateR2<float>({{11.0, 3.0}, {4.0, 5.0}})});
382 auto arg_buffer = LiteralToShapedBuffer(arg_literal);
383
384 ScopedShapedBuffer result_0 = ExecuteLocallyOrDie(computation, {&arg_buffer});
385 Literal result_0_literal = ShapedBufferToLiteral(result_0);
386 LiteralTestUtil::ExpectR2Equal<float>({{-1.0, -2.0}, {-3.0, -4.0}},
387 LiteralSlice(result_0_literal, {0}));
388 LiteralTestUtil::ExpectR2Equal<float>({{22.0, 6.0}, {8.0, 10}},
389 LiteralSlice(result_0_literal, {1}));
390
391 ScopedShapedBuffer result_1 = ExecuteLocallyOrDie(computation, {&result_0});
392 Literal result_1_literal = ShapedBufferToLiteral(result_1);
393 LiteralTestUtil::ExpectR2Equal<float>({{1.0, 2.0}, {3.0, 4.0}},
394 LiteralSlice(result_1_literal, {0}));
395 LiteralTestUtil::ExpectR2Equal<float>({{44.0, 12.0}, {16.0, 20}},
396 LiteralSlice(result_1_literal, {1}));
397 }
398
XLA_TEST_F(LocalClientExecuteTest,LargeTuple)399 XLA_TEST_F(LocalClientExecuteTest, LargeTuple) {
400 // Construct a computation which takes a tuple parameter with a very large
401 // number of elements.
402
403 // A larger number of elements would make for a better, more strenuous test,
404 // but:
405 // TODO(b/66959878): On cpu a large number of elements results in long
406 // compilation time.
407 // TODO(b/66954197): On gpu a large number of elements OOMs.
408 const int kElementCount = 100;
409
410 // Each element is a 2-element vector.
411 const Shape element_shape = ShapeUtil::MakeShape(F32, {2});
412 std::vector<Shape> element_shapes(kElementCount, element_shape);
413 const Shape tuple_shape = ShapeUtil::MakeTupleShape(element_shapes);
414
415 XlaBuilder builder(TestName());
416 auto param = Parameter(&builder, 0, tuple_shape, "param");
417
418 // Add each element's tuple index value to every element.
419 std::vector<XlaOp> result_elements;
420 for (int i = 0; i < kElementCount; ++i) {
421 auto element = GetTupleElement(param, i);
422 result_elements.push_back(Add(element, ConstantR0<float>(&builder, i)));
423 }
424 Tuple(&builder, result_elements);
425 auto computation = builder.Build().ConsumeValueOrDie();
426
427 // Feed in a tuple where each two-element vector element is {tuple_index,
428 // -tuple_index}.
429 std::vector<Literal> arg_elements;
430 for (int i = 0; i < kElementCount; ++i) {
431 arg_elements.push_back(LiteralUtil::CreateR1<float>({1.0f * i, -1.0f * i}));
432 }
433 Literal arg_literal = LiteralUtil::MakeTupleOwned(std::move(arg_elements));
434 auto arg_buffer = LiteralToShapedBuffer(arg_literal);
435
436 ScopedShapedBuffer result = ExecuteLocallyOrDie(computation, {&arg_buffer});
437 Literal result_literal = ShapedBufferToLiteral(result);
438
439 for (int i = 0; i < kElementCount; ++i) {
440 LiteralTestUtil::ExpectR1Near<float>(
441 {2.0f * i, 0.0f}, LiteralSlice(result_literal, {i}), error_spec_);
442 }
443 }
444
XLA_TEST_F(LocalClientExecuteTest,LargeNestedTuple)445 XLA_TEST_F(LocalClientExecuteTest, LargeNestedTuple) {
446 // Construct and run a computation which takes a two-level nested tuple
447 // parameter with a large fanout.
448 const int kFanout = 40;
449
450 // Tuple shape is full two-level tree with the given fanout.
451 const Shape element_shape = ShapeUtil::MakeShape(F32, {});
452 std::vector<Shape> element_shapes(kFanout, element_shape);
453 const Shape inner_tuple_shape = ShapeUtil::MakeTupleShape(element_shapes);
454 std::vector<Shape> inner_tuple_shapes(kFanout, inner_tuple_shape);
455 const Shape tuple_shape = ShapeUtil::MakeTupleShape(inner_tuple_shapes);
456
457 XlaBuilder builder(TestName());
458 auto param = Parameter(&builder, 0, tuple_shape, "param");
459
460 // The computation increments each leaf value by an amount equal to the leaf's
461 // ordinal position in a traversal of the tuple.
462 std::vector<XlaOp> result_elements;
463 for (int i = 0; i < kFanout; ++i) {
464 auto outer_element = GetTupleElement(param, i);
465 std::vector<XlaOp> inner_result_elements;
466 for (int j = 0; j < kFanout; ++j) {
467 auto inner_element = GetTupleElement(outer_element, j);
468 inner_result_elements.push_back(
469 Add(inner_element, ConstantR0<float>(&builder, i * kFanout + j)));
470 }
471 result_elements.push_back(Tuple(&builder, inner_result_elements));
472 }
473 Tuple(&builder, result_elements);
474 auto computation = builder.Build().ConsumeValueOrDie();
475
476 // Construct the argument to pass to the computation.
477 std::vector<Literal> outer_tuple_elements;
478 for (int i = 0; i < kFanout; ++i) {
479 std::vector<Literal> inner_tuple_elements;
480 for (int j = 0; j < kFanout; ++j) {
481 inner_tuple_elements.push_back(LiteralUtil::CreateR0<float>(i + j));
482 }
483 outer_tuple_elements.push_back(
484 LiteralUtil::MakeTupleOwned(std::move(inner_tuple_elements)));
485 }
486 auto arg_literal =
487 LiteralUtil::MakeTupleOwned(std::move(outer_tuple_elements));
488 auto arg_buffer = LiteralToShapedBuffer(arg_literal);
489
490 ScopedShapedBuffer result = ExecuteLocallyOrDie(computation, {&arg_buffer});
491 Literal result_literal = ShapedBufferToLiteral(result);
492
493 for (int i = 0; i < kFanout; ++i) {
494 for (int j = 0; j < kFanout; ++j) {
495 LiteralTestUtil::ExpectR0Near<float>(i + j + i * kFanout + j,
496 LiteralSlice(result_literal, {i, j}),
497 error_spec_);
498 }
499 }
500 }
501
XLA_TEST_F(LocalClientExecuteTest,DeepTuple)502 XLA_TEST_F(LocalClientExecuteTest, DeepTuple) {
503 // Construct and run a computation which takes a very deep tuple. The tuple
504 // has no fan out and a single scalar element at the bottom.
505 const int kTupleDepth = 100;
506
507 // Tuple shape is full two-level tree with the given fanout.
508 Shape shape = ShapeUtil::MakeShape(F32, {});
509 for (int i = 0; i < kTupleDepth; ++i) {
510 shape = ShapeUtil::MakeTupleShape({shape});
511 }
512
513 XlaBuilder builder(TestName());
514 auto element = Parameter(&builder, 0, shape, "param");
515 for (int i = 0; i < kTupleDepth; ++i) {
516 element = GetTupleElement(element, 0);
517 }
518
519 auto output = Add(element, ConstantR0<float>(&builder, 42.0));
520 for (int i = 0; i < kTupleDepth; ++i) {
521 output = Tuple(&builder, {output});
522 }
523 auto computation = builder.Build().ConsumeValueOrDie();
524
525 // Construct the argument to pass to the computation.
526 Literal arg_literal = LiteralUtil::CreateR0<float>(123.0);
527 for (int i = 0; i < kTupleDepth; ++i) {
528 std::vector<Literal> arg_vector;
529 arg_vector.push_back(std::move(arg_literal));
530 arg_literal = LiteralUtil::MakeTupleOwned(std::move(arg_vector));
531 }
532 auto arg_buffer = LiteralToShapedBuffer(arg_literal);
533
534 ScopedShapedBuffer result = ExecuteLocallyOrDie(computation, {&arg_buffer});
535 Literal result_literal = ShapedBufferToLiteral(result);
536
537 ShapeIndex index;
538 for (int i = 0; i < kTupleDepth; ++i) {
539 index.push_back(0);
540 }
541 LiteralTestUtil::ExpectR0Equal<float>(165.0,
542 LiteralSlice(result_literal, index));
543 }
544
XLA_TEST_F(LocalClientExecuteTest,InvalidNumberOfArguments)545 XLA_TEST_F(LocalClientExecuteTest, InvalidNumberOfArguments) {
546 // Test passing in an invalid number of arguments.
547 XlaBuilder builder(TestName());
548 auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {3}), "x");
549 auto y = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {3}), "y");
550 Add(x, y);
551
552 auto x_array =
553 LiteralToShapedBuffer(LiteralUtil::CreateR1<float>({1.0f, 2.0f, 3.0f}));
554 auto execute_status =
555 ExecuteLocally(builder.Build().ValueOrDie(), {&x_array});
556
557 EXPECT_FALSE(execute_status.ok());
558 EXPECT_THAT(execute_status.status().error_message(),
559 ContainsRegex("Invalid number of arguments"));
560 }
561
XLA_TEST_F(LocalClientExecuteTest,IncorrectArgumentShape)562 XLA_TEST_F(LocalClientExecuteTest, IncorrectArgumentShape) {
563 // Test passing in an argument with the wrong shape.
564 XlaBuilder builder(TestName());
565 auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {3}), "x");
566 Neg(x);
567
568 auto x_array = LiteralToShapedBuffer(
569 LiteralUtil::CreateR2<float>({{0.0f, 1.0f}, {2.0f, 3.0f}}));
570 auto execute_status =
571 ExecuteLocally(builder.Build().ValueOrDie(), {&x_array});
572
573 EXPECT_FALSE(execute_status.ok());
574 EXPECT_THAT(execute_status.status().error_message(),
575 ContainsRegex("Invalid argument shape"))
576 << execute_status.status();
577 }
578
XLA_TEST_F(LocalClientExecuteTest,InvalidResultLayout)579 XLA_TEST_F(LocalClientExecuteTest, InvalidResultLayout) {
580 // Test passing in an invalid result layout parameter.
581 XlaBuilder builder(TestName());
582 auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2, 2}), "x");
583 Neg(x);
584
585 auto x_array = LiteralToShapedBuffer(
586 LiteralUtil::CreateR2<float>({{0.0f, 1.0f}, {2.0f, 3.0f}}));
587 auto execute_status = ExecuteLocally(
588 builder.Build().ValueOrDie(), {&x_array},
589 DefaultExecutableBuildOptions().set_result_layout(
590 ShapeUtil::MakeShapeWithLayout(F32,
591 /*dimensions=*/{1, 2, 3, 4},
592 /*minor_to_major=*/{0, 1, 2, 3})),
593 DefaultExecutableRunOptions());
594
595 EXPECT_FALSE(execute_status.ok());
596 EXPECT_THAT(execute_status.status().error_message(),
597 ContainsRegex("not compatible with result shape"))
598 << execute_status.status();
599 }
600
XLA_TEST_F(LocalClientExecuteTest,RunOnAllDeviceOrdinals)601 XLA_TEST_F(LocalClientExecuteTest, RunOnAllDeviceOrdinals) {
602 // Try to run a trivial computation on every device on the system. If a
603 // specific device is not supported, check that the right error is returned.
604 XlaBuilder builder(TestName());
605 ConstantR0<float>(&builder, 42.0f);
606 auto computation = builder.Build().ConsumeValueOrDie();
607 for (int d = 0; d < local_client_->device_count(); ++d) {
608 if (!local_client_->device_ordinal_supported(d)) {
609 auto execute_status =
610 ExecuteLocally(computation, {},
611 DefaultExecutableBuildOptions().set_device_ordinal(d),
612 DefaultExecutableRunOptions().set_device_ordinal(d));
613 EXPECT_FALSE(execute_status.ok());
614 EXPECT_THAT(execute_status.status().error_message(),
615 ContainsRegex("device .* not supported"));
616 } else {
617 auto result = ExecuteLocallyOrDie(
618 computation, {},
619 DefaultExecutableBuildOptions().set_device_ordinal(d),
620 DefaultExecutableRunOptions().set_device_ordinal(d));
621 EXPECT_EQ(d, result.device_ordinal());
622 LiteralTestUtil::ExpectR0Equal<float>(42.0f,
623 ShapedBufferToLiteral(result));
624 }
625 }
626 }
627
XLA_TEST_F(LocalClientExecuteTest,InvalidDeviceOrdinalValues)628 XLA_TEST_F(LocalClientExecuteTest, InvalidDeviceOrdinalValues) {
629 // Try running computations on devices with device ordinal values which do not
630 // exist.
631 XlaBuilder builder(TestName());
632 ConstantR0<float>(&builder, 42.0f);
633 auto computation = builder.Build().ConsumeValueOrDie();
634
635 auto execute_status =
636 ExecuteLocally(computation, {},
637 DefaultExecutableBuildOptions().set_device_ordinal(
638 local_client_->device_count()),
639 DefaultExecutableRunOptions().set_device_ordinal(
640 local_client_->device_count()));
641 EXPECT_FALSE(execute_status.ok());
642 EXPECT_THAT(execute_status.status().error_message(),
643 ContainsRegex("Invalid device ordinal value"));
644 }
645
XLA_TEST_F(LocalClientExecuteTest,RunOnStream)646 XLA_TEST_F(LocalClientExecuteTest, RunOnStream) {
647 // Run a computation on a specific stream on each device on the system.
648 XlaBuilder builder(TestName());
649 ConstantR0<float>(&builder, 42.0f);
650 auto computation = builder.Build().ConsumeValueOrDie();
651
652 for (int d = 0; d < local_client_->device_count(); ++d) {
653 if (!local_client_->device_ordinal_supported(d)) {
654 continue;
655 }
656 se::StreamExecutor* executor =
657 local_client_->platform()->ExecutorForDevice(d).ValueOrDie();
658 se::Stream stream(executor);
659 stream.Init();
660
661 auto result =
662 ExecuteLocallyOrDie(computation, {}, DefaultExecutableBuildOptions(),
663 DefaultExecutableRunOptions().set_stream(&stream));
664 // As a check to verify that the computation ran of the device associated
665 // with the stream. This is a weak check, but stronger verification is hard.
666 EXPECT_EQ(d, result.device_ordinal());
667 LiteralTestUtil::ExpectR0Equal<float>(42.0f, ShapedBufferToLiteral(result));
668 }
669 }
670
671 // Disable this test on CPU because we're using the CPU as the platform
672 // which does not match the service platform.
XLA_TEST_F(LocalClientExecuteTest,DISABLED_ON_CPU (RunOnStreamForWrongPlatform))673 XLA_TEST_F(LocalClientExecuteTest,
674 DISABLED_ON_CPU(RunOnStreamForWrongPlatform)) {
675 // Try to run a computation on a stream for a platform (CPU) which does not
676 // match the platform of the service (!= CPU).
677 se::Platform* wrong_platform =
678 se::MultiPlatformManager::PlatformWithId(se::host::kHostPlatformId)
679 .ValueOrDie();
680 se::Stream wrong_stream(wrong_platform->ExecutorForDevice(0).ValueOrDie());
681 wrong_stream.Init();
682
683 XlaBuilder builder(TestName());
684 ConstantR0<float>(&builder, 42.0f);
685 auto execute_status = ExecuteLocally(
686 builder.Build().ValueOrDie(), {}, DefaultExecutableBuildOptions(),
687 DefaultExecutableRunOptions().set_stream(&wrong_stream));
688 EXPECT_FALSE(execute_status.ok());
689 EXPECT_THAT(execute_status.status().error_message(),
690 ContainsRegex("stream is for platform .*, but service targets"));
691 }
692
XLA_TEST_F(LocalClientExecuteTest,DISABLED_ON_CPU (AllocatorDoesNotMatchPlatform))693 XLA_TEST_F(LocalClientExecuteTest,
694 DISABLED_ON_CPU(AllocatorDoesNotMatchPlatform)) {
695 se::Platform* wrong_platform =
696 se::MultiPlatformManager::PlatformWithId(se::host::kHostPlatformId)
697 .ValueOrDie();
698 TestAllocator allocator(wrong_platform);
699
700 XlaBuilder builder(TestName());
701 ConstantR0<float>(&builder, 123.0f);
702
703 auto execute_status = ExecuteLocally(
704 builder.Build().ValueOrDie(), {}, DefaultExecutableBuildOptions(),
705 DefaultExecutableRunOptions().set_allocator(&allocator));
706 EXPECT_FALSE(execute_status.ok());
707 EXPECT_THAT(execute_status.status().error_message(),
708 ContainsRegex("allocator platform .* does not match service"));
709 }
710
XLA_TEST_F(LocalClientExecuteTest,RunOnUninitializedStream)711 XLA_TEST_F(LocalClientExecuteTest, RunOnUninitializedStream) {
712 // Try to run a computation on a stream that has not been initialized.
713 XlaBuilder builder(TestName());
714 ConstantR0<float>(&builder, 42.0f);
715
716 LOG(INFO) << "default device = " << local_client_->default_device_ordinal();
717 se::StreamExecutor* executor =
718 local_client_->platform()
719 ->ExecutorForDevice(local_client_->default_device_ordinal())
720 .ValueOrDie();
721 se::Stream stream(executor);
722 // Don't call stream.Init().
723
724 auto execute_status = ExecuteLocally(
725 builder.Build().ValueOrDie(), {}, DefaultExecutableBuildOptions(),
726 DefaultExecutableRunOptions().set_stream(&stream));
727 EXPECT_FALSE(execute_status.ok());
728 EXPECT_THAT(execute_status.status().error_message(),
729 ContainsRegex("stream is uninitialized or in an error state"));
730 }
731
XLA_TEST_F(LocalClientExecuteTest,SelectBetweenTuples)732 XLA_TEST_F(LocalClientExecuteTest, SelectBetweenTuples) {
733 XlaBuilder builder(TestName());
734
735 std::initializer_list<float> vec1 = {1.f, 2.f, 3.f};
736 std::initializer_list<float> vec2 = {2.f, 4.f, 6.f};
737 auto tuple12 = Tuple(&builder, {ConstantR1<float>(&builder, vec1),
738 ConstantR1<float>(&builder, vec2)});
739 auto tuple21 = Tuple(&builder, {ConstantR1<float>(&builder, vec2),
740 ConstantR1<float>(&builder, vec1)});
741 Select(ConstantR0<bool>(&builder, false), tuple12, tuple21);
742
743 ScopedShapedBuffer result =
744 ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {});
745 Literal tuple_literal = ShapedBufferToLiteral(result);
746 LiteralTestUtil::ExpectR1Equal<float>({2.0f, 4.0f, 6.0f},
747 LiteralSlice(tuple_literal, {0}));
748 LiteralTestUtil::ExpectR1Equal<float>({1.0f, 2.0f, 3.0f},
749 LiteralSlice(tuple_literal, {1}));
750 }
751
XLA_TEST_F(LocalClientExecuteTest,CompileExecutable)752 XLA_TEST_F(LocalClientExecuteTest, CompileExecutable) {
753 XlaBuilder builder(TestName());
754 auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {3}), "x");
755 auto y = ConstantR1<float>(&builder, {2.0f, 3.0f, 4.0f});
756 Add(x, y);
757
758 Shape argument_layout =
759 ShapeUtil::MakeShapeWithLayout(F32, /*dimensions=*/{3}, {0});
760 auto executable_status =
761 local_client_->Compile(builder.Build().ValueOrDie(), {&argument_layout},
762 ExecutableBuildOptions());
763 ASSERT_IS_OK(executable_status);
764 std::unique_ptr<LocalExecutable> executable =
765 executable_status.ConsumeValueOrDie();
766
767 auto x_array =
768 LiteralToShapedBuffer(LiteralUtil::CreateR1<float>({0.0f, 1.0f, 2.0f}));
769 ScopedShapedBuffer result =
770 executable->Run({&x_array}, DefaultExecutableRunOptions())
771 .ConsumeValueOrDie();
772 ASSERT_IS_OK(local_client_->mutable_backend()
773 ->BorrowStream(0)
774 .ValueOrDie()
775 ->BlockHostUntilDone());
776
777 LiteralTestUtil::ExpectR1Near<float>(
778 {2.0f, 4.0f, 6.0f}, ShapedBufferToLiteral(result), error_spec_);
779 }
780
XLA_TEST_F(LocalClientExecuteTest,ShapeBufferToLiteralConversion)781 XLA_TEST_F(LocalClientExecuteTest, ShapeBufferToLiteralConversion) {
782 // Test copying Literals to the device as ShapedBuffers, then copying them
783 // back again to Literals.
784 auto test_to_device_and_back = [this](const Literal& literal) {
785 TF_ASSERT_OK_AND_ASSIGN(
786 auto shaped_buffer,
787 local_client_->LiteralToShapedBuffer(
788 literal, local_client_->default_device_ordinal(), allocator_));
789 TF_ASSERT_OK_AND_ASSIGN(
790 auto transferred_literal,
791 local_client_->ShapedBufferToLiteral(shaped_buffer));
792 EXPECT_EQ(literal, transferred_literal);
793 };
794
795 // Array shapes.
796 test_to_device_and_back(LiteralUtil::CreateR0<float>(42.0));
797 test_to_device_and_back(LiteralUtil::CreateR0<bool>(true));
798 test_to_device_and_back(LiteralUtil::CreateR1<float>({1.0, 42.0, 744.4}));
799 test_to_device_and_back(
800 LiteralUtil::CreateR2<float>({{1.0, 2.0, 3.0}, {44.0, 0.1, -3}}));
801 test_to_device_and_back(LiteralUtil::CreateR2<int32>({{2, 1}, {4444, 56}}));
802
803 // Null shape (empty tuple).
804 test_to_device_and_back(LiteralUtil::MakeTuple({}));
805
806 // Non-nested tuples.
807 test_to_device_and_back(LiteralUtil::MakeTupleFromSlices(
808 {LiteralUtil::CreateR0<float>(12223.0)}));
809 test_to_device_and_back(LiteralUtil::MakeTupleFromSlices(
810 {LiteralUtil::CreateR1<float>({1.0, -42.0}),
811 LiteralUtil::CreateR0<float>(123456.0)}));
812
813 // Nested tuple.
814 test_to_device_and_back(LiteralUtil::MakeTupleFromSlices(
815 {LiteralUtil::MakeTupleFromSlices(
816 {LiteralUtil::CreateR1<float>({1.0, -42.0}),
817 LiteralUtil::CreateR0<float>(123456.0)}),
818 LiteralUtil::CreateR0<bool>(false)}));
819 }
820
XLA_TEST_F(LocalClientExecuteTest,ShapeBufferToLiteralConversion64bit)821 XLA_TEST_F(LocalClientExecuteTest, ShapeBufferToLiteralConversion64bit) {
822 // Test copying Literals to the device as ShapedBuffers, then copying them
823 // back again to Literals for 64-bit values.
824 auto test_to_device_and_back = [this](const Literal& literal) {
825 TF_ASSERT_OK_AND_ASSIGN(
826 auto shaped_buffer,
827 local_client_->LiteralToShapedBuffer(
828 literal, local_client_->default_device_ordinal(), allocator_));
829 TF_ASSERT_OK_AND_ASSIGN(
830 auto transferred_literal,
831 local_client_->ShapedBufferToLiteral(shaped_buffer));
832 EXPECT_EQ(literal, transferred_literal);
833 };
834
835 test_to_device_and_back(
836 LiteralUtil::CreateR2<double>({{1.0, 2.0, 3.0}, {44.0, 0.1, -3}}));
837 test_to_device_and_back(LiteralUtil::CreateR2<int64>({{2, 1}, {4444, 56}}));
838 test_to_device_and_back(
839 LiteralUtil::CreateR2<uint64>({{20000000000ULL, 1}, {4444, 56}}));
840 test_to_device_and_back(LiteralUtil::MakeTupleFromSlices(
841 {LiteralUtil::CreateR1<double>({1.0, -42.0}),
842 LiteralUtil::CreateR0<int64>(123456789000LL)}));
843 }
844
845 // Disabled on interpreter backend since infeed HLO is unsupported.
XLA_TEST_F(LocalClientExecuteTest,DISABLED_ON_INTERPRETER (InfeedTest))846 XLA_TEST_F(LocalClientExecuteTest, DISABLED_ON_INTERPRETER(InfeedTest)) {
847 XlaBuilder builder(TestName());
848 const Shape shape = ShapeUtil::MakeShape(F32, {3});
849 auto in = Infeed(&builder, shape);
850 auto constant = ConstantR1<float>(&builder, {1.0f, 2.0f, 3.0f});
851 Add(in, constant);
852
853 Literal result;
854 std::unique_ptr<tensorflow::Thread> thread(
855 tensorflow::Env::Default()->StartThread(
856 tensorflow::ThreadOptions(), "execute_thread", [&] {
857 result = ShapedBufferToLiteral(ExecuteLocallyOrDie(
858 builder.Build().ValueOrDie(), /*arguments=*/{}));
859 }));
860
861 ASSERT_IS_OK(local_client_->TransferToInfeedLocal(
862 LiteralUtil::CreateR1<float>({-5.0, 123.0, 42.0}),
863 local_client_->default_device_ordinal()));
864
865 // Join the thread.
866 thread.reset();
867
868 LiteralTestUtil::ExpectR1Equal<float>({-4.0, 125.0, 45.0}, result);
869 }
870
871 // Disabled on interpreter backend since infeed/outfeed HLOs are unsupported.
XLA_TEST_F(LocalClientExecuteTest,DISABLED_ON_INTERPRETER (InfeedOutfeedTest))872 XLA_TEST_F(LocalClientExecuteTest, DISABLED_ON_INTERPRETER(InfeedOutfeedTest)) {
873 XlaBuilder builder(TestName());
874 const Shape shape = ShapeUtil::MakeShape(F32, {3});
875 auto in = Infeed(&builder, shape);
876 auto constant = ConstantR1<float>(&builder, {1.0f, 2.0f, 3.0f});
877 auto sum = Add(in, constant);
878 Outfeed(sum, shape, /*outfeed_config=*/"");
879
880 std::unique_ptr<tensorflow::Thread> thread(
881 tensorflow::Env::Default()->StartThread(
882 tensorflow::ThreadOptions(), "execute_thread",
883 [&] { ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {}); }));
884
885 ASSERT_IS_OK(local_client_->TransferToInfeedLocal(
886 LiteralUtil::CreateR1<float>({-5.0, 123.0, 42.0}),
887 local_client_->default_device_ordinal()));
888
889 TF_ASSERT_OK_AND_ASSIGN(Literal result,
890 local_client_->TransferFromOutfeedLocal(
891 shape, local_client_->default_device_ordinal()));
892
893 LiteralTestUtil::ExpectR1Equal<float>({-4.0, 125.0, 45.0}, result);
894 }
895
896 // Benchmark that measures the overhead of the LocalClient API when running a
897 // trivial computation
BM_LocalClientOverhead(int num_iters)898 void BM_LocalClientOverhead(int num_iters) {
899 tensorflow::testing::StopTiming();
900
901 se::Platform* platform = PlatformUtil::GetDefaultPlatform().ValueOrDie();
902 auto executors = PlatformUtil::GetStreamExecutors(platform).ValueOrDie();
903 StreamExecutorMemoryAllocator allocator(platform, executors);
904 LocalClient* client =
905 ClientLibrary::GetOrCreateLocalClient(platform).ValueOrDie();
906 auto* transfer_manager =
907 TransferManager::GetForPlatform(platform).ValueOrDie();
908 int device_ordinal = client->default_device_ordinal();
909
910 // Use a tiny add operation as the computation.
911 XlaBuilder builder("Add");
912 auto shape = ShapeUtil::MakeShape(F32, {2, 3});
913 auto x = Parameter(&builder, 0, shape, "x");
914 Add(x, x);
915 auto computation = builder.Build().ConsumeValueOrDie();
916
917 auto buffer =
918 transfer_manager
919 ->AllocateScopedShapedBuffer(shape, &allocator, /*device_ordinal=*/0)
920 .ConsumeValueOrDie();
921 auto literal = LiteralUtil::CreateR2<float>({{0, 0, 0}, {0, 0, 0}});
922 auto stream =
923 client->mutable_backend()->BorrowStream(device_ordinal).ValueOrDie();
924 ASSERT_IS_OK(
925 transfer_manager->TransferLiteralToDevice(stream.get(), literal, buffer));
926
927 const int kWarmups = 2;
928
929 auto executable_status = client->Compile(
930 computation, {&buffer.on_host_shape()}, ExecutableBuildOptions());
931 ASSERT_IS_OK(executable_status);
932 std::unique_ptr<LocalExecutable> executable =
933 executable_status.ConsumeValueOrDie();
934
935 ExecutableRunOptions run_options;
936 run_options.set_allocator(&allocator).set_stream(stream.get());
937
938 for (int i = 0; i < kWarmups; ++i) {
939 auto result = executable->Run({&buffer}, run_options);
940 ASSERT_IS_OK(result);
941 }
942
943 tensorflow::testing::StartTiming();
944 for (int i = 0; i < num_iters; ++i) {
945 auto result = executable->Run({&buffer}, run_options);
946 ASSERT_IS_OK(result);
947 }
948 }
949
950 BENCHMARK(BM_LocalClientOverhead);
951
952 } // namespace
953 } // namespace xla
954