• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <random>
17 #include "tensorflow/compiler/xla/client/xla_builder.h"
18 #include "tensorflow/compiler/xla/client/xla_computation.h"
19 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
20 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
21 #include "tensorflow/compiler/xla/tests/test_macros.h"
22 
23 namespace xla {
24 namespace {
25 
26 class ConditionalOpTest : public ClientLibraryTestBase {
27  protected:
CreateR0ConstantComputation(float value)28   XlaComputation CreateR0ConstantComputation(float value) {
29     XlaBuilder builder("Constant");
30     Parameter(&builder, 0, empty_tuple_, "tuple");
31     ConstantR0<float>(&builder, value);
32     auto build_status = builder.Build();
33     EXPECT_IS_OK(build_status.status());
34     return build_status.ConsumeValueOrDie();
35   }
36 
CreateR0IdentityComputation()37   XlaComputation CreateR0IdentityComputation() {
38     XlaBuilder builder("Identity");
39     Parameter(&builder, 0, r0f32_, "x");
40     auto build_status = builder.Build();
41     EXPECT_IS_OK(build_status.status());
42     return build_status.ConsumeValueOrDie();
43   }
44 
CreateCeilComputation(const Shape & shape)45   XlaComputation CreateCeilComputation(const Shape& shape) {
46     XlaBuilder builder("Ceil");
47     auto param = Parameter(&builder, 0, shape, "param");
48     Ceil(param);
49     auto build_status = builder.Build();
50     EXPECT_IS_OK(build_status.status());
51     return build_status.ConsumeValueOrDie();
52   }
53 
CreateR0CeilComputation()54   XlaComputation CreateR0CeilComputation() {
55     return CreateCeilComputation(r0f32_);
56   }
57 
CreateR1CeilComputation()58   XlaComputation CreateR1CeilComputation() {
59     return CreateCeilComputation(r1s2f32_);
60   }
61 
CreateFloorComputation(const Shape & shape)62   XlaComputation CreateFloorComputation(const Shape& shape) {
63     XlaBuilder builder("Floor");
64     auto param = Parameter(&builder, 0, shape, "param");
65     Floor(param);
66     auto build_status = builder.Build();
67     EXPECT_IS_OK(build_status.status());
68     return build_status.ConsumeValueOrDie();
69   }
70 
CreateR0FloorComputation()71   XlaComputation CreateR0FloorComputation() {
72     return CreateFloorComputation(r0f32_);
73   }
74 
CreateR1FloorComputation()75   XlaComputation CreateR1FloorComputation() {
76     return CreateFloorComputation(r1s2f32_);
77   }
78 
CreateTupleCeilComputation(const string & computation_name,const Shape & tuple_shape)79   XlaComputation CreateTupleCeilComputation(const string& computation_name,
80                                             const Shape& tuple_shape) {
81     XlaBuilder builder(computation_name);
82     auto tuple = Parameter(&builder, 0, tuple_shape, "tuple");
83     auto x = GetTupleElement(tuple, 0);
84     auto y = GetTupleElement(tuple, 1);
85     auto x_ceil = Ceil(x);
86     auto y_ceil = Ceil(y);
87     Tuple(&builder, {x_ceil, y_ceil});
88     auto build_status = builder.Build();
89     EXPECT_IS_OK(build_status.status());
90     return build_status.ConsumeValueOrDie();
91   }
92 
CreateR0TupleCeilComputation()93   XlaComputation CreateR0TupleCeilComputation() {
94     return CreateTupleCeilComputation("CeilR0", tuple_2_r0f32_);
95   }
96 
CreateR1TupleCeilComputation()97   XlaComputation CreateR1TupleCeilComputation() {
98     return CreateTupleCeilComputation("CeilR1", tuple_2_r1s2f32_);
99   }
100 
CreateTupleFloorComputation(const string & computation_name,const Shape & tuple_shape)101   XlaComputation CreateTupleFloorComputation(const string& computation_name,
102                                              const Shape& tuple_shape) {
103     XlaBuilder builder(computation_name);
104     auto tuple = Parameter(&builder, 0, tuple_shape, "tuple");
105     auto x = GetTupleElement(tuple, 0);
106     auto y = GetTupleElement(tuple, 1);
107     auto x_floor = Floor(x);
108     auto y_floor = Floor(y);
109     Tuple(&builder, {x_floor, y_floor});
110     auto build_status = builder.Build();
111     EXPECT_IS_OK(build_status.status());
112     return build_status.ConsumeValueOrDie();
113   }
114 
CreateR0TupleFloorComputation()115   XlaComputation CreateR0TupleFloorComputation() {
116     return CreateTupleFloorComputation("FloorR0", tuple_2_r0f32_);
117   }
118 
CreateR1TupleFloorComputation()119   XlaComputation CreateR1TupleFloorComputation() {
120     return CreateTupleFloorComputation("FloorR1", tuple_2_r1s2f32_);
121   }
122 
CreateTupleAddComputation(const string & computation_name,const Shape & tuple_shape)123   XlaComputation CreateTupleAddComputation(const string& computation_name,
124                                            const Shape& tuple_shape) {
125     XlaBuilder builder(computation_name);
126     auto tuple = Parameter(&builder, 0, tuple_shape, "tuple");
127     auto x = GetTupleElement(tuple, 0);
128     auto y = GetTupleElement(tuple, 1);
129     Add(x, y);
130     auto build_status = builder.Build();
131     EXPECT_IS_OK(build_status.status());
132     return build_status.ConsumeValueOrDie();
133   }
134 
CreateR0TupleAddComputation()135   XlaComputation CreateR0TupleAddComputation() {
136     return CreateTupleAddComputation("AddR0", tuple_2_r0f32_);
137   }
138 
CreateR1TupleAddComputation()139   XlaComputation CreateR1TupleAddComputation() {
140     return CreateTupleAddComputation("AddR1", tuple_2_r1s2f32_);
141   }
142 
CreateTupleSubComputation(const string & computation_name,const Shape & tuple_shape)143   XlaComputation CreateTupleSubComputation(const string& computation_name,
144                                            const Shape& tuple_shape) {
145     XlaBuilder builder(computation_name);
146     auto tuple = Parameter(&builder, 0, tuple_shape, "tuple");
147     auto x = GetTupleElement(tuple, 0);
148     auto y = GetTupleElement(tuple, 1);
149     Sub(x, y);
150     auto build_status = builder.Build();
151     EXPECT_IS_OK(build_status.status());
152     return build_status.ConsumeValueOrDie();
153   }
154 
CreateR0TupleSubComputation()155   XlaComputation CreateR0TupleSubComputation() {
156     return CreateTupleSubComputation("SubR0", tuple_2_r0f32_);
157   }
158 
CreateR1TupleSubComputation()159   XlaComputation CreateR1TupleSubComputation() {
160     return CreateTupleSubComputation("SubR1", tuple_2_r1s2f32_);
161   }
162 
163   Shape r0f32_ = ShapeUtil::MakeShape(F32, {});
164   Shape r1s2f32_ = ShapeUtil::MakeShape(F32, {2});
165   Shape tuple_2_r0f32_ = ShapeUtil::MakeTupleShape(
166       {ShapeUtil::MakeShape(F32, {}), ShapeUtil::MakeShape(F32, {})});
167   Shape tuple_2_r1s2f32_ = ShapeUtil::MakeTupleShape(
168       {ShapeUtil::MakeShape(F32, {2}), ShapeUtil::MakeShape(F32, {2})});
169   Shape empty_tuple_ = ShapeUtil::MakeTupleShape({});
170   ErrorSpec error_spec_{0.001};
171 };
172 
173 // Test fixture to run indexed conditional (switch/case) tests with varying
174 // number of branches.
175 class CaseOpTest : public ConditionalOpTest,
176                    public ::testing::WithParamInterface<int> {};
177 
178 // Test true and false computations that do not take any parameters.
XLA_TEST_F(ConditionalOpTest,Parameters0)179 XLA_TEST_F(ConditionalOpTest, Parameters0) {
180   XlaBuilder builder(TestName());
181   XlaOp pred;
182   auto pred_arg = CreateR0Parameter<bool>(true, 0, "pred", &builder, &pred);
183   auto operands = Tuple(&builder, {});
184   auto true_computation = CreateR0ConstantComputation(56.0f);
185   auto false_computation = CreateR0ConstantComputation(12.0f);
186   Conditional(pred, operands, true_computation, operands, false_computation);
187 
188   ComputeAndCompareR0<float>(&builder, 56.0f, {pred_arg.get()}, error_spec_);
189 }
190 
191 // Test branch computations that do not take any parameters.
XLA_TEST_P(CaseOpTest,Parameters0)192 XLA_TEST_P(CaseOpTest, Parameters0) {
193   int num_branches = GetParam();
194   for (int bi = -1; bi <= num_branches; ++bi) {
195     SCOPED_TRACE(bi);
196     XlaBuilder builder(TestName());
197     XlaOp branch_index;
198     auto branch_index_arg = CreateR0Parameter<int32>(bi, 0, "branch_index_arg",
199                                                      &builder, &branch_index);
200     auto operand = Tuple(&builder, {});
201 
202     std::vector<XlaOp> operands(num_branches, operand);
203     std::vector<XlaComputation> branches;
204     branches.reserve(num_branches);
205     std::vector<const XlaComputation*> branches_p(num_branches);
206     for (int i = 0; i < num_branches; ++i) {
207       branches.emplace_back(
208           CreateR0ConstantComputation(static_cast<float>(i) * 10));
209       branches_p[i] = &branches[i];
210     }
211     Conditional(branch_index, branches_p, operands);
212 
213     float expected = 10 * static_cast<float>((bi < 0 || bi >= num_branches)
214                                                  ? num_branches - 1
215                                                  : bi);
216     ComputeAndCompareR0<float>(&builder, expected, {branch_index_arg.get()},
217                                error_spec_);
218   }
219 }
220 
221 // Test true and false computations that take in 1 parameter.
XLA_TEST_F(ConditionalOpTest,Parameters1)222 XLA_TEST_F(ConditionalOpTest, Parameters1) {
223   XlaBuilder builder(TestName());
224   XlaOp pred;
225   auto pred_arg = CreateR0Parameter<bool>(false, 0, "pred", &builder, &pred);
226   auto operand1 = ConstantR0<float>(&builder, 56.0f);
227   auto operand2 = ConstantR0<float>(&builder, 12.0f);
228   auto identity = CreateR0IdentityComputation();
229   Conditional(pred, operand1, identity, operand2, identity);
230 
231   ComputeAndCompareR0<float>(&builder, 12.0f, {pred_arg.get()}, error_spec_);
232 }
233 
234 // Test branch computations that take in 1 parameter.
XLA_TEST_P(CaseOpTest,Parameters1)235 XLA_TEST_P(CaseOpTest, Parameters1) {
236   int num_branches = GetParam();
237   for (int bi = -1; bi <= num_branches; ++bi) {
238     SCOPED_TRACE(bi);
239     XlaBuilder builder(TestName());
240     XlaOp branch_index;
241     auto branch_index_arg = CreateR0Parameter<int32>(bi, 0, "branch_index_arg",
242                                                      &builder, &branch_index);
243 
244     auto make_branch = [&builder, this](int i) {
245       auto sb = builder.CreateSubBuilder(absl::StrCat("branch_", i));
246       Add(ConstantR0<float>(sb.get(), static_cast<float>(i)),
247           Parameter(sb.get(), 0, r0f32_, "p0"));
248       return sb->BuildAndNoteError();
249     };
250     std::vector<XlaComputation> branches;
251     branches.reserve(num_branches);
252     std::vector<const XlaComputation*> branches_p(num_branches);
253     std::vector<XlaOp> operands;
254     operands.reserve(num_branches);
255     std::vector<float> expecteds(num_branches);
256     for (int i = 0; i < num_branches; ++i) {
257       branches.emplace_back(make_branch(i));
258       branches_p[i] = &branches[i];
259       auto fi = static_cast<float>(i);
260       operands.emplace_back(ConstantR0<float>(&builder, 10 * fi + 7));
261       expecteds[i] = 10 * fi + 7 + fi;
262     }
263 
264     Conditional(branch_index, branches_p, operands);
265     float expected = (bi < 0 || bi >= num_branches)
266                          ? expecteds[num_branches - 1]
267                          : expecteds[bi];
268     ComputeAndCompareR0<float>(&builder, expected, {branch_index_arg.get()},
269                                error_spec_);
270   }
271 }
272 
273 // Test conditional with two different computations in the true and false cases
274 // that take in different arguments.
XLA_TEST_F(ConditionalOpTest,DiffComputationsDiffArgs)275 XLA_TEST_F(ConditionalOpTest, DiffComputationsDiffArgs) {
276   XlaBuilder builder(TestName());
277   XlaOp pred;
278   auto pred_arg = CreateR0Parameter<bool>(false, 0, "pred", &builder, &pred);
279   auto operand1 = ConstantR0<float>(&builder, 56.4f);
280   auto operand2 = ConstantR0<float>(&builder, 12.6f);
281   Conditional(pred, operand1, CreateR0CeilComputation(), operand2,
282               CreateR0FloorComputation());
283 
284   ComputeAndCompareR0<float>(&builder, 12.0f, {pred_arg.get()}, error_spec_);
285 }
286 
287 // Test conditional with two different computations in the true and false cases
288 // that take in the same arguments.
XLA_TEST_F(ConditionalOpTest,DiffComputationsSameArg)289 XLA_TEST_F(ConditionalOpTest, DiffComputationsSameArg) {
290   XlaBuilder builder(TestName());
291   XlaOp pred;
292   auto pred_arg = CreateR0Parameter<bool>(false, 0, "pred", &builder, &pred);
293   auto operand = ConstantR0<float>(&builder, 12.6f);
294   Conditional(pred, operand, CreateR0CeilComputation(), operand,
295               CreateR0FloorComputation());
296 
297   ComputeAndCompareR0<float>(&builder, 12.0f, {pred_arg.get()}, error_spec_);
298 }
299 
300 // Test conditional with the same computation in the true and false cases but
301 // take in different arguments.
XLA_TEST_F(ConditionalOpTest,SameComputationDiffArgs)302 XLA_TEST_F(ConditionalOpTest, SameComputationDiffArgs) {
303   XlaBuilder builder(TestName());
304   XlaOp pred;
305   auto pred_arg = CreateR0Parameter<bool>(false, 0, "pred", &builder, &pred);
306   auto operand1 = ConstantR0<float>(&builder, 56.4f);
307   auto operand2 = ConstantR0<float>(&builder, 12.6f);
308   auto floor = CreateR0FloorComputation();
309   Conditional(pred, operand1, floor, operand2, floor);
310 
311   ComputeAndCompareR0<float>(&builder, 12.0f, {pred_arg.get()}, error_spec_);
312 }
313 
314 // Test conditional with the same computation in the true and false cases that
315 // take in the same arguments.
XLA_TEST_F(ConditionalOpTest,SameComputationSameArg)316 XLA_TEST_F(ConditionalOpTest, SameComputationSameArg) {
317   XlaBuilder builder(TestName());
318   XlaOp pred;
319   auto pred_arg = CreateR0Parameter<bool>(false, 0, "pred", &builder, &pred);
320   auto operand = ConstantR0<float>(&builder, 12.6f);
321   auto floor = CreateR0FloorComputation();
322   Conditional(pred, operand, floor, operand, floor);
323 
324   ComputeAndCompareR0<float>(&builder, 12.0f, {pred_arg.get()}, error_spec_);
325 }
326 
327 // Test conditional with different instances of the same computation in the true
328 // and false cases.
XLA_TEST_F(ConditionalOpTest,SameComputationDiffInstances)329 XLA_TEST_F(ConditionalOpTest, SameComputationDiffInstances) {
330   XlaBuilder builder(TestName());
331   XlaOp pred;
332   auto pred_arg = CreateR0Parameter<bool>(false, 0, "pred", &builder, &pred);
333   auto operand1 = ConstantR0<float>(&builder, 56.4f);
334   auto operand2 = ConstantR0<float>(&builder, 12.6f);
335   Conditional(pred, operand1, CreateR0FloorComputation(), operand2,
336               CreateR0FloorComputation());
337 
338   ComputeAndCompareR0<float>(&builder, 12.0f, {pred_arg.get()}, error_spec_);
339 }
340 
341 // Test the case when a call invokes a computation that contains a conditional.
XLA_TEST_F(ConditionalOpTest,ConditionalWithCall)342 XLA_TEST_F(ConditionalOpTest, ConditionalWithCall) {
343   Shape r0bool = ShapeUtil::MakeShape(PRED, {});
344   XlaBuilder inner_builder(TestName() + ".inner_conditional");
345   auto pred_cond = Parameter(&inner_builder, 0, r0bool, "param0");
346   auto true_operand = Parameter(&inner_builder, 1, r0f32_, "param1");
347   auto false_operand = Parameter(&inner_builder, 2, r0f32_, "param2");
348   Conditional(pred_cond, true_operand, CreateR0CeilComputation(), false_operand,
349               CreateR0FloorComputation());
350   auto inner_builder_result = inner_builder.Build();
351 
352   XlaBuilder builder(TestName());
353   XlaOp pred;
354   auto pred_arg = CreateR0Parameter<bool>(false, 0, "pred", &builder, &pred);
355   auto operand1 = ConstantR0<float>(&builder, 56.4f);
356   auto operand2 = ConstantR0<float>(&builder, 12.6f);
357   Call(&builder, inner_builder_result.ConsumeValueOrDie(),
358        {pred, operand1, operand2});
359 
360   ComputeAndCompareR0<float>(&builder, 12.0f, {pred_arg.get()}, error_spec_);
361 }
362 
363 // Test true and false computations that take in 2 parameters and predicate is
364 // true.
XLA_TEST_F(ConditionalOpTest,Parameters2TrueBranch)365 XLA_TEST_F(ConditionalOpTest, Parameters2TrueBranch) {
366   XlaBuilder builder(TestName());
367   XlaOp pred;
368   auto pred_arg = CreateR0Parameter<bool>(true, 0, "pred", &builder, &pred);
369   auto operand1 = ConstantR0<float>(&builder, 56.0f);
370   auto operand2 = ConstantR0<float>(&builder, 12.0f);
371   auto operands = Tuple(&builder, {operand1, operand2});
372   Conditional(pred, operands, CreateR0TupleAddComputation(), operands,
373               CreateR0TupleSubComputation());
374 
375   ComputeAndCompareR0<float>(&builder, 68.0f, {pred_arg.get()}, error_spec_);
376 }
377 
378 // Test true and false computations that take in 2 parameters and predicate is
379 // false.
XLA_TEST_F(ConditionalOpTest,Parameters2FalseBranch)380 XLA_TEST_F(ConditionalOpTest, Parameters2FalseBranch) {
381   XlaBuilder builder(TestName());
382   XlaOp pred;
383   auto pred_arg = CreateR0Parameter<bool>(false, 0, "pred", &builder, &pred);
384   auto operand1 = ConstantR0<float>(&builder, 56.0f);
385   auto operand2 = ConstantR0<float>(&builder, 12.0f);
386   auto operands = Tuple(&builder, {operand1, operand2});
387   Conditional(pred, operands, CreateR0TupleAddComputation(), operands,
388               CreateR0TupleSubComputation());
389 
390   ComputeAndCompareR0<float>(&builder, 44.0f, {pred_arg.get()}, error_spec_);
391 }
392 
393 // Test true and false computations that take in 2 array parameters and
394 // predicate is true.
XLA_TEST_F(ConditionalOpTest,Parameters2ArrayTrueBranch)395 XLA_TEST_F(ConditionalOpTest, Parameters2ArrayTrueBranch) {
396   XlaBuilder builder(TestName());
397   XlaOp pred;
398   auto pred_arg = CreateR0Parameter<bool>(true, 0, "pred", &builder, &pred);
399   auto operand1 = ConstantR1<float>(&builder, {24.0f, 56.0f});
400   auto operand2 = ConstantR1<float>(&builder, {10.0f, 11.0f});
401   auto operands = Tuple(&builder, {operand1, operand2});
402   Conditional(pred, operands, CreateR1TupleAddComputation(), operands,
403               CreateR1TupleSubComputation());
404 
405   ComputeAndCompareR1<float>(&builder, {34.0f, 67.0f}, {pred_arg.get()},
406                              error_spec_);
407 }
408 
409 // Test branch computations that take in 2 array parameters.
XLA_TEST_P(CaseOpTest,Parameters2Array)410 XLA_TEST_P(CaseOpTest, Parameters2Array) {
411   int num_branches = GetParam();
412   for (int bi = -1; bi <= num_branches; ++bi) {
413     SCOPED_TRACE(bi);
414     XlaBuilder builder(TestName());
415     XlaOp branch_index;
416     auto branch_index_arg =
417         CreateR0Parameter<int32>(bi, 0, "pred", &builder, &branch_index);
418     auto operand1 = ConstantR1<float>(&builder, {24.0f, 56.0f});
419     auto operand2 = ConstantR1<float>(&builder, {10.0f, 11.0f});
420     auto operands = Tuple(&builder, {operand1, operand2});
421     auto make_branch = [&builder, this](int i) {
422       auto sb = builder.CreateSubBuilder(absl::StrCat("branch_", i));
423       auto p = Parameter(sb.get(), 0, tuple_2_r1s2f32_, "p0");
424       Add(Mul(ConstantR0<float>(sb.get(), static_cast<float>(i)),
425               GetTupleElement(p, 0)),
426           GetTupleElement(p, 1));
427       return sb->BuildAndNoteError();
428     };
429     std::vector<XlaComputation> branches;
430     branches.reserve(num_branches);
431     std::vector<const XlaComputation*> branches_p(num_branches);
432     for (int i = 0; i < num_branches; ++i) {
433       branches.emplace_back(make_branch(i));
434       branches_p[i] = &branches[i];
435     }
436     Conditional(branch_index, branches_p,
437                 std::vector<XlaOp>(num_branches, operands));
438     auto modified_bi = static_cast<float>(
439         (bi < 0 || bi >= num_branches) ? num_branches - 1 : bi);
440     ComputeAndCompareR1<float>(
441         &builder, {24.0f * modified_bi + 10, 56.0f * modified_bi + 11},
442         {branch_index_arg.get()}, error_spec_);
443   }
444 }
445 
446 INSTANTIATE_TEST_SUITE_P(CaseOpTest_Instantiation, CaseOpTest,
447                          ::testing::Values(1, 2, 3, 4, 5));
448 
449 // Test true and false computations that take in 2 array parameters and
450 // predicate is false.
XLA_TEST_F(ConditionalOpTest,Parameters2ArrayFalseBranch)451 XLA_TEST_F(ConditionalOpTest, Parameters2ArrayFalseBranch) {
452   XlaBuilder builder(TestName());
453   XlaOp pred;
454   auto pred_arg = CreateR0Parameter<bool>(false, 0, "pred", &builder, &pred);
455   auto operand1 = ConstantR1<float>(&builder, {24.0f, 56.0f});
456   auto operand2 = ConstantR1<float>(&builder, {10.0f, 11.0f});
457   auto operands = Tuple(&builder, {operand1, operand2});
458   Conditional(pred, operands, CreateR1TupleAddComputation(), operands,
459               CreateR1TupleSubComputation());
460 
461   ComputeAndCompareR1<float>(&builder, {14.0f, 45.0f}, {pred_arg.get()},
462                              error_spec_);
463 }
464 
465 // Test true and false computations that return a tuple of scalars.
XLA_TEST_F(ConditionalOpTest,ReturnTupleOfScalars)466 XLA_TEST_F(ConditionalOpTest, ReturnTupleOfScalars) {
467   XlaBuilder builder(TestName());
468   XlaOp pred;
469   auto pred_arg = CreateR0Parameter<bool>(false, 0, "pred", &builder, &pred);
470   auto operands = Tuple(&builder, {ConstantR0<float>(&builder, 12.2f),
471                                    ConstantR0<float>(&builder, 25.6f)});
472   Conditional(pred, operands, CreateR0TupleCeilComputation(), operands,
473               CreateR0TupleFloorComputation());
474 
475   ComputeAndCompareTuple(
476       &builder,
477       LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR0<float>(12.0f),
478                                         LiteralUtil::CreateR0<float>(25.0f)}),
479       {pred_arg.get()}, error_spec_);
480 }
481 
482 // Test true and false computations that return a tuple of arrays.
XLA_TEST_F(ConditionalOpTest,ReturnTupleOfArrays)483 XLA_TEST_F(ConditionalOpTest, ReturnTupleOfArrays) {
484   XlaBuilder builder(TestName());
485   XlaOp pred;
486   auto pred_arg = CreateR0Parameter<bool>(true, 0, "pred", &builder, &pred);
487   auto operands =
488       Tuple(&builder, {ConstantR1<float>(&builder, {12.2f, 15.8f}),
489                        ConstantR1<float>(&builder, {25.6f, 29.2f})});
490   Conditional(pred, operands, CreateR1TupleCeilComputation(), operands,
491               CreateR1TupleFloorComputation());
492 
493   ComputeAndCompareTuple(&builder,
494                          LiteralUtil::MakeTupleFromSlices(
495                              {LiteralUtil::CreateR1<float>({13.0f, 16.0f}),
496                               LiteralUtil::CreateR1<float>({26.0f, 30.0f})}),
497                          {pred_arg.get()}, error_spec_);
498 }
499 
500 // Test true and false computations that return a tuple of a predicate, a
501 // scalar, and an array.
XLA_TEST_F(ConditionalOpTest,ReturnTupleofPredicateScalarArray)502 XLA_TEST_F(ConditionalOpTest, ReturnTupleofPredicateScalarArray) {
503   XlaBuilder true_builder(TestName() + ".true");
504   {
505     Parameter(&true_builder, 0, empty_tuple_, "tuple");
506     auto true_pred = ConstantR0<bool>(&true_builder, true);
507     auto true_scalar = ConstantR0<float>(&true_builder, 12.2f);
508     auto true_array = ConstantR1<float>(&true_builder, {12.8f, 14.6f});
509     Tuple(&true_builder, {true_pred, true_scalar, true_array});
510   }
511   auto true_builder_result = true_builder.Build();
512   EXPECT_IS_OK(true_builder_result.status());
513 
514   XlaBuilder false_builder(TestName() + ".false");
515   {
516     Parameter(&false_builder, 0, empty_tuple_, "tuple");
517     auto false_pred = ConstantR0<bool>(&false_builder, false);
518     auto false_scalar = ConstantR0<float>(&false_builder, 25.6f);
519     auto false_array = ConstantR1<float>(&false_builder, {26.4f, 32.6f});
520     Tuple(&false_builder, {false_pred, false_scalar, false_array});
521   }
522   auto false_builder_result = false_builder.Build();
523   EXPECT_IS_OK(false_builder_result.status());
524 
525   XlaBuilder builder(TestName());
526   XlaOp pred;
527   auto pred_arg = CreateR0Parameter<bool>(true, 0, "pred", &builder, &pred);
528   auto operands = Tuple(&builder, {});
529   Conditional(pred, operands, true_builder_result.ConsumeValueOrDie(), operands,
530               false_builder_result.ConsumeValueOrDie());
531 
532   ComputeAndCompareTuple(&builder,
533                          LiteralUtil::MakeTupleFromSlices(
534                              {LiteralUtil::CreateR0<bool>(true),
535                               LiteralUtil::CreateR0<float>(12.2f),
536                               LiteralUtil::CreateR1<float>({12.8f, 14.6f})}),
537                          {pred_arg.get()}, error_spec_);
538 }
539 
540 // Test true and false computations that return a nested tuple.
XLA_TEST_F(ConditionalOpTest,ReturnNestedTuple)541 XLA_TEST_F(ConditionalOpTest, ReturnNestedTuple) {
542   XlaBuilder true_builder(TestName() + ".true");
543   {
544     Parameter(&true_builder, 0, empty_tuple_, "tuple");
545     auto true_constant1 = ConstantR0<float>(&true_builder, 12.2f);
546     auto true_constant2 = ConstantR1<float>(&true_builder, {12.8f, 14.6f});
547     auto true_constant3 = ConstantR1<float>(&true_builder, {25.4f, 29.8f});
548     auto true_constant4 = ConstantR0<float>(&true_builder, 35.6f);
549     Tuple(&true_builder,
550           {Tuple(&true_builder, {true_constant1, true_constant2}),
551            Tuple(&true_builder, {true_constant3, true_constant4})});
552   }
553   auto true_builder_result = true_builder.Build();
554   EXPECT_IS_OK(true_builder_result.status());
555 
556   XlaBuilder false_builder(TestName() + ".false");
557   {
558     Parameter(&false_builder, 0, empty_tuple_, "tuple");
559     auto false_constant1 = ConstantR0<float>(&false_builder, 46.6f);
560     auto false_constant2 = ConstantR1<float>(&false_builder, {54.4f, 58.4f});
561     auto false_constant3 = ConstantR1<float>(&false_builder, {62.1f, 67.4f});
562     auto false_constant4 = ConstantR0<float>(&false_builder, 9.3f);
563     Tuple(&false_builder,
564           {Tuple(&false_builder, {false_constant1, false_constant2}),
565            Tuple(&false_builder, {false_constant3, false_constant4})});
566   }
567   auto false_builder_result = false_builder.Build();
568   EXPECT_IS_OK(false_builder_result.status());
569 
570   XlaBuilder builder(TestName());
571   XlaOp pred;
572   auto pred_arg = CreateR0Parameter<bool>(false, 0, "pred", &builder, &pred);
573   auto operands = Tuple(&builder, {});
574   Conditional(pred, operands, true_builder_result.ConsumeValueOrDie(), operands,
575               false_builder_result.ConsumeValueOrDie());
576 
577   ComputeAndCompareTuple(
578       &builder,
579       LiteralUtil::MakeTupleFromSlices(
580           {LiteralUtil::MakeTupleFromSlices(
581                {LiteralUtil::CreateR0<float>(46.6f),
582                 LiteralUtil::CreateR1<float>({54.4f, 58.4f})}),
583            LiteralUtil::MakeTupleFromSlices(
584                {LiteralUtil::CreateR1<float>({62.1f, 67.4f}),
585                 LiteralUtil::CreateR0<float>(9.3f)})}),
586       {pred_arg.get()}, error_spec_);
587 }
588 
589 // Test conditional that takes in scalar operands in the form of external
590 // params.
XLA_TEST_F(ConditionalOpTest,ScalarOperandsFromExternalParams)591 XLA_TEST_F(ConditionalOpTest, ScalarOperandsFromExternalParams) {
592   Shape r0bool = ShapeUtil::MakeShape(PRED, {});
593   XlaBuilder builder(TestName());
594 
595   XlaOp pred, operand1, operand2;
596   auto pred_arg = CreateR0Parameter<bool>(true, 0, "pred", &builder, &pred);
597   auto operand1_param =
598       CreateR0Parameter<float>(56.3f, 1, "operand1", &builder, &operand1);
599   auto operand2_param =
600       CreateR0Parameter<float>(12.7f, 2, "operand2", &builder, &operand2);
601   Conditional(pred, operand1, CreateR0CeilComputation(), operand2,
602               CreateR0FloorComputation());
603 
604   ComputeAndCompareR0<float>(
605       &builder, 57.0f,
606       {pred_arg.get(), operand1_param.get(), operand2_param.get()},
607       error_spec_);
608 }
609 
610 // Test conditional that takes in array operands in the form of external params.
XLA_TEST_F(ConditionalOpTest,ArrayOperandsFromExternalParams)611 XLA_TEST_F(ConditionalOpTest, ArrayOperandsFromExternalParams) {
612   Shape r0bool = ShapeUtil::MakeShape(PRED, {});
613   XlaBuilder builder(TestName());
614 
615   XlaOp pred, operand1, operand2;
616   auto pred_arg = CreateR0Parameter<bool>(false, 0, "pred", &builder, &pred);
617   auto operand1_param = CreateR1Parameter<float>({24.3f, 56.7f}, 1, "operand1",
618                                                  &builder, &operand1);
619   auto operand2_param = CreateR1Parameter<float>({10.2f, 11.6f}, 2, "operand2",
620                                                  &builder, &operand2);
621   Conditional(pred, operand1, CreateR1CeilComputation(), operand2,
622               CreateR1FloorComputation());
623 
624   ComputeAndCompareR1<float>(
625       &builder, {10.0f, 11.0f},
626       {pred_arg.get(), operand1_param.get(), operand2_param.get()},
627       error_spec_);
628 }
629 
630 // Test the case where one conditional is nested within another.
XLA_TEST_F(ConditionalOpTest,NestedConditionals)631 XLA_TEST_F(ConditionalOpTest, NestedConditionals) {
632   XlaBuilder inner_builder(TestName() + ".inner_conditional");
633   {
634     Shape r0bool = ShapeUtil::MakeShape(PRED, {});
635     Shape tuple_shape = ShapeUtil::MakeTupleShape({r0bool, r0f32_, r0f32_});
636     auto param0 = Parameter(&inner_builder, 0, tuple_shape, "param0");
637     auto pred_cond = GetTupleElement(param0, 0);
638     auto true_operand = GetTupleElement(param0, 1);
639     auto false_operand = GetTupleElement(param0, 2);
640     Conditional(pred_cond, true_operand, CreateR0CeilComputation(),
641                 false_operand, CreateR0FloorComputation());
642   }
643   auto inner_builder_result = inner_builder.Build();
644   EXPECT_IS_OK(inner_builder_result.status());
645 
646   XlaBuilder builder(TestName());
647   XlaOp pred1, pred2;
648   auto pred1_arg = CreateR0Parameter<bool>(true, 0, "pred1", &builder, &pred1);
649   auto pred2_arg = CreateR0Parameter<bool>(false, 1, "pred2", &builder, &pred2);
650   auto operand1 = ConstantR0<float>(&builder, 1.1f);
651   auto operand2 = ConstantR0<float>(&builder, 12.2f);
652   auto operand3 = ConstantR0<float>(&builder, 43.3f);
653   auto tuple_operand = Tuple(&builder, {pred2, operand1, operand2});
654   Conditional(pred1, tuple_operand, inner_builder_result.ConsumeValueOrDie(),
655               operand3, CreateR0IdentityComputation());
656 
657   ComputeAndCompareR0<float>(&builder, 12.0f,
658                              {pred1_arg.get(), pred2_arg.get()}, error_spec_);
659 }
660 
XLA_TEST_F(ConditionalOpTest,ConditionalInNestedComputation)661 XLA_TEST_F(ConditionalOpTest, ConditionalInNestedComputation) {
662   XlaBuilder inner_builder(TestName() + ".inner_conditional");
663   {
664     Shape r0bool = ShapeUtil::MakeShape(PRED, {});
665     Shape tuple_shape = ShapeUtil::MakeTupleShape({r0bool, r0f32_, r0f32_});
666     auto param0 = Parameter(&inner_builder, 0, tuple_shape, "param0");
667     auto pred_cond = GetTupleElement(param0, 0);
668     auto true_operand = GetTupleElement(param0, 1);
669     auto false_operand = GetTupleElement(param0, 2);
670     Conditional(pred_cond, true_operand, CreateR0CeilComputation(),
671                 false_operand, CreateR0FloorComputation());
672   }
673   auto inner_builder_result = inner_builder.Build();
674   EXPECT_IS_OK(inner_builder_result.status());
675 
676   XlaBuilder builder(TestName());
677   XlaOp pred;
678   auto pred_arg = CreateR0Parameter<bool>(false, 0, "pred", &builder, &pred);
679   auto operand1 = ConstantR0<float>(&builder, 1.1f);
680   auto operand2 = ConstantR0<float>(&builder, 12.2f);
681   auto tuple_operand = Tuple(&builder, {pred, operand1, operand2});
682   Call(&builder, inner_builder_result.ConsumeValueOrDie(), {tuple_operand});
683 
684   ComputeAndCompareR0<float>(&builder, 12.0f, {pred_arg.get()}, error_spec_);
685 }
686 
687 // Test a mismatch in the shape of the true operand and true computation.
XLA_TEST_F(ConditionalOpTest,ShapeMismatch)688 XLA_TEST_F(ConditionalOpTest, ShapeMismatch) {
689   XlaBuilder builder(TestName());
690   auto pred = ConstantR0<bool>(&builder, true);
691   auto operand1 = ConstantR0<float>(&builder, 56.0f);
692   auto operand2 = ConstantR0<float>(&builder, 12.0f);
693   auto operands = Tuple(&builder, {operand1, operand2});
694   Conditional(pred, operands, CreateR1TupleAddComputation(), operands,
695               CreateR0TupleSubComputation());
696 
697   auto result = builder.Build();
698   EXPECT_FALSE(result.ok());
699   EXPECT_THAT(result.status().error_message(),
700               ::testing::HasSubstr("operand 0 must match the shape of the "
701                                    "only parameter of branch computation 0"));
702 }
703 
XLA_TEST_F(ConditionalOpTest,SwappedInputsInSequentialConditionals)704 XLA_TEST_F(ConditionalOpTest, SwappedInputsInSequentialConditionals) {
705   Shape tuple_shape = ShapeUtil::MakeTupleShape({r0f32_, r0f32_});
706   XlaComputation swapper;
707   {
708     XlaBuilder builder(TestName() + ".swapper");
709     auto param0 = Parameter(&builder, 0, tuple_shape, "sp0");
710     auto x = GetTupleElement(param0, 0);
711     auto y = GetTupleElement(param0, 1);
712     Tuple(&builder, {y, x});
713     swapper = builder.Build().ConsumeValueOrDie();
714   }
715   XlaComputation forwarder;
716   {
717     XlaBuilder builder(TestName() + ".forwarder");
718     auto param0 = Parameter(&builder, 0, tuple_shape, "fp0");
719     auto x = GetTupleElement(param0, 0);
720     auto y = GetTupleElement(param0, 1);
721     Tuple(&builder, {x, y});
722     forwarder = builder.Build().ConsumeValueOrDie();
723   }
724   XlaComputation main;
725   {
726     XlaBuilder builder(TestName() + ".main");
727     auto param0 = Parameter(&builder, 0, tuple_shape, "mp0");
728     auto x = GetTupleElement(param0, 0);
729     auto y = GetTupleElement(param0, 1);
730     auto lt_pred = Lt(x, y);
731     auto res = Conditional(lt_pred, param0, forwarder, param0, swapper);
732     auto ge_pred = Ge(x, y);
733     Conditional(ge_pred, res, swapper, res, forwarder);
734     main = builder.Build().ConsumeValueOrDie();
735   }
736 
737   auto test_swap = [&](float a, float b) {
738     XlaBuilder builder(TestName());
739     XlaOp x, y;
740     auto x_arg = CreateR0Parameter<float>(a, 0, "x", &builder, &x);
741     auto y_arg = CreateR0Parameter<float>(b, 1, "y", &builder, &y);
742     auto tuple_operand = Tuple(&builder, {x, y});
743     Call(&builder, main, {tuple_operand});
744 
745     ComputeAndCompareTuple(
746         &builder,
747         LiteralUtil::MakeTupleFromSlices(
748             {LiteralUtil::CreateR0<float>(a), LiteralUtil::CreateR0<float>(b)}),
749         {x_arg.get(), y_arg.get()}, error_spec_);
750   };
751 
752   test_swap(3.11f, 9.4f);
753   test_swap(11.24f, 5.55f);
754 }
755 
756 // Test conditional that duplicates tuple elements in the then and else
757 // computations. This is a regression test for b/112550242.
XLA_TEST_F(ConditionalOpTest,DuplicateElementsConditional)758 XLA_TEST_F(ConditionalOpTest, DuplicateElementsConditional) {
759   const Shape scalar = ShapeUtil::MakeShape(S32, {});
760   const Shape tuple2 = ShapeUtil::MakeTupleShape({scalar, scalar});
761   XlaComputation then_comp;
762   {
763     XlaBuilder builder(TestName() + ".then");
764     auto p = Parameter(&builder, 0, tuple2, "then.p");
765     auto e0 = GetTupleElement(p, 0);
766     auto e1 = GetTupleElement(p, 1);
767     Tuple(&builder, {e0, e1, e0});
768     then_comp = builder.Build().ConsumeValueOrDie();
769   }
770   XlaComputation else_comp;
771   {
772     XlaBuilder builder(TestName() + ".else");
773     auto p = Parameter(&builder, 0, tuple2, "else.p");
774     auto e0 = GetTupleElement(p, 0);
775     auto e1 = GetTupleElement(p, 1);
776     Tuple(&builder, {e0, e1, e1});
777     else_comp = builder.Build().ConsumeValueOrDie();
778   }
779 
780   {
781     // Pred is true case.
782     std::vector<Literal> args;
783     args.push_back(
784         LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR0<int32>(123),
785                                           LiteralUtil::CreateR0<int32>(-42)}));
786     args.push_back(LiteralUtil::CreateR0<bool>(true));
787     XlaBuilder builder(TestName() + ".main");
788     auto p = Parameter(&builder, 0, tuple2, "p0");
789     auto p_pred = Parameter(&builder, 1, ShapeUtil::MakeShape(PRED, {}), "p1");
790     Conditional(p_pred, p, then_comp, p, else_comp);
791     ComputeAndCompare(&builder, args);
792   }
793   {
794     // Pred is false case.
795     std::vector<Literal> args;
796     args.push_back(
797         LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR0<int32>(123),
798                                           LiteralUtil::CreateR0<int32>(-42)}));
799     args.push_back(LiteralUtil::CreateR0<bool>(false));
800     XlaBuilder builder(TestName() + ".main");
801     auto p = Parameter(&builder, 0, tuple2, "p0");
802     auto p_pred = Parameter(&builder, 1, ShapeUtil::MakeShape(PRED, {}), "p1");
803     Conditional(p_pred, p, then_comp, p, else_comp);
804     ComputeAndCompare(&builder, args);
805   }
806 }
807 
808 }  // namespace
809 }  // namespace xla
810