1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <algorithm>
17 #include <memory>
18 #include <string>
19 #include <vector>
20
21 #include "tensorflow/compiler/xla/array2d.h"
22 #include "tensorflow/compiler/xla/client/global_data.h"
23 #include "tensorflow/compiler/xla/client/local_client.h"
24 #include "tensorflow/compiler/xla/client/xla_builder.h"
25 #include "tensorflow/compiler/xla/client/xla_computation.h"
26 #include "tensorflow/compiler/xla/layout_util.h"
27 #include "tensorflow/compiler/xla/literal.h"
28 #include "tensorflow/compiler/xla/shape_util.h"
29 #include "tensorflow/compiler/xla/statusor.h"
30 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
31 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
32 #include "tensorflow/compiler/xla/tests/test_macros.h"
33 #include "tensorflow/compiler/xla/xla_data.pb.h"
34 #include "tensorflow/core/platform/protobuf.h"
35 #include "tensorflow/core/platform/test.h"
36 #include "tensorflow/core/platform/types.h"
37
38 namespace xla {
39 namespace {
40
41 class ParamsTest : public ClientLibraryTestBase {};
42
XLA_TEST_F(ParamsTest,ConstantR0F32Param)43 XLA_TEST_F(ParamsTest, ConstantR0F32Param) {
44 XlaBuilder builder(TestName());
45 Literal param0_literal = LiteralUtil::CreateR0<float>(3.14159f);
46 std::unique_ptr<GlobalData> param0_data =
47 client_->TransferToServer(param0_literal).ConsumeValueOrDie();
48
49 Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {}), "param0");
50
51 ComputeAndCompareR0<float>(&builder, 3.14159f, {param0_data.get()},
52 ErrorSpec(0.0001f));
53 }
54
XLA_TEST_F(ParamsTest,ConstantR1S0F32Param)55 XLA_TEST_F(ParamsTest, ConstantR1S0F32Param) {
56 XlaBuilder builder(TestName());
57 Literal param0_literal = LiteralUtil::CreateR1<float>({});
58 std::unique_ptr<GlobalData> param0_data =
59 client_->TransferToServer(param0_literal).ConsumeValueOrDie();
60
61 Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {0}), "param0");
62
63 ComputeAndCompareR1<float>(&builder, {}, {param0_data.get()},
64 ErrorSpec(0.01f));
65 }
66
XLA_TEST_F(ParamsTest,ConstantR1S2F32Param)67 XLA_TEST_F(ParamsTest, ConstantR1S2F32Param) {
68 XlaBuilder builder(TestName());
69 Literal param0_literal = LiteralUtil::CreateR1<float>({3.14f, -100.25f});
70 std::unique_ptr<GlobalData> param0_data =
71 client_->TransferToServer(param0_literal).ConsumeValueOrDie();
72
73 Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2}), "param0");
74
75 ComputeAndCompareR1<float>(&builder, {3.14f, -100.25f}, {param0_data.get()},
76 ErrorSpec(0.01f));
77 }
78
XLA_TEST_F(ParamsTest,ConstantR1U8Param)79 XLA_TEST_F(ParamsTest, ConstantR1U8Param) {
80 XlaBuilder builder(TestName());
81 string str("hello world");
82 Literal param0_literal = LiteralUtil::CreateR1U8(str);
83 std::unique_ptr<GlobalData> param0_data =
84 client_->TransferToServer(param0_literal).ConsumeValueOrDie();
85
86 Parameter(&builder, 0,
87 ShapeUtil::MakeShape(U8, {static_cast<int64>(str.size())}),
88 "param0");
89
90 ComputeAndCompareR1U8(&builder, str, {param0_data.get()});
91 }
92
XLA_TEST_F(ParamsTest,ConstantR2_3x0_F32Param)93 XLA_TEST_F(ParamsTest, ConstantR2_3x0_F32Param) {
94 XlaBuilder builder(TestName());
95 Literal param0_literal =
96 LiteralUtil::CreateR2FromArray2D<float>(Array2D<float>(3, 0));
97 std::unique_ptr<GlobalData> param0_data =
98 client_->TransferToServer(param0_literal).ConsumeValueOrDie();
99
100 Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {3, 0}), "param0");
101
102 ComputeAndCompareR2<float>(&builder, Array2D<float>(3, 0),
103 {param0_data.get()}, ErrorSpec(0.01f));
104 }
105
XLA_TEST_F(ParamsTest,ConstantR2F32Param)106 XLA_TEST_F(ParamsTest, ConstantR2F32Param) {
107 XlaBuilder builder(TestName());
108 Literal param0_literal = LiteralUtil::CreateR2<float>(
109 {{3.14f, -100.25f}, {7e8f, 7e-9f}, {30.3f, -100.0f}});
110 std::unique_ptr<GlobalData> param0_data =
111 client_->TransferToServer(param0_literal).ConsumeValueOrDie();
112
113 Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {3, 2}), "param0");
114
115 Array2D<float> expected_array(
116 {{3.14f, -100.25f}, {7e8f, 7e-9f}, {30.3f, -100.0f}});
117 ComputeAndCompareR2<float>(&builder, expected_array, {param0_data.get()},
118 ErrorSpec(0.01f));
119 }
120
XLA_TEST_F(ParamsTest,TwoParameters)121 XLA_TEST_F(ParamsTest, TwoParameters) {
122 XlaBuilder builder(TestName());
123
124 Literal literal0 = LiteralUtil::CreateR1<float>({1, 2});
125 std::unique_ptr<GlobalData> param0_data =
126 client_->TransferToServer(literal0).ConsumeValueOrDie();
127 auto param0 = Parameter(&builder, 0, literal0.shape(), "param0");
128
129 Literal literal1 = LiteralUtil::CreateR1<float>({10, 20});
130 std::unique_ptr<GlobalData> param1_data =
131 client_->TransferToServer(literal1).ConsumeValueOrDie();
132 auto param1 = Parameter(&builder, 1, literal1.shape(), "param1");
133
134 // Use both parameters
135 //
136 // {1, 2} + {10, 20} = {11, 22}
137 auto sum = Add(param0, param1);
138 sum = Add(param0, param1);
139
140 // Use only the second parameter again, to show that it can be used
141 // twice and to make the computation asymmetric in the two
142 // parameters to test that the parameters are not swapped.
143 //
144 // {11, 22} * {10, 20} = {110, 440}
145 Mul(sum, param1);
146
147 ComputeAndCompareR1<float>(&builder, {110, 440},
148 {param0_data.get(), param1_data.get()},
149 ErrorSpec(0.0001f));
150 }
151
XLA_TEST_F(ParamsTest,MissingParameter)152 XLA_TEST_F(ParamsTest, MissingParameter) {
153 // Test that an error is returned when a computation with an incomplete set of
154 // parameters (parameter numbers not contiguous from 0) is executed.
155 Literal literal = LiteralUtil::CreateR0<float>(3.14159f);
156 std::unique_ptr<GlobalData> data =
157 client_->TransferToServer(literal).ConsumeValueOrDie();
158
159 XlaBuilder builder(TestName());
160 Parameter(&builder, 2, ShapeUtil::MakeShape(F32, {}), "param2");
161 auto computation_status = builder.Build();
162
163 ASSERT_NE(computation_status.status(), Status::OK());
164 }
165
XLA_TEST_F(ParamsTest,UnusedParameter)166 XLA_TEST_F(ParamsTest, UnusedParameter) {
167 XlaBuilder builder(TestName());
168
169 Literal literal0 = LiteralUtil::CreateR1<float>({1, 2});
170 std::unique_ptr<GlobalData> param0_data =
171 client_->TransferToServer(literal0).ConsumeValueOrDie();
172 Parameter(&builder, 0, literal0.shape(), "param0");
173
174 Literal literal1 = LiteralUtil::CreateR1<float>({10, 20});
175 std::unique_ptr<GlobalData> param1_data =
176 client_->TransferToServer(literal1).ConsumeValueOrDie();
177 Parameter(&builder, 1, literal1.shape(), "param1");
178
179 ComputeAndCompareR1<float>(&builder, {10, 20},
180 {param0_data.get(), param1_data.get()},
181 ErrorSpec(0.0001f));
182 }
183
XLA_TEST_F(ParamsTest,UnusedParametersInUnusedExpression)184 XLA_TEST_F(ParamsTest, UnusedParametersInUnusedExpression) {
185 // Build a computation with a couple unused parameters which are used in an
186 // unused expression.
187 XlaBuilder builder(TestName());
188
189 Literal literal0 = LiteralUtil::CreateR1<float>({1, 2});
190 std::unique_ptr<GlobalData> param0_data =
191 client_->TransferToServer(literal0).ConsumeValueOrDie();
192
193 Literal literal1 = LiteralUtil::CreateR1<float>({10, 20, 30});
194 std::unique_ptr<GlobalData> param1_data =
195 client_->TransferToServer(literal1).ConsumeValueOrDie();
196
197 auto param0 = Parameter(&builder, 0, literal0.shape(), "param0");
198 auto param1 = Parameter(&builder, 1, literal1.shape(), "param1");
199 auto param2 = Parameter(&builder, 2, literal1.shape(), "param2");
200
201 // This add is unused.
202 Add(param1, param2);
203
204 Neg(param0);
205
206 ComputeAndCompareR1<float>(
207 &builder, {-1, -2},
208 {param0_data.get(), param1_data.get(), param1_data.get()},
209 ErrorSpec(0.0001f));
210 }
211
XLA_TEST_F(ParamsTest,HundredLargeR1Parameters)212 XLA_TEST_F(ParamsTest, HundredLargeR1Parameters) {
213 XlaBuilder builder(TestName());
214 constexpr int size = 8 * 128 * 2;
215
216 std::vector<float> init_value = {{0, 1}};
217 init_value.resize(size);
218 XlaOp sum_handle = ConstantR1<float>(&builder, init_value);
219 std::vector<float> sum = {{0, 1}};
220 sum.resize(size);
221
222 std::vector<std::unique_ptr<GlobalData>> param_data_owner;
223
224 constexpr int parameter_count = 100;
225 for (int i = 0; i < parameter_count; ++i) {
226 const float entry0 = i;
227 const float entry1 = 2 * i;
228 sum[0] += entry0;
229 sum[1] += entry1;
230
231 std::vector<float> sum_value = {{entry0, entry1}};
232 sum_value.resize(size);
233 Literal literal = LiteralUtil::CreateR1<float>(sum_value);
234 param_data_owner.push_back(
235 client_->TransferToServer(literal).ConsumeValueOrDie());
236 XlaOp param = Parameter(&builder, i, literal.shape(), "param");
237 sum_handle = Add(sum_handle, param);
238 }
239
240 std::vector<GlobalData*> param_data;
241 param_data.reserve(param_data_owner.size());
242 for (const std::unique_ptr<GlobalData>& data : param_data_owner) {
243 param_data.push_back(data.get());
244 }
245
246 ComputeAndCompareR1<float>(&builder, sum, param_data, ErrorSpec(0.0001f));
247 }
248
249 // Only run the 3,000-parameter tests in opt mode to avoid test timeouts.
250 // Timeout last observed on 2017-11-20.
251 #ifdef NDEBUG
252
253 // TODO(b/65525254) Fails on GPU on 2017-09-10 because we try to reserve too
254 // much space in parameter memory for the kernel.
255 //
256 // TODO(b/65526061) Failed on CPU on 2017-09-10 due to timeout in LLVM
257 // compilation.
XLA_TEST_F(ParamsTest,DISABLED_ON_CPU (DISABLED_ON_GPU (ThreeThousandParameters)))258 XLA_TEST_F(ParamsTest,
259 DISABLED_ON_CPU(DISABLED_ON_GPU(ThreeThousandParameters))) {
260 XlaBuilder builder(TestName());
261
262 std::vector<std::unique_ptr<GlobalData>> param_data_owner;
263 XlaOp sum_handle = ConstantR0<float>(&builder, 0.0f);
264 float target = 0.0;
265 constexpr int kParamCount = 3000;
266 for (int i = 0; i < kParamCount; ++i) {
267 target += i;
268 Literal literal = LiteralUtil::CreateR0<float>(i);
269 param_data_owner.push_back(
270 std::move(client_->TransferToServer(literal)).ValueOrDie());
271 XlaOp param = Parameter(&builder, i, literal.shape(), "param");
272 sum_handle = Add(sum_handle, param);
273 }
274
275 std::vector<GlobalData*> param_data;
276 param_data.reserve(param_data_owner.size());
277 for (const std::unique_ptr<GlobalData>& data : param_data_owner) {
278 param_data.push_back(data.get());
279 }
280
281 ComputeAndCompareR0<float>(&builder, target, param_data, ErrorSpec(0.0001f));
282 }
283
284 // TODO(b/65525254) Fails on GPU on 2017-09-10 because we try to reserve too
285 // much space in parameter memory for the kernel.
286 //
287 // TODO(b/65526061) Failed on CPU on 2017-09-10 due to timeout in LLVM
288 // compilation.
XLA_TEST_F(ParamsTest,DISABLED_ON_CPU (DISABLED_ON_GPU (ThreeThousandParametersAndOutputElements)))289 XLA_TEST_F(ParamsTest, DISABLED_ON_CPU(DISABLED_ON_GPU(
290 ThreeThousandParametersAndOutputElements))) {
291 XlaBuilder builder(TestName());
292
293 std::vector<std::unique_ptr<GlobalData>> param_data_owner;
294 XlaOp sum_handle = ConstantR1<int32>(&builder, {0, 0});
295 int32 target = 0;
296 constexpr int kParamCount = 3000;
297 std::vector<XlaOp> params;
298 for (int i = 0; i < kParamCount; ++i) {
299 target += i;
300 Literal literal = LiteralUtil::CreateR1<int32>({i, i});
301 param_data_owner.push_back(
302 std::move(client_->TransferToServer(literal)).ValueOrDie());
303 XlaOp param = Parameter(&builder, i, literal.shape(), "param");
304 params.push_back(param);
305 sum_handle = Add(sum_handle, param);
306 }
307
308 std::vector<XlaOp> outputs;
309 for (int i = 0; i < kParamCount; ++i) {
310 outputs.push_back(Add(params[i], sum_handle));
311 }
312
313 Tuple(&builder, outputs);
314
315 std::vector<GlobalData*> param_data;
316 param_data.reserve(param_data_owner.size());
317 for (const std::unique_ptr<GlobalData>& data : param_data_owner) {
318 param_data.push_back(data.get());
319 }
320
321 std::vector<Literal> elements;
322 std::vector<const Literal*> ptrs;
323 elements.reserve(kParamCount);
324 for (int i = 0; i < kParamCount; ++i) {
325 elements.push_back(LiteralUtil::CreateR1<int32>({target + i, target + i}));
326 ptrs.push_back(&elements.back());
327 }
328 ComputeAndCompareTuple(&builder, LiteralUtil::MakeTuple(ptrs), param_data);
329 }
330
331 // Test large number of parameters flowing into a while-loop.
332 // Construct conceptually the following HLO graph:
333 //
334 // p0 = parameter(0)
335 // p1 = parameter(1)
336 // ...
337 // pN = parameter(N)
338 // result = while (false) {
339 // p0 += (1, 1);
340 // p1 += (1, 1);
341 // ...
342 // pN += (1, 1)
343 // }
344 // result = {p0, p1, ..., pN}
345 //
346 // TODO(b/70173746): Times out during compilation on GPU and CPU backends as of
347 // 2017-12-12.
XLA_TEST_F(ParamsTest,DISABLED_ON_CPU (DISABLED_ON_GPU (ManyParametersIntoWhileLoop)))348 XLA_TEST_F(ParamsTest,
349 DISABLED_ON_CPU(DISABLED_ON_GPU(ManyParametersIntoWhileLoop))) {
350 XlaBuilder builder(TestName());
351
352 std::vector<std::unique_ptr<GlobalData>> param_data_owner;
353 constexpr int kParamCount = 1900;
354 std::vector<XlaOp> params;
355 std::vector<Shape> parameter_shapes;
356 for (int i = 0; i < kParamCount; ++i) {
357 Literal literal = LiteralUtil::CreateR1<int32>({i, i});
358 param_data_owner.push_back(
359 std::move(client_->TransferToServer(literal)).ValueOrDie());
360 XlaOp param = Parameter(&builder, i, literal.shape(), "param");
361 params.push_back(param);
362 parameter_shapes.push_back(literal.shape());
363 }
364
365 // Add bool parameter for the loop condition. Use a parameter HLO instead of a
366 // constant because DCE may eliminate the while-body otherwise.
367 Literal bool_literal = LiteralUtil::CreateR0<bool>(false);
368 param_data_owner.push_back(
369 std::move(client_->TransferToServer(bool_literal)).ValueOrDie());
370 XlaOp bool_param =
371 Parameter(&builder, kParamCount, bool_literal.shape(), "bool_param");
372 params.push_back(bool_param);
373 parameter_shapes.push_back(bool_literal.shape());
374
375 auto init = Tuple(&builder, params);
376
377 // Create a computation for the condition: while(bool_param).
378 Shape while_shape = ShapeUtil::MakeTupleShape(parameter_shapes);
379 XlaComputation condition;
380 {
381 XlaBuilder builder("condition");
382 auto condition_parameter =
383 Parameter(&builder, 0, while_shape, "condition_parameter");
384 GetTupleElement(condition_parameter, kParamCount);
385 condition = builder.Build().ConsumeValueOrDie();
386 }
387
388 // Create a computation for the body.
389 // Add {1, 1} to the each tuple element.
390 XlaComputation body;
391 {
392 XlaBuilder builder("body");
393 auto body_parameter = Parameter(&builder, 0, while_shape, "body_parameter");
394 std::vector<XlaOp> updates;
395 for (int i = 0; i < kParamCount; ++i) {
396 auto add = Add(GetTupleElement(body_parameter, i),
397 ConstantR1<int32>(&builder, {1, 1}));
398 updates.push_back(add);
399 }
400 // Add bool parameter.
401 updates.push_back(GetTupleElement(body_parameter, kParamCount));
402
403 Tuple(&builder, updates);
404 body = builder.Build().ConsumeValueOrDie();
405 }
406
407 auto loop = While(condition, body, init);
408
409 std::vector<XlaOp> outputs;
410 for (int i = 0; i < kParamCount; ++i) {
411 outputs.push_back(GetTupleElement(loop, i));
412 }
413 Tuple(&builder, outputs);
414
415 std::vector<GlobalData*> param_data;
416 param_data.reserve(param_data_owner.size());
417 for (const std::unique_ptr<GlobalData>& data : param_data_owner) {
418 param_data.push_back(data.get());
419 }
420
421 std::vector<Literal> elements;
422 std::vector<const Literal*> ptrs;
423 elements.reserve(kParamCount);
424 for (int i = 0; i < kParamCount; ++i) {
425 elements.push_back(LiteralUtil::CreateR1<int32>({i, i}));
426 ptrs.push_back(&elements.back());
427 }
428 ComputeAndCompareTuple(&builder, LiteralUtil::MakeTuple(ptrs), param_data);
429 }
430
431 #endif
432
XLA_TEST_F(ParamsTest,TupleOfR1ParametersAddedTogether)433 XLA_TEST_F(ParamsTest, TupleOfR1ParametersAddedTogether) {
434 XlaBuilder builder(TestName());
435
436 Shape r1f32_3 = ShapeUtil::MakeShape(F32, {3});
437 Shape tuple_shape = ShapeUtil::MakeTupleShape({r1f32_3, r1f32_3});
438 auto input = Parameter(&builder, 0, tuple_shape, "input");
439 auto lhs = GetTupleElement(input, 0);
440 auto rhs = GetTupleElement(input, 1);
441 Add(lhs, rhs);
442
443 std::unique_ptr<GlobalData> data =
444 client_
445 ->TransferToServer(LiteralUtil::MakeTupleFromSlices({
446 LiteralUtil::CreateR1<float>({1, 2, 3}),
447 LiteralUtil::CreateR1<float>({4, 5, 6}),
448 }))
449 .ConsumeValueOrDie();
450
451 std::vector<GlobalData*> arguments = {data.get()};
452 const std::vector<float> expected = {1 + 4, 2 + 5, 3 + 6};
453 ComputeAndCompareR1<float>(&builder, expected, arguments, ErrorSpec(1e-5));
454 }
455
456 // Verifies that passing a 2x2 with {0, 1} layout returns the same value back
457 // when (transferred to the server and) passed through a parameter.
XLA_TEST_F(ParamsTest,R2_2x2_Layout_01)458 XLA_TEST_F(ParamsTest, R2_2x2_Layout_01) {
459 Literal literal = LiteralUtil::CreateR2WithLayout<float>(
460 {{1, 2}, {3, 4}}, LayoutUtil::MakeLayout({0, 1}));
461 XlaBuilder builder(TestName());
462 Parameter(&builder, 0, literal.shape(), "input");
463
464 std::unique_ptr<GlobalData> data =
465 client_->TransferToServer(literal).ConsumeValueOrDie();
466 ComputeAndCompareLiteral(&builder, literal, {data.get()}, ErrorSpec(1e-3));
467 }
468
469 // As above, but for {1, 0} layout.
XLA_TEST_F(ParamsTest,R2_2x2_Layout_10)470 XLA_TEST_F(ParamsTest, R2_2x2_Layout_10) {
471 Literal literal = LiteralUtil::CreateR2WithLayout<float>(
472 {{1, 3}, {2, 4}}, LayoutUtil::MakeLayout({1, 0}));
473 XlaBuilder builder(TestName());
474 Parameter(&builder, 0, literal.shape(), "input");
475
476 std::unique_ptr<GlobalData> data =
477 client_->TransferToServer(literal).ConsumeValueOrDie();
478 ComputeAndCompareLiteral(&builder, literal, {data.get()}, ErrorSpec(1e-3));
479 }
480
XLA_TEST_F(ParamsTest,R2_2x2_TryToPassReverseLayoutToParameter)481 XLA_TEST_F(ParamsTest, R2_2x2_TryToPassReverseLayoutToParameter) {
482 Literal literal = LiteralUtil::CreateR2<float>({
483 {1, 3},
484 {2, 4},
485 });
486 const Shape original = literal.shape();
487 {
488 // Reverse the layout present in original, and make that the layout of the
489 // literal.
490 std::vector<int64> original_layout(
491 original.layout().minor_to_major().begin(),
492 original.layout().minor_to_major().end());
493 std::reverse(original_layout.begin(), original_layout.end());
494 *literal.mutable_shape_do_not_use()->mutable_layout() =
495 LayoutUtil::MakeLayout(original_layout);
496 ASSERT_EQ(2, literal.Get<float>({0, 1}));
497 }
498 // Use the original shape in building the computation.
499 XlaBuilder builder(TestName());
500 auto input = Parameter(&builder, 0, original, "input");
501 // Use the slice operator to get an off-diagonal element.
502 Slice(input, {0, 1}, {1, 2}, {1, 1});
503
504 std::unique_ptr<GlobalData> data =
505 client_->TransferToServer(literal).ConsumeValueOrDie();
506 // Check that we got the off-diagonal value that we expected.
507 Array2D<float> expected(1, 1);
508 expected(0, 0) = 2;
509 ComputeAndCompareR2(&builder, expected, {data.get()}, ErrorSpec(1e-3));
510 }
511
512 } // namespace
513 } // namespace xla
514