• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <memory>
17 #include <string>
18 #include <vector>
19 
20 #include "tensorflow/compiler/xla/client/client_library.h"
21 #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
22 #include "tensorflow/compiler/xla/client/local_client.h"
23 #include "tensorflow/compiler/xla/client/xla_builder.h"
24 #include "tensorflow/compiler/xla/client/xla_computation.h"
25 #include "tensorflow/compiler/xla/literal.h"
26 #include "tensorflow/compiler/xla/service/platform_util.h"
27 #include "tensorflow/compiler/xla/shape_util.h"
28 #include "tensorflow/compiler/xla/status_macros.h"
29 #include "tensorflow/compiler/xla/statusor.h"
30 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
31 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
32 #include "tensorflow/compiler/xla/tests/test_macros.h"
33 #include "tensorflow/compiler/xla/xla_data.pb.h"
34 #include "tensorflow/core/lib/core/status_test_util.h"
35 #include "tensorflow/core/platform/logging.h"
36 #include "tensorflow/core/platform/test.h"
37 #include "tensorflow/core/platform/test_benchmark.h"
38 
39 namespace xla {
40 namespace {
41 
42 class WhileTest : public ClientLibraryTestBase {};
43 
44 // Tests a while node when the result type T is S32.
45 //
46 // int32_t result = 0;
47 // while (result < 5) {
48 //   result = result + 1;
49 // }
XLA_TEST_F(WhileTest,WhileWithScalarS32Result)50 XLA_TEST_F(WhileTest, WhileWithScalarS32Result) {
51   auto result_shape = ShapeUtil::MakeShape(S32, {});
52 
53   // Create a computation for the condition: repeat for 5 iterations.
54   XlaComputation condition;
55   {
56     XlaBuilder builder("condition");
57     auto prev = Parameter(&builder, 0, result_shape, "prev");
58     Gt(ConstantR0<int32_t>(&builder, 5), prev);
59     condition = builder.Build().value();
60   }
61 
62   // Create a computation for the body: add 1 to the result variable.
63   XlaComputation body;
64   {
65     XlaBuilder builder("body");
66     auto prev = Parameter(&builder, 0, result_shape, "prev");
67     auto input = ConstantR0<int32_t>(&builder, 1);
68     Add(input, prev);
69     body = builder.Build().value();
70   }
71 
72   // Create a While node with computations for the condition and the body.
73   XlaBuilder builder(TestName());
74   auto init = ConstantR0<int32_t>(&builder, 0);
75   While(condition, body, init);
76 
77   ComputeAndCompareR0<int32_t>(&builder, 5, {});
78 }
79 
80 // Tests a while node when the result type T is S64.
81 //
82 // int32_t result = 0;
83 // while (result < 5) {
84 //   result = result + 1;
85 // }
XLA_TEST_F(WhileTest,WhileWithScalarS64Result)86 XLA_TEST_F(WhileTest, WhileWithScalarS64Result) {
87   auto result_shape = ShapeUtil::MakeShape(S64, {});
88 
89   // Create a computation for the condition: repeat for 5 iterations.
90   XlaComputation condition;
91   {
92     XlaBuilder builder("condition");
93     auto prev = Parameter(&builder, 0, result_shape, "prev");
94     Gt(ConstantR0<int64_t>(&builder, 5), prev);
95     condition = builder.Build().value();
96   }
97 
98   // Create a computation for the body: add 1 to the result variable.
99   XlaComputation body;
100   {
101     XlaBuilder builder("body");
102     auto prev = Parameter(&builder, 0, result_shape, "prev");
103     auto input = ConstantR0<int64_t>(&builder, 1);
104     Add(input, prev);
105     body = builder.Build().value();
106   }
107 
108   // Create a While node with computations for the condition and the body.
109   XlaBuilder builder(TestName());
110   auto init = ConstantR0<int64_t>(&builder, 0);
111   While(condition, body, init);
112 
113   ComputeAndCompareR0<int64_t>(&builder, 5, {});
114 }
115 
XLA_TEST_F(WhileTest,WhileWithScalarResultNonConstInit)116 XLA_TEST_F(WhileTest, WhileWithScalarResultNonConstInit) {
117   auto result_shape = ShapeUtil::MakeShape(S32, {});
118   auto orig_shape = ShapeUtil::MakeShape(S32, {2});
119 
120   // Create a computation for the condition: repeat for 5 iterations.
121   XlaComputation condition;
122   {
123     XlaBuilder builder("condition");
124     auto prev = Parameter(&builder, 0, result_shape, "prev");
125     Gt(ConstantR0<int32_t>(&builder, 5), prev);
126     condition = builder.Build().value();
127   }
128 
129   // Create a computation for the body: add 1 to the result variable.
130   XlaComputation body;
131   {
132     XlaBuilder builder("body");
133     auto prev = Parameter(&builder, 0, result_shape, "prev");
134     auto input = ConstantR0<int32_t>(&builder, 1);
135     Add(input, prev);
136     body = builder.Build().value();
137   }
138 
139   // Create a While node with computations for the condition and the body.
140   XlaBuilder builder(TestName());
141   auto init = Reduce(ConstantR1<int32_t>(&builder, 2, 1),
142                      ConstantR0<int32_t>(&builder, 0),
143                      CreateScalarAddComputation(S32, &builder), {0});
144   While(condition, body, init);
145 
146   ComputeAndCompareR0<int32_t>(&builder, 5, {});
147 }
148 
XLA_TEST_F(WhileTest,WhileWithPredicateResult)149 XLA_TEST_F(WhileTest, WhileWithPredicateResult) {
150   auto result_shape = ShapeUtil::MakeShape(PRED, {});
151 
152   // Create a computation for the condition: run until condition is true.
153   XlaComputation condition;
154   {
155     XlaBuilder builder("condition");
156     auto prev = Parameter(&builder, 0, result_shape, "prev");
157     Ne(ConstantR0<bool>(&builder, true), prev);
158     condition = builder.Build().value();
159   }
160 
161   // Create a computation for the body: or condition with true.
162   XlaComputation body;
163   {
164     XlaBuilder builder("body");
165     auto prev = Parameter(&builder, 0, result_shape, "prev");
166     Or(prev, ConstantR0<bool>(&builder, true));
167     body = builder.Build().value();
168   }
169 
170   // Create a While node with computations for the condition and the body.
171   XlaBuilder builder(TestName());
172   auto init =
173       Ne(ConstantR0<bool>(&builder, false), ConstantR0<bool>(&builder, true));
174   While(condition, body, init);
175 
176   ComputeAndCompareR0<bool>(&builder, true, {});
177 }
178 
179 // Tests a while node when the result type T is a vector.
180 //
181 // All constants are chosen to produce exact results.
182 // vector<float> result(0);
183 // while (result.sum() < 15.5f) {
184 //   result = result + vector<float>(0);
185 // }
XLA_TEST_F(WhileTest,DISABLED_ON_INTERPRETER (WhileWithEmptyVectorResult))186 XLA_TEST_F(WhileTest, DISABLED_ON_INTERPRETER(WhileWithEmptyVectorResult)) {
187   Shape result_shape = ShapeUtil::MakeShape(F32, {0});
188 
189   // Create a computation for the reduction.
190   XlaComputation add;
191   {
192     XlaBuilder builder("add");
193     auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {}), "x");
194     auto y = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {}), "y");
195     Add(x, y);
196     add = builder.Build().value();
197   }
198 
199   // Create a computation for the condition.
200   // Repeat until the sum of the result vector is less than 15.5f.
201   XlaComputation condition;
202   {
203     XlaBuilder builder("condition");
204     auto prev = Parameter(&builder, 0, result_shape, "prev");
205     auto sum = Reduce(prev, ConstantR0<float>(&builder, 0.0f), add,
206                       /*dimensions_to_reduce=*/{0});
207     Gt(ConstantR0<float>(&builder, 15.5f), sum);
208     condition = builder.Build().value();
209   }
210 
211   // Create a computation for the body.
212   // Add a constant vector of 1.f to the result vector.
213   XlaComputation body;
214   {
215     XlaBuilder builder("body");
216     auto prev = Parameter(&builder, 0, result_shape, "prev");
217     auto input = ConstantR1<float>(&builder, {});
218     Add(input, prev);
219     body = builder.Build().value();
220   }
221 
222   // Create a While node with computations for the condition and the body.
223   XlaBuilder builder("while");
224   auto init = ConstantR1<float>(&builder, {});
225   auto result = While(condition, body, init);
226   VLOG(2) << "while = "
227           << ShapeUtil::HumanString(builder.GetShape(result).value());
228 
229   ComputeAndCompareR1<float>(&builder, {}, {}, ErrorSpec(0.0001));
230 }
231 
232 // Tests a while node when the result type T is a vector.
233 //
234 // All constants are chosen to produce exact results.
235 // vector<float> result(8, 0.0f);
236 // while (result.sum() < 15.5f) {
237 //   result = result + vector<float>(8, 0.125f);
238 // }
XLA_TEST_F(WhileTest,WhileWithVectorResult)239 XLA_TEST_F(WhileTest, WhileWithVectorResult) {
240   Shape result_shape = ShapeUtil::MakeShape(F32, {8});
241 
242   // Create a computation for the reduction.
243   XlaComputation add;
244   {
245     XlaBuilder builder("add");
246     auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {}), "x");
247     auto y = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {}), "y");
248     Add(x, y);
249     add = builder.Build().value();
250   }
251 
252   // Create a computation for the condition.
253   // Repeat until the sum of the result vector is less than 5.5f.
254   XlaComputation condition;
255   {
256     XlaBuilder builder("condition");
257     auto prev = Parameter(&builder, 0, result_shape, "prev");
258     auto sum = Reduce(prev, ConstantR0<float>(&builder, 0.0f), add,
259                       /*dimensions_to_reduce=*/{0});
260     Gt(ConstantR0<float>(&builder, 15.5f), sum);
261     condition = builder.Build().value();
262   }
263 
264   // Create a computation for the body.
265   // Add a constant vector of 1.f to the result vector.
266   XlaComputation body;
267   {
268     XlaBuilder builder("body");
269     auto prev = Parameter(&builder, 0, result_shape, "prev");
270     auto input = ConstantR1<float>(&builder, 8, 0.125f);
271     Add(input, prev);
272     body = builder.Build().value();
273   }
274 
275   // Create a While node with computations for the condition and the body.
276   XlaBuilder builder("while");
277   auto init = ConstantR1<float>(&builder, 8, 0.f);
278   auto result = While(condition, body, init);
279   VLOG(2) << "while = "
280           << ShapeUtil::HumanString(builder.GetShape(result).value());
281 
282   // Individual elements with increase by 1/8 each time through the loop, so
283   // the sum will increase by 1.0.  It will first be >15.5 when the elements
284   // have all reached 2.0.
285   std::vector<float> expected = {2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f};
286   ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
287 }
288 
289 // Tests a while node when the result type is a vector which is part
290 // of the result tuple.
291 //
292 // All constants are chosen to produce exact results.
293 // vector<float> result(8, 0.0f);
294 // while (result.sum() < 15.5f) {
295 //   result = result + vector<float>(8, 0.125f);
296 // }
297 // tuple = tuple { while }
XLA_TEST_F(WhileTest,WhileWithVectorResultIntoTuple)298 XLA_TEST_F(WhileTest, WhileWithVectorResultIntoTuple) {
299   Shape result_shape = ShapeUtil::MakeShape(F32, {8});
300 
301   // Create a computation for the reduction.
302   XlaComputation add;
303   {
304     XlaBuilder builder("add");
305     auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {}), "x");
306     auto y = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {}), "y");
307     Add(x, y);
308     add = builder.Build().value();
309   }
310 
311   // Create a computation for the condition.
312   // Repeat until the sum of the result vector is less than 5.5f.
313   XlaComputation condition;
314   {
315     XlaBuilder builder("condition");
316     auto prev = Parameter(&builder, 0, result_shape, "prev");
317     auto sum = Reduce(prev, ConstantR0<float>(&builder, 0.0f), add,
318                       /*dimensions_to_reduce=*/{0});
319     Gt(ConstantR0<float>(&builder, 15.5f), sum);
320     condition = builder.Build().value();
321   }
322 
323   // Create a computation for the body.
324   // Add a constant vector of 1.f to the result vector.
325   XlaComputation body;
326   {
327     XlaBuilder builder("body");
328     auto prev = Parameter(&builder, 0, result_shape, "prev");
329     auto input = ConstantR1<float>(&builder, 8, 0.125f);
330     Add(input, prev);
331     body = builder.Build().value();
332   }
333 
334   // Create a While node with computations for the condition and the body.
335   XlaBuilder builder("while");
336   auto init = ConstantR1<float>(&builder, 8, 0.f);
337   auto result = While(condition, body, init);
338   VLOG(2) << "while = "
339           << ShapeUtil::HumanString(builder.GetShape(result).value());
340   Tuple(&builder, {result});
341 
342   // Individual elements with increase by 1/8 each time through the loop, so
343   // the sum will increase by 1.0.  It will first be >15.5 when the elements
344   // have all reached 2.0.
345   auto expected_data =
346       LiteralUtil::CreateR1<float>({2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f});
347   auto expected = LiteralUtil::MakeTuple({&expected_data});
348   VLOG(2) << "expected = " << ShapeUtil::HumanString(expected.shape());
349   ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.0001));
350 }
351 
XLA_TEST_F(WhileTest,WhileWithPermutationAndTupleResult)352 XLA_TEST_F(WhileTest, WhileWithPermutationAndTupleResult) {
353   std::vector<Shape> shape_elements = {
354       ShapeUtil::MakeShape(S32, {}), ShapeUtil::MakeShape(F32, {3}),
355       ShapeUtil::MakeShape(F32, {3}), ShapeUtil::MakeShape(F32, {3})};
356   Shape result_shape = ShapeUtil::MakeTupleShape(shape_elements);
357 
358   // Create a computation for the condition.
359   // Repeat for N iterations.
360   const int N = 2;
361   XlaComputation condition;
362   {
363     XlaBuilder builder("condition");
364     auto prev = Parameter(&builder, 0, result_shape, "prev");
365     auto iteration = GetTupleElement(prev, 0);
366     Gt(ConstantR0<int32_t>(&builder, N), iteration);
367     condition = builder.Build().value();
368   }
369 
370   // Create a computation for the body.
371   // Add 1 to the iteration variable and permute the weights.
372   XlaComputation body;
373   {
374     XlaBuilder builder("body");
375     auto prev = Parameter(&builder, 0, result_shape, "prev");
376     auto iteration = GetTupleElement(prev, 0);
377     auto w1 = GetTupleElement(prev, 1);
378     auto w2 = GetTupleElement(prev, 2);
379     auto w3 = GetTupleElement(prev, 3);
380     Tuple(&builder,
381           {Add(iteration, ConstantR0<int32_t>(&builder, 1)), w3, w1, w2});
382     body = builder.Build().value();
383   }
384 
385   // Create a While node with computations for the condition and the body.
386   XlaBuilder builder("while");
387   auto init = Tuple(&builder, {ConstantR0<int32_t>(&builder, 0),
388                                ConstantR1<float>(&builder, 3, 1.f),
389                                ConstantR1<float>(&builder, 3, 2.f),
390                                ConstantR1<float>(&builder, 3, 3.f)});
391   auto result = While(condition, body, init);
392   VLOG(2) << "result = "
393           << ShapeUtil::HumanString(builder.GetShape(result).value());
394 
395   auto expected_counter = LiteralUtil::CreateR0<int32_t>(N);
396   auto expected_w1 = LiteralUtil::CreateR1<float>({1.0f, 1.0f, 1.0f});
397   auto expected_w2 = LiteralUtil::CreateR1<float>({2.0f, 2.0f, 2.0f});
398   auto expected_w3 = LiteralUtil::CreateR1<float>({3.0f, 3.0f, 3.0f});
399   auto expected = LiteralUtil::MakeTuple(
400       {&expected_counter, &expected_w2, &expected_w3, &expected_w1});
401   VLOG(2) << "expected = " << ShapeUtil::HumanString(expected.shape());
402   ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.0001));
403 }
404 
XLA_TEST_F(WhileTest,WhileWithPermutationAndVectorResult)405 XLA_TEST_F(WhileTest, WhileWithPermutationAndVectorResult) {
406   std::vector<Shape> shape_elements = {
407       ShapeUtil::MakeShape(S32, {}), ShapeUtil::MakeShape(F32, {3}),
408       ShapeUtil::MakeShape(F32, {3}), ShapeUtil::MakeShape(F32, {3})};
409   Shape result_shape = ShapeUtil::MakeTupleShape(shape_elements);
410 
411   // Create a computation for the condition.
412   // Repeat for N iterations.
413   const int N = 2;
414   XlaComputation condition;
415   {
416     XlaBuilder builder("condition");
417     auto prev = Parameter(&builder, 0, result_shape, "prev");
418     auto iteration = GetTupleElement(prev, 0);
419     Gt(ConstantR0<int32_t>(&builder, N), iteration);
420     condition = builder.Build().value();
421   }
422 
423   // Create a computation for the body.
424   // Add 1 to the iteration variable permute the weights.
425   XlaComputation body;
426   {
427     XlaBuilder builder("body");
428     auto prev = Parameter(&builder, 0, result_shape, "prev");
429     auto iteration = GetTupleElement(prev, 0);
430     auto w1 = GetTupleElement(prev, 1);
431     auto w2 = GetTupleElement(prev, 2);
432     auto w3 = GetTupleElement(prev, 3);
433     Tuple(&builder,
434           {Add(iteration, ConstantR0<int32_t>(&builder, 1)), w3, w1, w2});
435     body = builder.Build().value();
436   }
437 
438   // Create a While node with computations for the condition and the body.
439   XlaBuilder builder("while");
440   auto init = Tuple(&builder, {ConstantR0<int32_t>(&builder, 0),
441                                ConstantR1<float>(&builder, 3, 1.f),
442                                ConstantR1<float>(&builder, 3, 2.f),
443                                ConstantR1<float>(&builder, 3, 3.f)});
444   auto xla_while = While(condition, body, init);
445 
446   auto add12 =
447       Add(GetTupleElement(xla_while, 1), GetTupleElement(xla_while, 2));
448   auto result = Add(add12, GetTupleElement(xla_while, 3));
449   VLOG(2) << "result = "
450           << ShapeUtil::HumanString(builder.GetShape(result).value());
451   std::vector<float> expected = {6.f, 6.f, 6.f};
452   ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
453 }
454 
455 // Tests a while node when the result type T is a Tuple.
456 //
457 // tuple<int32_t, vector<float>> result(0, vector<float>(10, 0.0f));
458 // while (get<0>(result) < 5) {
459 //   get<0>(result) = get<0>(result) + 1;
460 //   get<1>(result) = get<1>(result) + vector<float>(10, 1.0f);
461 // }
XLA_TEST_F(WhileTest,WhileWithTupleResult)462 XLA_TEST_F(WhileTest, WhileWithTupleResult) {
463   std::vector<Shape> shape_elements = {ShapeUtil::MakeShape(S32, {}),
464                                        ShapeUtil::MakeShape(F32, {10})};
465   Shape result_shape = ShapeUtil::MakeTupleShape(shape_elements);
466 
467   // Create a computation for the condition.
468   // Repeat for 5 iterations.
469   XlaComputation condition;
470   {
471     XlaBuilder builder("condition");
472     auto prev = Parameter(&builder, 0, result_shape, "prev");
473     auto iteration = GetTupleElement(prev, 0);
474     Gt(ConstantR0<int32_t>(&builder, 5), iteration);
475     condition = builder.Build().value();
476   }
477 
478   // Create a computation for the body.
479   // Add 1 to the iteration variable and add a constant vector of 1.0f to
480   // the weight variable, both of which are tuple elements.
481   XlaComputation body;
482   {
483     XlaBuilder builder("body");
484     auto prev = Parameter(&builder, 0, result_shape, "prev");
485     auto iteration = GetTupleElement(prev, 0);
486     auto weights = GetTupleElement(prev, 1);
487     auto input = ConstantR1<float>(&builder, 10, 1.f);
488     auto new_weights = Add(weights, input);
489     Tuple(&builder,
490           {Add(iteration, ConstantR0<int32_t>(&builder, 1)), new_weights});
491     body = builder.Build().value();
492   }
493 
494   // Create a While node with computations for the condition and the body.
495   XlaBuilder builder("while");
496   auto init = Tuple(&builder, {ConstantR0<int32_t>(&builder, 0),
497                                ConstantR1<float>(&builder, 10, 0.f)});
498   auto result = While(condition, body, init);
499   VLOG(2) << "while = "
500           << ShapeUtil::HumanString(builder.GetShape(result).value());
501 
502   auto expected_counter = LiteralUtil::CreateR0<int32_t>(5);
503   auto expected_data = LiteralUtil::CreateR1<float>(
504       {5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f});
505   auto expected = LiteralUtil::MakeTuple({&expected_counter, &expected_data});
506   VLOG(2) << "expected = " << ShapeUtil::HumanString(expected.shape());
507   ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.0001));
508 }
509 
XLA_TEST_F(WhileTest,WhileWithPredicateTupleResult)510 XLA_TEST_F(WhileTest, WhileWithPredicateTupleResult) {
511   std::vector<Shape> shape_elements = {ShapeUtil::MakeShape(S32, {}),
512                                        ShapeUtil::MakeShape(PRED, {})};
513   Shape result_shape = ShapeUtil::MakeTupleShape(shape_elements);
514 
515   // Create a computation for the condition.
516   // Repeat for 5 iterations.
517   XlaComputation condition;
518   {
519     XlaBuilder builder("condition");
520     auto prev = Parameter(&builder, 0, result_shape, "prev");
521     auto iteration = GetTupleElement(prev, 0);
522     Gt(ConstantR0<int32_t>(&builder, 5), iteration);
523     condition = builder.Build().value();
524   }
525 
526   // Create a computation for the body.
527   // Add 1 to the iteration variable and or the predicate with true
528   XlaComputation body;
529   {
530     XlaBuilder builder("body");
531     auto prev = Parameter(&builder, 0, result_shape, "prev");
532     auto iteration = GetTupleElement(prev, 0);
533     auto pred = GetTupleElement(prev, 1);
534     auto new_pred = Or(pred, ConstantR0<bool>(&builder, true));
535     Tuple(&builder,
536           {Add(iteration, ConstantR0<int32_t>(&builder, 1)), new_pred});
537     body = builder.Build().value();
538   }
539 
540   // Create a While node with computations for the condition and the body.
541   XlaBuilder builder("while");
542   auto init = Tuple(&builder, {ConstantR0<int32_t>(&builder, 0),
543                                Ne(ConstantR0<bool>(&builder, false),
544                                   ConstantR0<bool>(&builder, true))});
545   auto result = While(condition, body, init);
546   VLOG(2) << "while = "
547           << ShapeUtil::HumanString(builder.GetShape(result).value());
548 
549   auto expected_counter = LiteralUtil::CreateR0<int32_t>(5);
550   auto expected_predicate = LiteralUtil::CreateR0<bool>(true);
551   auto expected =
552       LiteralUtil::MakeTuple({&expected_counter, &expected_predicate});
553   ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0));
554 }
555 
XLA_TEST_F(WhileTest,WhileWithTupleConstantScalarResult)556 XLA_TEST_F(WhileTest, WhileWithTupleConstantScalarResult) {
557   std::vector<Shape> shape_elements = {ShapeUtil::MakeShape(S32, {}),
558                                        ShapeUtil::MakeShape(S32, {})};
559   Shape result_shape = ShapeUtil::MakeTupleShape(shape_elements);
560 
561   // Create a computation for the condition.
562   // Repeat for 5 iterations.
563   XlaComputation condition;
564   {
565     XlaBuilder builder("condition");
566     auto prev = Parameter(&builder, 0, result_shape, "prev");
567     auto iteration = GetTupleElement(prev, 0);
568     Gt(ConstantR0<int32_t>(&builder, 5), iteration);
569     condition = builder.Build().value();
570   }
571 
572   // Create a computation for the body.
573   // Add 1 to the iteration variable and set the other tuple element to a
574   // constant.
575   XlaComputation body;
576   {
577     XlaBuilder builder("body");
578     auto prev = Parameter(&builder, 0, result_shape, "prev");
579     auto iteration = GetTupleElement(prev, 0);
580     Tuple(&builder, {Add(iteration, ConstantR0<int32_t>(&builder, 1)),
581                      ConstantR0<int32_t>(&builder, 7)});
582     body = builder.Build().value();
583   }
584 
585   // Create a While node with computations for the condition and the body.
586   XlaBuilder builder("while");
587   auto init = Tuple(&builder, {ConstantR0<int32_t>(&builder, 0),
588                                ConstantR0<int32_t>(&builder, 7)});
589   auto result = While(condition, body, init);
590   VLOG(2) << "while = "
591           << ShapeUtil::HumanString(builder.GetShape(result).value());
592 
593   auto expected_counter = LiteralUtil::CreateR0<int32_t>(5);
594   auto expected_data = LiteralUtil::CreateR0<int32_t>(7);
595   auto expected = LiteralUtil::MakeTuple({&expected_counter, &expected_data});
596   VLOG(2) << "expected = " << ShapeUtil::HumanString(expected.shape());
597   ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.0001));
598 }
599 
600 // Tests two while nodes when the result type T is a Tuple and the second
601 // while node uses the result of the first while node which is used in two
602 // nodes.
603 // tuple<int32_t, vector<float>> w0(0, vector<float>(10, 0.0f));
604 // w0 = while (get<0>(w0) < c1) {
605 //        get<0>(w0) = get<0>(w0) + 1;
606 //        get<1>(w0) = get<1>(w0) + vector<float>(10, 1.0f);
607 //      }
608 // tuple<int32_t, vector<float>> w1(get<0>(w0), get<1>(w0));
609 // w1 = while (get<0>(w1) < c2) {
610 //        get<0>(w1) = get<0>(w1) + 1;
611 //        get<1>(w1) = get<1>(w1) + vector<float>(10, 1.0f);
612 //      }
613 // result = get<1>(w0) + get<1>(w1)
XLA_TEST_F(WhileTest,TwoWhileWithTupleResult)614 XLA_TEST_F(WhileTest, TwoWhileWithTupleResult) {
615   std::vector<Shape> shape_elements = {ShapeUtil::MakeShape(S32, {}),
616                                        ShapeUtil::MakeShape(F32, {10})};
617   Shape result_shape = ShapeUtil::MakeTupleShape(shape_elements);
618 
619   // Create a computation for the condition.
620   // Repeat for 5 iterations.
621   XlaComputation condition;
622   const int c1 = 5;
623   {
624     XlaBuilder builder("condition");
625     auto prev = Parameter(&builder, 0, result_shape, "prev");
626     auto iteration = GetTupleElement(prev, 0);
627     Lt(iteration, ConstantR0<int32_t>(&builder, c1));
628     TF_ASSERT_OK_AND_ASSIGN(condition, builder.Build());
629   }
630 
631   XlaComputation condition2;
632   const int c2 = 7;
633   {
634     XlaBuilder builder("condition2");
635     auto prev = Parameter(&builder, 0, result_shape, "prev");
636     auto iteration = GetTupleElement(prev, 0);
637     Lt(iteration, ConstantR0<int32_t>(&builder, c2));
638     TF_ASSERT_OK_AND_ASSIGN(condition2, builder.Build());
639   }
640 
641   // Create a computation for the body.
642   // Add 1 to the iteration variable and add a constant vector of 1.0f to
643   // the weight variable, both of which are tuple elements.
644   XlaComputation body;
645   {
646     XlaBuilder builder("body");
647     auto prev = Parameter(&builder, 0, result_shape, "prev");
648     auto iteration = GetTupleElement(prev, 0);
649     auto weights = GetTupleElement(prev, 1);
650     auto input = ConstantR1<float>(&builder, 10, 1.f);
651     auto new_weights = Add(weights, input);
652     Tuple(&builder,
653           {Add(iteration, ConstantR0<int32_t>(&builder, 1)), new_weights});
654     TF_ASSERT_OK_AND_ASSIGN(body, builder.Build());
655   }
656 
657   XlaComputation body2;
658   {
659     XlaBuilder builder("body");
660     auto prev = Parameter(&builder, 0, result_shape, "prev");
661     auto iteration = GetTupleElement(prev, 0);
662     auto weights = GetTupleElement(prev, 1);
663     auto input = ConstantR1<float>(&builder, 10, 1.f);
664     auto new_weights = Add(weights, input);
665     Tuple(&builder,
666           {Add(iteration, ConstantR0<int32_t>(&builder, 1)), new_weights});
667     TF_ASSERT_OK_AND_ASSIGN(body2, builder.Build());
668   }
669 
670   // Create a While node with computations for the condition and the body.
671   XlaBuilder builder("while");
672   auto init = Tuple(&builder, {ConstantR0<int32_t>(&builder, 0),
673                                ConstantR1<float>(&builder, 10, 0.f)});
674   auto while1 = While(condition, body, init);
675 
676   auto while2 = While(condition2, body2, while1);
677 
678   auto while_result1 = GetTupleElement(while1, 1);
679   auto while_result2 = GetTupleElement(while2, 1);
680   VLOG(2) << "while_result2 = "
681           << ShapeUtil::HumanString(builder.GetShape(while_result2).value());
682   auto result = Add(while_result1, while_result2);
683   VLOG(2) << "result = "
684           << ShapeUtil::HumanString(builder.GetShape(result).value());
685   const float sum = c1 + c2;
686   std::vector<float> expected(10, sum);
687   ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
688 }
689 
690 // Test while nodes that share the while body computation.
XLA_TEST_F(WhileTest,TwoWhileLoopsAndSharedBody)691 XLA_TEST_F(WhileTest, TwoWhileLoopsAndSharedBody) {
692   std::vector<Shape> shape_elements = {ShapeUtil::MakeShape(S32, {}),
693                                        ShapeUtil::MakeShape(F32, {10})};
694   Shape result_shape = ShapeUtil::MakeTupleShape(shape_elements);
695 
696   // Create a computation for the condition.
697   // Repeat for 5 iterations.
698   XlaComputation condition;
699   const int c1 = 5;
700   {
701     XlaBuilder builder("condition");
702     auto prev = Parameter(&builder, 0, result_shape, "prev");
703     auto iteration = GetTupleElement(prev, 0);
704     Lt(iteration, ConstantR0<int32_t>(&builder, c1));
705     TF_ASSERT_OK_AND_ASSIGN(condition, builder.Build());
706   }
707 
708   XlaComputation condition2;
709   const int c2 = 7;
710   {
711     XlaBuilder builder("condition2");
712     auto prev = Parameter(&builder, 0, result_shape, "prev");
713     auto iteration = GetTupleElement(prev, 0);
714     Lt(iteration, ConstantR0<int32_t>(&builder, c2));
715     TF_ASSERT_OK_AND_ASSIGN(condition2, builder.Build());
716   }
717 
718   // Create a computation for the body.
719   // Add 1 to the iteration variable and add a constant vector of 1.0f to
720   // the weight variable, both of which are tuple elements.
721   XlaComputation body;
722   {
723     XlaBuilder builder("body");
724     auto prev = Parameter(&builder, 0, result_shape, "prev");
725     auto iteration = GetTupleElement(prev, 0);
726     auto weights = GetTupleElement(prev, 1);
727     auto input = ConstantR1<float>(&builder, 10, 1.f);
728     auto new_weights = Add(weights, input);
729     Tuple(&builder,
730           {Add(iteration, ConstantR0<int32_t>(&builder, 1)), new_weights});
731     TF_ASSERT_OK_AND_ASSIGN(body, builder.Build());
732   }
733 
734   // Create a While node with computations for the condition and the body.
735   XlaBuilder builder("while");
736   auto init = Tuple(&builder, {ConstantR0<int32_t>(&builder, 0),
737                                ConstantR1<float>(&builder, 10, 0.f)});
738   auto while1 = While(condition, body, init);
739 
740   auto while2 = While(condition2, body, while1);
741 
742   auto while_result1 = GetTupleElement(while1, 1);
743   auto while_result2 = GetTupleElement(while2, 1);
744   VLOG(2) << "while_result2 = "
745           << ShapeUtil::HumanString(builder.GetShape(while_result2).value());
746   auto result = Add(while_result1, while_result2);
747   VLOG(2) << "result = "
748           << ShapeUtil::HumanString(builder.GetShape(result).value());
749   const float sum = c1 + c2;
750   std::vector<float> expected(10, sum);
751   ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
752 }
753 
XLA_TEST_F(WhileTest,WhileLoopsWithSharedBodyAndInit)754 XLA_TEST_F(WhileTest, WhileLoopsWithSharedBodyAndInit) {
755   std::vector<Shape> shape_elements = {ShapeUtil::MakeShape(S32, {}),
756                                        ShapeUtil::MakeShape(F32, {10})};
757   Shape result_shape = ShapeUtil::MakeTupleShape(shape_elements);
758 
759   // Create a computation for the condition.
760   // Repeat for 5 iterations.
761   XlaComputation condition;
762   const int c1 = 5;
763   {
764     XlaBuilder builder("condition");
765     auto prev = Parameter(&builder, 0, result_shape, "prev");
766     auto iteration = GetTupleElement(prev, 0);
767     Lt(iteration, ConstantR0<int32_t>(&builder, c1));
768     TF_ASSERT_OK_AND_ASSIGN(condition, builder.Build());
769   }
770 
771   XlaComputation condition2;
772   const int c2 = 7;
773   {
774     XlaBuilder builder("condition2");
775     auto prev = Parameter(&builder, 0, result_shape, "prev");
776     auto iteration = GetTupleElement(prev, 0);
777     Lt(iteration, ConstantR0<int32_t>(&builder, c2));
778     TF_ASSERT_OK_AND_ASSIGN(condition2, builder.Build());
779   }
780 
781   // Create a computation for the body.
782   // Add 1 to the iteration variable and add a constant vector of 1.0f to
783   // the weight variable, both of which are tuple elements.
784   XlaComputation body;
785   {
786     XlaBuilder builder("body");
787     auto prev = Parameter(&builder, 0, result_shape, "prev");
788     auto iteration = GetTupleElement(prev, 0);
789     auto weights = GetTupleElement(prev, 1);
790     auto input = ConstantR1<float>(&builder, 10, 1.f);
791     auto new_weights = Add(weights, input);
792     Tuple(&builder,
793           {Add(iteration, ConstantR0<int32_t>(&builder, 1)), new_weights});
794     TF_ASSERT_OK_AND_ASSIGN(body, builder.Build());
795   }
796 
797   // Create a While node with computations for the condition and the body.
798   XlaBuilder builder("while");
799   auto init = Tuple(&builder, {ConstantR0<int32_t>(&builder, 0),
800                                ConstantR1<float>(&builder, 10, 0.f)});
801   auto while1 = While(condition, body, init);
802   auto while2 = While(condition2, body, init);
803 
804   auto while_result1 = GetTupleElement(while1, 1);
805   auto while_result2 = GetTupleElement(while2, 1);
806   VLOG(2) << "while_result2 = "
807           << ShapeUtil::HumanString(builder.GetShape(while_result2).value());
808   auto result = Add(while_result1, while_result2);
809   VLOG(2) << "result = "
810           << ShapeUtil::HumanString(builder.GetShape(result).value());
811   const float sum = c1 + c2;
812   std::vector<float> expected(10, sum);
813   ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
814 }
815 
816 // WhileTest that uses DynamicUpdateSlice instruction in body computation.
817 // Loop state tuple element 1 has as its single user operand(0) of
818 // DynamicUpdateSlice, which will trigger in-place dynamic slice update on GPU.
XLA_TEST_F(WhileTest,WhileWithDynamicUpdateSlice)819 XLA_TEST_F(WhileTest, WhileWithDynamicUpdateSlice) {
820   std::vector<Shape> shape_elements = {ShapeUtil::MakeShape(S32, {}),
821                                        ShapeUtil::MakeShape(F32, {10})};
822   Shape result_shape = ShapeUtil::MakeTupleShape(shape_elements);
823 
824   // Create a computation for the condition.
825   // Repeat for 5 iterations.
826   XlaComputation condition;
827   {
828     XlaBuilder builder("condition");
829     auto prev = Parameter(&builder, 0, result_shape, "prev");
830     auto iteration = GetTupleElement(prev, 0);
831     Gt(ConstantR0<int32_t>(&builder, 5), iteration);
832     condition = builder.Build().value();
833   }
834 
835   // Create a computation for the body.
836   // Add 1 to the iteration variable and add a constant vector of 1.0f to
837   // the weight variable, both of which are tuple elements.
838   XlaComputation body;
839   {
840     XlaBuilder builder("body");
841     auto prev = Parameter(&builder, 0, result_shape, "prev");
842     // TupleElement 0
843     auto iteration = GetTupleElement(prev, 0);
844     auto out0 = Add(iteration, ConstantR0<int32_t>(&builder, 1));
845     // TupleElement 1
846     auto input = GetTupleElement(prev, 1);
847     // Update.
848     auto update = ConvertElementType(Broadcast(out0, {2}), F32);
849     // Starts = iteration * 2;
850     auto starts = Mul(iteration, ConstantR0<int32_t>(&builder, 2));
851     // UpdateSlice.
852     auto out1 = DynamicUpdateSlice(input, update, {starts});
853 
854     Tuple(&builder, {out0, out1});
855     body = builder.Build().value();
856   }
857 
858   // Create a While node with computations for the condition and the body.
859   XlaBuilder builder("while");
860   auto init = Tuple(&builder, {ConstantR0<int32_t>(&builder, 0),
861                                ConstantR1<float>(&builder, 10, 0.f)});
862   auto result = While(condition, body, init);
863   VLOG(2) << "while = "
864           << ShapeUtil::HumanString(builder.GetShape(result).value());
865 
866   auto expected_counter = LiteralUtil::CreateR0<int32_t>(5);
867   auto expected_data = LiteralUtil::CreateR1<float>(
868       {1.0f, 1.0f, 2.0f, 2.0f, 3.0f, 3.0f, 4.0f, 4.0f, 5.0f, 5.0f});
869   auto expected = LiteralUtil::MakeTuple({&expected_counter, &expected_data});
870   VLOG(2) << "expected = " << ShapeUtil::HumanString(expected.shape());
871   ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.0001));
872 }
873 
874 // Tests a while node when the result type T is a vector of S32.
875 //
876 // int32_t result = (0, 0, 0, 0, 0, 0);
877 // while (result[0] < count) {
878 //   result += (1, U[0, 100], U[0, 100], U[0, 100], U[0, 100], U[0, 100]);
879 // }
880 //
881 // This test misuses a vector WhileTest.WhileLoopsWithSharedBodyto represent a
882 // pair:
883 //   ((iteration, (random vector))).
884 //
885 // Note: this test currently only tests generating random values within a loop.
886 // Per backend the values generated can be different as the different backends
887 // use different random number generators.
888 // TODO(b/32240857): Extend test to verify outputs.
XLA_TEST_F(WhileTest,WhileWithPrngScalarResult)889 XLA_TEST_F(WhileTest, WhileWithPrngScalarResult) {
890   auto v6s32 = ShapeUtil::MakeShape(S32, {6});
891 
892   // Create a computation for the condition: repeat for count iterations.
893   auto build_condition = [this, v6s32](int count) {
894     XlaBuilder builder(TestName());
895     auto prev = Reshape(
896         Slice(Parameter(&builder, 0, v6s32, "prev"), {0}, {1}, {1}), {0}, {});
897     Gt(ConstantR0<int32_t>(&builder, count), prev);
898     return builder.Build().value();
899   };
900 
901   // Create a computation for the body: add 1 to the result variable.
902   XlaComputation body;
903   {
904     XlaBuilder builder("body");
905     auto prev = Parameter(&builder, 0, v6s32, "prev");
906     auto inc = ConcatInDim(&builder,
907                            {ConstantR1<int32_t>(&builder, {1}),
908                             RngUniform(ConstantR0<int32_t>(&builder, 0),
909                                        ConstantR0<int32_t>(&builder, 100),
910                                        ShapeUtil::MakeShape(S32, {5}))},
911                            0);
912     Add(inc, prev);
913     body = builder.Build().value();
914   }
915 
916   // Create a While node with computations for the condition and the body.
917   auto while_loop = [this, &body, build_condition](int count) {
918     XlaBuilder builder(TestName());
919     auto init = ConstantR1<int32_t>(&builder, {0, 0, 0, 0, 0, 0});
920     While(build_condition(count), body, init);
921     return builder.Build();
922   };
923 
924   for (int i = 1; i < 4; ++i) {
925     TF_ASSERT_OK_AND_ASSIGN(auto computation, while_loop(i));
926 
927     ExecutionOptions execution_options = execution_options_;
928     execution_options.set_seed(65);
929     TF_ASSERT_OK_AND_ASSIGN(
930         auto result,
931         client_->ExecuteAndTransfer(computation, {}, &execution_options));
932   }
933 }
934 
XLA_TEST_F(WhileTest,WhileThatSwapsParameterWithTupleElement)935 XLA_TEST_F(WhileTest, WhileThatSwapsParameterWithTupleElement) {
936   auto element_shape = ShapeUtil::MakeShape(F32, {2});
937 
938   XlaBuilder outer("outer");
939   auto p = Parameter(&outer, 0, element_shape, "param");
940   auto t = Tuple(&outer, {p, ConstantR1<float>(&outer, {1, 1})});
941 
942   TF_ASSERT_OK_AND_ASSIGN(Shape tuple_shape, outer.GetShape(t));
943 
944   XlaBuilder cond("cond");
945   auto cond_t = Parameter(&cond, 0, tuple_shape, "t");
946   Any(Eq(GetTupleElement(cond_t, 0), ConstantR1<float>(&cond, {42, 42})));
947 
948   XlaBuilder body("body");
949   auto body_t = Parameter(&body, 0, tuple_shape, "t");
950   auto e = GetTupleElement(body_t, 1);
951   Tuple(&body, {e, e});
952 
953   TF_ASSERT_OK_AND_ASSIGN(auto cond_computation, cond.Build());
954   TF_ASSERT_OK_AND_ASSIGN(auto body_computation, body.Build());
955   While(cond_computation, body_computation, t);
956 
957   auto expected_element = LiteralUtil::CreateR1<float>({1, 1});
958   auto expected =
959       LiteralUtil::MakeTuple({&expected_element, &expected_element});
960   TF_ASSERT_OK_AND_ASSIGN(
961       std::unique_ptr<GlobalData> parameter_data,
962       client_->TransferToServer(LiteralUtil::CreateR1<float>({42, 42})));
963   ComputeAndCompareTuple(&outer, expected, {parameter_data.get()},
964                          ErrorSpec(1e-6));
965 }
966 
XLA_TEST_F(WhileTest,WhileThatSwapsParameterWithBroadcast)967 XLA_TEST_F(WhileTest, WhileThatSwapsParameterWithBroadcast) {
968   auto element_shape = ShapeUtil::MakeShape(F32, {2});
969 
970   XlaBuilder outer("outer");
971   auto p = Parameter(&outer, 0, element_shape, "param");
972 
973   XlaBuilder cond("cond");
974   auto cond_t = Parameter(&cond, 0, element_shape, "t");
975   Any(Eq(cond_t, ConstantR1<float>(&cond, {42, 42})));
976 
977   XlaBuilder body("body");
978   Parameter(&body, 0, element_shape, "t");
979   Broadcast(ConstantR0<float>(&body, 1.0), {2});
980 
981   TF_ASSERT_OK_AND_ASSIGN(auto cond_computation, cond.Build());
982   TF_ASSERT_OK_AND_ASSIGN(auto body_computation, body.Build());
983   While(cond_computation, body_computation, p);
984 
985   TF_ASSERT_OK_AND_ASSIGN(
986       std::unique_ptr<GlobalData> parameter_data,
987       client_->TransferToServer(LiteralUtil::CreateR1<float>({42, 42})));
988   ComputeAndCompareR1<float>(&outer, {1.0f, 1.0f}, {parameter_data.get()},
989                              ErrorSpec(1e-6));
990 }
991 
XLA_TEST_F(WhileTest,WhileThatTurnsScalarParameterToTupleElement)992 XLA_TEST_F(WhileTest, WhileThatTurnsScalarParameterToTupleElement) {
993   auto element_shape = ShapeUtil::MakeShape(F32, {});
994 
995   XlaBuilder outer("outer");
996   auto p = Parameter(&outer, 0, element_shape, "param");
997 
998   XlaBuilder cond("cond");
999   auto cond_t = Parameter(&cond, 0, element_shape, "t");
1000   Eq(cond_t, ConstantR0<float>(&cond, 42));
1001 
1002   XlaBuilder body("body");
1003   auto body_t = Parameter(&body, 0, element_shape, "t");
1004   auto tuple = Tuple(&body, {body_t, Add(body_t, ConstantR0<float>(&body, 1))});
1005   GetTupleElement(tuple, 1);
1006 
1007   TF_ASSERT_OK_AND_ASSIGN(auto cond_computation, cond.Build());
1008   TF_ASSERT_OK_AND_ASSIGN(auto body_computation, body.Build());
1009   While(cond_computation, body_computation, p);
1010 
1011   TF_ASSERT_OK_AND_ASSIGN(
1012       std::unique_ptr<GlobalData> parameter_data,
1013       client_->TransferToServer(LiteralUtil::CreateR0<float>(42)));
1014   ComputeAndCompareR0<float>(&outer, 43.0f, {parameter_data.get()},
1015                              ErrorSpec(1e-6));
1016 }
1017 
1018 // Tests loop where the init value comes from two sources (constant and
1019 // parameter).
1020 //
1021 // int32_t result = (0, 1);
1022 // while (result[0] + result[1] < 30) {
1023 //   result[0] = result[0] + 1;
1024 //   result[1] = result[1] + 1;
1025 // }
XLA_TEST_F(WhileTest,WhileWithMixedTupleElements)1026 XLA_TEST_F(WhileTest, WhileWithMixedTupleElements) {
1027   auto result_shape = ShapeUtil::MakeTupleShape(
1028       {ShapeUtil::MakeShape(S32, {}), ShapeUtil::MakeShape(S32, {})});
1029 
1030   XlaBuilder outer("outer");
1031   auto p =
1032       Tuple(&outer, {ConstantR0<int32_t>(&outer, 0),
1033                      Parameter(&outer, 0, ShapeUtil::MakeShape(S32, {}), "t")});
1034 
1035   XlaBuilder cond("cond");
1036   auto params = Parameter(&cond, 0, result_shape, "prev");
1037   auto cond_t = Add(GetTupleElement(params, 1), GetTupleElement(params, 0));
1038   Lt(cond_t, ConstantR0<int32_t>(&cond, 30));
1039 
1040   XlaBuilder body("body");
1041   auto body_t = Parameter(&body, 0, result_shape, "t");
1042 
1043   Tuple(&body,
1044         {Add(GetTupleElement(body_t, 0), ConstantR0<int32_t>(&body, 1)),
1045          Add(GetTupleElement(body_t, 1), ConstantR0<int32_t>(&body, 1))});
1046 
1047   TF_ASSERT_OK_AND_ASSIGN(auto cond_computation, cond.Build());
1048   TF_ASSERT_OK_AND_ASSIGN(auto body_computation, body.Build());
1049   While(cond_computation, body_computation, p);
1050 
1051   TF_ASSERT_OK_AND_ASSIGN(
1052       std::unique_ptr<GlobalData> parameter_data,
1053       client_->TransferToServer(LiteralUtil::CreateR0<int32_t>(1)));
1054 
1055   auto add1 = LiteralUtil::CreateR0<int32_t>(15);
1056   auto add2 = LiteralUtil::CreateR0<int32_t>(16);
1057   auto expected = LiteralUtil::MakeTuple({&add1, &add2});
1058   ComputeAndCompareTuple(&outer, expected, {parameter_data.get()},
1059                          ErrorSpec(1e-6));
1060 }
1061 
1062 // Tests nested while loops.
1063 //
1064 // int32_t result = 0;
1065 // while (result < 30) {
1066 //   int i = 0;
1067 //   while (i < 7) {
1068 //     result = result + 2;
1069 //     i = i + 1;
1070 //   }
1071 // }
XLA_TEST_F(WhileTest,NestedWhileWithScalarResult)1072 XLA_TEST_F(WhileTest, NestedWhileWithScalarResult) {
1073   auto outer_result_shape = ShapeUtil::MakeShape(S32, {});
1074   auto inner_result_shape = ShapeUtil::MakeTupleShape(
1075       {ShapeUtil::MakeShape(S32, {}), ShapeUtil::MakeShape(S32, {})});
1076 
1077   XlaComputation inner_condition;
1078   {
1079     XlaBuilder builder("inner_condition");
1080     auto params = Parameter(&builder, 0, inner_result_shape, "prev");
1081     auto i = GetTupleElement(params, 0);
1082     Lt(i, ConstantR0<int32_t>(&builder, 7));
1083     inner_condition = builder.Build().value();
1084   }
1085 
1086   // Creates a computation for the outer loop condition:
1087   // repeat while result < 30.
1088   XlaComputation outer_condition;
1089   {
1090     XlaBuilder builder("outer_condition");
1091     auto prev = Parameter(&builder, 0, outer_result_shape, "prev");
1092     Lt(prev, ConstantR0<int32_t>(&builder, 30));
1093     outer_condition = builder.Build().value();
1094   }
1095 
1096   // Creates a computation for the inner loop body: add 1 to `i`, and add 2 to
1097   // `result`.
1098   XlaComputation inner_body;
1099   {
1100     XlaBuilder builder("inner_body");
1101     auto params = Parameter(&builder, 0, inner_result_shape, "prev");
1102     auto i = GetTupleElement(params, 0);
1103     auto result = GetTupleElement(params, 1);
1104     i = Add(ConstantR0<int32_t>(&builder, 1), i);
1105     result = Add(ConstantR0<int32_t>(&builder, 2), result);
1106     Tuple(&builder, {i, result});
1107     inner_body = builder.Build().value();
1108   }
1109 
1110   // Creates a computation for the outer loop: run the inner loop with i = 0.
1111   XlaComputation outer_body;
1112   {
1113     XlaBuilder builder("outer_body");
1114     auto prev = Parameter(&builder, 0, outer_result_shape, "prev");
1115     auto init = Tuple(&builder, {ConstantR0<int32_t>(&builder, 0), prev});
1116     auto result = While(inner_condition, inner_body, init);
1117     GetTupleElement(result, 1);
1118     outer_body = builder.Build().value();
1119   }
1120 
1121   // Create a While node with computations for the condition and the body.
1122   XlaBuilder builder(TestName());
1123   auto init = ConstantR0<int32_t>(&builder, 0);
1124   While(outer_condition, outer_body, init);
1125 
1126   ComputeAndCompareR0<int32_t>(&builder, 42, {});
1127 }
1128 
1129 // Tests a while node when the result type T is S32.
1130 // f = lambda result: tuple({result < 5})
1131 // int32_t result = 0;
1132 // while (f(result).get<0>()) {
1133 //   result = result + 1;
1134 // }
XLA_TEST_F(WhileTest,WhileWithCallInsideCondition)1135 XLA_TEST_F(WhileTest, WhileWithCallInsideCondition) {
1136   auto result_shape = ShapeUtil::MakeShape(S32, {});
1137 
1138   // Create a computation for the condition: repeat for 5 iterations.
1139   XlaComputation condition_callee;
1140   {
1141     XlaBuilder builder("condition_callee");
1142     auto prev = Parameter(&builder, 0, result_shape, "prev");
1143     Tuple(&builder, {Gt(ConstantR0<int32_t>(&builder, 5), prev)});
1144 
1145     condition_callee = builder.Build().value();
1146   }
1147 
1148   XlaComputation condition;
1149   {
1150     XlaBuilder builder("condition");
1151     auto prev = Parameter(&builder, 0, result_shape, "prev");
1152     auto result = Call(&builder, condition_callee, {prev});
1153     GetTupleElement(result, 0);
1154     condition = builder.Build().value();
1155   }
1156 
1157   // Create a computation for the body: add 1 to the result variable.
1158   XlaComputation body;
1159   {
1160     XlaBuilder builder("body");
1161     auto prev = Parameter(&builder, 0, result_shape, "prev");
1162     auto input = ConstantR0<int32_t>(&builder, 1);
1163     Add(input, prev);
1164     body = builder.Build().value();
1165   }
1166 
1167   // Create a While node with computations for the condition and the body.
1168   XlaBuilder builder(TestName());
1169   auto init = ConstantR0<int32_t>(&builder, 0);
1170   While(condition, body, init);
1171 
1172   ComputeAndCompareR0<int32_t>(&builder, 5, {});
1173 }
1174 
XLA_TEST_F(WhileTest,WhileWithLoopInvariantOperation)1175 XLA_TEST_F(WhileTest, WhileWithLoopInvariantOperation) {
1176   auto matrix_shape = ShapeUtil::MakeShape(F32, {2, 2});
1177   auto scalar_s32 = ShapeUtil::MakeShape(S32, {});
1178   auto while_shape = ShapeUtil::MakeTupleShape(
1179       {scalar_s32, matrix_shape, matrix_shape, matrix_shape});
1180 
1181   // Create a computation for the condition: repeat for 5 iterations.
1182   XlaComputation condition;
1183   {
1184     XlaBuilder builder("condition");
1185     auto state = Parameter(&builder, 0, while_shape, "state");
1186     Gt(ConstantR0<int32_t>(&builder, 5), GetTupleElement(state, 0));
1187     TF_ASSERT_OK_AND_ASSIGN(condition, builder.Build());
1188   }
1189 
1190   XlaComputation body;
1191   {
1192     XlaBuilder builder("body");
1193     auto state = Parameter(&builder, 0, while_shape, "state");
1194     auto indvar = GetTupleElement(state, 0);
1195     auto input_0 = GetTupleElement(state, 1);
1196     auto input_1 = GetTupleElement(state, 2);
1197     auto output = Tanh(Dot(input_0, input_1));
1198     auto indvar_next = Add(indvar, ConstantR0<int32_t>(&builder, 1));
1199     Tuple(&builder, {indvar_next, input_0, input_1, output});
1200     TF_ASSERT_OK_AND_ASSIGN(body, builder.Build());
1201   }
1202 
1203   XlaBuilder builder(TestName());
1204   auto matrix_input = Parameter(&builder, 0, matrix_shape, "matrix");
1205   auto init = Tuple(&builder, {ConstantR0<int32_t>(&builder, 0), matrix_input,
1206                                matrix_input, matrix_input});
1207   auto while_instruction = While(condition, body, init);
1208   GetTupleElement(while_instruction, 3);
1209 
1210   TF_ASSERT_OK_AND_ASSIGN(
1211       auto param_value, client_->TransferToServer(LiteralUtil::CreateR2<float>(
1212                             {{1.0, 2.0}, {-1.0, -2.0}})));
1213 
1214   ComputeAndCompareR2<float>(
1215       &builder, {{-0.76159416, -0.96402758}, {0.76159416, 0.96402758}},
1216       {param_value.get()}, ErrorSpec(4e-5));
1217 }
1218 
XLA_TEST_F(WhileTest,DISABLED_ON_INTERPRETER (WhileInfeedCondition))1219 XLA_TEST_F(WhileTest, DISABLED_ON_INTERPRETER(WhileInfeedCondition)) {
1220   auto while_shape = ShapeUtil::MakeShape(S32, {});
1221 
1222   XlaComputation condition;
1223   {
1224     XlaBuilder builder("condition");
1225     Parameter(&builder, 0, while_shape, "state");
1226     Infeed(&builder, ShapeUtil::MakeShape(PRED, {}));
1227     TF_ASSERT_OK_AND_ASSIGN(condition, builder.Build());
1228   }
1229 
1230   XlaComputation body;
1231   {
1232     XlaBuilder builder("body");
1233     auto indvar = Parameter(&builder, 0, while_shape, "state");
1234     Add(indvar, ConstantR0<int32_t>(&builder, 1));
1235     TF_ASSERT_OK_AND_ASSIGN(body, builder.Build());
1236   }
1237 
1238   XlaBuilder builder(TestName());
1239   While(condition, body, ConstantR0<int32_t>(&builder, 0));
1240 
1241   TF_ASSERT_OK(client_->TransferToInfeed(LiteralUtil::CreateR0<bool>(true)));
1242   TF_ASSERT_OK(client_->TransferToInfeed(LiteralUtil::CreateR0<bool>(true)));
1243   TF_ASSERT_OK(client_->TransferToInfeed(LiteralUtil::CreateR0<bool>(false)));
1244 
1245   ComputeAndCompareR0<int32_t>(&builder, 2, {});
1246 }
1247 
BM_WhileLoop(::testing::benchmark::State & state)1248 void BM_WhileLoop(::testing::benchmark::State& state) {
1249   // Benchmark a simple kernel to measure while loop overheads.
1250 
1251   se::Platform* platform = PlatformUtil::GetDefaultPlatform().ValueOrDie();
1252   auto executors = PlatformUtil::GetStreamExecutors(platform).ValueOrDie();
1253   se::StreamExecutorMemoryAllocator allocator(platform, executors);
1254   LocalClient* client =
1255       ClientLibrary::GetOrCreateLocalClient(platform).ValueOrDie();
1256 
1257   const int64_t seq_len = 100;
1258   Shape loop_state_shape = ShapeUtil::MakeTupleShape(
1259       {ShapeUtil::MakeShape(S32, {}),
1260        ShapeUtil::MakeShape(F32, {seq_len, 1024, 1024})});
1261 
1262   // Create while condition computation with 'loop_limit'.
1263   const int32_t loop_limit = 100;
1264   XlaComputation condition;
1265   {
1266     XlaBuilder builder("condition");
1267     auto prev = Parameter(&builder, 0, loop_state_shape, "prev");
1268     auto iteration = GetTupleElement(prev, 0);
1269     Lt(iteration, ConstantR0<int32_t>(&builder, loop_limit));
1270     condition = builder.Build().value();
1271   }
1272 
1273   // Create while body computation with unit loop increment.
1274   XlaComputation body;
1275   {
1276     XlaBuilder builder("body");
1277     auto prev = Parameter(&builder, 0, loop_state_shape, "prev");
1278     // TupleElement 0
1279     auto iteration = GetTupleElement(prev, 0);
1280     auto out0 = Add(iteration, ConstantR0<int32_t>(&builder, 1));
1281     // TupleElement 1
1282     auto input = GetTupleElement(prev, 1);
1283     // Update.
1284     auto one = ConstantR0<float>(&builder, 1.0);
1285     auto update = Broadcast(one, {1, 1024, 1024});
1286     // Starts = iteration * 2;
1287     auto zero = ConstantR0<int32_t>(&builder, 0);
1288     // UpdateSlice.
1289     auto out1 = DynamicUpdateSlice(input, update, {zero, zero, zero});
1290     Tuple(&builder, {out0, out1});
1291     body = builder.Build().value();
1292   }
1293 
1294   // Create a While instruction.
1295   XlaBuilder builder("while");
1296   auto zero = ConstantR0<float>(&builder, 0.0);
1297   auto input = Broadcast(zero, {seq_len, 1024, 1024});
1298   auto init = Tuple(&builder, {ConstantR0<int32_t>(&builder, 0), input});
1299   While(condition, body, init);
1300   auto computation = builder.Build().value();
1301 
1302   TF_ASSERT_OK_AND_ASSIGN(
1303       auto executables,
1304       client->Compile(computation, {}, ExecutableBuildOptions()));
1305   auto executable = std::move(executables[0]);
1306 
1307   // Run some warm-up executions.
1308   ExecutableRunOptions options;
1309   options.set_allocator(&allocator);
1310   const int kWarmups = 2;
1311   for (int i = 0; i < kWarmups; ++i) {
1312     auto result =
1313         executable->Run(absl::Span<const ShapedBuffer* const>(), options);
1314     ASSERT_TRUE(result.ok());
1315   }
1316 
1317   // Run benchmark.
1318   for (auto s : state) {
1319     auto result =
1320         executable->Run(absl::Span<const ShapedBuffer* const>(), options);
1321     ASSERT_TRUE(result.ok());
1322   }
1323 }
1324 
1325 BENCHMARK(BM_WhileLoop);
1326 }  // namespace
1327 }  // namespace xla
1328