1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <unistd.h>
17
18 #include <memory>
19
20 #include "tensorflow/compiler/xla/client/global_data.h"
21 #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
22 #include "tensorflow/compiler/xla/client/local_client.h"
23 #include "tensorflow/compiler/xla/client/xla_builder.h"
24 #include "tensorflow/compiler/xla/literal.h"
25 #include "tensorflow/compiler/xla/shape_util.h"
26 #include "tensorflow/compiler/xla/test_helpers.h"
27 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
28 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
29 #include "tensorflow/core/lib/math/math_util.h"
30 #include "tensorflow/core/platform/env.h"
31
32 namespace xla {
33 namespace {
34
35 class InfeedTest : public ClientLibraryTestBase {
36 protected:
37 // Transfers the given literal to the infeed interface of the device, and
38 // check if the returned data from Infeed HLO is same as the literal.
TestInfeedRoundTrip(const Literal & literal)39 void TestInfeedRoundTrip(const Literal& literal) {
40 // TODO(b/30481585) Explicitly reset the Infeed state so that the
41 // test is not affected by the state from the previous tests.
42 ASSERT_IS_OK(client_->TransferToInfeed(literal));
43 XlaBuilder builder(TestName());
44 Infeed(&builder, literal.shape());
45 if (literal.shape().IsTuple()) {
46 // TODO(b/30609564): Use ComputeAndCompareLiteral instead.
47 ComputeAndCompareTuple(&builder, literal, {});
48 } else {
49 ComputeAndCompareLiteral(&builder, literal, {});
50 }
51 }
52 };
53
TEST_F(InfeedTest,SingleInfeedR0Bool)54 TEST_F(InfeedTest, SingleInfeedR0Bool) {
55 TestInfeedRoundTrip(LiteralUtil::CreateR0<bool>(true));
56 }
57
TEST_F(InfeedTest,SingleInfeedR1U32)58 TEST_F(InfeedTest, SingleInfeedR1U32) {
59 TestInfeedRoundTrip(LiteralUtil::CreateR1<uint32_t>({1, 2, 3}));
60 }
61
TEST_F(InfeedTest,SingleInfeedR2F32)62 TEST_F(InfeedTest, SingleInfeedR2F32) {
63 TestInfeedRoundTrip(LiteralUtil::CreateR2F32Linspace(0.0, 1.0, 128, 64));
64 }
65
TEST_F(InfeedTest,SingleInfeedR3F32)66 TEST_F(InfeedTest, SingleInfeedR3F32) {
67 TestInfeedRoundTrip(
68 LiteralUtil::CreateR3({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}},
69 {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}}));
70 }
71
TEST_F(InfeedTest,SingleInfeedR3F32DifferentLayout)72 TEST_F(InfeedTest, SingleInfeedR3F32DifferentLayout) {
73 const Layout r3_dim0minor = LayoutUtil::MakeLayout({0, 1, 2});
74 const Layout r3_dim0major = LayoutUtil::MakeLayout({2, 1, 0});
75
76 TestInfeedRoundTrip(LiteralUtil::CreateR3WithLayout(
77 {{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}},
78 {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}},
79 r3_dim0minor));
80
81 TestInfeedRoundTrip(LiteralUtil::CreateR3WithLayout(
82 {{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}},
83 {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}},
84 r3_dim0major));
85 }
86
TEST_F(InfeedTest,SingleInfeedR4S32)87 TEST_F(InfeedTest, SingleInfeedR4S32) {
88 TestInfeedRoundTrip(LiteralUtil::CreateR4(
89 {{{{1, -2}, {-4, 5}, {6, 7}}, {{8, 9}, {10, 11}, {12, 13}}},
90 {{{10, 3}, {7, -2}, {3, 6}}, {{2, 5}, {-11, 5}, {-2, -5}}}}));
91 }
92
93 // Tests that a large infeed can be handled.
TEST_F(InfeedTest,LargeInfeed)94 TEST_F(InfeedTest, LargeInfeed) {
95 Array4D<float> array(80, 100, 8, 128);
96 array.FillIota(1.0f);
97 TestInfeedRoundTrip(LiteralUtil::CreateR4FromArray4D<float>(array));
98 }
99
TEST_F(InfeedTest,SingleInfeedTuple)100 TEST_F(InfeedTest, SingleInfeedTuple) {
101 TestInfeedRoundTrip(LiteralUtil::MakeTupleFromSlices(
102 {LiteralUtil::CreateR1<uint32_t>({1, 2, 3}),
103 LiteralUtil::CreateR0<bool>(false)}));
104 }
105
TEST_F(InfeedTest,SingleInfeedEmptyTuple)106 TEST_F(InfeedTest, SingleInfeedEmptyTuple) {
107 TestInfeedRoundTrip(LiteralUtil::MakeTuple({}));
108 }
109
110 // Tests that a large tuple infeed can be handled.
TEST_F(InfeedTest,SingleInfeedLargeTuple)111 TEST_F(InfeedTest, SingleInfeedLargeTuple) {
112 Array4D<float> array(40, 100, 8, 128);
113 array.FillIota(1.0f);
114 TestInfeedRoundTrip(LiteralUtil::MakeTupleFromSlices(
115 {LiteralUtil::CreateR4FromArray4D<float>(array),
116 LiteralUtil::CreateR0<int32_t>(5)}));
117 }
118
119 class BlockingInfeedTest : public ClientLibraryTestBase {};
120
TEST_F(BlockingInfeedTest,TestNoOoms)121 TEST_F(BlockingInfeedTest, TestNoOoms) {
122 Array3D<float> array(1024, 1024, 64);
123 array.FillIota(1.0f);
124 auto literal = LiteralUtil::CreateR3FromArray3D<float>(array);
125
126 int64_t kMemoryPressure = 32ul * 1024 * 1024 * 1024;
127 int64_t infeed_count =
128 kMemoryPressure / (array.num_elements() * sizeof(float));
129
130 auto transfer_infeeds = [&] {
131 for (int i = 0; i < infeed_count; i++) {
132 ASSERT_IS_OK(client_->TransferToInfeed(literal));
133 }
134 };
135
136 auto* env = tensorflow::Env::Default();
137
138 std::unique_ptr<tensorflow::Thread> thread{env->StartThread(
139 tensorflow::ThreadOptions{}, "transfer_infeeds", transfer_infeeds)};
140
141 // Sleep for 30s waiting for the infeed thread to "catch up".
142 //
143 // Without the fix accompanying this test, transfer_infeeds causes an OOM if
144 // the consumer (XLA computation running on the main thread) is unable to keep
145 // up with the producer (the transfer_infeeds thread). When that happens, the
146 // GPU buffers from the producer pile up and consume all of GPU memory.
147 //
148 // To reliably reproduce the issue we need to slow down the consumer, and we
149 // do that by inserting a sleep here.
150 //
151 // The fix is to back TransferToInfeed by a blocking queue that does not allow
152 // more than kMaxInfeedsInFlight infeeds in flight.
153 env->SleepForMicroseconds(30ul * 1000 * 1000);
154
155 XlaBuilder builder(TestName());
156 for (int i = 0; i < infeed_count; i++) {
157 Infeed(&builder, literal.shape());
158 }
159
160 ComputeAndCompareLiteral(&builder, literal, {});
161 }
162
163 class BlockingInfeedTest : public ClientLibraryTestBase {};
164
TEST_F(BlockingInfeedTest,TestNoOoms)165 TEST_F(BlockingInfeedTest, TestNoOoms) {
166 Array3D<float> array(1024, 1024, 64);
167 array.FillIota(1.0f);
168 auto literal = LiteralUtil::CreateR3FromArray3D<float>(array);
169
170 int64_t kMemoryPressure = 32ul * 1024 * 1024 * 1024;
171 int64_t infeed_count =
172 kMemoryPressure / (array.num_elements() * sizeof(float));
173
174 auto transfer_infeeds = [&] {
175 for (int i = 0; i < infeed_count; i++) {
176 ASSERT_IS_OK(client_->TransferToInfeed(literal));
177 }
178 };
179
180 auto* env = tensorflow::Env::Default();
181
182 std::unique_ptr<tensorflow::Thread> thread{env->StartThread(
183 tensorflow::ThreadOptions{}, "transfer_infeeds", transfer_infeeds)};
184
185 // Sleep for 30s waiting for the infeed thread to "catch up".
186 //
187 // Without the fix accompanying this test, transfer_infeeds causes an OOM if
188 // the consumer (XLA computation running on the main thread) is unable to keep
189 // up with the producer (the transfer_infeeds thread). When that happens, the
190 // GPU buffers from the producer pile up and consume all of GPU memory.
191 //
192 // To reliably reproduce the issue we need to slow down the consumer, and we
193 // do that by inserting a sleep here.
194 //
195 // The fix is to back TransferToInfeed by a blocking queue that does not allow
196 // more than kMaxInfeedsInFlight infeeds in flight.
197 env->SleepForMicroseconds(30ul * 1000 * 1000);
198
199 XlaBuilder builder(TestName());
200 for (int i = 0; i < infeed_count; i++) {
201 Infeed(&builder, literal.shape());
202 }
203
204 ComputeAndCompareLiteral(&builder, literal, {});
205 }
206
207 } // namespace
208 } // namespace xla
209