• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <memory>
17 #include <string>
18 #include <vector>
19 
20 #include "tensorflow/compiler/xla/layout_util.h"
21 #include "tensorflow/compiler/xla/literal.h"
22 #include "tensorflow/compiler/xla/service/generic_transfer_manager.h"
23 #include "tensorflow/compiler/xla/service/shaped_buffer.h"
24 #include "tensorflow/compiler/xla/service/stream_pool.h"
25 #include "tensorflow/compiler/xla/shape_util.h"
26 #include "tensorflow/compiler/xla/statusor.h"
27 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
28 #include "tensorflow/compiler/xla/tests/local_client_test_base.h"
29 #include "tensorflow/compiler/xla/tests/test_macros.h"
30 #include "tensorflow/compiler/xla/types.h"
31 #include "tensorflow/compiler/xla/xla_data.pb.h"
32 #include "tensorflow/core/platform/logging.h"
33 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
34 #include "tensorflow/core/platform/test_benchmark.h"
35 #include "tensorflow/core/platform/types.h"
36 #include "tensorflow/stream_executor/device_memory_allocator.h"
37 
38 namespace xla {
39 namespace {
40 
41 class TransferManagerTest : public LocalClientTestBase {
42  protected:
TransferManagerTest()43   TransferManagerTest()
44       : shape_size_fn_([this](const Shape& shape) {
45           return transfer_manager_->GetByteSizeRequirement(shape);
46         }) {
47     stream_ptr_ = local_client_->mutable_backend()
48                       ->BorrowStream(stream_executor_)
49                       .ValueOrDie();
50     stream_ = stream_ptr_.get();
51   }
52 
53   ~TransferManagerTest() override = default;
54 
AllocateDeviceBuffer(const Shape & shape)55   ScopedShapedBuffer AllocateDeviceBuffer(const Shape& shape) {
56     return transfer_manager_
57         ->AllocateScopedShapedBuffer(
58             shape, GetOrCreateAllocator(local_client_->platform()),
59             /*device_ordinal=*/0)
60         .ValueOrDie();
61   }
62 
63  protected:
64   StreamPool::Ptr stream_ptr_;
65   se::Stream* stream_;
66 
67  private:
68   std::function<int64(const Shape&)> shape_size_fn_;
69 };
70 
XLA_TEST_F(TransferManagerTest,TransferR0U32)71 XLA_TEST_F(TransferManagerTest, TransferR0U32) {
72   Literal literal = LiteralUtil::CreateR0<uint32>(42);
73   const Shape& shape = literal.shape();
74   auto device_buffer = AllocateDeviceBuffer(shape);
75 
76   // Round trip literal through device.
77   ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal,
78                                                           device_buffer));
79   TF_ASSERT_OK_AND_ASSIGN(
80       Literal result,
81       transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer));
82 
83   LiteralTestUtil::ExpectR0Equal<uint32>(42, result);
84 }
85 
XLA_TEST_F(TransferManagerTest,TransferR1F32)86 XLA_TEST_F(TransferManagerTest, TransferR1F32) {
87   Literal literal =
88       LiteralUtil::CreateR1<float>({1.25f, 2.5f, -17.0f, -20.125f});
89   const Shape& shape = literal.shape();
90   auto device_buffer = AllocateDeviceBuffer(shape);
91 
92   // Round trip literal through device.
93   ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal,
94                                                           device_buffer));
95   TF_ASSERT_OK_AND_ASSIGN(
96       Literal result,
97       transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer));
98 
99   LiteralTestUtil::ExpectR1Equal<float>({1.25f, 2.5f, -17.0f, -20.125f},
100                                         result);
101 }
102 
XLA_TEST_F(TransferManagerTest,TransferR1F32AwkwardSizes)103 XLA_TEST_F(TransferManagerTest, TransferR1F32AwkwardSizes) {
104   // Test transferring R1s from 0 to kMaxR1Size. The goal is to find bugs
105   // related to "awkwardly" sized R1s.
106   constexpr int kMaxR1Size = (1 << 11);
107   for (int i = 0; i < kMaxR1Size; ++i) {
108     std::vector<float> inputs(i);
109     std::iota(inputs.begin(), inputs.end(), 0);
110     Literal literal = LiteralUtil::CreateR1<float>(inputs);
111     const Shape& shape = literal.shape();
112     auto device_buffer = AllocateDeviceBuffer(shape);
113 
114     // Round trip literal through device.
115     ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal,
116                                                             device_buffer));
117     TF_ASSERT_OK_AND_ASSIGN(
118         Literal result,
119         transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer));
120 
121     LiteralTestUtil::ExpectR1Equal<float>(inputs, result);
122   }
123 }
124 
XLA_TEST_F(TransferManagerTest,TransferR1LargeF32)125 XLA_TEST_F(TransferManagerTest, TransferR1LargeF32) {
126   std::vector<float> test_vector(1024 * 1024);
127   std::iota(test_vector.begin(), test_vector.end(), 0);
128   Literal literal = LiteralUtil::CreateR1<float>(test_vector);
129   const Shape& shape = literal.shape();
130   auto device_buffer = AllocateDeviceBuffer(shape);
131 
132   // Round trip literal through device.
133   ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal,
134                                                           device_buffer));
135   TF_ASSERT_OK_AND_ASSIGN(
136       Literal result,
137       transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer));
138 
139   LiteralTestUtil::ExpectR1Equal<float>(test_vector, result);
140 }
141 
XLA_TEST_F(TransferManagerTest,TransferR1LargeUnalignedF32)142 XLA_TEST_F(TransferManagerTest, TransferR1LargeUnalignedF32) {
143   std::vector<float> test_vector(1025);
144   std::iota(test_vector.begin(), test_vector.end(), 0);
145   Shape shape = ShapeUtil::MakeShape(F32, {1024});
146   BorrowingLiteral literal(reinterpret_cast<const char*>(&test_vector[1]),
147                            shape);
148   auto device_buffer = AllocateDeviceBuffer(shape);
149 
150   // Round trip literal through device.
151   ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal,
152                                                           device_buffer));
153   TF_ASSERT_OK_AND_ASSIGN(
154       Literal result,
155       transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer));
156 
157   std::vector<float> expected_output(1024);
158   std::iota(expected_output.begin(), expected_output.end(), 1);
159   LiteralTestUtil::ExpectR1Equal<float>(expected_output, result);
160 }
161 
XLA_TEST_F(TransferManagerTest,TransferR1U8)162 XLA_TEST_F(TransferManagerTest, TransferR1U8) {
163   const char* test_string = "0123456789abcdef";
164   Literal literal = LiteralUtil::CreateR1U8(test_string);
165   const Shape& shape = literal.shape();
166   auto device_buffer = AllocateDeviceBuffer(shape);
167 
168   // Round trip literal through device.
169   ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal,
170                                                           device_buffer));
171   TF_ASSERT_OK_AND_ASSIGN(
172       Literal result,
173       transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer));
174 
175   EXPECT_EQ(result.GetR1U8AsString(), test_string);
176 }
177 
XLA_TEST_F(TransferManagerTest,TransferR2F32)178 XLA_TEST_F(TransferManagerTest, TransferR2F32) {
179   Literal literal =
180       LiteralUtil::CreateR2<float>({{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}});
181   const Shape& shape = literal.shape();
182   auto device_buffer = AllocateDeviceBuffer(shape);
183 
184   // Round trip literal through device.
185   ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal,
186                                                           device_buffer));
187   TF_ASSERT_OK_AND_ASSIGN(
188       Literal result,
189       transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer));
190 
191   LiteralTestUtil::ExpectR2Equal<float>(
192       {{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, result);
193 }
194 
XLA_TEST_F(TransferManagerTest,TransferR2F32AndChangeLayoutTransferringToDevice)195 XLA_TEST_F(TransferManagerTest,
196            TransferR2F32AndChangeLayoutTransferringToDevice) {
197   Literal literal = LiteralUtil::CreateR2WithLayout<float>(
198       {{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, LayoutUtil::MakeLayout({0, 1}));
199   const Shape ondevice_shape =
200       ShapeUtil::MakeShapeWithLayout(F32, {2, 3}, {1, 0});
201   auto device_buffer = AllocateDeviceBuffer(ondevice_shape);
202 
203   // Round trip literal through device. Set the on-device layout to something
204   // different than the literal layout.
205   ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal,
206                                                           device_buffer));
207   TF_ASSERT_OK_AND_ASSIGN(
208       Literal result,
209       transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer));
210 
211   EXPECT_FALSE(
212       LayoutUtil::Equal(result.shape().layout(), literal.shape().layout()));
213   LiteralTestUtil::ExpectR2Equal<float>(
214       {{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, result);
215 }
216 
XLA_TEST_F(TransferManagerTest,TransferTuple)217 XLA_TEST_F(TransferManagerTest, TransferTuple) {
218   Literal literal = LiteralUtil::MakeTupleFromSlices(
219       {LiteralUtil::CreateR0<float>(123.0f),
220        LiteralUtil::CreateR2<float>({{1.0f, 2.0f}, {4.0f, 5.0f}}),
221        LiteralUtil::CreateR1<float>({44.0f, -10.0f, 3333333.3f})});
222   auto device_buffer = AllocateDeviceBuffer(literal.shape());
223 
224   // Round trip literal through device.
225   ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal,
226                                                           device_buffer));
227   TF_ASSERT_OK_AND_ASSIGN(
228       Literal result,
229       transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer));
230 
231   EXPECT_TRUE(LiteralTestUtil::Equal(literal, result));
232 }
233 
XLA_TEST_F(TransferManagerTest,TransferEmptyTuple)234 XLA_TEST_F(TransferManagerTest, TransferEmptyTuple) {
235   Literal literal = LiteralUtil::MakeTuple({});
236   auto device_buffer = AllocateDeviceBuffer(literal.shape());
237 
238   // Round trip literal through device.
239   ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal,
240                                                           device_buffer));
241   TF_ASSERT_OK_AND_ASSIGN(
242       Literal result,
243       transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer));
244 
245   EXPECT_TRUE(LiteralTestUtil::Equal(literal, result));
246 }
247 
XLA_TEST_F(TransferManagerTest,TransferNestedTuple)248 XLA_TEST_F(TransferManagerTest, TransferNestedTuple) {
249   Literal literal = LiteralUtil::MakeTupleFromSlices(
250       {LiteralUtil::CreateR0<float>(123.0f),
251        LiteralUtil::MakeTupleFromSlices(
252            {LiteralUtil::CreateR2<float>({{1.0f, 2.0f}, {4.0f, 5.0f}}),
253             LiteralUtil::CreateR1<float>({44.0f, -10.0f, 3333333.3f})}),
254        LiteralUtil::CreateR1<float>({-10.0f, 123.0f})});
255   auto device_buffer = AllocateDeviceBuffer(literal.shape());
256 
257   // Round trip literal through device.
258   ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal,
259                                                           device_buffer));
260   TF_ASSERT_OK_AND_ASSIGN(
261       Literal result,
262       transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer));
263 
264   EXPECT_TRUE(LiteralTestUtil::Equal(literal, result));
265 }
266 
XLA_TEST_F(TransferManagerTest,TransferComplexValue)267 XLA_TEST_F(TransferManagerTest, TransferComplexValue) {
268   Literal literal = LiteralUtil::CreateR1<complex64>(
269       {complex64(1.0f, 2.0f), complex64(42.0f, -123.4f)});
270   auto device_buffer = AllocateDeviceBuffer(literal.shape());
271 
272   // Round trip literal through device.
273   ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal,
274                                                           device_buffer));
275   TF_ASSERT_OK_AND_ASSIGN(
276       Literal result,
277       transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer));
278 
279   EXPECT_TRUE(LiteralTestUtil::Equal(literal, result));
280 }
281 
XLA_TEST_F(TransferManagerTest,TransferComplexValueInTuple)282 XLA_TEST_F(TransferManagerTest, TransferComplexValueInTuple) {
283   Literal literal = LiteralUtil::MakeTupleFromSlices(
284       {LiteralUtil::CreateR1<complex64>(
285            {complex64(1.0f, 2.0f), complex64(42.0f, -123.4f)}),
286        LiteralUtil::CreateR1<int32>({1, 2, 3, 4, 5, 6}),
287        LiteralUtil::CreateR0<complex64>(complex64(0.3f, -0.4f))});
288   auto device_buffer = AllocateDeviceBuffer(literal.shape());
289 
290   // Round trip literal through device.
291   ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal,
292                                                           device_buffer));
293   TF_ASSERT_OK_AND_ASSIGN(
294       Literal result,
295       transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer));
296 
297   EXPECT_TRUE(LiteralTestUtil::Equal(literal, result));
298 }
299 
XLA_TEST_F(TransferManagerTest,TransferTokenFromDevice)300 XLA_TEST_F(TransferManagerTest, TransferTokenFromDevice) {
301   // "Copy" a token from the device. The token has no physical representation
302   // so no copying is actually performed, but it shouldn't fail.
303   // TODO(b/110532604): Add transferring the token to device when this is
304   // supported.
305   auto device_buffer = AllocateDeviceBuffer(ShapeUtil::MakeTokenShape());
306   TF_ASSERT_OK_AND_ASSIGN(
307       Literal result,
308       transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer));
309   EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateToken(), result));
310 }
311 
XLA_TEST_F(TransferManagerTest,MultiStreamRoundTripSoak)312 XLA_TEST_F(TransferManagerTest, MultiStreamRoundTripSoak) {
313   const int64_t kIterationCount = 5000;
314   Literal literal1 = LiteralUtil::MakeTupleFromSlices(
315       {LiteralUtil::CreateR0<float>(123.0f),
316        LiteralUtil::MakeTupleFromSlices(
317            {LiteralUtil::CreateR2<float>({{1.0f, 2.0f}, {4.0f, 5.0f}}),
318             LiteralUtil::CreateR1<float>({44.0f, -10.0f, 3333333.3f})}),
319        LiteralUtil::CreateR1<float>({-10.0f, 123.0f})});
320   Literal literal2 = LiteralUtil::MakeTupleFromSlices(
321       {LiteralUtil::CreateR0<float>(456.0f),
322        LiteralUtil::MakeTupleFromSlices(
323            {LiteralUtil::CreateR2<float>({{5.0f, 7.0f}, {9.0f, 4.0f}}),
324             LiteralUtil::CreateR1<float>({44.0f, -11.0f, 3333333.3f})}),
325        LiteralUtil::CreateR1<float>({-98.0f, 153.0f})});
326 
327   auto device_buffer1 = AllocateDeviceBuffer(literal1.shape());
328   auto device_buffer2 = AllocateDeviceBuffer(literal2.shape());
329 
330   auto stream1 = stream_;
331   auto stream2 = stream_->GetOrCreateSubStream();
332 
333   Literal result1, result2;
334 
335   // Round trip literals through device in multiple streams asynchronously.
336   for (int i = 0; i < kIterationCount; ++i) {
337     ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream1, literal1,
338                                                             device_buffer1));
339     ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream2, literal2,
340                                                             device_buffer2));
341     TF_ASSERT_OK_AND_ASSIGN(
342         Literal this_result1,
343         transfer_manager_->TransferLiteralFromDevice(stream1, device_buffer1));
344     TF_ASSERT_OK_AND_ASSIGN(
345         Literal this_result2,
346         transfer_manager_->TransferLiteralFromDevice(stream2, device_buffer2));
347     result1 = std::move(this_result1);
348     result2 = std::move(this_result2);
349   }
350 
351   EXPECT_TRUE(LiteralTestUtil::Equal(literal1, result1));
352   EXPECT_TRUE(LiteralTestUtil::Equal(literal2, result2));
353 }
354 
355 class TransferDeviceToHostBenchmark : public TransferManagerTest {
356  public:
357   using TransferManagerTest::TransferManagerTest;
~TransferDeviceToHostBenchmark()358   ~TransferDeviceToHostBenchmark() override {}
359 
Run(::testing::benchmark::State & state,int num_tuple_elements,int array_size)360   void Run(::testing::benchmark::State& state, int num_tuple_elements,
361            int array_size) {
362     SetUp();
363 
364     std::vector<Literal> tuple_elements;
365     for (int i = 0; i < num_tuple_elements; ++i) {
366       tuple_elements.push_back(
367           LiteralUtil::CreateR2F32Linspace(0.0f, 1.0f, array_size, array_size));
368     }
369     Literal literal = LiteralUtil::MakeTupleOwned(std::move(tuple_elements));
370     auto device_buffer = AllocateDeviceBuffer(literal.shape());
371     TF_CHECK_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal,
372                                                            device_buffer));
373     for (auto s : state) {
374       TF_ASSERT_OK_AND_ASSIGN(
375           Literal result,
376           transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer));
377     }
378     TearDown();
379   }
380 
TestBody()381   void TestBody() override {}
382 };
383 
384 class TransferHostToDeviceBenchmark : public TransferManagerTest {
385  public:
386   using TransferManagerTest::TransferManagerTest;
~TransferHostToDeviceBenchmark()387   ~TransferHostToDeviceBenchmark() override {}
388 
Run(::testing::benchmark::State & state,int num_tuple_elements,int array_size)389   void Run(::testing::benchmark::State& state, int num_tuple_elements,
390            int array_size) {
391     SetUp();
392 
393     std::vector<Literal> tuple_elements;
394     for (int i = 0; i < num_tuple_elements; ++i) {
395       tuple_elements.push_back(
396           LiteralUtil::CreateR2F32Linspace(0.0f, 1.0f, array_size, array_size));
397     }
398     Literal literal = LiteralUtil::MakeTupleOwned(std::move(tuple_elements));
399     auto device_buffer = AllocateDeviceBuffer(literal.shape());
400 
401     for (auto s : state) {
402       TF_CHECK_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal,
403                                                              device_buffer));
404     }
405     TearDown();
406   }
407 
TestBody()408   void TestBody() override {}
409 };
410 
BM_TransferDeviceToHost(::testing::benchmark::State & state)411 void BM_TransferDeviceToHost(::testing::benchmark::State& state) {
412   const int num_tuple_elements = state.range(0);
413   const int array_size = state.range(1);
414 
415   TransferDeviceToHostBenchmark bm;
416   bm.Run(state, num_tuple_elements, array_size);
417 }
418 
BM_TransferHostToDevice(::testing::benchmark::State & state)419 void BM_TransferHostToDevice(::testing::benchmark::State& state) {
420   const int num_tuple_elements = state.range(0);
421   const int array_size = state.range(1);
422 
423   TransferHostToDeviceBenchmark bm;
424   bm.Run(state, num_tuple_elements, array_size);
425 }
426 
427 BENCHMARK(BM_TransferHostToDevice)
428     ->ArgPair(1, 256)
429     ->ArgPair(1, 257)
430     ->ArgPair(100, 256)
431     ->ArgPair(100, 257);
432 
433 BENCHMARK(BM_TransferDeviceToHost)
434     ->ArgPair(1, 256)
435     ->ArgPair(1, 257)
436     ->ArgPair(100, 256)
437     ->ArgPair(100, 257);
438 
main(int argc,char ** argv)439 int main(int argc, char** argv) {
440   ::testing::InitGoogleTest(&argc, argv);
441   tensorflow::testing::RunBenchmarks();
442   return RUN_ALL_TESTS();
443 }
444 
445 }  // namespace
446 }  // namespace xla
447