1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <random>
17 #include "tensorflow/compiler/xla/client/xla_builder.h"
18 #include "tensorflow/compiler/xla/client/xla_computation.h"
19 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
20 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
21 #include "tensorflow/compiler/xla/tests/test_macros.h"
22
23 namespace xla {
24 namespace {
25
26 class ConditionalOpTest : public ClientLibraryTestBase {
27 protected:
CreateR0ConstantComputation(float value)28 XlaComputation CreateR0ConstantComputation(float value) {
29 XlaBuilder builder("Constant");
30 Parameter(&builder, 0, empty_tuple_, "tuple");
31 ConstantR0<float>(&builder, value);
32 auto build_status = builder.Build();
33 EXPECT_IS_OK(build_status.status());
34 return build_status.ConsumeValueOrDie();
35 }
36
CreateR0IdentityComputation()37 XlaComputation CreateR0IdentityComputation() {
38 XlaBuilder builder("Identity");
39 Parameter(&builder, 0, r0f32_, "x");
40 auto build_status = builder.Build();
41 EXPECT_IS_OK(build_status.status());
42 return build_status.ConsumeValueOrDie();
43 }
44
CreateCeilComputation(const Shape & shape)45 XlaComputation CreateCeilComputation(const Shape& shape) {
46 XlaBuilder builder("Ceil");
47 auto param = Parameter(&builder, 0, shape, "param");
48 Ceil(param);
49 auto build_status = builder.Build();
50 EXPECT_IS_OK(build_status.status());
51 return build_status.ConsumeValueOrDie();
52 }
53
CreateR0CeilComputation()54 XlaComputation CreateR0CeilComputation() {
55 return CreateCeilComputation(r0f32_);
56 }
57
CreateR1CeilComputation()58 XlaComputation CreateR1CeilComputation() {
59 return CreateCeilComputation(r1s2f32_);
60 }
61
CreateFloorComputation(const Shape & shape)62 XlaComputation CreateFloorComputation(const Shape& shape) {
63 XlaBuilder builder("Floor");
64 auto param = Parameter(&builder, 0, shape, "param");
65 Floor(param);
66 auto build_status = builder.Build();
67 EXPECT_IS_OK(build_status.status());
68 return build_status.ConsumeValueOrDie();
69 }
70
CreateR0FloorComputation()71 XlaComputation CreateR0FloorComputation() {
72 return CreateFloorComputation(r0f32_);
73 }
74
CreateR1FloorComputation()75 XlaComputation CreateR1FloorComputation() {
76 return CreateFloorComputation(r1s2f32_);
77 }
78
CreateTupleCeilComputation(const string & computation_name,const Shape & tuple_shape)79 XlaComputation CreateTupleCeilComputation(const string& computation_name,
80 const Shape& tuple_shape) {
81 XlaBuilder builder(computation_name);
82 auto tuple = Parameter(&builder, 0, tuple_shape, "tuple");
83 auto x = GetTupleElement(tuple, 0);
84 auto y = GetTupleElement(tuple, 1);
85 auto x_ceil = Ceil(x);
86 auto y_ceil = Ceil(y);
87 Tuple(&builder, {x_ceil, y_ceil});
88 auto build_status = builder.Build();
89 EXPECT_IS_OK(build_status.status());
90 return build_status.ConsumeValueOrDie();
91 }
92
CreateR0TupleCeilComputation()93 XlaComputation CreateR0TupleCeilComputation() {
94 return CreateTupleCeilComputation("CeilR0", tuple_2_r0f32_);
95 }
96
CreateR1TupleCeilComputation()97 XlaComputation CreateR1TupleCeilComputation() {
98 return CreateTupleCeilComputation("CeilR1", tuple_2_r1s2f32_);
99 }
100
CreateTupleFloorComputation(const string & computation_name,const Shape & tuple_shape)101 XlaComputation CreateTupleFloorComputation(const string& computation_name,
102 const Shape& tuple_shape) {
103 XlaBuilder builder(computation_name);
104 auto tuple = Parameter(&builder, 0, tuple_shape, "tuple");
105 auto x = GetTupleElement(tuple, 0);
106 auto y = GetTupleElement(tuple, 1);
107 auto x_floor = Floor(x);
108 auto y_floor = Floor(y);
109 Tuple(&builder, {x_floor, y_floor});
110 auto build_status = builder.Build();
111 EXPECT_IS_OK(build_status.status());
112 return build_status.ConsumeValueOrDie();
113 }
114
CreateR0TupleFloorComputation()115 XlaComputation CreateR0TupleFloorComputation() {
116 return CreateTupleFloorComputation("FloorR0", tuple_2_r0f32_);
117 }
118
CreateR1TupleFloorComputation()119 XlaComputation CreateR1TupleFloorComputation() {
120 return CreateTupleFloorComputation("FloorR1", tuple_2_r1s2f32_);
121 }
122
CreateTupleAddComputation(const string & computation_name,const Shape & tuple_shape)123 XlaComputation CreateTupleAddComputation(const string& computation_name,
124 const Shape& tuple_shape) {
125 XlaBuilder builder(computation_name);
126 auto tuple = Parameter(&builder, 0, tuple_shape, "tuple");
127 auto x = GetTupleElement(tuple, 0);
128 auto y = GetTupleElement(tuple, 1);
129 Add(x, y);
130 auto build_status = builder.Build();
131 EXPECT_IS_OK(build_status.status());
132 return build_status.ConsumeValueOrDie();
133 }
134
CreateR0TupleAddComputation()135 XlaComputation CreateR0TupleAddComputation() {
136 return CreateTupleAddComputation("AddR0", tuple_2_r0f32_);
137 }
138
CreateR1TupleAddComputation()139 XlaComputation CreateR1TupleAddComputation() {
140 return CreateTupleAddComputation("AddR1", tuple_2_r1s2f32_);
141 }
142
CreateTupleSubComputation(const string & computation_name,const Shape & tuple_shape)143 XlaComputation CreateTupleSubComputation(const string& computation_name,
144 const Shape& tuple_shape) {
145 XlaBuilder builder(computation_name);
146 auto tuple = Parameter(&builder, 0, tuple_shape, "tuple");
147 auto x = GetTupleElement(tuple, 0);
148 auto y = GetTupleElement(tuple, 1);
149 Sub(x, y);
150 auto build_status = builder.Build();
151 EXPECT_IS_OK(build_status.status());
152 return build_status.ConsumeValueOrDie();
153 }
154
CreateR0TupleSubComputation()155 XlaComputation CreateR0TupleSubComputation() {
156 return CreateTupleSubComputation("SubR0", tuple_2_r0f32_);
157 }
158
CreateR1TupleSubComputation()159 XlaComputation CreateR1TupleSubComputation() {
160 return CreateTupleSubComputation("SubR1", tuple_2_r1s2f32_);
161 }
162
163 Shape r0f32_ = ShapeUtil::MakeShape(F32, {});
164 Shape r1s2f32_ = ShapeUtil::MakeShape(F32, {2});
165 Shape tuple_2_r0f32_ = ShapeUtil::MakeTupleShape(
166 {ShapeUtil::MakeShape(F32, {}), ShapeUtil::MakeShape(F32, {})});
167 Shape tuple_2_r1s2f32_ = ShapeUtil::MakeTupleShape(
168 {ShapeUtil::MakeShape(F32, {2}), ShapeUtil::MakeShape(F32, {2})});
169 Shape empty_tuple_ = ShapeUtil::MakeTupleShape({});
170 ErrorSpec error_spec_{0.001};
171 };
172
173 // Test fixture to run indexed conditional (switch/case) tests with varying
174 // number of branches.
175 class CaseOpTest : public ConditionalOpTest,
176 public ::testing::WithParamInterface<int> {};
177
178 // Test true and false computations that do not take any parameters.
XLA_TEST_F(ConditionalOpTest,Parameters0)179 XLA_TEST_F(ConditionalOpTest, Parameters0) {
180 XlaBuilder builder(TestName());
181 XlaOp pred;
182 auto pred_arg = CreateR0Parameter<bool>(true, 0, "pred", &builder, &pred);
183 auto operands = Tuple(&builder, {});
184 auto true_computation = CreateR0ConstantComputation(56.0f);
185 auto false_computation = CreateR0ConstantComputation(12.0f);
186 Conditional(pred, operands, true_computation, operands, false_computation);
187
188 ComputeAndCompareR0<float>(&builder, 56.0f, {pred_arg.get()}, error_spec_);
189 }
190
191 // Test branch computations that do not take any parameters.
XLA_TEST_P(CaseOpTest,Parameters0)192 XLA_TEST_P(CaseOpTest, Parameters0) {
193 int num_branches = GetParam();
194 for (int bi = -1; bi <= num_branches; ++bi) {
195 SCOPED_TRACE(bi);
196 XlaBuilder builder(TestName());
197 XlaOp branch_index;
198 auto branch_index_arg = CreateR0Parameter<int32>(bi, 0, "branch_index_arg",
199 &builder, &branch_index);
200 auto operand = Tuple(&builder, {});
201
202 std::vector<XlaOp> operands(num_branches, operand);
203 std::vector<XlaComputation> branches;
204 branches.reserve(num_branches);
205 std::vector<const XlaComputation*> branches_p(num_branches);
206 for (int i = 0; i < num_branches; ++i) {
207 branches.emplace_back(
208 CreateR0ConstantComputation(static_cast<float>(i) * 10));
209 branches_p[i] = &branches[i];
210 }
211 Conditional(branch_index, branches_p, operands);
212
213 float expected = 10 * static_cast<float>((bi < 0 || bi >= num_branches)
214 ? num_branches - 1
215 : bi);
216 ComputeAndCompareR0<float>(&builder, expected, {branch_index_arg.get()},
217 error_spec_);
218 }
219 }
220
221 // Test true and false computations that take in 1 parameter.
XLA_TEST_F(ConditionalOpTest,Parameters1)222 XLA_TEST_F(ConditionalOpTest, Parameters1) {
223 XlaBuilder builder(TestName());
224 XlaOp pred;
225 auto pred_arg = CreateR0Parameter<bool>(false, 0, "pred", &builder, &pred);
226 auto operand1 = ConstantR0<float>(&builder, 56.0f);
227 auto operand2 = ConstantR0<float>(&builder, 12.0f);
228 auto identity = CreateR0IdentityComputation();
229 Conditional(pred, operand1, identity, operand2, identity);
230
231 ComputeAndCompareR0<float>(&builder, 12.0f, {pred_arg.get()}, error_spec_);
232 }
233
234 // Test branch computations that take in 1 parameter.
XLA_TEST_P(CaseOpTest,Parameters1)235 XLA_TEST_P(CaseOpTest, Parameters1) {
236 int num_branches = GetParam();
237 for (int bi = -1; bi <= num_branches; ++bi) {
238 SCOPED_TRACE(bi);
239 XlaBuilder builder(TestName());
240 XlaOp branch_index;
241 auto branch_index_arg = CreateR0Parameter<int32>(bi, 0, "branch_index_arg",
242 &builder, &branch_index);
243
244 auto make_branch = [&builder, this](int i) {
245 auto sb = builder.CreateSubBuilder(absl::StrCat("branch_", i));
246 Add(ConstantR0<float>(sb.get(), static_cast<float>(i)),
247 Parameter(sb.get(), 0, r0f32_, "p0"));
248 return sb->BuildAndNoteError();
249 };
250 std::vector<XlaComputation> branches;
251 branches.reserve(num_branches);
252 std::vector<const XlaComputation*> branches_p(num_branches);
253 std::vector<XlaOp> operands;
254 operands.reserve(num_branches);
255 std::vector<float> expecteds(num_branches);
256 for (int i = 0; i < num_branches; ++i) {
257 branches.emplace_back(make_branch(i));
258 branches_p[i] = &branches[i];
259 auto fi = static_cast<float>(i);
260 operands.emplace_back(ConstantR0<float>(&builder, 10 * fi + 7));
261 expecteds[i] = 10 * fi + 7 + fi;
262 }
263
264 Conditional(branch_index, branches_p, operands);
265 float expected = (bi < 0 || bi >= num_branches)
266 ? expecteds[num_branches - 1]
267 : expecteds[bi];
268 ComputeAndCompareR0<float>(&builder, expected, {branch_index_arg.get()},
269 error_spec_);
270 }
271 }
272
273 // Test conditional with two different computations in the true and false cases
274 // that take in different arguments.
XLA_TEST_F(ConditionalOpTest,DiffComputationsDiffArgs)275 XLA_TEST_F(ConditionalOpTest, DiffComputationsDiffArgs) {
276 XlaBuilder builder(TestName());
277 XlaOp pred;
278 auto pred_arg = CreateR0Parameter<bool>(false, 0, "pred", &builder, &pred);
279 auto operand1 = ConstantR0<float>(&builder, 56.4f);
280 auto operand2 = ConstantR0<float>(&builder, 12.6f);
281 Conditional(pred, operand1, CreateR0CeilComputation(), operand2,
282 CreateR0FloorComputation());
283
284 ComputeAndCompareR0<float>(&builder, 12.0f, {pred_arg.get()}, error_spec_);
285 }
286
287 // Test conditional with two different computations in the true and false cases
288 // that take in the same arguments.
XLA_TEST_F(ConditionalOpTest,DiffComputationsSameArg)289 XLA_TEST_F(ConditionalOpTest, DiffComputationsSameArg) {
290 XlaBuilder builder(TestName());
291 XlaOp pred;
292 auto pred_arg = CreateR0Parameter<bool>(false, 0, "pred", &builder, &pred);
293 auto operand = ConstantR0<float>(&builder, 12.6f);
294 Conditional(pred, operand, CreateR0CeilComputation(), operand,
295 CreateR0FloorComputation());
296
297 ComputeAndCompareR0<float>(&builder, 12.0f, {pred_arg.get()}, error_spec_);
298 }
299
300 // Test conditional with the same computation in the true and false cases but
301 // take in different arguments.
XLA_TEST_F(ConditionalOpTest,SameComputationDiffArgs)302 XLA_TEST_F(ConditionalOpTest, SameComputationDiffArgs) {
303 XlaBuilder builder(TestName());
304 XlaOp pred;
305 auto pred_arg = CreateR0Parameter<bool>(false, 0, "pred", &builder, &pred);
306 auto operand1 = ConstantR0<float>(&builder, 56.4f);
307 auto operand2 = ConstantR0<float>(&builder, 12.6f);
308 auto floor = CreateR0FloorComputation();
309 Conditional(pred, operand1, floor, operand2, floor);
310
311 ComputeAndCompareR0<float>(&builder, 12.0f, {pred_arg.get()}, error_spec_);
312 }
313
314 // Test conditional with the same computation in the true and false cases that
315 // take in the same arguments.
XLA_TEST_F(ConditionalOpTest,SameComputationSameArg)316 XLA_TEST_F(ConditionalOpTest, SameComputationSameArg) {
317 XlaBuilder builder(TestName());
318 XlaOp pred;
319 auto pred_arg = CreateR0Parameter<bool>(false, 0, "pred", &builder, &pred);
320 auto operand = ConstantR0<float>(&builder, 12.6f);
321 auto floor = CreateR0FloorComputation();
322 Conditional(pred, operand, floor, operand, floor);
323
324 ComputeAndCompareR0<float>(&builder, 12.0f, {pred_arg.get()}, error_spec_);
325 }
326
327 // Test conditional with different instances of the same computation in the true
328 // and false cases.
XLA_TEST_F(ConditionalOpTest,SameComputationDiffInstances)329 XLA_TEST_F(ConditionalOpTest, SameComputationDiffInstances) {
330 XlaBuilder builder(TestName());
331 XlaOp pred;
332 auto pred_arg = CreateR0Parameter<bool>(false, 0, "pred", &builder, &pred);
333 auto operand1 = ConstantR0<float>(&builder, 56.4f);
334 auto operand2 = ConstantR0<float>(&builder, 12.6f);
335 Conditional(pred, operand1, CreateR0FloorComputation(), operand2,
336 CreateR0FloorComputation());
337
338 ComputeAndCompareR0<float>(&builder, 12.0f, {pred_arg.get()}, error_spec_);
339 }
340
341 // Test the case when a call invokes a computation that contains a conditional.
XLA_TEST_F(ConditionalOpTest,ConditionalWithCall)342 XLA_TEST_F(ConditionalOpTest, ConditionalWithCall) {
343 Shape r0bool = ShapeUtil::MakeShape(PRED, {});
344 XlaBuilder inner_builder(TestName() + ".inner_conditional");
345 auto pred_cond = Parameter(&inner_builder, 0, r0bool, "param0");
346 auto true_operand = Parameter(&inner_builder, 1, r0f32_, "param1");
347 auto false_operand = Parameter(&inner_builder, 2, r0f32_, "param2");
348 Conditional(pred_cond, true_operand, CreateR0CeilComputation(), false_operand,
349 CreateR0FloorComputation());
350 auto inner_builder_result = inner_builder.Build();
351
352 XlaBuilder builder(TestName());
353 XlaOp pred;
354 auto pred_arg = CreateR0Parameter<bool>(false, 0, "pred", &builder, &pred);
355 auto operand1 = ConstantR0<float>(&builder, 56.4f);
356 auto operand2 = ConstantR0<float>(&builder, 12.6f);
357 Call(&builder, inner_builder_result.ConsumeValueOrDie(),
358 {pred, operand1, operand2});
359
360 ComputeAndCompareR0<float>(&builder, 12.0f, {pred_arg.get()}, error_spec_);
361 }
362
363 // Test true and false computations that take in 2 parameters and predicate is
364 // true.
XLA_TEST_F(ConditionalOpTest,Parameters2TrueBranch)365 XLA_TEST_F(ConditionalOpTest, Parameters2TrueBranch) {
366 XlaBuilder builder(TestName());
367 XlaOp pred;
368 auto pred_arg = CreateR0Parameter<bool>(true, 0, "pred", &builder, &pred);
369 auto operand1 = ConstantR0<float>(&builder, 56.0f);
370 auto operand2 = ConstantR0<float>(&builder, 12.0f);
371 auto operands = Tuple(&builder, {operand1, operand2});
372 Conditional(pred, operands, CreateR0TupleAddComputation(), operands,
373 CreateR0TupleSubComputation());
374
375 ComputeAndCompareR0<float>(&builder, 68.0f, {pred_arg.get()}, error_spec_);
376 }
377
378 // Test true and false computations that take in 2 parameters and predicate is
379 // false.
XLA_TEST_F(ConditionalOpTest,Parameters2FalseBranch)380 XLA_TEST_F(ConditionalOpTest, Parameters2FalseBranch) {
381 XlaBuilder builder(TestName());
382 XlaOp pred;
383 auto pred_arg = CreateR0Parameter<bool>(false, 0, "pred", &builder, &pred);
384 auto operand1 = ConstantR0<float>(&builder, 56.0f);
385 auto operand2 = ConstantR0<float>(&builder, 12.0f);
386 auto operands = Tuple(&builder, {operand1, operand2});
387 Conditional(pred, operands, CreateR0TupleAddComputation(), operands,
388 CreateR0TupleSubComputation());
389
390 ComputeAndCompareR0<float>(&builder, 44.0f, {pred_arg.get()}, error_spec_);
391 }
392
393 // Test true and false computations that take in 2 array parameters and
394 // predicate is true.
XLA_TEST_F(ConditionalOpTest,Parameters2ArrayTrueBranch)395 XLA_TEST_F(ConditionalOpTest, Parameters2ArrayTrueBranch) {
396 XlaBuilder builder(TestName());
397 XlaOp pred;
398 auto pred_arg = CreateR0Parameter<bool>(true, 0, "pred", &builder, &pred);
399 auto operand1 = ConstantR1<float>(&builder, {24.0f, 56.0f});
400 auto operand2 = ConstantR1<float>(&builder, {10.0f, 11.0f});
401 auto operands = Tuple(&builder, {operand1, operand2});
402 Conditional(pred, operands, CreateR1TupleAddComputation(), operands,
403 CreateR1TupleSubComputation());
404
405 ComputeAndCompareR1<float>(&builder, {34.0f, 67.0f}, {pred_arg.get()},
406 error_spec_);
407 }
408
409 // Test branch computations that take in 2 array parameters.
XLA_TEST_P(CaseOpTest,Parameters2Array)410 XLA_TEST_P(CaseOpTest, Parameters2Array) {
411 int num_branches = GetParam();
412 for (int bi = -1; bi <= num_branches; ++bi) {
413 SCOPED_TRACE(bi);
414 XlaBuilder builder(TestName());
415 XlaOp branch_index;
416 auto branch_index_arg =
417 CreateR0Parameter<int32>(bi, 0, "pred", &builder, &branch_index);
418 auto operand1 = ConstantR1<float>(&builder, {24.0f, 56.0f});
419 auto operand2 = ConstantR1<float>(&builder, {10.0f, 11.0f});
420 auto operands = Tuple(&builder, {operand1, operand2});
421 auto make_branch = [&builder, this](int i) {
422 auto sb = builder.CreateSubBuilder(absl::StrCat("branch_", i));
423 auto p = Parameter(sb.get(), 0, tuple_2_r1s2f32_, "p0");
424 Add(Mul(ConstantR0<float>(sb.get(), static_cast<float>(i)),
425 GetTupleElement(p, 0)),
426 GetTupleElement(p, 1));
427 return sb->BuildAndNoteError();
428 };
429 std::vector<XlaComputation> branches;
430 branches.reserve(num_branches);
431 std::vector<const XlaComputation*> branches_p(num_branches);
432 for (int i = 0; i < num_branches; ++i) {
433 branches.emplace_back(make_branch(i));
434 branches_p[i] = &branches[i];
435 }
436 Conditional(branch_index, branches_p,
437 std::vector<XlaOp>(num_branches, operands));
438 auto modified_bi = static_cast<float>(
439 (bi < 0 || bi >= num_branches) ? num_branches - 1 : bi);
440 ComputeAndCompareR1<float>(
441 &builder, {24.0f * modified_bi + 10, 56.0f * modified_bi + 11},
442 {branch_index_arg.get()}, error_spec_);
443 }
444 }
445
446 INSTANTIATE_TEST_SUITE_P(CaseOpTest_Instantiation, CaseOpTest,
447 ::testing::Values(1, 2, 3, 4, 5));
448
449 // Test true and false computations that take in 2 array parameters and
450 // predicate is false.
XLA_TEST_F(ConditionalOpTest,Parameters2ArrayFalseBranch)451 XLA_TEST_F(ConditionalOpTest, Parameters2ArrayFalseBranch) {
452 XlaBuilder builder(TestName());
453 XlaOp pred;
454 auto pred_arg = CreateR0Parameter<bool>(false, 0, "pred", &builder, &pred);
455 auto operand1 = ConstantR1<float>(&builder, {24.0f, 56.0f});
456 auto operand2 = ConstantR1<float>(&builder, {10.0f, 11.0f});
457 auto operands = Tuple(&builder, {operand1, operand2});
458 Conditional(pred, operands, CreateR1TupleAddComputation(), operands,
459 CreateR1TupleSubComputation());
460
461 ComputeAndCompareR1<float>(&builder, {14.0f, 45.0f}, {pred_arg.get()},
462 error_spec_);
463 }
464
465 // Test true and false computations that return a tuple of scalars.
XLA_TEST_F(ConditionalOpTest,ReturnTupleOfScalars)466 XLA_TEST_F(ConditionalOpTest, ReturnTupleOfScalars) {
467 XlaBuilder builder(TestName());
468 XlaOp pred;
469 auto pred_arg = CreateR0Parameter<bool>(false, 0, "pred", &builder, &pred);
470 auto operands = Tuple(&builder, {ConstantR0<float>(&builder, 12.2f),
471 ConstantR0<float>(&builder, 25.6f)});
472 Conditional(pred, operands, CreateR0TupleCeilComputation(), operands,
473 CreateR0TupleFloorComputation());
474
475 ComputeAndCompareTuple(
476 &builder,
477 LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR0<float>(12.0f),
478 LiteralUtil::CreateR0<float>(25.0f)}),
479 {pred_arg.get()}, error_spec_);
480 }
481
482 // Test true and false computations that return a tuple of arrays.
XLA_TEST_F(ConditionalOpTest,ReturnTupleOfArrays)483 XLA_TEST_F(ConditionalOpTest, ReturnTupleOfArrays) {
484 XlaBuilder builder(TestName());
485 XlaOp pred;
486 auto pred_arg = CreateR0Parameter<bool>(true, 0, "pred", &builder, &pred);
487 auto operands =
488 Tuple(&builder, {ConstantR1<float>(&builder, {12.2f, 15.8f}),
489 ConstantR1<float>(&builder, {25.6f, 29.2f})});
490 Conditional(pred, operands, CreateR1TupleCeilComputation(), operands,
491 CreateR1TupleFloorComputation());
492
493 ComputeAndCompareTuple(&builder,
494 LiteralUtil::MakeTupleFromSlices(
495 {LiteralUtil::CreateR1<float>({13.0f, 16.0f}),
496 LiteralUtil::CreateR1<float>({26.0f, 30.0f})}),
497 {pred_arg.get()}, error_spec_);
498 }
499
500 // Test true and false computations that return a tuple of a predicate, a
501 // scalar, and an array.
XLA_TEST_F(ConditionalOpTest,ReturnTupleofPredicateScalarArray)502 XLA_TEST_F(ConditionalOpTest, ReturnTupleofPredicateScalarArray) {
503 XlaBuilder true_builder(TestName() + ".true");
504 {
505 Parameter(&true_builder, 0, empty_tuple_, "tuple");
506 auto true_pred = ConstantR0<bool>(&true_builder, true);
507 auto true_scalar = ConstantR0<float>(&true_builder, 12.2f);
508 auto true_array = ConstantR1<float>(&true_builder, {12.8f, 14.6f});
509 Tuple(&true_builder, {true_pred, true_scalar, true_array});
510 }
511 auto true_builder_result = true_builder.Build();
512 EXPECT_IS_OK(true_builder_result.status());
513
514 XlaBuilder false_builder(TestName() + ".false");
515 {
516 Parameter(&false_builder, 0, empty_tuple_, "tuple");
517 auto false_pred = ConstantR0<bool>(&false_builder, false);
518 auto false_scalar = ConstantR0<float>(&false_builder, 25.6f);
519 auto false_array = ConstantR1<float>(&false_builder, {26.4f, 32.6f});
520 Tuple(&false_builder, {false_pred, false_scalar, false_array});
521 }
522 auto false_builder_result = false_builder.Build();
523 EXPECT_IS_OK(false_builder_result.status());
524
525 XlaBuilder builder(TestName());
526 XlaOp pred;
527 auto pred_arg = CreateR0Parameter<bool>(true, 0, "pred", &builder, &pred);
528 auto operands = Tuple(&builder, {});
529 Conditional(pred, operands, true_builder_result.ConsumeValueOrDie(), operands,
530 false_builder_result.ConsumeValueOrDie());
531
532 ComputeAndCompareTuple(&builder,
533 LiteralUtil::MakeTupleFromSlices(
534 {LiteralUtil::CreateR0<bool>(true),
535 LiteralUtil::CreateR0<float>(12.2f),
536 LiteralUtil::CreateR1<float>({12.8f, 14.6f})}),
537 {pred_arg.get()}, error_spec_);
538 }
539
540 // Test true and false computations that return a nested tuple.
XLA_TEST_F(ConditionalOpTest,ReturnNestedTuple)541 XLA_TEST_F(ConditionalOpTest, ReturnNestedTuple) {
542 XlaBuilder true_builder(TestName() + ".true");
543 {
544 Parameter(&true_builder, 0, empty_tuple_, "tuple");
545 auto true_constant1 = ConstantR0<float>(&true_builder, 12.2f);
546 auto true_constant2 = ConstantR1<float>(&true_builder, {12.8f, 14.6f});
547 auto true_constant3 = ConstantR1<float>(&true_builder, {25.4f, 29.8f});
548 auto true_constant4 = ConstantR0<float>(&true_builder, 35.6f);
549 Tuple(&true_builder,
550 {Tuple(&true_builder, {true_constant1, true_constant2}),
551 Tuple(&true_builder, {true_constant3, true_constant4})});
552 }
553 auto true_builder_result = true_builder.Build();
554 EXPECT_IS_OK(true_builder_result.status());
555
556 XlaBuilder false_builder(TestName() + ".false");
557 {
558 Parameter(&false_builder, 0, empty_tuple_, "tuple");
559 auto false_constant1 = ConstantR0<float>(&false_builder, 46.6f);
560 auto false_constant2 = ConstantR1<float>(&false_builder, {54.4f, 58.4f});
561 auto false_constant3 = ConstantR1<float>(&false_builder, {62.1f, 67.4f});
562 auto false_constant4 = ConstantR0<float>(&false_builder, 9.3f);
563 Tuple(&false_builder,
564 {Tuple(&false_builder, {false_constant1, false_constant2}),
565 Tuple(&false_builder, {false_constant3, false_constant4})});
566 }
567 auto false_builder_result = false_builder.Build();
568 EXPECT_IS_OK(false_builder_result.status());
569
570 XlaBuilder builder(TestName());
571 XlaOp pred;
572 auto pred_arg = CreateR0Parameter<bool>(false, 0, "pred", &builder, &pred);
573 auto operands = Tuple(&builder, {});
574 Conditional(pred, operands, true_builder_result.ConsumeValueOrDie(), operands,
575 false_builder_result.ConsumeValueOrDie());
576
577 ComputeAndCompareTuple(
578 &builder,
579 LiteralUtil::MakeTupleFromSlices(
580 {LiteralUtil::MakeTupleFromSlices(
581 {LiteralUtil::CreateR0<float>(46.6f),
582 LiteralUtil::CreateR1<float>({54.4f, 58.4f})}),
583 LiteralUtil::MakeTupleFromSlices(
584 {LiteralUtil::CreateR1<float>({62.1f, 67.4f}),
585 LiteralUtil::CreateR0<float>(9.3f)})}),
586 {pred_arg.get()}, error_spec_);
587 }
588
589 // Test conditional that takes in scalar operands in the form of external
590 // params.
XLA_TEST_F(ConditionalOpTest,ScalarOperandsFromExternalParams)591 XLA_TEST_F(ConditionalOpTest, ScalarOperandsFromExternalParams) {
592 Shape r0bool = ShapeUtil::MakeShape(PRED, {});
593 XlaBuilder builder(TestName());
594
595 XlaOp pred, operand1, operand2;
596 auto pred_arg = CreateR0Parameter<bool>(true, 0, "pred", &builder, &pred);
597 auto operand1_param =
598 CreateR0Parameter<float>(56.3f, 1, "operand1", &builder, &operand1);
599 auto operand2_param =
600 CreateR0Parameter<float>(12.7f, 2, "operand2", &builder, &operand2);
601 Conditional(pred, operand1, CreateR0CeilComputation(), operand2,
602 CreateR0FloorComputation());
603
604 ComputeAndCompareR0<float>(
605 &builder, 57.0f,
606 {pred_arg.get(), operand1_param.get(), operand2_param.get()},
607 error_spec_);
608 }
609
610 // Test conditional that takes in array operands in the form of external params.
XLA_TEST_F(ConditionalOpTest,ArrayOperandsFromExternalParams)611 XLA_TEST_F(ConditionalOpTest, ArrayOperandsFromExternalParams) {
612 Shape r0bool = ShapeUtil::MakeShape(PRED, {});
613 XlaBuilder builder(TestName());
614
615 XlaOp pred, operand1, operand2;
616 auto pred_arg = CreateR0Parameter<bool>(false, 0, "pred", &builder, &pred);
617 auto operand1_param = CreateR1Parameter<float>({24.3f, 56.7f}, 1, "operand1",
618 &builder, &operand1);
619 auto operand2_param = CreateR1Parameter<float>({10.2f, 11.6f}, 2, "operand2",
620 &builder, &operand2);
621 Conditional(pred, operand1, CreateR1CeilComputation(), operand2,
622 CreateR1FloorComputation());
623
624 ComputeAndCompareR1<float>(
625 &builder, {10.0f, 11.0f},
626 {pred_arg.get(), operand1_param.get(), operand2_param.get()},
627 error_spec_);
628 }
629
630 // Test the case where one conditional is nested within another.
XLA_TEST_F(ConditionalOpTest,NestedConditionals)631 XLA_TEST_F(ConditionalOpTest, NestedConditionals) {
632 XlaBuilder inner_builder(TestName() + ".inner_conditional");
633 {
634 Shape r0bool = ShapeUtil::MakeShape(PRED, {});
635 Shape tuple_shape = ShapeUtil::MakeTupleShape({r0bool, r0f32_, r0f32_});
636 auto param0 = Parameter(&inner_builder, 0, tuple_shape, "param0");
637 auto pred_cond = GetTupleElement(param0, 0);
638 auto true_operand = GetTupleElement(param0, 1);
639 auto false_operand = GetTupleElement(param0, 2);
640 Conditional(pred_cond, true_operand, CreateR0CeilComputation(),
641 false_operand, CreateR0FloorComputation());
642 }
643 auto inner_builder_result = inner_builder.Build();
644 EXPECT_IS_OK(inner_builder_result.status());
645
646 XlaBuilder builder(TestName());
647 XlaOp pred1, pred2;
648 auto pred1_arg = CreateR0Parameter<bool>(true, 0, "pred1", &builder, &pred1);
649 auto pred2_arg = CreateR0Parameter<bool>(false, 1, "pred2", &builder, &pred2);
650 auto operand1 = ConstantR0<float>(&builder, 1.1f);
651 auto operand2 = ConstantR0<float>(&builder, 12.2f);
652 auto operand3 = ConstantR0<float>(&builder, 43.3f);
653 auto tuple_operand = Tuple(&builder, {pred2, operand1, operand2});
654 Conditional(pred1, tuple_operand, inner_builder_result.ConsumeValueOrDie(),
655 operand3, CreateR0IdentityComputation());
656
657 ComputeAndCompareR0<float>(&builder, 12.0f,
658 {pred1_arg.get(), pred2_arg.get()}, error_spec_);
659 }
660
XLA_TEST_F(ConditionalOpTest,ConditionalInNestedComputation)661 XLA_TEST_F(ConditionalOpTest, ConditionalInNestedComputation) {
662 XlaBuilder inner_builder(TestName() + ".inner_conditional");
663 {
664 Shape r0bool = ShapeUtil::MakeShape(PRED, {});
665 Shape tuple_shape = ShapeUtil::MakeTupleShape({r0bool, r0f32_, r0f32_});
666 auto param0 = Parameter(&inner_builder, 0, tuple_shape, "param0");
667 auto pred_cond = GetTupleElement(param0, 0);
668 auto true_operand = GetTupleElement(param0, 1);
669 auto false_operand = GetTupleElement(param0, 2);
670 Conditional(pred_cond, true_operand, CreateR0CeilComputation(),
671 false_operand, CreateR0FloorComputation());
672 }
673 auto inner_builder_result = inner_builder.Build();
674 EXPECT_IS_OK(inner_builder_result.status());
675
676 XlaBuilder builder(TestName());
677 XlaOp pred;
678 auto pred_arg = CreateR0Parameter<bool>(false, 0, "pred", &builder, &pred);
679 auto operand1 = ConstantR0<float>(&builder, 1.1f);
680 auto operand2 = ConstantR0<float>(&builder, 12.2f);
681 auto tuple_operand = Tuple(&builder, {pred, operand1, operand2});
682 Call(&builder, inner_builder_result.ConsumeValueOrDie(), {tuple_operand});
683
684 ComputeAndCompareR0<float>(&builder, 12.0f, {pred_arg.get()}, error_spec_);
685 }
686
687 // Test a mismatch in the shape of the true operand and true computation.
XLA_TEST_F(ConditionalOpTest,ShapeMismatch)688 XLA_TEST_F(ConditionalOpTest, ShapeMismatch) {
689 XlaBuilder builder(TestName());
690 auto pred = ConstantR0<bool>(&builder, true);
691 auto operand1 = ConstantR0<float>(&builder, 56.0f);
692 auto operand2 = ConstantR0<float>(&builder, 12.0f);
693 auto operands = Tuple(&builder, {operand1, operand2});
694 Conditional(pred, operands, CreateR1TupleAddComputation(), operands,
695 CreateR0TupleSubComputation());
696
697 auto result = builder.Build();
698 EXPECT_FALSE(result.ok());
699 EXPECT_THAT(result.status().error_message(),
700 ::testing::HasSubstr("operand 0 must match the shape of the "
701 "only parameter of branch computation 0"));
702 }
703
XLA_TEST_F(ConditionalOpTest,SwappedInputsInSequentialConditionals)704 XLA_TEST_F(ConditionalOpTest, SwappedInputsInSequentialConditionals) {
705 Shape tuple_shape = ShapeUtil::MakeTupleShape({r0f32_, r0f32_});
706 XlaComputation swapper;
707 {
708 XlaBuilder builder(TestName() + ".swapper");
709 auto param0 = Parameter(&builder, 0, tuple_shape, "sp0");
710 auto x = GetTupleElement(param0, 0);
711 auto y = GetTupleElement(param0, 1);
712 Tuple(&builder, {y, x});
713 swapper = builder.Build().ConsumeValueOrDie();
714 }
715 XlaComputation forwarder;
716 {
717 XlaBuilder builder(TestName() + ".forwarder");
718 auto param0 = Parameter(&builder, 0, tuple_shape, "fp0");
719 auto x = GetTupleElement(param0, 0);
720 auto y = GetTupleElement(param0, 1);
721 Tuple(&builder, {x, y});
722 forwarder = builder.Build().ConsumeValueOrDie();
723 }
724 XlaComputation main;
725 {
726 XlaBuilder builder(TestName() + ".main");
727 auto param0 = Parameter(&builder, 0, tuple_shape, "mp0");
728 auto x = GetTupleElement(param0, 0);
729 auto y = GetTupleElement(param0, 1);
730 auto lt_pred = Lt(x, y);
731 auto res = Conditional(lt_pred, param0, forwarder, param0, swapper);
732 auto ge_pred = Ge(x, y);
733 Conditional(ge_pred, res, swapper, res, forwarder);
734 main = builder.Build().ConsumeValueOrDie();
735 }
736
737 auto test_swap = [&](float a, float b) {
738 XlaBuilder builder(TestName());
739 XlaOp x, y;
740 auto x_arg = CreateR0Parameter<float>(a, 0, "x", &builder, &x);
741 auto y_arg = CreateR0Parameter<float>(b, 1, "y", &builder, &y);
742 auto tuple_operand = Tuple(&builder, {x, y});
743 Call(&builder, main, {tuple_operand});
744
745 ComputeAndCompareTuple(
746 &builder,
747 LiteralUtil::MakeTupleFromSlices(
748 {LiteralUtil::CreateR0<float>(a), LiteralUtil::CreateR0<float>(b)}),
749 {x_arg.get(), y_arg.get()}, error_spec_);
750 };
751
752 test_swap(3.11f, 9.4f);
753 test_swap(11.24f, 5.55f);
754 }
755
756 // Test conditional that duplicates tuple elements in the then and else
757 // computations. This is a regression test for b/112550242.
XLA_TEST_F(ConditionalOpTest,DuplicateElementsConditional)758 XLA_TEST_F(ConditionalOpTest, DuplicateElementsConditional) {
759 const Shape scalar = ShapeUtil::MakeShape(S32, {});
760 const Shape tuple2 = ShapeUtil::MakeTupleShape({scalar, scalar});
761 XlaComputation then_comp;
762 {
763 XlaBuilder builder(TestName() + ".then");
764 auto p = Parameter(&builder, 0, tuple2, "then.p");
765 auto e0 = GetTupleElement(p, 0);
766 auto e1 = GetTupleElement(p, 1);
767 Tuple(&builder, {e0, e1, e0});
768 then_comp = builder.Build().ConsumeValueOrDie();
769 }
770 XlaComputation else_comp;
771 {
772 XlaBuilder builder(TestName() + ".else");
773 auto p = Parameter(&builder, 0, tuple2, "else.p");
774 auto e0 = GetTupleElement(p, 0);
775 auto e1 = GetTupleElement(p, 1);
776 Tuple(&builder, {e0, e1, e1});
777 else_comp = builder.Build().ConsumeValueOrDie();
778 }
779
780 {
781 // Pred is true case.
782 std::vector<Literal> args;
783 args.push_back(
784 LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR0<int32>(123),
785 LiteralUtil::CreateR0<int32>(-42)}));
786 args.push_back(LiteralUtil::CreateR0<bool>(true));
787 XlaBuilder builder(TestName() + ".main");
788 auto p = Parameter(&builder, 0, tuple2, "p0");
789 auto p_pred = Parameter(&builder, 1, ShapeUtil::MakeShape(PRED, {}), "p1");
790 Conditional(p_pred, p, then_comp, p, else_comp);
791 ComputeAndCompare(&builder, args);
792 }
793 {
794 // Pred is false case.
795 std::vector<Literal> args;
796 args.push_back(
797 LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR0<int32>(123),
798 LiteralUtil::CreateR0<int32>(-42)}));
799 args.push_back(LiteralUtil::CreateR0<bool>(false));
800 XlaBuilder builder(TestName() + ".main");
801 auto p = Parameter(&builder, 0, tuple2, "p0");
802 auto p_pred = Parameter(&builder, 1, ShapeUtil::MakeShape(PRED, {}), "p1");
803 Conditional(p_pred, p, then_comp, p, else_comp);
804 ComputeAndCompare(&builder, args);
805 }
806 }
807
808 } // namespace
809 } // namespace xla
810