1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <memory>
17 #include <string>
18 #include <vector>
19
20 #include "tensorflow/compiler/xla/layout_util.h"
21 #include "tensorflow/compiler/xla/literal.h"
22 #include "tensorflow/compiler/xla/service/generic_transfer_manager.h"
23 #include "tensorflow/compiler/xla/service/shaped_buffer.h"
24 #include "tensorflow/compiler/xla/service/stream_pool.h"
25 #include "tensorflow/compiler/xla/shape_util.h"
26 #include "tensorflow/compiler/xla/statusor.h"
27 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
28 #include "tensorflow/compiler/xla/tests/local_client_test_base.h"
29 #include "tensorflow/compiler/xla/tests/test_macros.h"
30 #include "tensorflow/compiler/xla/types.h"
31 #include "tensorflow/compiler/xla/xla_data.pb.h"
32 #include "tensorflow/core/platform/logging.h"
33 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
34 #include "tensorflow/core/platform/test_benchmark.h"
35 #include "tensorflow/core/platform/types.h"
36 #include "tensorflow/stream_executor/device_memory_allocator.h"
37
38 namespace xla {
39 namespace {
40
41 class TransferManagerTest : public LocalClientTestBase {
42 protected:
TransferManagerTest()43 TransferManagerTest()
44 : shape_size_fn_([this](const Shape& shape) {
45 return transfer_manager_->GetByteSizeRequirement(shape);
46 }) {
47 stream_ptr_ = local_client_->mutable_backend()
48 ->BorrowStream(stream_executor_)
49 .ValueOrDie();
50 stream_ = stream_ptr_.get();
51 }
52
53 ~TransferManagerTest() override = default;
54
AllocateDeviceBuffer(const Shape & shape)55 ScopedShapedBuffer AllocateDeviceBuffer(const Shape& shape) {
56 return transfer_manager_
57 ->AllocateScopedShapedBuffer(
58 shape, GetOrCreateAllocator(local_client_->platform()),
59 /*device_ordinal=*/0)
60 .ValueOrDie();
61 }
62
63 protected:
64 StreamPool::Ptr stream_ptr_;
65 se::Stream* stream_;
66
67 private:
68 std::function<int64(const Shape&)> shape_size_fn_;
69 };
70
XLA_TEST_F(TransferManagerTest,TransferR0U32)71 XLA_TEST_F(TransferManagerTest, TransferR0U32) {
72 Literal literal = LiteralUtil::CreateR0<uint32>(42);
73 const Shape& shape = literal.shape();
74 auto device_buffer = AllocateDeviceBuffer(shape);
75
76 // Round trip literal through device.
77 ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal,
78 device_buffer));
79 TF_ASSERT_OK_AND_ASSIGN(
80 Literal result,
81 transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer));
82
83 LiteralTestUtil::ExpectR0Equal<uint32>(42, result);
84 }
85
XLA_TEST_F(TransferManagerTest,TransferR1F32)86 XLA_TEST_F(TransferManagerTest, TransferR1F32) {
87 Literal literal =
88 LiteralUtil::CreateR1<float>({1.25f, 2.5f, -17.0f, -20.125f});
89 const Shape& shape = literal.shape();
90 auto device_buffer = AllocateDeviceBuffer(shape);
91
92 // Round trip literal through device.
93 ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal,
94 device_buffer));
95 TF_ASSERT_OK_AND_ASSIGN(
96 Literal result,
97 transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer));
98
99 LiteralTestUtil::ExpectR1Equal<float>({1.25f, 2.5f, -17.0f, -20.125f},
100 result);
101 }
102
XLA_TEST_F(TransferManagerTest,TransferR1F32AwkwardSizes)103 XLA_TEST_F(TransferManagerTest, TransferR1F32AwkwardSizes) {
104 // Test transferring R1s from 0 to kMaxR1Size. The goal is to find bugs
105 // related to "awkwardly" sized R1s.
106 constexpr int kMaxR1Size = (1 << 11);
107 for (int i = 0; i < kMaxR1Size; ++i) {
108 std::vector<float> inputs(i);
109 std::iota(inputs.begin(), inputs.end(), 0);
110 Literal literal = LiteralUtil::CreateR1<float>(inputs);
111 const Shape& shape = literal.shape();
112 auto device_buffer = AllocateDeviceBuffer(shape);
113
114 // Round trip literal through device.
115 ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal,
116 device_buffer));
117 TF_ASSERT_OK_AND_ASSIGN(
118 Literal result,
119 transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer));
120
121 LiteralTestUtil::ExpectR1Equal<float>(inputs, result);
122 }
123 }
124
XLA_TEST_F(TransferManagerTest,TransferR1LargeF32)125 XLA_TEST_F(TransferManagerTest, TransferR1LargeF32) {
126 std::vector<float> test_vector(1024 * 1024);
127 std::iota(test_vector.begin(), test_vector.end(), 0);
128 Literal literal = LiteralUtil::CreateR1<float>(test_vector);
129 const Shape& shape = literal.shape();
130 auto device_buffer = AllocateDeviceBuffer(shape);
131
132 // Round trip literal through device.
133 ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal,
134 device_buffer));
135 TF_ASSERT_OK_AND_ASSIGN(
136 Literal result,
137 transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer));
138
139 LiteralTestUtil::ExpectR1Equal<float>(test_vector, result);
140 }
141
XLA_TEST_F(TransferManagerTest,TransferR1LargeUnalignedF32)142 XLA_TEST_F(TransferManagerTest, TransferR1LargeUnalignedF32) {
143 std::vector<float> test_vector(1025);
144 std::iota(test_vector.begin(), test_vector.end(), 0);
145 Shape shape = ShapeUtil::MakeShape(F32, {1024});
146 BorrowingLiteral literal(reinterpret_cast<const char*>(&test_vector[1]),
147 shape);
148 auto device_buffer = AllocateDeviceBuffer(shape);
149
150 // Round trip literal through device.
151 ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal,
152 device_buffer));
153 TF_ASSERT_OK_AND_ASSIGN(
154 Literal result,
155 transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer));
156
157 std::vector<float> expected_output(1024);
158 std::iota(expected_output.begin(), expected_output.end(), 1);
159 LiteralTestUtil::ExpectR1Equal<float>(expected_output, result);
160 }
161
XLA_TEST_F(TransferManagerTest,TransferR1U8)162 XLA_TEST_F(TransferManagerTest, TransferR1U8) {
163 const char* test_string = "0123456789abcdef";
164 Literal literal = LiteralUtil::CreateR1U8(test_string);
165 const Shape& shape = literal.shape();
166 auto device_buffer = AllocateDeviceBuffer(shape);
167
168 // Round trip literal through device.
169 ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal,
170 device_buffer));
171 TF_ASSERT_OK_AND_ASSIGN(
172 Literal result,
173 transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer));
174
175 EXPECT_EQ(result.GetR1U8AsString(), test_string);
176 }
177
XLA_TEST_F(TransferManagerTest,TransferR2F32)178 XLA_TEST_F(TransferManagerTest, TransferR2F32) {
179 Literal literal =
180 LiteralUtil::CreateR2<float>({{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}});
181 const Shape& shape = literal.shape();
182 auto device_buffer = AllocateDeviceBuffer(shape);
183
184 // Round trip literal through device.
185 ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal,
186 device_buffer));
187 TF_ASSERT_OK_AND_ASSIGN(
188 Literal result,
189 transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer));
190
191 LiteralTestUtil::ExpectR2Equal<float>(
192 {{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, result);
193 }
194
XLA_TEST_F(TransferManagerTest,TransferR2F32AndChangeLayoutTransferringToDevice)195 XLA_TEST_F(TransferManagerTest,
196 TransferR2F32AndChangeLayoutTransferringToDevice) {
197 Literal literal = LiteralUtil::CreateR2WithLayout<float>(
198 {{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, LayoutUtil::MakeLayout({0, 1}));
199 const Shape ondevice_shape =
200 ShapeUtil::MakeShapeWithLayout(F32, {2, 3}, {1, 0});
201 auto device_buffer = AllocateDeviceBuffer(ondevice_shape);
202
203 // Round trip literal through device. Set the on-device layout to something
204 // different than the literal layout.
205 ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal,
206 device_buffer));
207 TF_ASSERT_OK_AND_ASSIGN(
208 Literal result,
209 transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer));
210
211 EXPECT_FALSE(
212 LayoutUtil::Equal(result.shape().layout(), literal.shape().layout()));
213 LiteralTestUtil::ExpectR2Equal<float>(
214 {{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, result);
215 }
216
XLA_TEST_F(TransferManagerTest,TransferTuple)217 XLA_TEST_F(TransferManagerTest, TransferTuple) {
218 Literal literal = LiteralUtil::MakeTupleFromSlices(
219 {LiteralUtil::CreateR0<float>(123.0f),
220 LiteralUtil::CreateR2<float>({{1.0f, 2.0f}, {4.0f, 5.0f}}),
221 LiteralUtil::CreateR1<float>({44.0f, -10.0f, 3333333.3f})});
222 auto device_buffer = AllocateDeviceBuffer(literal.shape());
223
224 // Round trip literal through device.
225 ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal,
226 device_buffer));
227 TF_ASSERT_OK_AND_ASSIGN(
228 Literal result,
229 transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer));
230
231 EXPECT_TRUE(LiteralTestUtil::Equal(literal, result));
232 }
233
XLA_TEST_F(TransferManagerTest,TransferEmptyTuple)234 XLA_TEST_F(TransferManagerTest, TransferEmptyTuple) {
235 Literal literal = LiteralUtil::MakeTuple({});
236 auto device_buffer = AllocateDeviceBuffer(literal.shape());
237
238 // Round trip literal through device.
239 ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal,
240 device_buffer));
241 TF_ASSERT_OK_AND_ASSIGN(
242 Literal result,
243 transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer));
244
245 EXPECT_TRUE(LiteralTestUtil::Equal(literal, result));
246 }
247
XLA_TEST_F(TransferManagerTest,TransferNestedTuple)248 XLA_TEST_F(TransferManagerTest, TransferNestedTuple) {
249 Literal literal = LiteralUtil::MakeTupleFromSlices(
250 {LiteralUtil::CreateR0<float>(123.0f),
251 LiteralUtil::MakeTupleFromSlices(
252 {LiteralUtil::CreateR2<float>({{1.0f, 2.0f}, {4.0f, 5.0f}}),
253 LiteralUtil::CreateR1<float>({44.0f, -10.0f, 3333333.3f})}),
254 LiteralUtil::CreateR1<float>({-10.0f, 123.0f})});
255 auto device_buffer = AllocateDeviceBuffer(literal.shape());
256
257 // Round trip literal through device.
258 ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal,
259 device_buffer));
260 TF_ASSERT_OK_AND_ASSIGN(
261 Literal result,
262 transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer));
263
264 EXPECT_TRUE(LiteralTestUtil::Equal(literal, result));
265 }
266
XLA_TEST_F(TransferManagerTest,TransferComplexValue)267 XLA_TEST_F(TransferManagerTest, TransferComplexValue) {
268 Literal literal = LiteralUtil::CreateR1<complex64>(
269 {complex64(1.0f, 2.0f), complex64(42.0f, -123.4f)});
270 auto device_buffer = AllocateDeviceBuffer(literal.shape());
271
272 // Round trip literal through device.
273 ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal,
274 device_buffer));
275 TF_ASSERT_OK_AND_ASSIGN(
276 Literal result,
277 transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer));
278
279 EXPECT_TRUE(LiteralTestUtil::Equal(literal, result));
280 }
281
XLA_TEST_F(TransferManagerTest,TransferComplexValueInTuple)282 XLA_TEST_F(TransferManagerTest, TransferComplexValueInTuple) {
283 Literal literal = LiteralUtil::MakeTupleFromSlices(
284 {LiteralUtil::CreateR1<complex64>(
285 {complex64(1.0f, 2.0f), complex64(42.0f, -123.4f)}),
286 LiteralUtil::CreateR1<int32>({1, 2, 3, 4, 5, 6}),
287 LiteralUtil::CreateR0<complex64>(complex64(0.3f, -0.4f))});
288 auto device_buffer = AllocateDeviceBuffer(literal.shape());
289
290 // Round trip literal through device.
291 ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal,
292 device_buffer));
293 TF_ASSERT_OK_AND_ASSIGN(
294 Literal result,
295 transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer));
296
297 EXPECT_TRUE(LiteralTestUtil::Equal(literal, result));
298 }
299
XLA_TEST_F(TransferManagerTest,TransferTokenFromDevice)300 XLA_TEST_F(TransferManagerTest, TransferTokenFromDevice) {
301 // "Copy" a token from the device. The token has no physical representation
302 // so no copying is actually performed, but it shouldn't fail.
303 // TODO(b/110532604): Add transferring the token to device when this is
304 // supported.
305 auto device_buffer = AllocateDeviceBuffer(ShapeUtil::MakeTokenShape());
306 TF_ASSERT_OK_AND_ASSIGN(
307 Literal result,
308 transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer));
309 EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateToken(), result));
310 }
311
XLA_TEST_F(TransferManagerTest,MultiStreamRoundTripSoak)312 XLA_TEST_F(TransferManagerTest, MultiStreamRoundTripSoak) {
313 const int64_t kIterationCount = 5000;
314 Literal literal1 = LiteralUtil::MakeTupleFromSlices(
315 {LiteralUtil::CreateR0<float>(123.0f),
316 LiteralUtil::MakeTupleFromSlices(
317 {LiteralUtil::CreateR2<float>({{1.0f, 2.0f}, {4.0f, 5.0f}}),
318 LiteralUtil::CreateR1<float>({44.0f, -10.0f, 3333333.3f})}),
319 LiteralUtil::CreateR1<float>({-10.0f, 123.0f})});
320 Literal literal2 = LiteralUtil::MakeTupleFromSlices(
321 {LiteralUtil::CreateR0<float>(456.0f),
322 LiteralUtil::MakeTupleFromSlices(
323 {LiteralUtil::CreateR2<float>({{5.0f, 7.0f}, {9.0f, 4.0f}}),
324 LiteralUtil::CreateR1<float>({44.0f, -11.0f, 3333333.3f})}),
325 LiteralUtil::CreateR1<float>({-98.0f, 153.0f})});
326
327 auto device_buffer1 = AllocateDeviceBuffer(literal1.shape());
328 auto device_buffer2 = AllocateDeviceBuffer(literal2.shape());
329
330 auto stream1 = stream_;
331 auto stream2 = stream_->GetOrCreateSubStream();
332
333 Literal result1, result2;
334
335 // Round trip literals through device in multiple streams asynchronously.
336 for (int i = 0; i < kIterationCount; ++i) {
337 ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream1, literal1,
338 device_buffer1));
339 ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream2, literal2,
340 device_buffer2));
341 TF_ASSERT_OK_AND_ASSIGN(
342 Literal this_result1,
343 transfer_manager_->TransferLiteralFromDevice(stream1, device_buffer1));
344 TF_ASSERT_OK_AND_ASSIGN(
345 Literal this_result2,
346 transfer_manager_->TransferLiteralFromDevice(stream2, device_buffer2));
347 result1 = std::move(this_result1);
348 result2 = std::move(this_result2);
349 }
350
351 EXPECT_TRUE(LiteralTestUtil::Equal(literal1, result1));
352 EXPECT_TRUE(LiteralTestUtil::Equal(literal2, result2));
353 }
354
355 class TransferDeviceToHostBenchmark : public TransferManagerTest {
356 public:
357 using TransferManagerTest::TransferManagerTest;
~TransferDeviceToHostBenchmark()358 ~TransferDeviceToHostBenchmark() override {}
359
Run(::testing::benchmark::State & state,int num_tuple_elements,int array_size)360 void Run(::testing::benchmark::State& state, int num_tuple_elements,
361 int array_size) {
362 SetUp();
363
364 std::vector<Literal> tuple_elements;
365 for (int i = 0; i < num_tuple_elements; ++i) {
366 tuple_elements.push_back(
367 LiteralUtil::CreateR2F32Linspace(0.0f, 1.0f, array_size, array_size));
368 }
369 Literal literal = LiteralUtil::MakeTupleOwned(std::move(tuple_elements));
370 auto device_buffer = AllocateDeviceBuffer(literal.shape());
371 TF_CHECK_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal,
372 device_buffer));
373 for (auto s : state) {
374 TF_ASSERT_OK_AND_ASSIGN(
375 Literal result,
376 transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer));
377 }
378 TearDown();
379 }
380
TestBody()381 void TestBody() override {}
382 };
383
384 class TransferHostToDeviceBenchmark : public TransferManagerTest {
385 public:
386 using TransferManagerTest::TransferManagerTest;
~TransferHostToDeviceBenchmark()387 ~TransferHostToDeviceBenchmark() override {}
388
Run(::testing::benchmark::State & state,int num_tuple_elements,int array_size)389 void Run(::testing::benchmark::State& state, int num_tuple_elements,
390 int array_size) {
391 SetUp();
392
393 std::vector<Literal> tuple_elements;
394 for (int i = 0; i < num_tuple_elements; ++i) {
395 tuple_elements.push_back(
396 LiteralUtil::CreateR2F32Linspace(0.0f, 1.0f, array_size, array_size));
397 }
398 Literal literal = LiteralUtil::MakeTupleOwned(std::move(tuple_elements));
399 auto device_buffer = AllocateDeviceBuffer(literal.shape());
400
401 for (auto s : state) {
402 TF_CHECK_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal,
403 device_buffer));
404 }
405 TearDown();
406 }
407
TestBody()408 void TestBody() override {}
409 };
410
BM_TransferDeviceToHost(::testing::benchmark::State & state)411 void BM_TransferDeviceToHost(::testing::benchmark::State& state) {
412 const int num_tuple_elements = state.range(0);
413 const int array_size = state.range(1);
414
415 TransferDeviceToHostBenchmark bm;
416 bm.Run(state, num_tuple_elements, array_size);
417 }
418
BM_TransferHostToDevice(::testing::benchmark::State & state)419 void BM_TransferHostToDevice(::testing::benchmark::State& state) {
420 const int num_tuple_elements = state.range(0);
421 const int array_size = state.range(1);
422
423 TransferHostToDeviceBenchmark bm;
424 bm.Run(state, num_tuple_elements, array_size);
425 }
426
427 BENCHMARK(BM_TransferHostToDevice)
428 ->ArgPair(1, 256)
429 ->ArgPair(1, 257)
430 ->ArgPair(100, 256)
431 ->ArgPair(100, 257);
432
433 BENCHMARK(BM_TransferDeviceToHost)
434 ->ArgPair(1, 256)
435 ->ArgPair(1, 257)
436 ->ArgPair(100, 256)
437 ->ArgPair(100, 257);
438
main(int argc,char ** argv)439 int main(int argc, char** argv) {
440 ::testing::InitGoogleTest(&argc, argv);
441 tensorflow::testing::RunBenchmarks();
442 return RUN_ALL_TESTS();
443 }
444
445 } // namespace
446 } // namespace xla
447