1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <memory>
17 #include <string>
18 #include <vector>
19
20 #include "tensorflow/compiler/xla/client/client_library.h"
21 #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
22 #include "tensorflow/compiler/xla/client/local_client.h"
23 #include "tensorflow/compiler/xla/client/xla_builder.h"
24 #include "tensorflow/compiler/xla/client/xla_computation.h"
25 #include "tensorflow/compiler/xla/literal.h"
26 #include "tensorflow/compiler/xla/service/platform_util.h"
27 #include "tensorflow/compiler/xla/shape_util.h"
28 #include "tensorflow/compiler/xla/status_macros.h"
29 #include "tensorflow/compiler/xla/statusor.h"
30 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
31 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
32 #include "tensorflow/compiler/xla/tests/test_macros.h"
33 #include "tensorflow/compiler/xla/xla_data.pb.h"
34 #include "tensorflow/core/lib/core/status_test_util.h"
35 #include "tensorflow/core/platform/logging.h"
36 #include "tensorflow/core/platform/test.h"
37 #include "tensorflow/core/platform/test_benchmark.h"
38
39 namespace xla {
40 namespace {
41
42 class WhileTest : public ClientLibraryTestBase {};
43
44 // Tests a while node when the result type T is S32.
45 //
46 // int32_t result = 0;
47 // while (result < 5) {
48 // result = result + 1;
49 // }
XLA_TEST_F(WhileTest,WhileWithScalarS32Result)50 XLA_TEST_F(WhileTest, WhileWithScalarS32Result) {
51 auto result_shape = ShapeUtil::MakeShape(S32, {});
52
53 // Create a computation for the condition: repeat for 5 iterations.
54 XlaComputation condition;
55 {
56 XlaBuilder builder("condition");
57 auto prev = Parameter(&builder, 0, result_shape, "prev");
58 Gt(ConstantR0<int32_t>(&builder, 5), prev);
59 condition = builder.Build().value();
60 }
61
62 // Create a computation for the body: add 1 to the result variable.
63 XlaComputation body;
64 {
65 XlaBuilder builder("body");
66 auto prev = Parameter(&builder, 0, result_shape, "prev");
67 auto input = ConstantR0<int32_t>(&builder, 1);
68 Add(input, prev);
69 body = builder.Build().value();
70 }
71
72 // Create a While node with computations for the condition and the body.
73 XlaBuilder builder(TestName());
74 auto init = ConstantR0<int32_t>(&builder, 0);
75 While(condition, body, init);
76
77 ComputeAndCompareR0<int32_t>(&builder, 5, {});
78 }
79
80 // Tests a while node when the result type T is S64.
81 //
82 // int32_t result = 0;
83 // while (result < 5) {
84 // result = result + 1;
85 // }
XLA_TEST_F(WhileTest,WhileWithScalarS64Result)86 XLA_TEST_F(WhileTest, WhileWithScalarS64Result) {
87 auto result_shape = ShapeUtil::MakeShape(S64, {});
88
89 // Create a computation for the condition: repeat for 5 iterations.
90 XlaComputation condition;
91 {
92 XlaBuilder builder("condition");
93 auto prev = Parameter(&builder, 0, result_shape, "prev");
94 Gt(ConstantR0<int64_t>(&builder, 5), prev);
95 condition = builder.Build().value();
96 }
97
98 // Create a computation for the body: add 1 to the result variable.
99 XlaComputation body;
100 {
101 XlaBuilder builder("body");
102 auto prev = Parameter(&builder, 0, result_shape, "prev");
103 auto input = ConstantR0<int64_t>(&builder, 1);
104 Add(input, prev);
105 body = builder.Build().value();
106 }
107
108 // Create a While node with computations for the condition and the body.
109 XlaBuilder builder(TestName());
110 auto init = ConstantR0<int64_t>(&builder, 0);
111 While(condition, body, init);
112
113 ComputeAndCompareR0<int64_t>(&builder, 5, {});
114 }
115
XLA_TEST_F(WhileTest,WhileWithScalarResultNonConstInit)116 XLA_TEST_F(WhileTest, WhileWithScalarResultNonConstInit) {
117 auto result_shape = ShapeUtil::MakeShape(S32, {});
118 auto orig_shape = ShapeUtil::MakeShape(S32, {2});
119
120 // Create a computation for the condition: repeat for 5 iterations.
121 XlaComputation condition;
122 {
123 XlaBuilder builder("condition");
124 auto prev = Parameter(&builder, 0, result_shape, "prev");
125 Gt(ConstantR0<int32_t>(&builder, 5), prev);
126 condition = builder.Build().value();
127 }
128
129 // Create a computation for the body: add 1 to the result variable.
130 XlaComputation body;
131 {
132 XlaBuilder builder("body");
133 auto prev = Parameter(&builder, 0, result_shape, "prev");
134 auto input = ConstantR0<int32_t>(&builder, 1);
135 Add(input, prev);
136 body = builder.Build().value();
137 }
138
139 // Create a While node with computations for the condition and the body.
140 XlaBuilder builder(TestName());
141 auto init = Reduce(ConstantR1<int32_t>(&builder, 2, 1),
142 ConstantR0<int32_t>(&builder, 0),
143 CreateScalarAddComputation(S32, &builder), {0});
144 While(condition, body, init);
145
146 ComputeAndCompareR0<int32_t>(&builder, 5, {});
147 }
148
XLA_TEST_F(WhileTest,WhileWithPredicateResult)149 XLA_TEST_F(WhileTest, WhileWithPredicateResult) {
150 auto result_shape = ShapeUtil::MakeShape(PRED, {});
151
152 // Create a computation for the condition: run until condition is true.
153 XlaComputation condition;
154 {
155 XlaBuilder builder("condition");
156 auto prev = Parameter(&builder, 0, result_shape, "prev");
157 Ne(ConstantR0<bool>(&builder, true), prev);
158 condition = builder.Build().value();
159 }
160
161 // Create a computation for the body: or condition with true.
162 XlaComputation body;
163 {
164 XlaBuilder builder("body");
165 auto prev = Parameter(&builder, 0, result_shape, "prev");
166 Or(prev, ConstantR0<bool>(&builder, true));
167 body = builder.Build().value();
168 }
169
170 // Create a While node with computations for the condition and the body.
171 XlaBuilder builder(TestName());
172 auto init =
173 Ne(ConstantR0<bool>(&builder, false), ConstantR0<bool>(&builder, true));
174 While(condition, body, init);
175
176 ComputeAndCompareR0<bool>(&builder, true, {});
177 }
178
179 // Tests a while node when the result type T is a vector.
180 //
181 // All constants are chosen to produce exact results.
182 // vector<float> result(0);
183 // while (result.sum() < 15.5f) {
184 // result = result + vector<float>(0);
185 // }
XLA_TEST_F(WhileTest,DISABLED_ON_INTERPRETER (WhileWithEmptyVectorResult))186 XLA_TEST_F(WhileTest, DISABLED_ON_INTERPRETER(WhileWithEmptyVectorResult)) {
187 Shape result_shape = ShapeUtil::MakeShape(F32, {0});
188
189 // Create a computation for the reduction.
190 XlaComputation add;
191 {
192 XlaBuilder builder("add");
193 auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {}), "x");
194 auto y = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {}), "y");
195 Add(x, y);
196 add = builder.Build().value();
197 }
198
199 // Create a computation for the condition.
200 // Repeat until the sum of the result vector is less than 15.5f.
201 XlaComputation condition;
202 {
203 XlaBuilder builder("condition");
204 auto prev = Parameter(&builder, 0, result_shape, "prev");
205 auto sum = Reduce(prev, ConstantR0<float>(&builder, 0.0f), add,
206 /*dimensions_to_reduce=*/{0});
207 Gt(ConstantR0<float>(&builder, 15.5f), sum);
208 condition = builder.Build().value();
209 }
210
211 // Create a computation for the body.
212 // Add a constant vector of 1.f to the result vector.
213 XlaComputation body;
214 {
215 XlaBuilder builder("body");
216 auto prev = Parameter(&builder, 0, result_shape, "prev");
217 auto input = ConstantR1<float>(&builder, {});
218 Add(input, prev);
219 body = builder.Build().value();
220 }
221
222 // Create a While node with computations for the condition and the body.
223 XlaBuilder builder("while");
224 auto init = ConstantR1<float>(&builder, {});
225 auto result = While(condition, body, init);
226 VLOG(2) << "while = "
227 << ShapeUtil::HumanString(builder.GetShape(result).value());
228
229 ComputeAndCompareR1<float>(&builder, {}, {}, ErrorSpec(0.0001));
230 }
231
232 // Tests a while node when the result type T is a vector.
233 //
234 // All constants are chosen to produce exact results.
235 // vector<float> result(8, 0.0f);
236 // while (result.sum() < 15.5f) {
237 // result = result + vector<float>(8, 0.125f);
238 // }
XLA_TEST_F(WhileTest,WhileWithVectorResult)239 XLA_TEST_F(WhileTest, WhileWithVectorResult) {
240 Shape result_shape = ShapeUtil::MakeShape(F32, {8});
241
242 // Create a computation for the reduction.
243 XlaComputation add;
244 {
245 XlaBuilder builder("add");
246 auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {}), "x");
247 auto y = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {}), "y");
248 Add(x, y);
249 add = builder.Build().value();
250 }
251
252 // Create a computation for the condition.
253 // Repeat until the sum of the result vector is less than 5.5f.
254 XlaComputation condition;
255 {
256 XlaBuilder builder("condition");
257 auto prev = Parameter(&builder, 0, result_shape, "prev");
258 auto sum = Reduce(prev, ConstantR0<float>(&builder, 0.0f), add,
259 /*dimensions_to_reduce=*/{0});
260 Gt(ConstantR0<float>(&builder, 15.5f), sum);
261 condition = builder.Build().value();
262 }
263
264 // Create a computation for the body.
265 // Add a constant vector of 1.f to the result vector.
266 XlaComputation body;
267 {
268 XlaBuilder builder("body");
269 auto prev = Parameter(&builder, 0, result_shape, "prev");
270 auto input = ConstantR1<float>(&builder, 8, 0.125f);
271 Add(input, prev);
272 body = builder.Build().value();
273 }
274
275 // Create a While node with computations for the condition and the body.
276 XlaBuilder builder("while");
277 auto init = ConstantR1<float>(&builder, 8, 0.f);
278 auto result = While(condition, body, init);
279 VLOG(2) << "while = "
280 << ShapeUtil::HumanString(builder.GetShape(result).value());
281
282 // Individual elements with increase by 1/8 each time through the loop, so
283 // the sum will increase by 1.0. It will first be >15.5 when the elements
284 // have all reached 2.0.
285 std::vector<float> expected = {2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f};
286 ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
287 }
288
289 // Tests a while node when the result type is a vector which is part
290 // of the result tuple.
291 //
292 // All constants are chosen to produce exact results.
293 // vector<float> result(8, 0.0f);
294 // while (result.sum() < 15.5f) {
295 // result = result + vector<float>(8, 0.125f);
296 // }
297 // tuple = tuple { while }
XLA_TEST_F(WhileTest,WhileWithVectorResultIntoTuple)298 XLA_TEST_F(WhileTest, WhileWithVectorResultIntoTuple) {
299 Shape result_shape = ShapeUtil::MakeShape(F32, {8});
300
301 // Create a computation for the reduction.
302 XlaComputation add;
303 {
304 XlaBuilder builder("add");
305 auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {}), "x");
306 auto y = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {}), "y");
307 Add(x, y);
308 add = builder.Build().value();
309 }
310
311 // Create a computation for the condition.
312 // Repeat until the sum of the result vector is less than 5.5f.
313 XlaComputation condition;
314 {
315 XlaBuilder builder("condition");
316 auto prev = Parameter(&builder, 0, result_shape, "prev");
317 auto sum = Reduce(prev, ConstantR0<float>(&builder, 0.0f), add,
318 /*dimensions_to_reduce=*/{0});
319 Gt(ConstantR0<float>(&builder, 15.5f), sum);
320 condition = builder.Build().value();
321 }
322
323 // Create a computation for the body.
324 // Add a constant vector of 1.f to the result vector.
325 XlaComputation body;
326 {
327 XlaBuilder builder("body");
328 auto prev = Parameter(&builder, 0, result_shape, "prev");
329 auto input = ConstantR1<float>(&builder, 8, 0.125f);
330 Add(input, prev);
331 body = builder.Build().value();
332 }
333
334 // Create a While node with computations for the condition and the body.
335 XlaBuilder builder("while");
336 auto init = ConstantR1<float>(&builder, 8, 0.f);
337 auto result = While(condition, body, init);
338 VLOG(2) << "while = "
339 << ShapeUtil::HumanString(builder.GetShape(result).value());
340 Tuple(&builder, {result});
341
342 // Individual elements with increase by 1/8 each time through the loop, so
343 // the sum will increase by 1.0. It will first be >15.5 when the elements
344 // have all reached 2.0.
345 auto expected_data =
346 LiteralUtil::CreateR1<float>({2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f});
347 auto expected = LiteralUtil::MakeTuple({&expected_data});
348 VLOG(2) << "expected = " << ShapeUtil::HumanString(expected.shape());
349 ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.0001));
350 }
351
XLA_TEST_F(WhileTest,WhileWithPermutationAndTupleResult)352 XLA_TEST_F(WhileTest, WhileWithPermutationAndTupleResult) {
353 std::vector<Shape> shape_elements = {
354 ShapeUtil::MakeShape(S32, {}), ShapeUtil::MakeShape(F32, {3}),
355 ShapeUtil::MakeShape(F32, {3}), ShapeUtil::MakeShape(F32, {3})};
356 Shape result_shape = ShapeUtil::MakeTupleShape(shape_elements);
357
358 // Create a computation for the condition.
359 // Repeat for N iterations.
360 const int N = 2;
361 XlaComputation condition;
362 {
363 XlaBuilder builder("condition");
364 auto prev = Parameter(&builder, 0, result_shape, "prev");
365 auto iteration = GetTupleElement(prev, 0);
366 Gt(ConstantR0<int32_t>(&builder, N), iteration);
367 condition = builder.Build().value();
368 }
369
370 // Create a computation for the body.
371 // Add 1 to the iteration variable and permute the weights.
372 XlaComputation body;
373 {
374 XlaBuilder builder("body");
375 auto prev = Parameter(&builder, 0, result_shape, "prev");
376 auto iteration = GetTupleElement(prev, 0);
377 auto w1 = GetTupleElement(prev, 1);
378 auto w2 = GetTupleElement(prev, 2);
379 auto w3 = GetTupleElement(prev, 3);
380 Tuple(&builder,
381 {Add(iteration, ConstantR0<int32_t>(&builder, 1)), w3, w1, w2});
382 body = builder.Build().value();
383 }
384
385 // Create a While node with computations for the condition and the body.
386 XlaBuilder builder("while");
387 auto init = Tuple(&builder, {ConstantR0<int32_t>(&builder, 0),
388 ConstantR1<float>(&builder, 3, 1.f),
389 ConstantR1<float>(&builder, 3, 2.f),
390 ConstantR1<float>(&builder, 3, 3.f)});
391 auto result = While(condition, body, init);
392 VLOG(2) << "result = "
393 << ShapeUtil::HumanString(builder.GetShape(result).value());
394
395 auto expected_counter = LiteralUtil::CreateR0<int32_t>(N);
396 auto expected_w1 = LiteralUtil::CreateR1<float>({1.0f, 1.0f, 1.0f});
397 auto expected_w2 = LiteralUtil::CreateR1<float>({2.0f, 2.0f, 2.0f});
398 auto expected_w3 = LiteralUtil::CreateR1<float>({3.0f, 3.0f, 3.0f});
399 auto expected = LiteralUtil::MakeTuple(
400 {&expected_counter, &expected_w2, &expected_w3, &expected_w1});
401 VLOG(2) << "expected = " << ShapeUtil::HumanString(expected.shape());
402 ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.0001));
403 }
404
XLA_TEST_F(WhileTest,WhileWithPermutationAndVectorResult)405 XLA_TEST_F(WhileTest, WhileWithPermutationAndVectorResult) {
406 std::vector<Shape> shape_elements = {
407 ShapeUtil::MakeShape(S32, {}), ShapeUtil::MakeShape(F32, {3}),
408 ShapeUtil::MakeShape(F32, {3}), ShapeUtil::MakeShape(F32, {3})};
409 Shape result_shape = ShapeUtil::MakeTupleShape(shape_elements);
410
411 // Create a computation for the condition.
412 // Repeat for N iterations.
413 const int N = 2;
414 XlaComputation condition;
415 {
416 XlaBuilder builder("condition");
417 auto prev = Parameter(&builder, 0, result_shape, "prev");
418 auto iteration = GetTupleElement(prev, 0);
419 Gt(ConstantR0<int32_t>(&builder, N), iteration);
420 condition = builder.Build().value();
421 }
422
423 // Create a computation for the body.
424 // Add 1 to the iteration variable permute the weights.
425 XlaComputation body;
426 {
427 XlaBuilder builder("body");
428 auto prev = Parameter(&builder, 0, result_shape, "prev");
429 auto iteration = GetTupleElement(prev, 0);
430 auto w1 = GetTupleElement(prev, 1);
431 auto w2 = GetTupleElement(prev, 2);
432 auto w3 = GetTupleElement(prev, 3);
433 Tuple(&builder,
434 {Add(iteration, ConstantR0<int32_t>(&builder, 1)), w3, w1, w2});
435 body = builder.Build().value();
436 }
437
438 // Create a While node with computations for the condition and the body.
439 XlaBuilder builder("while");
440 auto init = Tuple(&builder, {ConstantR0<int32_t>(&builder, 0),
441 ConstantR1<float>(&builder, 3, 1.f),
442 ConstantR1<float>(&builder, 3, 2.f),
443 ConstantR1<float>(&builder, 3, 3.f)});
444 auto xla_while = While(condition, body, init);
445
446 auto add12 =
447 Add(GetTupleElement(xla_while, 1), GetTupleElement(xla_while, 2));
448 auto result = Add(add12, GetTupleElement(xla_while, 3));
449 VLOG(2) << "result = "
450 << ShapeUtil::HumanString(builder.GetShape(result).value());
451 std::vector<float> expected = {6.f, 6.f, 6.f};
452 ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
453 }
454
455 // Tests a while node when the result type T is a Tuple.
456 //
457 // tuple<int32_t, vector<float>> result(0, vector<float>(10, 0.0f));
458 // while (get<0>(result) < 5) {
459 // get<0>(result) = get<0>(result) + 1;
460 // get<1>(result) = get<1>(result) + vector<float>(10, 1.0f);
461 // }
XLA_TEST_F(WhileTest,WhileWithTupleResult)462 XLA_TEST_F(WhileTest, WhileWithTupleResult) {
463 std::vector<Shape> shape_elements = {ShapeUtil::MakeShape(S32, {}),
464 ShapeUtil::MakeShape(F32, {10})};
465 Shape result_shape = ShapeUtil::MakeTupleShape(shape_elements);
466
467 // Create a computation for the condition.
468 // Repeat for 5 iterations.
469 XlaComputation condition;
470 {
471 XlaBuilder builder("condition");
472 auto prev = Parameter(&builder, 0, result_shape, "prev");
473 auto iteration = GetTupleElement(prev, 0);
474 Gt(ConstantR0<int32_t>(&builder, 5), iteration);
475 condition = builder.Build().value();
476 }
477
478 // Create a computation for the body.
479 // Add 1 to the iteration variable and add a constant vector of 1.0f to
480 // the weight variable, both of which are tuple elements.
481 XlaComputation body;
482 {
483 XlaBuilder builder("body");
484 auto prev = Parameter(&builder, 0, result_shape, "prev");
485 auto iteration = GetTupleElement(prev, 0);
486 auto weights = GetTupleElement(prev, 1);
487 auto input = ConstantR1<float>(&builder, 10, 1.f);
488 auto new_weights = Add(weights, input);
489 Tuple(&builder,
490 {Add(iteration, ConstantR0<int32_t>(&builder, 1)), new_weights});
491 body = builder.Build().value();
492 }
493
494 // Create a While node with computations for the condition and the body.
495 XlaBuilder builder("while");
496 auto init = Tuple(&builder, {ConstantR0<int32_t>(&builder, 0),
497 ConstantR1<float>(&builder, 10, 0.f)});
498 auto result = While(condition, body, init);
499 VLOG(2) << "while = "
500 << ShapeUtil::HumanString(builder.GetShape(result).value());
501
502 auto expected_counter = LiteralUtil::CreateR0<int32_t>(5);
503 auto expected_data = LiteralUtil::CreateR1<float>(
504 {5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f});
505 auto expected = LiteralUtil::MakeTuple({&expected_counter, &expected_data});
506 VLOG(2) << "expected = " << ShapeUtil::HumanString(expected.shape());
507 ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.0001));
508 }
509
XLA_TEST_F(WhileTest,WhileWithPredicateTupleResult)510 XLA_TEST_F(WhileTest, WhileWithPredicateTupleResult) {
511 std::vector<Shape> shape_elements = {ShapeUtil::MakeShape(S32, {}),
512 ShapeUtil::MakeShape(PRED, {})};
513 Shape result_shape = ShapeUtil::MakeTupleShape(shape_elements);
514
515 // Create a computation for the condition.
516 // Repeat for 5 iterations.
517 XlaComputation condition;
518 {
519 XlaBuilder builder("condition");
520 auto prev = Parameter(&builder, 0, result_shape, "prev");
521 auto iteration = GetTupleElement(prev, 0);
522 Gt(ConstantR0<int32_t>(&builder, 5), iteration);
523 condition = builder.Build().value();
524 }
525
526 // Create a computation for the body.
527 // Add 1 to the iteration variable and or the predicate with true
528 XlaComputation body;
529 {
530 XlaBuilder builder("body");
531 auto prev = Parameter(&builder, 0, result_shape, "prev");
532 auto iteration = GetTupleElement(prev, 0);
533 auto pred = GetTupleElement(prev, 1);
534 auto new_pred = Or(pred, ConstantR0<bool>(&builder, true));
535 Tuple(&builder,
536 {Add(iteration, ConstantR0<int32_t>(&builder, 1)), new_pred});
537 body = builder.Build().value();
538 }
539
540 // Create a While node with computations for the condition and the body.
541 XlaBuilder builder("while");
542 auto init = Tuple(&builder, {ConstantR0<int32_t>(&builder, 0),
543 Ne(ConstantR0<bool>(&builder, false),
544 ConstantR0<bool>(&builder, true))});
545 auto result = While(condition, body, init);
546 VLOG(2) << "while = "
547 << ShapeUtil::HumanString(builder.GetShape(result).value());
548
549 auto expected_counter = LiteralUtil::CreateR0<int32_t>(5);
550 auto expected_predicate = LiteralUtil::CreateR0<bool>(true);
551 auto expected =
552 LiteralUtil::MakeTuple({&expected_counter, &expected_predicate});
553 ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0));
554 }
555
XLA_TEST_F(WhileTest,WhileWithTupleConstantScalarResult)556 XLA_TEST_F(WhileTest, WhileWithTupleConstantScalarResult) {
557 std::vector<Shape> shape_elements = {ShapeUtil::MakeShape(S32, {}),
558 ShapeUtil::MakeShape(S32, {})};
559 Shape result_shape = ShapeUtil::MakeTupleShape(shape_elements);
560
561 // Create a computation for the condition.
562 // Repeat for 5 iterations.
563 XlaComputation condition;
564 {
565 XlaBuilder builder("condition");
566 auto prev = Parameter(&builder, 0, result_shape, "prev");
567 auto iteration = GetTupleElement(prev, 0);
568 Gt(ConstantR0<int32_t>(&builder, 5), iteration);
569 condition = builder.Build().value();
570 }
571
572 // Create a computation for the body.
573 // Add 1 to the iteration variable and set the other tuple element to a
574 // constant.
575 XlaComputation body;
576 {
577 XlaBuilder builder("body");
578 auto prev = Parameter(&builder, 0, result_shape, "prev");
579 auto iteration = GetTupleElement(prev, 0);
580 Tuple(&builder, {Add(iteration, ConstantR0<int32_t>(&builder, 1)),
581 ConstantR0<int32_t>(&builder, 7)});
582 body = builder.Build().value();
583 }
584
585 // Create a While node with computations for the condition and the body.
586 XlaBuilder builder("while");
587 auto init = Tuple(&builder, {ConstantR0<int32_t>(&builder, 0),
588 ConstantR0<int32_t>(&builder, 7)});
589 auto result = While(condition, body, init);
590 VLOG(2) << "while = "
591 << ShapeUtil::HumanString(builder.GetShape(result).value());
592
593 auto expected_counter = LiteralUtil::CreateR0<int32_t>(5);
594 auto expected_data = LiteralUtil::CreateR0<int32_t>(7);
595 auto expected = LiteralUtil::MakeTuple({&expected_counter, &expected_data});
596 VLOG(2) << "expected = " << ShapeUtil::HumanString(expected.shape());
597 ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.0001));
598 }
599
600 // Tests two while nodes when the result type T is a Tuple and the second
601 // while node uses the result of the first while node which is used in two
602 // nodes.
603 // tuple<int32_t, vector<float>> w0(0, vector<float>(10, 0.0f));
604 // w0 = while (get<0>(w0) < c1) {
605 // get<0>(w0) = get<0>(w0) + 1;
606 // get<1>(w0) = get<1>(w0) + vector<float>(10, 1.0f);
607 // }
608 // tuple<int32_t, vector<float>> w1(get<0>(w0), get<1>(w0));
609 // w1 = while (get<0>(w1) < c2) {
610 // get<0>(w1) = get<0>(w1) + 1;
611 // get<1>(w1) = get<1>(w1) + vector<float>(10, 1.0f);
612 // }
613 // result = get<1>(w0) + get<1>(w1)
XLA_TEST_F(WhileTest,TwoWhileWithTupleResult)614 XLA_TEST_F(WhileTest, TwoWhileWithTupleResult) {
615 std::vector<Shape> shape_elements = {ShapeUtil::MakeShape(S32, {}),
616 ShapeUtil::MakeShape(F32, {10})};
617 Shape result_shape = ShapeUtil::MakeTupleShape(shape_elements);
618
619 // Create a computation for the condition.
620 // Repeat for 5 iterations.
621 XlaComputation condition;
622 const int c1 = 5;
623 {
624 XlaBuilder builder("condition");
625 auto prev = Parameter(&builder, 0, result_shape, "prev");
626 auto iteration = GetTupleElement(prev, 0);
627 Lt(iteration, ConstantR0<int32_t>(&builder, c1));
628 TF_ASSERT_OK_AND_ASSIGN(condition, builder.Build());
629 }
630
631 XlaComputation condition2;
632 const int c2 = 7;
633 {
634 XlaBuilder builder("condition2");
635 auto prev = Parameter(&builder, 0, result_shape, "prev");
636 auto iteration = GetTupleElement(prev, 0);
637 Lt(iteration, ConstantR0<int32_t>(&builder, c2));
638 TF_ASSERT_OK_AND_ASSIGN(condition2, builder.Build());
639 }
640
641 // Create a computation for the body.
642 // Add 1 to the iteration variable and add a constant vector of 1.0f to
643 // the weight variable, both of which are tuple elements.
644 XlaComputation body;
645 {
646 XlaBuilder builder("body");
647 auto prev = Parameter(&builder, 0, result_shape, "prev");
648 auto iteration = GetTupleElement(prev, 0);
649 auto weights = GetTupleElement(prev, 1);
650 auto input = ConstantR1<float>(&builder, 10, 1.f);
651 auto new_weights = Add(weights, input);
652 Tuple(&builder,
653 {Add(iteration, ConstantR0<int32_t>(&builder, 1)), new_weights});
654 TF_ASSERT_OK_AND_ASSIGN(body, builder.Build());
655 }
656
657 XlaComputation body2;
658 {
659 XlaBuilder builder("body");
660 auto prev = Parameter(&builder, 0, result_shape, "prev");
661 auto iteration = GetTupleElement(prev, 0);
662 auto weights = GetTupleElement(prev, 1);
663 auto input = ConstantR1<float>(&builder, 10, 1.f);
664 auto new_weights = Add(weights, input);
665 Tuple(&builder,
666 {Add(iteration, ConstantR0<int32_t>(&builder, 1)), new_weights});
667 TF_ASSERT_OK_AND_ASSIGN(body2, builder.Build());
668 }
669
670 // Create a While node with computations for the condition and the body.
671 XlaBuilder builder("while");
672 auto init = Tuple(&builder, {ConstantR0<int32_t>(&builder, 0),
673 ConstantR1<float>(&builder, 10, 0.f)});
674 auto while1 = While(condition, body, init);
675
676 auto while2 = While(condition2, body2, while1);
677
678 auto while_result1 = GetTupleElement(while1, 1);
679 auto while_result2 = GetTupleElement(while2, 1);
680 VLOG(2) << "while_result2 = "
681 << ShapeUtil::HumanString(builder.GetShape(while_result2).value());
682 auto result = Add(while_result1, while_result2);
683 VLOG(2) << "result = "
684 << ShapeUtil::HumanString(builder.GetShape(result).value());
685 const float sum = c1 + c2;
686 std::vector<float> expected(10, sum);
687 ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
688 }
689
690 // Test while nodes that share the while body computation.
XLA_TEST_F(WhileTest,TwoWhileLoopsAndSharedBody)691 XLA_TEST_F(WhileTest, TwoWhileLoopsAndSharedBody) {
692 std::vector<Shape> shape_elements = {ShapeUtil::MakeShape(S32, {}),
693 ShapeUtil::MakeShape(F32, {10})};
694 Shape result_shape = ShapeUtil::MakeTupleShape(shape_elements);
695
696 // Create a computation for the condition.
697 // Repeat for 5 iterations.
698 XlaComputation condition;
699 const int c1 = 5;
700 {
701 XlaBuilder builder("condition");
702 auto prev = Parameter(&builder, 0, result_shape, "prev");
703 auto iteration = GetTupleElement(prev, 0);
704 Lt(iteration, ConstantR0<int32_t>(&builder, c1));
705 TF_ASSERT_OK_AND_ASSIGN(condition, builder.Build());
706 }
707
708 XlaComputation condition2;
709 const int c2 = 7;
710 {
711 XlaBuilder builder("condition2");
712 auto prev = Parameter(&builder, 0, result_shape, "prev");
713 auto iteration = GetTupleElement(prev, 0);
714 Lt(iteration, ConstantR0<int32_t>(&builder, c2));
715 TF_ASSERT_OK_AND_ASSIGN(condition2, builder.Build());
716 }
717
718 // Create a computation for the body.
719 // Add 1 to the iteration variable and add a constant vector of 1.0f to
720 // the weight variable, both of which are tuple elements.
721 XlaComputation body;
722 {
723 XlaBuilder builder("body");
724 auto prev = Parameter(&builder, 0, result_shape, "prev");
725 auto iteration = GetTupleElement(prev, 0);
726 auto weights = GetTupleElement(prev, 1);
727 auto input = ConstantR1<float>(&builder, 10, 1.f);
728 auto new_weights = Add(weights, input);
729 Tuple(&builder,
730 {Add(iteration, ConstantR0<int32_t>(&builder, 1)), new_weights});
731 TF_ASSERT_OK_AND_ASSIGN(body, builder.Build());
732 }
733
734 // Create a While node with computations for the condition and the body.
735 XlaBuilder builder("while");
736 auto init = Tuple(&builder, {ConstantR0<int32_t>(&builder, 0),
737 ConstantR1<float>(&builder, 10, 0.f)});
738 auto while1 = While(condition, body, init);
739
740 auto while2 = While(condition2, body, while1);
741
742 auto while_result1 = GetTupleElement(while1, 1);
743 auto while_result2 = GetTupleElement(while2, 1);
744 VLOG(2) << "while_result2 = "
745 << ShapeUtil::HumanString(builder.GetShape(while_result2).value());
746 auto result = Add(while_result1, while_result2);
747 VLOG(2) << "result = "
748 << ShapeUtil::HumanString(builder.GetShape(result).value());
749 const float sum = c1 + c2;
750 std::vector<float> expected(10, sum);
751 ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
752 }
753
XLA_TEST_F(WhileTest,WhileLoopsWithSharedBodyAndInit)754 XLA_TEST_F(WhileTest, WhileLoopsWithSharedBodyAndInit) {
755 std::vector<Shape> shape_elements = {ShapeUtil::MakeShape(S32, {}),
756 ShapeUtil::MakeShape(F32, {10})};
757 Shape result_shape = ShapeUtil::MakeTupleShape(shape_elements);
758
759 // Create a computation for the condition.
760 // Repeat for 5 iterations.
761 XlaComputation condition;
762 const int c1 = 5;
763 {
764 XlaBuilder builder("condition");
765 auto prev = Parameter(&builder, 0, result_shape, "prev");
766 auto iteration = GetTupleElement(prev, 0);
767 Lt(iteration, ConstantR0<int32_t>(&builder, c1));
768 TF_ASSERT_OK_AND_ASSIGN(condition, builder.Build());
769 }
770
771 XlaComputation condition2;
772 const int c2 = 7;
773 {
774 XlaBuilder builder("condition2");
775 auto prev = Parameter(&builder, 0, result_shape, "prev");
776 auto iteration = GetTupleElement(prev, 0);
777 Lt(iteration, ConstantR0<int32_t>(&builder, c2));
778 TF_ASSERT_OK_AND_ASSIGN(condition2, builder.Build());
779 }
780
781 // Create a computation for the body.
782 // Add 1 to the iteration variable and add a constant vector of 1.0f to
783 // the weight variable, both of which are tuple elements.
784 XlaComputation body;
785 {
786 XlaBuilder builder("body");
787 auto prev = Parameter(&builder, 0, result_shape, "prev");
788 auto iteration = GetTupleElement(prev, 0);
789 auto weights = GetTupleElement(prev, 1);
790 auto input = ConstantR1<float>(&builder, 10, 1.f);
791 auto new_weights = Add(weights, input);
792 Tuple(&builder,
793 {Add(iteration, ConstantR0<int32_t>(&builder, 1)), new_weights});
794 TF_ASSERT_OK_AND_ASSIGN(body, builder.Build());
795 }
796
797 // Create a While node with computations for the condition and the body.
798 XlaBuilder builder("while");
799 auto init = Tuple(&builder, {ConstantR0<int32_t>(&builder, 0),
800 ConstantR1<float>(&builder, 10, 0.f)});
801 auto while1 = While(condition, body, init);
802 auto while2 = While(condition2, body, init);
803
804 auto while_result1 = GetTupleElement(while1, 1);
805 auto while_result2 = GetTupleElement(while2, 1);
806 VLOG(2) << "while_result2 = "
807 << ShapeUtil::HumanString(builder.GetShape(while_result2).value());
808 auto result = Add(while_result1, while_result2);
809 VLOG(2) << "result = "
810 << ShapeUtil::HumanString(builder.GetShape(result).value());
811 const float sum = c1 + c2;
812 std::vector<float> expected(10, sum);
813 ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
814 }
815
816 // WhileTest that uses DynamicUpdateSlice instruction in body computation.
817 // Loop state tuple element 1 has as its single user operand(0) of
818 // DynamicUpdateSlice, which will trigger in-place dynamic slice update on GPU.
XLA_TEST_F(WhileTest,WhileWithDynamicUpdateSlice)819 XLA_TEST_F(WhileTest, WhileWithDynamicUpdateSlice) {
820 std::vector<Shape> shape_elements = {ShapeUtil::MakeShape(S32, {}),
821 ShapeUtil::MakeShape(F32, {10})};
822 Shape result_shape = ShapeUtil::MakeTupleShape(shape_elements);
823
824 // Create a computation for the condition.
825 // Repeat for 5 iterations.
826 XlaComputation condition;
827 {
828 XlaBuilder builder("condition");
829 auto prev = Parameter(&builder, 0, result_shape, "prev");
830 auto iteration = GetTupleElement(prev, 0);
831 Gt(ConstantR0<int32_t>(&builder, 5), iteration);
832 condition = builder.Build().value();
833 }
834
835 // Create a computation for the body.
836 // Add 1 to the iteration variable and add a constant vector of 1.0f to
837 // the weight variable, both of which are tuple elements.
838 XlaComputation body;
839 {
840 XlaBuilder builder("body");
841 auto prev = Parameter(&builder, 0, result_shape, "prev");
842 // TupleElement 0
843 auto iteration = GetTupleElement(prev, 0);
844 auto out0 = Add(iteration, ConstantR0<int32_t>(&builder, 1));
845 // TupleElement 1
846 auto input = GetTupleElement(prev, 1);
847 // Update.
848 auto update = ConvertElementType(Broadcast(out0, {2}), F32);
849 // Starts = iteration * 2;
850 auto starts = Mul(iteration, ConstantR0<int32_t>(&builder, 2));
851 // UpdateSlice.
852 auto out1 = DynamicUpdateSlice(input, update, {starts});
853
854 Tuple(&builder, {out0, out1});
855 body = builder.Build().value();
856 }
857
858 // Create a While node with computations for the condition and the body.
859 XlaBuilder builder("while");
860 auto init = Tuple(&builder, {ConstantR0<int32_t>(&builder, 0),
861 ConstantR1<float>(&builder, 10, 0.f)});
862 auto result = While(condition, body, init);
863 VLOG(2) << "while = "
864 << ShapeUtil::HumanString(builder.GetShape(result).value());
865
866 auto expected_counter = LiteralUtil::CreateR0<int32_t>(5);
867 auto expected_data = LiteralUtil::CreateR1<float>(
868 {1.0f, 1.0f, 2.0f, 2.0f, 3.0f, 3.0f, 4.0f, 4.0f, 5.0f, 5.0f});
869 auto expected = LiteralUtil::MakeTuple({&expected_counter, &expected_data});
870 VLOG(2) << "expected = " << ShapeUtil::HumanString(expected.shape());
871 ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.0001));
872 }
873
874 // Tests a while node when the result type T is a vector of S32.
875 //
876 // int32_t result = (0, 0, 0, 0, 0, 0);
877 // while (result[0] < count) {
878 // result += (1, U[0, 100], U[0, 100], U[0, 100], U[0, 100], U[0, 100]);
879 // }
880 //
881 // This test misuses a vector WhileTest.WhileLoopsWithSharedBodyto represent a
882 // pair:
883 // ((iteration, (random vector))).
884 //
885 // Note: this test currently only tests generating random values within a loop.
886 // Per backend the values generated can be different as the different backends
887 // use different random number generators.
888 // TODO(b/32240857): Extend test to verify outputs.
XLA_TEST_F(WhileTest,WhileWithPrngScalarResult)889 XLA_TEST_F(WhileTest, WhileWithPrngScalarResult) {
890 auto v6s32 = ShapeUtil::MakeShape(S32, {6});
891
892 // Create a computation for the condition: repeat for count iterations.
893 auto build_condition = [this, v6s32](int count) {
894 XlaBuilder builder(TestName());
895 auto prev = Reshape(
896 Slice(Parameter(&builder, 0, v6s32, "prev"), {0}, {1}, {1}), {0}, {});
897 Gt(ConstantR0<int32_t>(&builder, count), prev);
898 return builder.Build().value();
899 };
900
901 // Create a computation for the body: add 1 to the result variable.
902 XlaComputation body;
903 {
904 XlaBuilder builder("body");
905 auto prev = Parameter(&builder, 0, v6s32, "prev");
906 auto inc = ConcatInDim(&builder,
907 {ConstantR1<int32_t>(&builder, {1}),
908 RngUniform(ConstantR0<int32_t>(&builder, 0),
909 ConstantR0<int32_t>(&builder, 100),
910 ShapeUtil::MakeShape(S32, {5}))},
911 0);
912 Add(inc, prev);
913 body = builder.Build().value();
914 }
915
916 // Create a While node with computations for the condition and the body.
917 auto while_loop = [this, &body, build_condition](int count) {
918 XlaBuilder builder(TestName());
919 auto init = ConstantR1<int32_t>(&builder, {0, 0, 0, 0, 0, 0});
920 While(build_condition(count), body, init);
921 return builder.Build();
922 };
923
924 for (int i = 1; i < 4; ++i) {
925 TF_ASSERT_OK_AND_ASSIGN(auto computation, while_loop(i));
926
927 ExecutionOptions execution_options = execution_options_;
928 execution_options.set_seed(65);
929 TF_ASSERT_OK_AND_ASSIGN(
930 auto result,
931 client_->ExecuteAndTransfer(computation, {}, &execution_options));
932 }
933 }
934
XLA_TEST_F(WhileTest,WhileThatSwapsParameterWithTupleElement)935 XLA_TEST_F(WhileTest, WhileThatSwapsParameterWithTupleElement) {
936 auto element_shape = ShapeUtil::MakeShape(F32, {2});
937
938 XlaBuilder outer("outer");
939 auto p = Parameter(&outer, 0, element_shape, "param");
940 auto t = Tuple(&outer, {p, ConstantR1<float>(&outer, {1, 1})});
941
942 TF_ASSERT_OK_AND_ASSIGN(Shape tuple_shape, outer.GetShape(t));
943
944 XlaBuilder cond("cond");
945 auto cond_t = Parameter(&cond, 0, tuple_shape, "t");
946 Any(Eq(GetTupleElement(cond_t, 0), ConstantR1<float>(&cond, {42, 42})));
947
948 XlaBuilder body("body");
949 auto body_t = Parameter(&body, 0, tuple_shape, "t");
950 auto e = GetTupleElement(body_t, 1);
951 Tuple(&body, {e, e});
952
953 TF_ASSERT_OK_AND_ASSIGN(auto cond_computation, cond.Build());
954 TF_ASSERT_OK_AND_ASSIGN(auto body_computation, body.Build());
955 While(cond_computation, body_computation, t);
956
957 auto expected_element = LiteralUtil::CreateR1<float>({1, 1});
958 auto expected =
959 LiteralUtil::MakeTuple({&expected_element, &expected_element});
960 TF_ASSERT_OK_AND_ASSIGN(
961 std::unique_ptr<GlobalData> parameter_data,
962 client_->TransferToServer(LiteralUtil::CreateR1<float>({42, 42})));
963 ComputeAndCompareTuple(&outer, expected, {parameter_data.get()},
964 ErrorSpec(1e-6));
965 }
966
XLA_TEST_F(WhileTest,WhileThatSwapsParameterWithBroadcast)967 XLA_TEST_F(WhileTest, WhileThatSwapsParameterWithBroadcast) {
968 auto element_shape = ShapeUtil::MakeShape(F32, {2});
969
970 XlaBuilder outer("outer");
971 auto p = Parameter(&outer, 0, element_shape, "param");
972
973 XlaBuilder cond("cond");
974 auto cond_t = Parameter(&cond, 0, element_shape, "t");
975 Any(Eq(cond_t, ConstantR1<float>(&cond, {42, 42})));
976
977 XlaBuilder body("body");
978 Parameter(&body, 0, element_shape, "t");
979 Broadcast(ConstantR0<float>(&body, 1.0), {2});
980
981 TF_ASSERT_OK_AND_ASSIGN(auto cond_computation, cond.Build());
982 TF_ASSERT_OK_AND_ASSIGN(auto body_computation, body.Build());
983 While(cond_computation, body_computation, p);
984
985 TF_ASSERT_OK_AND_ASSIGN(
986 std::unique_ptr<GlobalData> parameter_data,
987 client_->TransferToServer(LiteralUtil::CreateR1<float>({42, 42})));
988 ComputeAndCompareR1<float>(&outer, {1.0f, 1.0f}, {parameter_data.get()},
989 ErrorSpec(1e-6));
990 }
991
XLA_TEST_F(WhileTest,WhileThatTurnsScalarParameterToTupleElement)992 XLA_TEST_F(WhileTest, WhileThatTurnsScalarParameterToTupleElement) {
993 auto element_shape = ShapeUtil::MakeShape(F32, {});
994
995 XlaBuilder outer("outer");
996 auto p = Parameter(&outer, 0, element_shape, "param");
997
998 XlaBuilder cond("cond");
999 auto cond_t = Parameter(&cond, 0, element_shape, "t");
1000 Eq(cond_t, ConstantR0<float>(&cond, 42));
1001
1002 XlaBuilder body("body");
1003 auto body_t = Parameter(&body, 0, element_shape, "t");
1004 auto tuple = Tuple(&body, {body_t, Add(body_t, ConstantR0<float>(&body, 1))});
1005 GetTupleElement(tuple, 1);
1006
1007 TF_ASSERT_OK_AND_ASSIGN(auto cond_computation, cond.Build());
1008 TF_ASSERT_OK_AND_ASSIGN(auto body_computation, body.Build());
1009 While(cond_computation, body_computation, p);
1010
1011 TF_ASSERT_OK_AND_ASSIGN(
1012 std::unique_ptr<GlobalData> parameter_data,
1013 client_->TransferToServer(LiteralUtil::CreateR0<float>(42)));
1014 ComputeAndCompareR0<float>(&outer, 43.0f, {parameter_data.get()},
1015 ErrorSpec(1e-6));
1016 }
1017
1018 // Tests loop where the init value comes from two sources (constant and
1019 // parameter).
1020 //
1021 // int32_t result = (0, 1);
1022 // while (result[0] + result[1] < 30) {
1023 // result[0] = result[0] + 1;
1024 // result[1] = result[1] + 1;
1025 // }
XLA_TEST_F(WhileTest,WhileWithMixedTupleElements)1026 XLA_TEST_F(WhileTest, WhileWithMixedTupleElements) {
1027 auto result_shape = ShapeUtil::MakeTupleShape(
1028 {ShapeUtil::MakeShape(S32, {}), ShapeUtil::MakeShape(S32, {})});
1029
1030 XlaBuilder outer("outer");
1031 auto p =
1032 Tuple(&outer, {ConstantR0<int32_t>(&outer, 0),
1033 Parameter(&outer, 0, ShapeUtil::MakeShape(S32, {}), "t")});
1034
1035 XlaBuilder cond("cond");
1036 auto params = Parameter(&cond, 0, result_shape, "prev");
1037 auto cond_t = Add(GetTupleElement(params, 1), GetTupleElement(params, 0));
1038 Lt(cond_t, ConstantR0<int32_t>(&cond, 30));
1039
1040 XlaBuilder body("body");
1041 auto body_t = Parameter(&body, 0, result_shape, "t");
1042
1043 Tuple(&body,
1044 {Add(GetTupleElement(body_t, 0), ConstantR0<int32_t>(&body, 1)),
1045 Add(GetTupleElement(body_t, 1), ConstantR0<int32_t>(&body, 1))});
1046
1047 TF_ASSERT_OK_AND_ASSIGN(auto cond_computation, cond.Build());
1048 TF_ASSERT_OK_AND_ASSIGN(auto body_computation, body.Build());
1049 While(cond_computation, body_computation, p);
1050
1051 TF_ASSERT_OK_AND_ASSIGN(
1052 std::unique_ptr<GlobalData> parameter_data,
1053 client_->TransferToServer(LiteralUtil::CreateR0<int32_t>(1)));
1054
1055 auto add1 = LiteralUtil::CreateR0<int32_t>(15);
1056 auto add2 = LiteralUtil::CreateR0<int32_t>(16);
1057 auto expected = LiteralUtil::MakeTuple({&add1, &add2});
1058 ComputeAndCompareTuple(&outer, expected, {parameter_data.get()},
1059 ErrorSpec(1e-6));
1060 }
1061
1062 // Tests nested while loops.
1063 //
1064 // int32_t result = 0;
1065 // while (result < 30) {
1066 // int i = 0;
1067 // while (i < 7) {
1068 // result = result + 2;
1069 // i = i + 1;
1070 // }
1071 // }
XLA_TEST_F(WhileTest,NestedWhileWithScalarResult)1072 XLA_TEST_F(WhileTest, NestedWhileWithScalarResult) {
1073 auto outer_result_shape = ShapeUtil::MakeShape(S32, {});
1074 auto inner_result_shape = ShapeUtil::MakeTupleShape(
1075 {ShapeUtil::MakeShape(S32, {}), ShapeUtil::MakeShape(S32, {})});
1076
1077 XlaComputation inner_condition;
1078 {
1079 XlaBuilder builder("inner_condition");
1080 auto params = Parameter(&builder, 0, inner_result_shape, "prev");
1081 auto i = GetTupleElement(params, 0);
1082 Lt(i, ConstantR0<int32_t>(&builder, 7));
1083 inner_condition = builder.Build().value();
1084 }
1085
1086 // Creates a computation for the outer loop condition:
1087 // repeat while result < 30.
1088 XlaComputation outer_condition;
1089 {
1090 XlaBuilder builder("outer_condition");
1091 auto prev = Parameter(&builder, 0, outer_result_shape, "prev");
1092 Lt(prev, ConstantR0<int32_t>(&builder, 30));
1093 outer_condition = builder.Build().value();
1094 }
1095
1096 // Creates a computation for the inner loop body: add 1 to `i`, and add 2 to
1097 // `result`.
1098 XlaComputation inner_body;
1099 {
1100 XlaBuilder builder("inner_body");
1101 auto params = Parameter(&builder, 0, inner_result_shape, "prev");
1102 auto i = GetTupleElement(params, 0);
1103 auto result = GetTupleElement(params, 1);
1104 i = Add(ConstantR0<int32_t>(&builder, 1), i);
1105 result = Add(ConstantR0<int32_t>(&builder, 2), result);
1106 Tuple(&builder, {i, result});
1107 inner_body = builder.Build().value();
1108 }
1109
1110 // Creates a computation for the outer loop: run the inner loop with i = 0.
1111 XlaComputation outer_body;
1112 {
1113 XlaBuilder builder("outer_body");
1114 auto prev = Parameter(&builder, 0, outer_result_shape, "prev");
1115 auto init = Tuple(&builder, {ConstantR0<int32_t>(&builder, 0), prev});
1116 auto result = While(inner_condition, inner_body, init);
1117 GetTupleElement(result, 1);
1118 outer_body = builder.Build().value();
1119 }
1120
1121 // Create a While node with computations for the condition and the body.
1122 XlaBuilder builder(TestName());
1123 auto init = ConstantR0<int32_t>(&builder, 0);
1124 While(outer_condition, outer_body, init);
1125
1126 ComputeAndCompareR0<int32_t>(&builder, 42, {});
1127 }
1128
1129 // Tests a while node when the result type T is S32.
1130 // f = lambda result: tuple({result < 5})
1131 // int32_t result = 0;
1132 // while (f(result).get<0>()) {
1133 // result = result + 1;
1134 // }
XLA_TEST_F(WhileTest,WhileWithCallInsideCondition)1135 XLA_TEST_F(WhileTest, WhileWithCallInsideCondition) {
1136 auto result_shape = ShapeUtil::MakeShape(S32, {});
1137
1138 // Create a computation for the condition: repeat for 5 iterations.
1139 XlaComputation condition_callee;
1140 {
1141 XlaBuilder builder("condition_callee");
1142 auto prev = Parameter(&builder, 0, result_shape, "prev");
1143 Tuple(&builder, {Gt(ConstantR0<int32_t>(&builder, 5), prev)});
1144
1145 condition_callee = builder.Build().value();
1146 }
1147
1148 XlaComputation condition;
1149 {
1150 XlaBuilder builder("condition");
1151 auto prev = Parameter(&builder, 0, result_shape, "prev");
1152 auto result = Call(&builder, condition_callee, {prev});
1153 GetTupleElement(result, 0);
1154 condition = builder.Build().value();
1155 }
1156
1157 // Create a computation for the body: add 1 to the result variable.
1158 XlaComputation body;
1159 {
1160 XlaBuilder builder("body");
1161 auto prev = Parameter(&builder, 0, result_shape, "prev");
1162 auto input = ConstantR0<int32_t>(&builder, 1);
1163 Add(input, prev);
1164 body = builder.Build().value();
1165 }
1166
1167 // Create a While node with computations for the condition and the body.
1168 XlaBuilder builder(TestName());
1169 auto init = ConstantR0<int32_t>(&builder, 0);
1170 While(condition, body, init);
1171
1172 ComputeAndCompareR0<int32_t>(&builder, 5, {});
1173 }
1174
XLA_TEST_F(WhileTest,WhileWithLoopInvariantOperation)1175 XLA_TEST_F(WhileTest, WhileWithLoopInvariantOperation) {
1176 auto matrix_shape = ShapeUtil::MakeShape(F32, {2, 2});
1177 auto scalar_s32 = ShapeUtil::MakeShape(S32, {});
1178 auto while_shape = ShapeUtil::MakeTupleShape(
1179 {scalar_s32, matrix_shape, matrix_shape, matrix_shape});
1180
1181 // Create a computation for the condition: repeat for 5 iterations.
1182 XlaComputation condition;
1183 {
1184 XlaBuilder builder("condition");
1185 auto state = Parameter(&builder, 0, while_shape, "state");
1186 Gt(ConstantR0<int32_t>(&builder, 5), GetTupleElement(state, 0));
1187 TF_ASSERT_OK_AND_ASSIGN(condition, builder.Build());
1188 }
1189
1190 XlaComputation body;
1191 {
1192 XlaBuilder builder("body");
1193 auto state = Parameter(&builder, 0, while_shape, "state");
1194 auto indvar = GetTupleElement(state, 0);
1195 auto input_0 = GetTupleElement(state, 1);
1196 auto input_1 = GetTupleElement(state, 2);
1197 auto output = Tanh(Dot(input_0, input_1));
1198 auto indvar_next = Add(indvar, ConstantR0<int32_t>(&builder, 1));
1199 Tuple(&builder, {indvar_next, input_0, input_1, output});
1200 TF_ASSERT_OK_AND_ASSIGN(body, builder.Build());
1201 }
1202
1203 XlaBuilder builder(TestName());
1204 auto matrix_input = Parameter(&builder, 0, matrix_shape, "matrix");
1205 auto init = Tuple(&builder, {ConstantR0<int32_t>(&builder, 0), matrix_input,
1206 matrix_input, matrix_input});
1207 auto while_instruction = While(condition, body, init);
1208 GetTupleElement(while_instruction, 3);
1209
1210 TF_ASSERT_OK_AND_ASSIGN(
1211 auto param_value, client_->TransferToServer(LiteralUtil::CreateR2<float>(
1212 {{1.0, 2.0}, {-1.0, -2.0}})));
1213
1214 ComputeAndCompareR2<float>(
1215 &builder, {{-0.76159416, -0.96402758}, {0.76159416, 0.96402758}},
1216 {param_value.get()}, ErrorSpec(4e-5));
1217 }
1218
XLA_TEST_F(WhileTest,DISABLED_ON_INTERPRETER (WhileInfeedCondition))1219 XLA_TEST_F(WhileTest, DISABLED_ON_INTERPRETER(WhileInfeedCondition)) {
1220 auto while_shape = ShapeUtil::MakeShape(S32, {});
1221
1222 XlaComputation condition;
1223 {
1224 XlaBuilder builder("condition");
1225 Parameter(&builder, 0, while_shape, "state");
1226 Infeed(&builder, ShapeUtil::MakeShape(PRED, {}));
1227 TF_ASSERT_OK_AND_ASSIGN(condition, builder.Build());
1228 }
1229
1230 XlaComputation body;
1231 {
1232 XlaBuilder builder("body");
1233 auto indvar = Parameter(&builder, 0, while_shape, "state");
1234 Add(indvar, ConstantR0<int32_t>(&builder, 1));
1235 TF_ASSERT_OK_AND_ASSIGN(body, builder.Build());
1236 }
1237
1238 XlaBuilder builder(TestName());
1239 While(condition, body, ConstantR0<int32_t>(&builder, 0));
1240
1241 TF_ASSERT_OK(client_->TransferToInfeed(LiteralUtil::CreateR0<bool>(true)));
1242 TF_ASSERT_OK(client_->TransferToInfeed(LiteralUtil::CreateR0<bool>(true)));
1243 TF_ASSERT_OK(client_->TransferToInfeed(LiteralUtil::CreateR0<bool>(false)));
1244
1245 ComputeAndCompareR0<int32_t>(&builder, 2, {});
1246 }
1247
BM_WhileLoop(::testing::benchmark::State & state)1248 void BM_WhileLoop(::testing::benchmark::State& state) {
1249 // Benchmark a simple kernel to measure while loop overheads.
1250
1251 se::Platform* platform = PlatformUtil::GetDefaultPlatform().ValueOrDie();
1252 auto executors = PlatformUtil::GetStreamExecutors(platform).ValueOrDie();
1253 se::StreamExecutorMemoryAllocator allocator(platform, executors);
1254 LocalClient* client =
1255 ClientLibrary::GetOrCreateLocalClient(platform).ValueOrDie();
1256
1257 const int64_t seq_len = 100;
1258 Shape loop_state_shape = ShapeUtil::MakeTupleShape(
1259 {ShapeUtil::MakeShape(S32, {}),
1260 ShapeUtil::MakeShape(F32, {seq_len, 1024, 1024})});
1261
1262 // Create while condition computation with 'loop_limit'.
1263 const int32_t loop_limit = 100;
1264 XlaComputation condition;
1265 {
1266 XlaBuilder builder("condition");
1267 auto prev = Parameter(&builder, 0, loop_state_shape, "prev");
1268 auto iteration = GetTupleElement(prev, 0);
1269 Lt(iteration, ConstantR0<int32_t>(&builder, loop_limit));
1270 condition = builder.Build().value();
1271 }
1272
1273 // Create while body computation with unit loop increment.
1274 XlaComputation body;
1275 {
1276 XlaBuilder builder("body");
1277 auto prev = Parameter(&builder, 0, loop_state_shape, "prev");
1278 // TupleElement 0
1279 auto iteration = GetTupleElement(prev, 0);
1280 auto out0 = Add(iteration, ConstantR0<int32_t>(&builder, 1));
1281 // TupleElement 1
1282 auto input = GetTupleElement(prev, 1);
1283 // Update.
1284 auto one = ConstantR0<float>(&builder, 1.0);
1285 auto update = Broadcast(one, {1, 1024, 1024});
1286 // Starts = iteration * 2;
1287 auto zero = ConstantR0<int32_t>(&builder, 0);
1288 // UpdateSlice.
1289 auto out1 = DynamicUpdateSlice(input, update, {zero, zero, zero});
1290 Tuple(&builder, {out0, out1});
1291 body = builder.Build().value();
1292 }
1293
1294 // Create a While instruction.
1295 XlaBuilder builder("while");
1296 auto zero = ConstantR0<float>(&builder, 0.0);
1297 auto input = Broadcast(zero, {seq_len, 1024, 1024});
1298 auto init = Tuple(&builder, {ConstantR0<int32_t>(&builder, 0), input});
1299 While(condition, body, init);
1300 auto computation = builder.Build().value();
1301
1302 TF_ASSERT_OK_AND_ASSIGN(
1303 auto executables,
1304 client->Compile(computation, {}, ExecutableBuildOptions()));
1305 auto executable = std::move(executables[0]);
1306
1307 // Run some warm-up executions.
1308 ExecutableRunOptions options;
1309 options.set_allocator(&allocator);
1310 const int kWarmups = 2;
1311 for (int i = 0; i < kWarmups; ++i) {
1312 auto result =
1313 executable->Run(absl::Span<const ShapedBuffer* const>(), options);
1314 ASSERT_TRUE(result.ok());
1315 }
1316
1317 // Run benchmark.
1318 for (auto s : state) {
1319 auto result =
1320 executable->Run(absl::Span<const ShapedBuffer* const>(), options);
1321 ASSERT_TRUE(result.ok());
1322 }
1323 }
1324
1325 BENCHMARK(BM_WhileLoop);
1326 } // namespace
1327 } // namespace xla
1328