1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <initializer_list>
17 #include <memory>
18 #include <vector>
19
20 #include "tensorflow/compiler/xla/client/client_library.h"
21 #include "tensorflow/compiler/xla/client/local_client.h"
22 #include "tensorflow/compiler/xla/client/sharding_builder.h"
23 #include "tensorflow/compiler/xla/client/xla_builder.h"
24 #include "tensorflow/compiler/xla/layout_util.h"
25 #include "tensorflow/compiler/xla/literal.h"
26 #include "tensorflow/compiler/xla/service/local_service.h"
27 #include "tensorflow/compiler/xla/service/platform_util.h"
28 #include "tensorflow/compiler/xla/service/shaped_buffer.h"
29 #include "tensorflow/compiler/xla/service/transfer_manager.h"
30 #include "tensorflow/compiler/xla/shape_util.h"
31 #include "tensorflow/compiler/xla/statusor.h"
32 #include "tensorflow/compiler/xla/test.h"
33 #include "tensorflow/compiler/xla/test_helpers.h"
34 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
35 #include "tensorflow/compiler/xla/tests/local_client_test_base.h"
36 #include "tensorflow/compiler/xla/tests/test_macros.h"
37 #include "tensorflow/compiler/xla/tests/test_utils.h"
38 #include "tensorflow/compiler/xla/xla_data.pb.h"
39 #include "tensorflow/core/platform/env.h"
40 #include "tensorflow/core/platform/logging.h"
41 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
42 #include "tensorflow/core/platform/test.h"
43 #include "tensorflow/core/platform/test_benchmark.h"
44 #include "tensorflow/stream_executor/device_memory_allocator.h"
45
46 namespace xla {
47 namespace {
48
49 using ::testing::ContainsRegex;
50
51 class LocalClientExecuteTest : public LocalClientTestBase {
52 protected:
53 ErrorSpec error_spec_{0.0001};
54 };
55
XLA_TEST_F(LocalClientExecuteTest,Constant)56 XLA_TEST_F(LocalClientExecuteTest, Constant) {
57 XlaBuilder builder(TestName());
58 ConstantR0<float>(&builder, 123.0f);
59
60 ScopedShapedBuffer result =
61 ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {});
62 LiteralTestUtil::ExpectR0Near<float>(123.f, ShapedBufferToLiteral(result),
63 error_spec_);
64 }
65
XLA_TEST_F(LocalClientExecuteTest,AddScalars)66 XLA_TEST_F(LocalClientExecuteTest, AddScalars) {
67 XlaBuilder builder(TestName());
68 auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {}), "x");
69 auto y = ConstantR0<float>(&builder, 123.0f);
70 Add(x, y);
71
72 auto x_value = LiteralToShapedBuffer(LiteralUtil::CreateR0<float>(42.0f));
73 ScopedShapedBuffer result =
74 ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {&x_value});
75 LiteralTestUtil::ExpectR0Near<float>(165.f, ShapedBufferToLiteral(result),
76 error_spec_);
77 }
78
XLA_TEST_F(LocalClientExecuteTest,AddZeroElementVectors)79 XLA_TEST_F(LocalClientExecuteTest, AddZeroElementVectors) {
80 XlaBuilder builder(TestName());
81 auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {0}), "x");
82 auto y = ConstantR1<float>(&builder, {});
83 Add(x, y);
84
85 auto x_array = LiteralToShapedBuffer(LiteralUtil::CreateR1<float>({}));
86 ScopedShapedBuffer result =
87 ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {&x_array});
88 LiteralTestUtil::ExpectR1Near<float>({}, ShapedBufferToLiteral(result),
89 error_spec_);
90 }
91
XLA_TEST_F(LocalClientExecuteTest,AddVectors)92 XLA_TEST_F(LocalClientExecuteTest, AddVectors) {
93 XlaBuilder builder(TestName());
94 auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {3}), "x");
95 auto y = ConstantR1<float>(&builder, {2.0f, 3.0f, 4.0f});
96 Add(x, y);
97
98 auto x_array =
99 LiteralToShapedBuffer(LiteralUtil::CreateR1<float>({0.0f, 1.0f, 2.0f}));
100 ScopedShapedBuffer result =
101 ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {&x_array});
102 LiteralTestUtil::ExpectR1Near<float>(
103 {2.0f, 4.0f, 6.0f}, ShapedBufferToLiteral(result), error_spec_);
104 }
105
XLA_TEST_F(LocalClientExecuteTest,AddVectorsWithProfile)106 XLA_TEST_F(LocalClientExecuteTest, AddVectorsWithProfile) {
107 XlaBuilder builder(TestName());
108 auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {3}), "x");
109 auto y = ConstantR1<float>(&builder, {2.0f, 3.0f, 4.0f});
110 Add(x, y);
111
112 auto x_array =
113 LiteralToShapedBuffer(LiteralUtil::CreateR1<float>({0.0f, 1.0f, 2.0f}));
114 ExecutionProfile profile;
115 ScopedShapedBuffer result = ExecuteLocallyOrDie(
116 builder.Build().ValueOrDie(), {&x_array}, DefaultExecutableBuildOptions(),
117 DefaultExecutableRunOptions().set_execution_profile(&profile));
118
119 LiteralTestUtil::ExpectR1Near<float>(
120 {2.0f, 4.0f, 6.0f}, ShapedBufferToLiteral(result), error_spec_);
121 EXPECT_GT(profile.compute_and_transfer_time_ns(), 0);
122 }
123
XLA_TEST_F(LocalClientExecuteTest,AddArraysWithDifferentInputLayouts)124 XLA_TEST_F(LocalClientExecuteTest, AddArraysWithDifferentInputLayouts) {
125 XlaBuilder builder(TestName());
126 auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2, 2}), "x");
127 auto y = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {2, 2}), "y");
128 Add(x, y);
129 auto computation = builder.Build().ConsumeValueOrDie();
130
131 // Create x as a col-major array.
132 auto x_array = LiteralToShapedBuffer(LiteralUtil::CreateR2WithLayout(
133 {{1.0f, 2.0f}, {3.0f, 4.0f}}, LayoutUtil::MakeLayout({0, 1})));
134 EXPECT_TRUE(Layout::Equal().MinorToMajorOnly()(
135 x_array.on_device_shape().layout(), LayoutUtil::MakeLayout({0, 1})));
136
137 // Create y as a row-major array.
138 auto y_array = LiteralToShapedBuffer(LiteralUtil::CreateR2WithLayout(
139 {{10.0f, 20.0f}, {30.0f, 40.0f}}, LayoutUtil::MakeLayout({1, 0})));
140 EXPECT_TRUE(Layout::Equal().MinorToMajorOnly()(
141 y_array.on_device_shape().layout(), LayoutUtil::MakeLayout({1, 0})));
142
143 ScopedShapedBuffer result_colmaj =
144 ExecuteLocallyOrDie(computation, {&x_array, &y_array});
145 LiteralTestUtil::ExpectR2Near<float>({{11.0f, 22.0f}, {33.0f, 44.0f}},
146 ShapedBufferToLiteral(result_colmaj),
147 error_spec_);
148
149 // Run with the parameter values in a different order.
150 ScopedShapedBuffer result_param_swap =
151 ExecuteLocallyOrDie(computation, {&y_array, &x_array});
152 LiteralTestUtil::ExpectR2Near<float>({{11.0f, 22.0f}, {33.0f, 44.0f}},
153 ShapedBufferToLiteral(result_param_swap),
154 error_spec_);
155 }
156
XLA_TEST_F(LocalClientExecuteTest,AddArraysWithDifferentOutputLayouts)157 XLA_TEST_F(LocalClientExecuteTest, AddArraysWithDifferentOutputLayouts) {
158 XlaBuilder builder(TestName());
159 auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2, 2}), "x");
160 auto y = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {2, 2}), "y");
161 Add(x, y);
162 auto computation = builder.Build().ConsumeValueOrDie();
163
164 auto x_array = LiteralToShapedBuffer(
165 LiteralUtil::CreateR2<float>({{1.0f, 2.0f}, {3.0f, 4.0f}}));
166 auto y_array = LiteralToShapedBuffer(
167 LiteralUtil::CreateR2<float>({{10.0f, 20.0f}, {30.0f, 40.0f}}));
168
169 // Run with col-major result layout.
170 ScopedShapedBuffer result_colmaj = ExecuteLocallyOrDie(
171 computation, {&x_array, &y_array},
172 DefaultExecutableBuildOptions().set_result_layout(
173 ShapeUtil::MakeShapeWithLayout(F32, /*dimensions=*/{2, 2}, {0, 1})),
174 DefaultExecutableRunOptions());
175 EXPECT_TRUE(Layout::Equal().MinorToMajorOnly()(
176 result_colmaj.on_device_shape().layout(),
177 LayoutUtil::MakeLayout({0, 1})));
178 LiteralTestUtil::ExpectR2Near<float>({{11.0f, 22.0f}, {33.0f, 44.0f}},
179 ShapedBufferToLiteral(result_colmaj),
180 error_spec_);
181
182 // Run with row-major result layout.
183 ScopedShapedBuffer result_rowmaj = ExecuteLocallyOrDie(
184 computation, {&x_array, &y_array},
185 DefaultExecutableBuildOptions().set_result_layout(
186 ShapeUtil::MakeShapeWithLayout(F32, /*dimensions=*/{2, 2}, {1, 0})),
187 DefaultExecutableRunOptions());
188 EXPECT_TRUE(Layout::Equal().MinorToMajorOnly()(
189 result_rowmaj.on_device_shape().layout(),
190 LayoutUtil::MakeLayout({1, 0})));
191 LiteralTestUtil::ExpectR2Near<float>({{11.0f, 22.0f}, {33.0f, 44.0f}},
192 ShapedBufferToLiteral(result_rowmaj),
193 error_spec_);
194 }
195
XLA_TEST_F(LocalClientExecuteTest,TupleResult)196 XLA_TEST_F(LocalClientExecuteTest, TupleResult) {
197 XlaBuilder builder(TestName());
198 auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2, 2}), "x");
199 auto y = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {2, 2}), "y");
200 Tuple(&builder, {x, y, x});
201 auto computation = builder.Build().ConsumeValueOrDie();
202
203 auto x_array = LiteralToShapedBuffer(
204 LiteralUtil::CreateR2<float>({{1.0f, 2.0f}, {3.0f, 4.0f}}));
205 auto y_array = LiteralToShapedBuffer(
206 LiteralUtil::CreateR2<float>({{10.0f, 20.0f}, {30.0f, 40.0f}}));
207
208 ScopedShapedBuffer result =
209 ExecuteLocallyOrDie(computation, {&x_array, &y_array});
210
211 EXPECT_TRUE(result.on_host_shape().IsTuple());
212 EXPECT_EQ(3, ShapeUtil::TupleElementCount(result.on_host_shape()));
213
214 Literal result_literal = ShapedBufferToLiteral(result);
215 LiteralTestUtil::ExpectR2Equal<float>({{1.0f, 2.0f}, {3.0f, 4.0f}},
216 LiteralSlice(result_literal, {0}));
217 LiteralTestUtil::ExpectR2Equal<float>({{10.0f, 20.0f}, {30.0f, 40.0f}},
218 LiteralSlice(result_literal, {1}));
219 LiteralTestUtil::ExpectR2Equal<float>({{1.0f, 2.0f}, {3.0f, 4.0f}},
220 LiteralSlice(result_literal, {2}));
221 }
222
XLA_TEST_F(LocalClientExecuteTest,NestedTupleResult)223 XLA_TEST_F(LocalClientExecuteTest, NestedTupleResult) {
224 XlaBuilder builder(TestName());
225 auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2, 2}), "x");
226 auto y = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {2, 2}), "y");
227 auto inner_tuple = Tuple(&builder, {x, y, x});
228 Tuple(&builder, {inner_tuple, x});
229 auto computation = builder.Build().ConsumeValueOrDie();
230
231 auto x_array = LiteralToShapedBuffer(
232 LiteralUtil::CreateR2<float>({{1.0f, 2.0f}, {3.0f, 4.0f}}));
233 auto y_array = LiteralToShapedBuffer(
234 LiteralUtil::CreateR2<float>({{10.0f, 20.0f}, {30.0f, 40.0f}}));
235
236 ScopedShapedBuffer result =
237 ExecuteLocallyOrDie(computation, {&x_array, &y_array});
238
239 EXPECT_TRUE(result.on_host_shape().IsTuple());
240 EXPECT_EQ(2, ShapeUtil::TupleElementCount(result.on_host_shape()));
241
242 Literal result_literal = ShapedBufferToLiteral(result);
243 LiteralTestUtil::ExpectR2Equal<float>({{1.0f, 2.0f}, {3.0f, 4.0f}},
244 LiteralSlice(result_literal, {1}));
245 LiteralTestUtil::ExpectR2Equal<float>({{1.0f, 2.0f}, {3.0f, 4.0f}},
246 LiteralSlice(result_literal, {0, 0}));
247 LiteralTestUtil::ExpectR2Equal<float>({{10.0f, 20.0f}, {30.0f, 40.0f}},
248 LiteralSlice(result_literal, {0, 1}));
249 LiteralTestUtil::ExpectR2Equal<float>({{1.0f, 2.0f}, {3.0f, 4.0f}},
250 LiteralSlice(result_literal, {0, 2}));
251 }
252
XLA_TEST_F(LocalClientExecuteTest,TupleResultWithLayout)253 XLA_TEST_F(LocalClientExecuteTest, TupleResultWithLayout) {
254 // Verify setting the result layout of a computation with a tuple output.
255 XlaBuilder builder(TestName());
256 auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2, 2}), "x");
257 auto y = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {2, 2}), "y");
258 Tuple(&builder, {x, y});
259
260 auto array = LiteralToShapedBuffer(
261 LiteralUtil::CreateR2<float>({{1.0f, 2.0f}, {3.0f, 4.0f}}));
262
263 ExecutableBuildOptions options = DefaultExecutableBuildOptions();
264 Shape shape_with_layout = ShapeUtil::MakeTupleShape(
265 {ShapeUtil::MakeShapeWithLayout(F32, /*dimensions=*/{2, 2},
266 /*minor_to_major=*/{0, 1}),
267 ShapeUtil::MakeShapeWithLayout(F32, /*dimensions=*/{2, 2},
268 /*minor_to_major=*/{1, 0})});
269 options.set_result_layout(shape_with_layout);
270 ScopedShapedBuffer result =
271 ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {&array, &array},
272 options, DefaultExecutableRunOptions());
273
274 Literal result_literal = ShapedBufferToLiteral(result);
275 LiteralTestUtil::ExpectR2Equal<float>({{1.0f, 2.0f}, {3.0f, 4.0f}},
276 LiteralSlice(result_literal, {0}));
277 LiteralTestUtil::ExpectR2Equal<float>({{1.0f, 2.0f}, {3.0f, 4.0f}},
278 LiteralSlice(result_literal, {1}));
279 }
280
XLA_TEST_F(LocalClientExecuteTest,TupleArguments)281 XLA_TEST_F(LocalClientExecuteTest, TupleArguments) {
282 const Shape array_shape = ShapeUtil::MakeShape(F32, {2, 2});
283 const Shape vector_shape = ShapeUtil::MakeShape(F32, {3});
284
285 const Shape tuple_shape0 =
286 ShapeUtil::MakeTupleShape({array_shape, vector_shape});
287 const Shape tuple_shape1 =
288 ShapeUtil::MakeTupleShape({vector_shape, array_shape});
289
290 // Computation adds the respective array and vector elements from each tuple
291 // argument and returns the results as a tuple.
292 XlaBuilder builder(TestName());
293 auto x = Parameter(&builder, 0, tuple_shape0, "x");
294 auto y = Parameter(&builder, 1, tuple_shape1, "y");
295 auto x_0 = GetTupleElement(x, 0);
296 auto x_1 = GetTupleElement(x, 1);
297 auto y_0 = GetTupleElement(y, 0);
298 auto y_1 = GetTupleElement(y, 1);
299 auto array_sum = Add(x_0, y_1);
300 auto vector_diff = Sub(x_1, y_0);
301 Tuple(&builder, {array_sum, vector_diff});
302 auto computation = builder.Build().ConsumeValueOrDie();
303
304 auto x_literal = LiteralUtil::MakeTupleFromSlices(
305 {LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}),
306 LiteralUtil::CreateR1<float>({42.0, 75.0, 123.0})});
307 auto y_literal = LiteralUtil::MakeTupleFromSlices(
308 {LiteralUtil::CreateR1<float>({2.0, 4.0, 6.0}),
309 LiteralUtil::CreateR2<float>({{55.0, 44.0}, {33.0, 22.0}})});
310
311 auto x_buffer = LiteralToShapedBuffer(x_literal);
312 auto y_buffer = LiteralToShapedBuffer(y_literal);
313
314 ScopedShapedBuffer result =
315 ExecuteLocallyOrDie(computation, {&x_buffer, &y_buffer});
316
317 EXPECT_TRUE(result.on_host_shape().IsTuple());
318 EXPECT_EQ(2, ShapeUtil::TupleElementCount(result.on_host_shape()));
319
320 Literal result_literal = ShapedBufferToLiteral(result);
321 LiteralTestUtil::ExpectR2Equal<float>({{56.0f, 46.0f}, {36.0f, 26.0f}},
322 LiteralSlice(result_literal, {0}));
323 LiteralTestUtil::ExpectR1Equal<float>({40.0f, 71.0f, 117.0f},
324 LiteralSlice(result_literal, {1}));
325 }
326
XLA_TEST_F(LocalClientExecuteTest,NestedTupleArgument)327 XLA_TEST_F(LocalClientExecuteTest, NestedTupleArgument) {
328 const Shape array_shape = ShapeUtil::MakeShape(F32, {2, 2});
329 const Shape vector_shape = ShapeUtil::MakeShape(F32, {3});
330
331 const Shape inner_tuple_shape =
332 ShapeUtil::MakeTupleShape({array_shape, vector_shape});
333 const Shape nested_tuple_shape =
334 ShapeUtil::MakeTupleShape({inner_tuple_shape, vector_shape});
335
336 // Computation negates the array element and sums the two vector elements in
337 // the nested tuple. The resulting array and vector are returned as a tuple.
338 XlaBuilder builder(TestName());
339 auto param = Parameter(&builder, 0, nested_tuple_shape, "param");
340 auto inner_tuple = GetTupleElement(param, 0);
341 auto inner_array = GetTupleElement(inner_tuple, 0);
342 auto inner_vector = GetTupleElement(inner_tuple, 1);
343 auto outer_vector = GetTupleElement(param, 1);
344
345 auto negate_array = Neg(inner_array);
346 auto vector_sum = Add(inner_vector, outer_vector);
347 Tuple(&builder, {negate_array, vector_sum});
348 auto computation = builder.Build().ConsumeValueOrDie();
349
350 auto arg_literal = LiteralUtil::MakeTupleFromSlices(
351 {LiteralUtil::MakeTupleFromSlices(
352 {LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}),
353 LiteralUtil::CreateR1<float>({42.0, 75.0, 123.0})}),
354 LiteralUtil::CreateR1<float>({222.0, -2.0, 10.0})});
355 auto arg_buffer = LiteralToShapedBuffer(arg_literal);
356
357 ScopedShapedBuffer result = ExecuteLocallyOrDie(computation, {&arg_buffer});
358
359 Literal result_literal = ShapedBufferToLiteral(result);
360 LiteralTestUtil::ExpectR2Equal<float>({{-1.0, -2.0}, {-3.0, -4}},
361 LiteralSlice(result_literal, {0}));
362 LiteralTestUtil::ExpectR1Equal<float>({264.0, 73.0, 133.0},
363 LiteralSlice(result_literal, {1}));
364 }
365
XLA_TEST_F(LocalClientExecuteTest,PassingTupleResultBackIntoComputation)366 XLA_TEST_F(LocalClientExecuteTest, PassingTupleResultBackIntoComputation) {
367 // Construct a computation which takes and returns the same shape (a
368 // tuple). Feed the result of the computation back into the input. This
369 // provides additional verification that the returned tuple is properly
370 // constructed.
371 const Shape array_shape = ShapeUtil::MakeShape(F32, {2, 2});
372 const Shape tuple_shape =
373 ShapeUtil::MakeTupleShape({array_shape, array_shape});
374
375 XlaBuilder builder(TestName());
376 auto param = Parameter(&builder, 0, tuple_shape, "param");
377 auto element_0 = GetTupleElement(param, 0);
378 auto element_1 = GetTupleElement(param, 1);
379 Tuple(&builder, {Neg(element_0), Add(element_1, element_1)});
380 auto computation = builder.Build().ConsumeValueOrDie();
381
382 auto arg_literal = LiteralUtil::MakeTupleFromSlices(
383 {LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}),
384 LiteralUtil::CreateR2<float>({{11.0, 3.0}, {4.0, 5.0}})});
385 auto arg_buffer = LiteralToShapedBuffer(arg_literal);
386
387 ScopedShapedBuffer result_0 = ExecuteLocallyOrDie(computation, {&arg_buffer});
388 Literal result_0_literal = ShapedBufferToLiteral(result_0);
389 LiteralTestUtil::ExpectR2Equal<float>({{-1.0, -2.0}, {-3.0, -4.0}},
390 LiteralSlice(result_0_literal, {0}));
391 LiteralTestUtil::ExpectR2Equal<float>({{22.0, 6.0}, {8.0, 10}},
392 LiteralSlice(result_0_literal, {1}));
393
394 ScopedShapedBuffer result_1 = ExecuteLocallyOrDie(computation, {&result_0});
395 Literal result_1_literal = ShapedBufferToLiteral(result_1);
396 LiteralTestUtil::ExpectR2Equal<float>({{1.0, 2.0}, {3.0, 4.0}},
397 LiteralSlice(result_1_literal, {0}));
398 LiteralTestUtil::ExpectR2Equal<float>({{44.0, 12.0}, {16.0, 20}},
399 LiteralSlice(result_1_literal, {1}));
400 }
401
XLA_TEST_F(LocalClientExecuteTest,LargeTuple)402 XLA_TEST_F(LocalClientExecuteTest, LargeTuple) {
403 // Construct a computation which takes a tuple parameter with a very large
404 // number of elements.
405
406 // A larger number of elements would make for a better, more strenuous test,
407 // but:
408 // TODO(b/66959878): On cpu a large number of elements results in long
409 // compilation time.
410 // TODO(b/66954197): On gpu a large number of elements OOMs.
411 const int kElementCount = 100;
412
413 // Each element is a 2-element vector.
414 const Shape element_shape = ShapeUtil::MakeShape(F32, {2});
415 std::vector<Shape> element_shapes(kElementCount, element_shape);
416 const Shape tuple_shape = ShapeUtil::MakeTupleShape(element_shapes);
417
418 XlaBuilder builder(TestName());
419 auto param = Parameter(&builder, 0, tuple_shape, "param");
420
421 // Add each element's tuple index value to every element.
422 std::vector<XlaOp> result_elements;
423 for (int i = 0; i < kElementCount; ++i) {
424 auto element = GetTupleElement(param, i);
425 result_elements.push_back(Add(element, ConstantR0<float>(&builder, i)));
426 }
427 Tuple(&builder, result_elements);
428 auto computation = builder.Build().ConsumeValueOrDie();
429
430 // Feed in a tuple where each two-element vector element is {tuple_index,
431 // -tuple_index}.
432 std::vector<Literal> arg_elements;
433 for (int i = 0; i < kElementCount; ++i) {
434 arg_elements.push_back(LiteralUtil::CreateR1<float>({1.0f * i, -1.0f * i}));
435 }
436 Literal arg_literal = LiteralUtil::MakeTupleOwned(std::move(arg_elements));
437 auto arg_buffer = LiteralToShapedBuffer(arg_literal);
438
439 ScopedShapedBuffer result = ExecuteLocallyOrDie(computation, {&arg_buffer});
440 Literal result_literal = ShapedBufferToLiteral(result);
441
442 for (int i = 0; i < kElementCount; ++i) {
443 LiteralTestUtil::ExpectR1Near<float>(
444 {2.0f * i, 0.0f}, LiteralSlice(result_literal, {i}), error_spec_);
445 }
446 }
447
XLA_TEST_F(LocalClientExecuteTest,LargeNestedTuple)448 XLA_TEST_F(LocalClientExecuteTest, LargeNestedTuple) {
449 // Construct and run a computation which takes a two-level nested tuple
450 // parameter with a large fanout.
451 const int kFanout = 40;
452
453 // Tuple shape is full two-level tree with the given fanout.
454 const Shape element_shape = ShapeUtil::MakeShape(F32, {});
455 std::vector<Shape> element_shapes(kFanout, element_shape);
456 const Shape inner_tuple_shape = ShapeUtil::MakeTupleShape(element_shapes);
457 std::vector<Shape> inner_tuple_shapes(kFanout, inner_tuple_shape);
458 const Shape tuple_shape = ShapeUtil::MakeTupleShape(inner_tuple_shapes);
459
460 XlaBuilder builder(TestName());
461 auto param = Parameter(&builder, 0, tuple_shape, "param");
462
463 // The computation increments each leaf value by an amount equal to the leaf's
464 // ordinal position in a traversal of the tuple.
465 std::vector<XlaOp> result_elements;
466 for (int i = 0; i < kFanout; ++i) {
467 auto outer_element = GetTupleElement(param, i);
468 std::vector<XlaOp> inner_result_elements;
469 for (int j = 0; j < kFanout; ++j) {
470 auto inner_element = GetTupleElement(outer_element, j);
471 inner_result_elements.push_back(
472 Add(inner_element, ConstantR0<float>(&builder, i * kFanout + j)));
473 }
474 result_elements.push_back(Tuple(&builder, inner_result_elements));
475 }
476 Tuple(&builder, result_elements);
477 auto computation = builder.Build().ConsumeValueOrDie();
478
479 // Construct the argument to pass to the computation.
480 std::vector<Literal> outer_tuple_elements;
481 for (int i = 0; i < kFanout; ++i) {
482 std::vector<Literal> inner_tuple_elements;
483 for (int j = 0; j < kFanout; ++j) {
484 inner_tuple_elements.push_back(LiteralUtil::CreateR0<float>(i + j));
485 }
486 outer_tuple_elements.push_back(
487 LiteralUtil::MakeTupleOwned(std::move(inner_tuple_elements)));
488 }
489 auto arg_literal =
490 LiteralUtil::MakeTupleOwned(std::move(outer_tuple_elements));
491 auto arg_buffer = LiteralToShapedBuffer(arg_literal);
492
493 ScopedShapedBuffer result = ExecuteLocallyOrDie(computation, {&arg_buffer});
494 Literal result_literal = ShapedBufferToLiteral(result);
495
496 for (int i = 0; i < kFanout; ++i) {
497 for (int j = 0; j < kFanout; ++j) {
498 LiteralTestUtil::ExpectR0Near<float>(i + j + i * kFanout + j,
499 LiteralSlice(result_literal, {i, j}),
500 error_spec_);
501 }
502 }
503 }
504
XLA_TEST_F(LocalClientExecuteTest,DeepTuple)505 XLA_TEST_F(LocalClientExecuteTest, DeepTuple) {
506 // Construct and run a computation which takes a very deep tuple. The tuple
507 // has no fan out and a single scalar element at the bottom.
508 const int kTupleDepth = 100;
509
510 // Tuple shape is full two-level tree with the given fanout.
511 Shape shape = ShapeUtil::MakeShape(F32, {});
512 for (int i = 0; i < kTupleDepth; ++i) {
513 shape = ShapeUtil::MakeTupleShape({shape});
514 }
515
516 XlaBuilder builder(TestName());
517 auto element = Parameter(&builder, 0, shape, "param");
518 for (int i = 0; i < kTupleDepth; ++i) {
519 element = GetTupleElement(element, 0);
520 }
521
522 auto output = Add(element, ConstantR0<float>(&builder, 42.0));
523 for (int i = 0; i < kTupleDepth; ++i) {
524 output = Tuple(&builder, {output});
525 }
526 auto computation = builder.Build().ConsumeValueOrDie();
527
528 // Construct the argument to pass to the computation.
529 Literal arg_literal = LiteralUtil::CreateR0<float>(123.0);
530 for (int i = 0; i < kTupleDepth; ++i) {
531 std::vector<Literal> arg_vector;
532 arg_vector.push_back(std::move(arg_literal));
533 arg_literal = LiteralUtil::MakeTupleOwned(std::move(arg_vector));
534 }
535 auto arg_buffer = LiteralToShapedBuffer(arg_literal);
536
537 ScopedShapedBuffer result = ExecuteLocallyOrDie(computation, {&arg_buffer});
538 Literal result_literal = ShapedBufferToLiteral(result);
539
540 ShapeIndex index;
541 for (int i = 0; i < kTupleDepth; ++i) {
542 index.push_back(0);
543 }
544 LiteralTestUtil::ExpectR0Equal<float>(165.0,
545 LiteralSlice(result_literal, index));
546 }
547
XLA_TEST_F(LocalClientExecuteTest,InvalidNumberOfArguments)548 XLA_TEST_F(LocalClientExecuteTest, InvalidNumberOfArguments) {
549 // Test passing in an invalid number of arguments.
550 XlaBuilder builder(TestName());
551 auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {3}), "x");
552 auto y = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {3}), "y");
553 Add(x, y);
554
555 auto x_array =
556 LiteralToShapedBuffer(LiteralUtil::CreateR1<float>({1.0f, 2.0f, 3.0f}));
557 auto execute_status =
558 ExecuteLocally(builder.Build().ValueOrDie(), {&x_array});
559
560 EXPECT_FALSE(execute_status.ok());
561 EXPECT_THAT(execute_status.status().error_message(),
562 ContainsRegex("Invalid number of arguments"));
563 }
564
XLA_TEST_F(LocalClientExecuteTest,IncorrectArgumentShape)565 XLA_TEST_F(LocalClientExecuteTest, IncorrectArgumentShape) {
566 // Test passing in an argument with the wrong shape.
567 XlaBuilder builder(TestName());
568 auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {3}), "x");
569 Neg(x);
570
571 auto x_array = LiteralToShapedBuffer(
572 LiteralUtil::CreateR2<float>({{0.0f, 1.0f}, {2.0f, 3.0f}}));
573 auto execute_status =
574 ExecuteLocally(builder.Build().ValueOrDie(), {&x_array});
575
576 EXPECT_FALSE(execute_status.ok());
577 EXPECT_THAT(execute_status.status().error_message(),
578 ContainsRegex("Invalid argument shape"))
579 << execute_status.status();
580 }
581
XLA_TEST_F(LocalClientExecuteTest,InvalidResultLayout)582 XLA_TEST_F(LocalClientExecuteTest, InvalidResultLayout) {
583 // Test passing in an invalid result layout parameter.
584 XlaBuilder builder(TestName());
585 auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2, 2}), "x");
586 Neg(x);
587
588 auto x_array = LiteralToShapedBuffer(
589 LiteralUtil::CreateR2<float>({{0.0f, 1.0f}, {2.0f, 3.0f}}));
590 auto execute_status = ExecuteLocally(
591 builder.Build().ValueOrDie(), {&x_array},
592 DefaultExecutableBuildOptions().set_result_layout(
593 ShapeUtil::MakeShapeWithLayout(F32,
594 /*dimensions=*/{1, 2, 3, 4},
595 /*minor_to_major=*/{0, 1, 2, 3})),
596 DefaultExecutableRunOptions());
597
598 EXPECT_FALSE(execute_status.ok());
599 EXPECT_THAT(execute_status.status().error_message(),
600 ContainsRegex("not compatible with result shape"))
601 << execute_status.status();
602 }
603
XLA_TEST_F(LocalClientExecuteTest,RunOnAllDeviceOrdinals)604 XLA_TEST_F(LocalClientExecuteTest, RunOnAllDeviceOrdinals) {
605 // Try to run a trivial computation on every device on the system. If a
606 // specific device is not supported, check that the right error is returned.
607 XlaBuilder builder(TestName());
608 ConstantR0<float>(&builder, 42.0f);
609 auto computation = builder.Build().ConsumeValueOrDie();
610 for (int d = 0; d < local_client_->device_count(); ++d) {
611 if (!local_client_->device_ordinal_supported(d)) {
612 auto execute_status =
613 ExecuteLocally(computation, {},
614 DefaultExecutableBuildOptions().set_device_ordinal(d),
615 DefaultExecutableRunOptions().set_device_ordinal(d));
616 EXPECT_FALSE(execute_status.ok());
617 EXPECT_THAT(execute_status.status().error_message(),
618 ContainsRegex("device .* not supported"));
619 } else {
620 auto result = ExecuteLocallyOrDie(
621 computation, {},
622 DefaultExecutableBuildOptions().set_device_ordinal(d),
623 DefaultExecutableRunOptions().set_device_ordinal(d));
624 EXPECT_EQ(d, result.device_ordinal());
625 LiteralTestUtil::ExpectR0Equal<float>(42.0f,
626 ShapedBufferToLiteral(result));
627 }
628 }
629 }
630
XLA_TEST_F(LocalClientExecuteTest,InvalidDeviceOrdinalValues)631 XLA_TEST_F(LocalClientExecuteTest, InvalidDeviceOrdinalValues) {
632 // Try running computations on devices with device ordinal values which do not
633 // exist.
634 XlaBuilder builder(TestName());
635 ConstantR0<float>(&builder, 42.0f);
636 auto computation = builder.Build().ConsumeValueOrDie();
637
638 auto execute_status =
639 ExecuteLocally(computation, {},
640 DefaultExecutableBuildOptions().set_device_ordinal(
641 local_client_->device_count()),
642 DefaultExecutableRunOptions().set_device_ordinal(
643 local_client_->device_count()));
644 EXPECT_FALSE(execute_status.ok());
645 EXPECT_THAT(execute_status.status().error_message(),
646 ContainsRegex("Invalid device ordinal value"));
647 }
648
XLA_TEST_F(LocalClientExecuteTest,RunOnStream)649 XLA_TEST_F(LocalClientExecuteTest, RunOnStream) {
650 // Run a computation on a specific stream on each device on the system.
651 XlaBuilder builder(TestName());
652 ConstantR0<float>(&builder, 42.0f);
653 auto computation = builder.Build().ConsumeValueOrDie();
654
655 for (int d = 0; d < local_client_->device_count(); ++d) {
656 if (!local_client_->device_ordinal_supported(d)) {
657 continue;
658 }
659 se::StreamExecutor* executor =
660 local_client_->platform()->ExecutorForDevice(d).ValueOrDie();
661 se::Stream stream(executor);
662 stream.Init();
663
664 auto result =
665 ExecuteLocallyOrDie(computation, {}, DefaultExecutableBuildOptions(),
666 DefaultExecutableRunOptions().set_stream(&stream));
667 // As a check to verify that the computation ran of the device associated
668 // with the stream. This is a weak check, but stronger verification is hard.
669 EXPECT_EQ(d, result.device_ordinal());
670 LiteralTestUtil::ExpectR0Equal<float>(42.0f, ShapedBufferToLiteral(result));
671 }
672 }
673
674 // Disable this test on CPU because we're using the CPU as the platform
675 // which does not match the service platform.
XLA_TEST_F(LocalClientExecuteTest,DISABLED_ON_CPU (RunOnStreamForWrongPlatform))676 XLA_TEST_F(LocalClientExecuteTest,
677 DISABLED_ON_CPU(RunOnStreamForWrongPlatform)) {
678 // Try to run a computation on a stream for a platform (CPU) which does not
679 // match the platform of the service (!= CPU).
680 se::Platform* wrong_platform =
681 se::MultiPlatformManager::PlatformWithId(se::host::kHostPlatformId)
682 .ValueOrDie();
683 se::Stream wrong_stream(wrong_platform->ExecutorForDevice(0).ValueOrDie());
684 wrong_stream.Init();
685
686 XlaBuilder builder(TestName());
687 ConstantR0<float>(&builder, 42.0f);
688 auto execute_status = ExecuteLocally(
689 builder.Build().ValueOrDie(), {}, DefaultExecutableBuildOptions(),
690 DefaultExecutableRunOptions().set_stream(&wrong_stream));
691 EXPECT_FALSE(execute_status.ok());
692 EXPECT_THAT(execute_status.status().error_message(),
693 ContainsRegex("stream is for platform .*, but service targets"));
694 }
695
XLA_TEST_F(LocalClientExecuteTest,DISABLED_ON_CPU (AllocatorDoesNotMatchPlatform))696 XLA_TEST_F(LocalClientExecuteTest,
697 DISABLED_ON_CPU(AllocatorDoesNotMatchPlatform)) {
698 se::Platform* wrong_platform =
699 se::MultiPlatformManager::PlatformWithId(se::host::kHostPlatformId)
700 .ValueOrDie();
701 TestAllocator allocator(wrong_platform);
702
703 XlaBuilder builder(TestName());
704 ConstantR0<float>(&builder, 123.0f);
705
706 auto execute_status = ExecuteLocally(
707 builder.Build().ValueOrDie(), {}, DefaultExecutableBuildOptions(),
708 DefaultExecutableRunOptions().set_allocator(&allocator));
709 EXPECT_FALSE(execute_status.ok());
710 EXPECT_THAT(execute_status.status().error_message(),
711 ContainsRegex("allocator platform .* does not match service"));
712 }
713
XLA_TEST_F(LocalClientExecuteTest,RunOnUninitializedStream)714 XLA_TEST_F(LocalClientExecuteTest, RunOnUninitializedStream) {
715 // Try to run a computation on a stream that has not been initialized.
716 XlaBuilder builder(TestName());
717 ConstantR0<float>(&builder, 42.0f);
718
719 LOG(INFO) << "default device = " << local_client_->default_device_ordinal();
720 se::StreamExecutor* executor =
721 local_client_->platform()
722 ->ExecutorForDevice(local_client_->default_device_ordinal())
723 .ValueOrDie();
724 se::Stream stream(executor);
725 // Don't call stream.Init().
726
727 auto execute_status = ExecuteLocally(
728 builder.Build().ValueOrDie(), {}, DefaultExecutableBuildOptions(),
729 DefaultExecutableRunOptions().set_stream(&stream));
730 EXPECT_FALSE(execute_status.ok());
731 EXPECT_THAT(execute_status.status().error_message(),
732 ContainsRegex("stream is uninitialized or in an error state"));
733 }
734
XLA_TEST_F(LocalClientExecuteTest,DISABLED_ON_GPU (SelectBetweenTuples))735 XLA_TEST_F(LocalClientExecuteTest, DISABLED_ON_GPU(SelectBetweenTuples)) {
736 XlaBuilder builder(TestName());
737
738 std::initializer_list<float> vec1 = {1.f, 2.f, 3.f};
739 std::initializer_list<float> vec2 = {2.f, 4.f, 6.f};
740 auto tuple12 = Tuple(&builder, {ConstantR1<float>(&builder, vec1),
741 ConstantR1<float>(&builder, vec2)});
742 auto tuple21 = Tuple(&builder, {ConstantR1<float>(&builder, vec2),
743 ConstantR1<float>(&builder, vec1)});
744 Select(ConstantR0<bool>(&builder, false), tuple12, tuple21);
745
746 ScopedShapedBuffer result =
747 ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {});
748 Literal tuple_literal = ShapedBufferToLiteral(result);
749 LiteralTestUtil::ExpectR1Equal<float>({2.0f, 4.0f, 6.0f},
750 LiteralSlice(tuple_literal, {0}));
751 LiteralTestUtil::ExpectR1Equal<float>({1.0f, 2.0f, 3.0f},
752 LiteralSlice(tuple_literal, {1}));
753 }
754
XLA_TEST_F(LocalClientExecuteTest,CompileExecutable)755 XLA_TEST_F(LocalClientExecuteTest, CompileExecutable) {
756 XlaBuilder builder(TestName());
757 auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {3}), "x");
758 auto y = ConstantR1<float>(&builder, {2.0f, 3.0f, 4.0f});
759 Add(x, y);
760
761 Shape argument_layout =
762 ShapeUtil::MakeShapeWithLayout(F32, /*dimensions=*/{3}, {0});
763 TF_ASSERT_OK_AND_ASSIGN(
764 auto executables,
765 local_client_->Compile(builder.Build().ValueOrDie(), {&argument_layout},
766 ExecutableBuildOptions()));
767 EXPECT_EQ(1, executables.size());
768
769 auto x_array =
770 LiteralToShapedBuffer(LiteralUtil::CreateR1<float>({0.0f, 1.0f, 2.0f}));
771 ScopedShapedBuffer result =
772 executables[0]
773 ->Run({&x_array}, DefaultExecutableRunOptions())
774 .ConsumeValueOrDie();
775 ASSERT_IS_OK(local_client_->mutable_backend()
776 ->BorrowStream(0)
777 .ValueOrDie()
778 ->BlockHostUntilDone());
779
780 LiteralTestUtil::ExpectR1Near<float>(
781 {2.0f, 4.0f, 6.0f}, ShapedBufferToLiteral(result), error_spec_);
782 }
783
XLA_TEST_F(LocalClientExecuteTest,CompilePartitionedExecutable)784 XLA_TEST_F(LocalClientExecuteTest, CompilePartitionedExecutable) {
785 if (local_client_->device_count() < 2) {
786 GTEST_SKIP_("requires two devices");
787 }
788
789 XlaBuilder builder(TestName());
790 auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {3}), "x");
791 auto y = ConstantR1<float>(&builder, {2.0f, 3.0f, 4.0f});
792 auto z = ConstantR1<float>(&builder, {5.0f, 6.0f, 7.0f});
793 auto r = Add(x, y);
794 builder.SetSharding(sharding_builder::AssignDevice(1));
795 Add(r, z);
796 builder.ClearSharding();
797
798 Shape argument_layout =
799 ShapeUtil::MakeShapeWithLayout(F32, /*dimensions=*/{3}, {0});
800 ExecutableBuildOptions build_options;
801 build_options.set_num_partitions(2);
802 TF_ASSERT_OK_AND_ASSIGN(
803 auto executables,
804 local_client_->Compile(builder.Build().ValueOrDie(), {&argument_layout},
805 build_options));
806 EXPECT_EQ(2, executables.size());
807 }
808
XLA_TEST_F(LocalClientExecuteTest,DISABLED_ON_INTERPRETER (SizeOfGeneratedCodeInBytes))809 XLA_TEST_F(LocalClientExecuteTest,
810 DISABLED_ON_INTERPRETER(SizeOfGeneratedCodeInBytes)) {
811 XlaBuilder builder(TestName());
812 auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {}), "x");
813 constexpr int size = 100000;
814 TF_ASSERT_OK_AND_ASSIGN(auto literal,
815 LiteralUtil::CreateRandomLiteral<F32>(
816 ShapeUtil::MakeShape(F32, {size}), 0.0, 1.0));
817 auto y = ConstantLiteral(&builder, literal);
818 Add(x, y);
819
820 Shape argument_layout =
821 ShapeUtil::MakeShapeWithLayout(F32, /*dimensions=*/{}, {});
822 TF_ASSERT_OK_AND_ASSIGN(
823 auto executables,
824 local_client_->Compile(builder.Build().ValueOrDie(), {&argument_layout},
825 ExecutableBuildOptions()));
826 EXPECT_EQ(1, executables.size());
827 // The executable should be at least as large as the constant it contains.
828 EXPECT_GT(executables.front()->executable()->SizeOfGeneratedCodeInBytes(),
829 int64{sizeof(float) * size});
830 }
831
XLA_TEST_F(LocalClientExecuteTest,ShapeBufferToLiteralConversion)832 XLA_TEST_F(LocalClientExecuteTest, ShapeBufferToLiteralConversion) {
833 // Test copying Literals to the device as ShapedBuffers, then copying them
834 // back again to Literals.
835 auto test_to_device_and_back = [this](const Literal& literal) {
836 TF_ASSERT_OK_AND_ASSIGN(
837 auto shaped_buffer,
838 local_client_->LiteralToShapedBuffer(
839 literal, local_client_->default_device_ordinal(), allocator_));
840 TF_ASSERT_OK_AND_ASSIGN(
841 auto transferred_literal,
842 local_client_->ShapedBufferToLiteral(shaped_buffer));
843 EXPECT_EQ(literal, transferred_literal);
844 };
845
846 // Array shapes.
847 test_to_device_and_back(LiteralUtil::CreateR0<float>(42.0));
848 test_to_device_and_back(LiteralUtil::CreateR0<bool>(true));
849 test_to_device_and_back(LiteralUtil::CreateR1<float>({1.0, 42.0, 744.4}));
850 test_to_device_and_back(
851 LiteralUtil::CreateR2<float>({{1.0, 2.0, 3.0}, {44.0, 0.1, -3}}));
852 test_to_device_and_back(LiteralUtil::CreateR2<int32>({{2, 1}, {4444, 56}}));
853
854 // Null shape (empty tuple).
855 test_to_device_and_back(LiteralUtil::MakeTuple({}));
856
857 // Non-nested tuples.
858 test_to_device_and_back(LiteralUtil::MakeTupleFromSlices(
859 {LiteralUtil::CreateR0<float>(12223.0)}));
860 test_to_device_and_back(LiteralUtil::MakeTupleFromSlices(
861 {LiteralUtil::CreateR1<float>({1.0, -42.0}),
862 LiteralUtil::CreateR0<float>(123456.0)}));
863
864 // Nested tuple.
865 test_to_device_and_back(LiteralUtil::MakeTupleFromSlices(
866 {LiteralUtil::MakeTupleFromSlices(
867 {LiteralUtil::CreateR1<float>({1.0, -42.0}),
868 LiteralUtil::CreateR0<float>(123456.0)}),
869 LiteralUtil::CreateR0<bool>(false)}));
870 }
871
XLA_TEST_F(LocalClientExecuteTest,ShapeBufferToLiteralConversion64bit)872 XLA_TEST_F(LocalClientExecuteTest, ShapeBufferToLiteralConversion64bit) {
873 // Test copying Literals to the device as ShapedBuffers, then copying them
874 // back again to Literals for 64-bit values.
875 auto test_to_device_and_back = [this](const Literal& literal) {
876 TF_ASSERT_OK_AND_ASSIGN(
877 auto shaped_buffer,
878 local_client_->LiteralToShapedBuffer(
879 literal, local_client_->default_device_ordinal(), allocator_));
880 TF_ASSERT_OK_AND_ASSIGN(
881 auto transferred_literal,
882 local_client_->ShapedBufferToLiteral(shaped_buffer));
883 EXPECT_EQ(literal, transferred_literal);
884 };
885
886 test_to_device_and_back(LiteralUtil::CreateR2<double>(
887 {{1.0, 2.0, 3.0}, {44.0, 0.099999999999999978, -3}}));
888 test_to_device_and_back(LiteralUtil::CreateR2<int64>({{2, 1}, {4444, 56}}));
889 test_to_device_and_back(
890 LiteralUtil::CreateR2<uint64>({{20000000000ULL, 1}, {4444, 56}}));
891 test_to_device_and_back(LiteralUtil::MakeTupleFromSlices(
892 {LiteralUtil::CreateR1<double>({1.0, -42.0}),
893 LiteralUtil::CreateR0<int64>(123456789000LL)}));
894 }
895
896 // Disabled on interpreter backend since infeed HLO is unsupported.
XLA_TEST_F(LocalClientExecuteTest,DISABLED_ON_INTERPRETER (InfeedTest))897 XLA_TEST_F(LocalClientExecuteTest, DISABLED_ON_INTERPRETER(InfeedTest)) {
898 XlaBuilder builder(TestName());
899 const Shape shape = ShapeUtil::MakeShape(F32, {3});
900 auto in = Infeed(&builder, shape);
901 auto constant = ConstantR1<float>(&builder, {1.0f, 2.0f, 3.0f});
902 Add(in, constant);
903
904 Literal result;
905 std::unique_ptr<tensorflow::Thread> thread(
906 tensorflow::Env::Default()->StartThread(
907 tensorflow::ThreadOptions(), "execute_thread", [&] {
908 result = ShapedBufferToLiteral(ExecuteLocallyOrDie(
909 builder.Build().ValueOrDie(), /*arguments=*/{}));
910 }));
911
912 ASSERT_IS_OK(local_client_->TransferToInfeedLocal(
913 LiteralUtil::CreateR1<float>({-5.0, 123.0, 42.0}),
914 local_client_->default_device_ordinal()));
915
916 // Join the thread.
917 thread.reset();
918
919 LiteralTestUtil::ExpectR1Equal<float>({-4.0, 125.0, 45.0}, result);
920 }
921
922 // Disabled on interpreter backend since infeed/outfeed HLOs are unsupported.
XLA_TEST_F(LocalClientExecuteTest,DISABLED_ON_INTERPRETER (InfeedOutfeedTest))923 XLA_TEST_F(LocalClientExecuteTest, DISABLED_ON_INTERPRETER(InfeedOutfeedTest)) {
924 XlaBuilder builder(TestName());
925 const Shape shape = ShapeUtil::MakeShape(F32, {3});
926 auto in = Infeed(&builder, shape);
927 auto constant = ConstantR1<float>(&builder, {1.0f, 2.0f, 3.0f});
928 auto sum = Add(in, constant);
929 Outfeed(sum, shape, /*outfeed_config=*/"");
930
931 std::unique_ptr<tensorflow::Thread> thread(
932 tensorflow::Env::Default()->StartThread(
933 tensorflow::ThreadOptions(), "execute_thread",
934 [&] { ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {}); }));
935
936 ASSERT_IS_OK(local_client_->TransferToInfeedLocal(
937 LiteralUtil::CreateR1<float>({-5.0, 123.0, 42.0}),
938 local_client_->default_device_ordinal()));
939
940 Literal result(shape);
941 ASSERT_IS_OK(local_client_->TransferFromOutfeedLocal(
942 local_client_->default_device_ordinal(), &result));
943
944 LiteralTestUtil::ExpectR1Equal<float>({-4.0, 125.0, 45.0}, result);
945 }
946
947 // Benchmark that measures the overhead of the LocalClient API when running a
948 // trivial computation
BM_LocalClientOverhead(::testing::benchmark::State & state)949 void BM_LocalClientOverhead(::testing::benchmark::State& state) {
950 se::Platform* platform = PlatformUtil::GetDefaultPlatform().ValueOrDie();
951 auto executors = PlatformUtil::GetStreamExecutors(platform).ValueOrDie();
952 se::StreamExecutorMemoryAllocator allocator(platform, executors);
953 LocalClient* client =
954 ClientLibrary::GetOrCreateLocalClient(platform).ValueOrDie();
955 auto* transfer_manager =
956 TransferManager::GetForPlatform(platform).ValueOrDie();
957 int device_ordinal = client->default_device_ordinal();
958
959 // Use a tiny add operation as the computation.
960 XlaBuilder builder("Add");
961 auto shape = ShapeUtil::MakeShape(F32, {2, 3});
962 auto x = Parameter(&builder, 0, shape, "x");
963 Add(x, x);
964 auto computation = builder.Build().ConsumeValueOrDie();
965
966 auto buffer =
967 transfer_manager
968 ->AllocateScopedShapedBuffer(shape, &allocator, /*device_ordinal=*/0)
969 .ConsumeValueOrDie();
970 auto literal = LiteralUtil::CreateR2<float>({{0, 0, 0}, {0, 0, 0}});
971 auto stream =
972 client->mutable_backend()->BorrowStream(device_ordinal).ValueOrDie();
973 ASSERT_IS_OK(
974 transfer_manager->TransferLiteralToDevice(stream.get(), literal, buffer));
975
976 const int kWarmups = 2;
977
978 TF_ASSERT_OK_AND_ASSIGN(
979 auto executables, client->Compile(computation, {&buffer.on_host_shape()},
980 ExecutableBuildOptions()));
981 std::unique_ptr<LocalExecutable> executable = std::move(executables[0]);
982
983 ExecutableRunOptions run_options;
984 run_options.set_allocator(&allocator).set_stream(stream.get());
985
986 for (int i = 0; i < kWarmups; ++i) {
987 auto result = executable->Run({&buffer}, run_options);
988 ASSERT_IS_OK(result);
989 }
990
991 for (auto s : state) {
992 auto result = executable->Run({&buffer}, run_options);
993 ASSERT_IS_OK(result);
994 }
995 }
996
997 BENCHMARK(BM_LocalClientOverhead);
998
999 } // namespace
1000 } // namespace xla
1001