1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <cmath>
17 #include <limits>
18 #include <memory>
19 #include <numeric>
20 #include <vector>
21
22 #include "absl/base/casts.h"
23 #include "absl/types/span.h"
24 #include "tensorflow/compiler/xla/array2d.h"
25 #include "tensorflow/compiler/xla/array3d.h"
26 #include "tensorflow/compiler/xla/array4d.h"
27 #include "tensorflow/compiler/xla/client/global_data.h"
28 #include "tensorflow/compiler/xla/client/local_client.h"
29 #include "tensorflow/compiler/xla/client/xla_builder.h"
30 #include "tensorflow/compiler/xla/layout_util.h"
31 #include "tensorflow/compiler/xla/literal.h"
32 #include "tensorflow/compiler/xla/statusor.h"
33 #include "tensorflow/compiler/xla/test.h"
34 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
35 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
36 #include "tensorflow/compiler/xla/tests/test_macros.h"
37 #include "tensorflow/compiler/xla/types.h"
38 #include "tensorflow/core/platform/types.h"
39
40 namespace xla {
41 namespace {
42
43 class ArrayElementwiseOpTest : public ClientLibraryTestBase {
44 public:
45 ErrorSpec error_spec_{0.0001, 0.0001};
46 ErrorSpec strict_error_spec_{3.6e-15, 3.6e-15};
47 };
48
49 class ArrayElementwiseOpTestParamCount
50 : public ArrayElementwiseOpTest,
51 public ::testing::WithParamInterface<int> {};
52
XLA_TEST_F(ArrayElementwiseOpTest,NegConstantZeroElementF32)53 XLA_TEST_F(ArrayElementwiseOpTest, NegConstantZeroElementF32) {
54 XlaBuilder builder(TestName());
55 auto a = ConstantR1<float>(&builder, {});
56 Neg(a);
57
58 ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
59 }
60
XLA_TEST_F(ArrayElementwiseOpTest,NegConstantF32)61 XLA_TEST_F(ArrayElementwiseOpTest, NegConstantF32) {
62 XlaBuilder builder(TestName());
63 auto a = ConstantR1<float>(&builder, {-2.5f, 3.14f, 2.25f, -10.0f, 6.0f});
64 Neg(a);
65
66 ComputeAndCompareR1<float>(&builder, {2.5f, -3.14f, -2.25f, 10.0f, -6.0f}, {},
67 error_spec_);
68 }
69
XLA_TEST_F(ArrayElementwiseOpTest,NegConstantF64)70 XLA_TEST_F(ArrayElementwiseOpTest, NegConstantF64) {
71 XlaBuilder builder(TestName());
72 auto a = ConstantR1<double>(&builder, {-2.5, 3.14, 2.25, -10.0, 6.0});
73 Neg(a);
74
75 ComputeAndCompare(&builder, {}, strict_error_spec_);
76 }
77
XLA_TEST_F(ArrayElementwiseOpTest,NegConstantS32)78 XLA_TEST_F(ArrayElementwiseOpTest, NegConstantS32) {
79 XlaBuilder builder(TestName());
80 auto a = ConstantR1<int32>(&builder,
81 {-1, 0, 1, 324, std::numeric_limits<int32>::min(),
82 std::numeric_limits<int32>::max()});
83 Neg(a);
84
85 // -min == min for int32 due to an overflow. In C++ it is undefined behavior
86 // to do this calculation. For XLA we have not specified that, so it
87 // ought to work.
88 ComputeAndCompareR1<int32>(&builder,
89 {1, 0, -1, -324, std::numeric_limits<int32>::min(),
90 -std::numeric_limits<int32>::max()},
91 {});
92 }
93
XLA_TEST_F(ArrayElementwiseOpTest,NegConstantZeroElementC64)94 XLA_TEST_F(ArrayElementwiseOpTest, NegConstantZeroElementC64) {
95 XlaBuilder builder(TestName());
96 auto a = ConstantR1<complex64>(&builder, {});
97 Neg(a);
98
99 ComputeAndCompareR1<complex64>(&builder, {}, {}, error_spec_);
100 }
101
XLA_TEST_F(ArrayElementwiseOpTest,NegConstantC64)102 XLA_TEST_F(ArrayElementwiseOpTest, NegConstantC64) {
103 XlaBuilder builder(TestName());
104 auto a = ConstantR1<complex64>(
105 &builder, {{-2.5f, 1.0f}, {0.0f, 3.14f}, {2.25f, -1.0f}, {-10.0f, 0.0f}});
106 Neg(a);
107
108 ComputeAndCompareR1<complex64>(
109 &builder, {{2.5f, -1.0f}, {0.0f, -3.14f}, {-2.25f, 1.0f}, {10.0f, 0.0f}},
110 {}, error_spec_);
111 }
112
XLA_TEST_F(ArrayElementwiseOpTest,NegConstantS64)113 XLA_TEST_F(ArrayElementwiseOpTest, NegConstantS64) {
114 XlaBuilder builder(TestName());
115 auto a =
116 ConstantR1<int64>(&builder, {
117 -1,
118 1,
119 0,
120 0x12345678,
121 static_cast<int64>(0xffffffff12345678l),
122 static_cast<int64>(0x8000000000000000LL),
123 static_cast<int64>(0x8000000000000001LL),
124 });
125 Neg(a);
126 LOG(INFO) << -static_cast<int64>(0x7FFFFFFFFFFFFFFFLL);
127
128 ComputeAndCompareR1<int64>(&builder,
129 {
130 1,
131 -1,
132 0,
133 -0x12345678,
134 0xedcba988,
135 static_cast<int64>(0x8000000000000000LL),
136 -static_cast<int64>(0x8000000000000001LL),
137 },
138 {});
139 }
140
XLA_TEST_F(ArrayElementwiseOpTest,IsFiniteZeroElementF32s)141 XLA_TEST_F(ArrayElementwiseOpTest, IsFiniteZeroElementF32s) {
142 XlaBuilder builder(TestName());
143 auto a = ConstantR1<float>(&builder, {});
144 IsFinite(a);
145
146 ComputeAndCompareR1<bool>(&builder, {}, {});
147 }
148
XLA_TEST_F(ArrayElementwiseOpTest,IntPow)149 XLA_TEST_F(ArrayElementwiseOpTest, IntPow) {
150 XlaBuilder builder(TestName());
151 XlaOp lhs =
152 ConstantR1<int32>(&builder, {0, 1, 2, 3, 4, 5, -1, -2, 3, 5, 3, 1});
153 XlaOp rhs =
154 ConstantR1<int32>(&builder, {0, 3, 3, 3, 3, 3, 2, 3, 2, 10, -100, -2});
155 Pow(lhs, rhs);
156
157 std::vector<int32> expected = {1, 1, 8, 27, 64, 125, 1, -8, 9, 9765625, 0, 1};
158
159 ComputeAndCompareR1<int32>(&builder, expected, {});
160 }
161
XLA_TEST_F(ArrayElementwiseOpTest,IntPowLarge)162 XLA_TEST_F(ArrayElementwiseOpTest, IntPowLarge) {
163 XlaBuilder builder(TestName());
164 XlaOp lhs = ConstantR1<int64>(&builder, {2});
165 XlaOp rhs = ConstantR1<int64>(&builder, {62});
166 Pow(lhs, rhs);
167
168 std::vector<int64> expected = {4611686018427387904};
169
170 ComputeAndCompareR1<int64>(&builder, expected, {});
171 }
172
173 // A non-canonical quiet NaN value.
174 static const float kNonCanonicalNaN = absl::bit_cast<float>(0x7FD01234);
175
XLA_TEST_F(ArrayElementwiseOpTest,IsFiniteScalarF32)176 XLA_TEST_F(ArrayElementwiseOpTest, IsFiniteScalarF32) {
177 XlaBuilder builder(TestName());
178 IsFinite(ConstantR0<float>(&builder, NAN));
179 ComputeAndCompareR0<bool>(&builder, false, {});
180
181 EXPECT_TRUE(std::isnan(kNonCanonicalNaN));
182 IsFinite(ConstantR0<float>(&builder, kNonCanonicalNaN));
183 ComputeAndCompareR0<bool>(&builder, false, {});
184
185 const float inf = std::numeric_limits<float>::infinity();
186 IsFinite(ConstantR0<float>(&builder, inf));
187 ComputeAndCompareR0<bool>(&builder, false, {});
188
189 IsFinite(ConstantR0<float>(&builder, -inf));
190 ComputeAndCompareR0<bool>(&builder, false, {});
191
192 IsFinite(ConstantR0<float>(&builder, 0.0f));
193 ComputeAndCompareR0<bool>(&builder, true, {});
194 }
195
XLA_TEST_F(ArrayElementwiseOpTest,IsFiniteR1F32s)196 XLA_TEST_F(ArrayElementwiseOpTest, IsFiniteR1F32s) {
197 XlaBuilder builder(TestName());
198 const float inf = std::numeric_limits<float>::infinity();
199 EXPECT_TRUE(std::isnan(kNonCanonicalNaN));
200 auto a = ConstantR1<float>(&builder,
201 {{NAN, 7.0f, kNonCanonicalNaN, -1.0f, inf, -inf}});
202 IsFinite(a);
203
204 ComputeAndCompareR1<bool>(&builder, {false, true, false, true, false, false},
205 {});
206 }
207
XLA_TEST_F(ArrayElementwiseOpTest,AddTwoConstantF32s)208 XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantF32s) {
209 XlaBuilder builder(TestName());
210 auto a = ConstantR1<float>(&builder, {-2.5f, 3.14f, 2.25f, -10.0f, 6.0f});
211 auto b = ConstantR1<float>(&builder, {100.0f, 3.13f, 2.75f, 10.5f, -999.0f});
212 Add(a, b);
213
214 ComputeAndCompareR1<float>(&builder, {97.5f, 6.27f, 5.0f, 0.5f, -993.0f}, {},
215 error_spec_);
216 }
217
XLA_TEST_F(ArrayElementwiseOpTest,AddTwoConstantZeroElementF32s)218 XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantZeroElementF32s) {
219 XlaBuilder builder(TestName());
220 auto a = ConstantR1<float>(&builder, {});
221 auto b = ConstantR1<float>(&builder, {});
222 Add(a, b);
223
224 ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
225 }
226
XLA_TEST_F(ArrayElementwiseOpTest,AddTwoConstantC64s)227 XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantC64s) {
228 XlaBuilder builder(TestName());
229 auto a = ConstantR1<complex64>(
230 &builder, {{-2.5f, 0.0f}, {0.0f, 3.14f}, {2.25f, 0.0f}, {1.0f, -10.0f}});
231 auto b = ConstantR1<complex64>(
232 &builder, {{100.0f, 0.0f}, {3.13f, 0.0f}, {2.75f, 1.0f}, {-2.0f, 10.5f}});
233 Add(a, b);
234
235 ComputeAndCompareR1<complex64>(
236 &builder, {97.5f, {3.13f, 3.14f}, {5.0f, 1.0f}, {-1.0f, 0.5f}}, {},
237 error_spec_);
238 }
239
XLA_TEST_F(ArrayElementwiseOpTest,AddTwoConstantZeroElementC64s)240 XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantZeroElementC64s) {
241 XlaBuilder builder(TestName());
242 auto a = ConstantR1<complex64>(&builder, {});
243 auto b = ConstantR1<complex64>(&builder, {});
244 Add(a, b);
245
246 ComputeAndCompareR1<complex64>(&builder, {}, {}, error_spec_);
247 }
248
XLA_TEST_F(ArrayElementwiseOpTest,AddTwoConstantU64s)249 XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantU64s) {
250 XlaBuilder b(TestName());
251
252 std::vector<uint64> lhs{0xFFFFFFFF,
253 static_cast<uint64>(-1),
254 0,
255 0,
256 0x7FFFFFFFFFFFFFFFLL,
257 0x7FFFFFFFFFFFFFFLL,
258 0x8000000000000000ULL,
259 0x8000000000000000ULL,
260 1};
261 Literal lhs_literal = LiteralUtil::CreateR1<uint64>({lhs});
262 auto lhs_param = Parameter(&b, 0, lhs_literal.shape(), "lhs_param");
263 std::unique_ptr<GlobalData> lhs_data =
264 client_->TransferToServer(lhs_literal).ConsumeValueOrDie();
265
266 std::vector<uint64> rhs{1,
267 0x7FFFFFFFFFFFFFFLL,
268 0x7FFFFFFFFFFFFFFFLL,
269 0x8000000000000000ULL,
270 0,
271 static_cast<uint64>(-1),
272 0,
273 1,
274 0x8000000000000000ULL};
275 Literal rhs_literal = LiteralUtil::CreateR1<uint64>({rhs});
276 auto rhs_param = Parameter(&b, 1, rhs_literal.shape(), "rhs_param");
277 std::unique_ptr<GlobalData> rhs_data =
278 client_->TransferToServer(rhs_literal).ConsumeValueOrDie();
279
280 Add(lhs_param, rhs_param);
281
282 std::vector<uint64> expected(lhs.size());
283 for (int64_t i = 0; i < lhs.size(); ++i) {
284 expected[i] = lhs[i] + rhs[i];
285 }
286
287 ComputeAndCompareR1<uint64>(&b, expected, {lhs_data.get(), rhs_data.get()});
288 }
289
XLA_TEST_F(ArrayElementwiseOpTest,SubTwoConstantS64s)290 XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantS64s) {
291 XlaBuilder b(TestName());
292
293 std::vector<int64> lhs{static_cast<int64>(0x8000000000000000LL),
294 static_cast<int64>(0x8000000000000000LL),
295 -1,
296 0x7FFFFFFFFFFFFFFLL,
297 0x7FFFFFFFFFFFFFFFLL,
298 1,
299 0,
300 -1};
301 Literal lhs_literal = LiteralUtil::CreateR1<int64>({lhs});
302 auto lhs_param = Parameter(&b, 0, lhs_literal.shape(), "lhs_param");
303 std::unique_ptr<GlobalData> lhs_data =
304 client_->TransferToServer(lhs_literal).ConsumeValueOrDie();
305
306 std::vector<int64> rhs{-1,
307 0,
308 static_cast<int64>(0x8000000000000000LL),
309 1,
310 0,
311 0x7FFFFFFFFFFFFFFLL,
312 0x7FFFFFFFFFFFFFFFLL,
313 0x7FFFFFFFFFFFFFFFLL};
314 Literal rhs_literal = LiteralUtil::CreateR1<int64>({rhs});
315 auto rhs_param = Parameter(&b, 1, rhs_literal.shape(), "rhs_param");
316 std::unique_ptr<GlobalData> rhs_data =
317 client_->TransferToServer(rhs_literal).ConsumeValueOrDie();
318
319 Sub(lhs_param, rhs_param);
320
321 std::vector<int64> expected(lhs.size());
322 for (int64_t i = 0; i < lhs.size(); ++i) {
323 expected[i] = lhs[i] - rhs[i];
324 }
325
326 ComputeAndCompareR1<int64>(&b, expected, {lhs_data.get(), rhs_data.get()});
327 }
328
XLA_TEST_F(ArrayElementwiseOpTest,CmpTwoConstantU64s)329 XLA_TEST_F(ArrayElementwiseOpTest, CmpTwoConstantU64s) {
330 XlaBuilder b(TestName());
331
332 std::vector<uint64> lhs{static_cast<uint64>(0x8000000000000000ULL)};
333 Literal lhs_literal = LiteralUtil::CreateR1<uint64>({lhs});
334 auto lhs_param = Parameter(&b, 0, lhs_literal.shape(), "lhs_param");
335
336 std::vector<uint64> rhs{static_cast<uint64>(0x7FFFFFFFFFFFFFFFULL)};
337 Literal rhs_literal = LiteralUtil::CreateR1<uint64>({rhs});
338 auto rhs_param = Parameter(&b, 1, rhs_literal.shape(), "rhs_param");
339
340 Lt(lhs_param, rhs_param);
341
342 ComputeAndCompare(&b, {std::move(lhs_literal), std::move(rhs_literal)});
343 }
344
TEST_P(ArrayElementwiseOpTestParamCount,AddManyValues)345 TEST_P(ArrayElementwiseOpTestParamCount, AddManyValues) {
346 const int count = GetParam();
347 XlaBuilder builder(TestName());
348 std::vector<float> a_values;
349 std::vector<float> b_values;
350 for (int i = 0; i < count; ++i) {
351 a_values.push_back(i / static_cast<float>(count));
352 b_values.push_back(2 * i / static_cast<float>(count + 2));
353 }
354
355 Literal a_literal = LiteralUtil::CreateR1<float>({a_values});
356 std::unique_ptr<GlobalData> a_data =
357 client_->TransferToServer(a_literal).ConsumeValueOrDie();
358 auto a_constant = ConstantR1<float>(&builder, a_values);
359 auto a_param = Parameter(&builder, 0, a_literal.shape(), "a_param");
360
361 Literal b_literal = LiteralUtil::CreateR1<float>({b_values});
362 std::unique_ptr<GlobalData> b_data =
363 client_->TransferToServer(b_literal).ConsumeValueOrDie();
364 auto b_param = Parameter(&builder, 1, a_literal.shape(), "b_param");
365 auto b_constant = ConstantR1<float>(&builder, b_values);
366
367 auto sum1 = Add(a_constant, b_param);
368 auto sum2 = Add(a_constant, b_constant);
369 auto sum3 = Add(a_param, b_param);
370 auto sum4 = Add(a_param, b_constant);
371
372 auto sum = Add(sum1, sum2);
373 sum = Add(sum, sum3);
374 sum = Add(sum, sum4);
375
376 std::vector<float> expected;
377 for (int64_t i = 0; i < count; ++i) {
378 expected.push_back(4 * (a_values[i] + b_values[i]));
379 }
380
381 ComputeAndCompareR1<float>(&builder, expected, {a_data.get(), b_data.get()},
382 error_spec_);
383 }
384
XLA_TEST_F(ArrayElementwiseOpTest,DeeplyNestedAddWithSlices)385 XLA_TEST_F(ArrayElementwiseOpTest, DeeplyNestedAddWithSlices) {
386 XlaBuilder builder(TestName());
387 std::vector<float> values(30, 0.0);
388 auto a_literal = LiteralUtil::CreateR1<float>(values);
389 auto a = Parameter(&builder, 0, a_literal.shape(), "x");
390 auto b_literal = LiteralUtil::CreateR1<float>(values);
391 auto b = Parameter(&builder, 1, b_literal.shape(), "x");
392
393 // Construct a sequence of diamond-shaped gadgets like this:
394 //
395 // add
396 // / \
397 // slice slice
398 // \ /
399 // add
400 //
401 // Each 'left' slice removes the last element, each 'right' slice removes the
402 // first element. In this way, we index into the add with different
403 // multi-dimensional index arrays, which defeats the caching we use to avoid
404 // exponential compile time.
405 std::function<XlaOp(int64_t)> generate_recursive =
406 [&](int64_t slice_size) -> XlaOp {
407 if (slice_size == values.size()) {
408 return Add(a, b);
409 }
410 XlaOp param = generate_recursive(slice_size + 1);
411 auto slice1 = Slice(param, {0}, {slice_size}, {1});
412 auto slice2 = Slice(param, {1}, {slice_size + 1}, {1});
413 return Add(slice1, slice2);
414 };
415 generate_recursive(1);
416 auto a_data = client_->TransferToServer(a_literal).ConsumeValueOrDie();
417 auto b_data = client_->TransferToServer(b_literal).ConsumeValueOrDie();
418 ComputeAndCompareR1<float>(&builder, {0.0}, {a_data.get(), b_data.get()});
419 }
420
XLA_TEST_F(ArrayElementwiseOpTest,SubTwoConstantF32s)421 XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantF32s) {
422 XlaBuilder builder(TestName());
423 auto a = ConstantR1<float>(&builder, {-2.5f, 3.14f, 2.25f, -10.0f, 6.0f});
424 auto b = ConstantR1<float>(&builder, {100.0f, 3.13f, 2.75f, 10.5f, -999.0f});
425 Sub(a, b);
426
427 ComputeAndCompareR1<float>(&builder, {-102.5f, 0.01f, -0.5f, -20.5f, 1005.0f},
428 {}, error_spec_);
429 }
430
XLA_TEST_F(ArrayElementwiseOpTest,SubTwoConstantZeroElementF32s)431 XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantZeroElementF32s) {
432 XlaBuilder builder(TestName());
433 auto a = ConstantR1<float>(&builder, {});
434 auto b = ConstantR1<float>(&builder, {});
435 Sub(a, b);
436
437 ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
438 }
439
XLA_TEST_F(ArrayElementwiseOpTest,SubTwoConstantS32s)440 XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantS32s) {
441 XlaBuilder builder(TestName());
442 auto a = ConstantR1<int32>(&builder, {-1, 0, 2, 1000000000});
443 auto b = ConstantR1<int32>(&builder, {-1, 2, 1, -1});
444 Sub(a, b);
445
446 ComputeAndCompareR1<int32>(&builder, {0, -2, 1, 1000000001}, {});
447 }
448
XLA_TEST_F(ArrayElementwiseOpTest,SubTwoConstantZeroElementS32s)449 XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantZeroElementS32s) {
450 XlaBuilder builder(TestName());
451 auto a = ConstantR1<int32>(&builder, {});
452 auto b = ConstantR1<int32>(&builder, {});
453 Sub(a, b);
454
455 ComputeAndCompareR1<int32>(&builder, {}, {});
456 }
457
XLA_TEST_F(ArrayElementwiseOpTest,SubTwoConstantC64s)458 XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantC64s) {
459 XlaBuilder builder(TestName());
460 auto a = ConstantR1<complex64>(&builder,
461 {{-2.5f, 0.0f}, {0.0f, 3.14f}, {3.0f, 2.25f}});
462 auto b = ConstantR1<complex64>(
463 &builder, {{0.0f, 10.0f}, {3.13f, 0.0f}, {2.75f, -0.25f}});
464 Sub(a, b);
465
466 ComputeAndCompareR1<complex64>(
467 &builder, {{-2.5f, -10.0f}, {-3.13f, 3.14f}, {0.25f, 2.5f}}, {},
468 error_spec_);
469 }
470
XLA_TEST_F(ArrayElementwiseOpTest,SubTwoConstantZeroElementC64s)471 XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantZeroElementC64s) {
472 XlaBuilder builder(TestName());
473 auto a = ConstantR1<complex64>(&builder, {});
474 auto b = ConstantR1<complex64>(&builder, {});
475 Sub(a, b);
476
477 ComputeAndCompareR1<complex64>(&builder, {}, {}, error_spec_);
478 }
479
XLA_TEST_F(ArrayElementwiseOpTest,SubTwoConstantF64s)480 XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantF64s) {
481 XlaBuilder builder(TestName());
482 auto a = ConstantR1<double>(&builder, {-2.5, 3.14, 2.25, -10.0, 6.0});
483 auto b = ConstantR1<double>(&builder, {100.0, 3.13, 2.75, 10.5, -999.0});
484 Sub(a, b);
485
486 ComputeAndCompare(&builder, {}, strict_error_spec_);
487 }
488
XLA_TEST_F(ArrayElementwiseOpTest,DivTwoConstantF32s)489 XLA_TEST_F(ArrayElementwiseOpTest, DivTwoConstantF32s) {
490 XlaBuilder builder(TestName());
491 auto a = ConstantR1<float>(&builder, {-2.5f, 25.5f, 2.25f, -10.0f, 6.0f});
492 auto b = ConstantR1<float>(&builder, {10.0f, 5.1f, 1.0f, 10.0f, -6.0f});
493 Div(a, b);
494
495 ComputeAndCompareR1<float>(&builder, {-0.25f, 5.0f, 2.25f, -1.0f, -1.0f}, {},
496 error_spec_);
497 }
498
XLA_TEST_F(ArrayElementwiseOpTest,DivTwoConstantZeroElementF32s)499 XLA_TEST_F(ArrayElementwiseOpTest, DivTwoConstantZeroElementF32s) {
500 XlaBuilder builder(TestName());
501 auto a = ConstantR1<float>(&builder, {});
502 auto b = ConstantR1<float>(&builder, {});
503 Div(a, b);
504
505 ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
506 }
507
XLA_TEST_F(ArrayElementwiseOpTest,DivTwoConstantF64s)508 XLA_TEST_F(ArrayElementwiseOpTest, DivTwoConstantF64s) {
509 XlaBuilder builder(TestName());
510 auto a = ConstantR1<double>(
511 &builder, {-2.5, 25.5, 2.25, -10.0, 6.0, 1.0, 2.0, 3.2, -4.0, 0.45, 5.7,
512 0.1, 1.0, 2.0, 0.5, -1.0, -0.5, 1.0});
513 auto b = ConstantR1<double>(
514 &builder, {10.0, 5.1, 1.0, 10.0, -6.0, 0.1, 1.0, 2.0, 0.5, -1.0, -0.5,
515 2.1, 3.1, 9.9, -4.5, -11.0, -21.5, M_PI});
516 Div(a, b);
517
518 ComputeAndCompare(&builder, {}, strict_error_spec_);
519 }
520
521 class IntegerDivideOpTest : public ArrayElementwiseOpTest {
522 protected:
523 template <typename T>
TestDivRem(absl::Span<const T> dividends,absl::Span<const T> divisors,absl::Span<const T> quotients,absl::Span<const T> remainders)524 void TestDivRem(absl::Span<const T> dividends, absl::Span<const T> divisors,
525 absl::Span<const T> quotients,
526 absl::Span<const T> remainders) {
527 {
528 XlaBuilder builder(TestName());
529 XlaOp dividend;
530 XlaOp divisor;
531 auto dividend_data =
532 CreateR1Parameter<T>(dividends, 0, "dividend", &builder, ÷nd);
533 auto divisor_data =
534 CreateR1Parameter<T>(divisors, 1, "divisor", &builder, &divisor);
535 Div(dividend, divisor);
536
537 ComputeAndCompareR1<T>(&builder, quotients,
538 {dividend_data.get(), divisor_data.get()});
539 }
540
541 // Test with a compile-time constant divisor.
542 {
543 XlaBuilder builder(TestName());
544 XlaOp dividend;
545 auto dividend_data =
546 CreateR1Parameter<T>(dividends, 0, "dividend", &builder, ÷nd);
547 Div(dividend, ConstantR1<T>(&builder, divisors));
548
549 ComputeAndCompareR1<T>(&builder, quotients, {dividend_data.get()});
550 }
551
552 {
553 XlaBuilder builder(TestName());
554 XlaOp dividend;
555 XlaOp divisor;
556 auto dividend_data =
557 CreateR1Parameter<T>(dividends, 0, "dividend", &builder, ÷nd);
558 auto divisor_data =
559 CreateR1Parameter<T>(divisors, 1, "divisor", &builder, &divisor);
560 Rem(dividend, divisor);
561
562 ComputeAndCompareR1<T>(&builder, remainders,
563 {dividend_data.get(), divisor_data.get()});
564 }
565
566 // Test with a compile-time constant divisor.
567 {
568 XlaBuilder builder(TestName());
569 XlaOp dividend;
570 auto dividend_data =
571 CreateR1Parameter<T>(dividends, 0, "dividend", &builder, ÷nd);
572 Rem(dividend, ConstantR1<T>(&builder, divisors));
573
574 ComputeAndCompareR1<T>(&builder, remainders, {dividend_data.get()});
575 }
576 }
577 };
578
XLA_TEST_F(IntegerDivideOpTest,DivS32s)579 XLA_TEST_F(IntegerDivideOpTest, DivS32s) {
580 // clang-format off
581 // Some interesting values to test.
582 std::vector<int32> vals = {
583 INT32_MIN, INT32_MIN + 1, INT32_MIN + 2, -0x40000000, -0x3fffffff,
584 -271181, -1309, -17, -10, -5, -3, -2, -1, 0, 1, 2, 3, 5, 10, 17, 26, 101,
585 7919, 0x40000000, INT32_MAX - 2, INT32_MAX - 1, INT32_MAX};
586 // clang-format on
587
588 std::vector<int32> dividends, divisors, quotients, remainders;
589 for (int32_t divisor : vals) {
590 if (divisor != 0) {
591 for (int32_t dividend : vals) {
592 // Avoid integer overflow.
593 if (dividend != INT32_MIN || divisor != -1) {
594 dividends.push_back(dividend);
595 divisors.push_back(divisor);
596 quotients.push_back(dividend / divisor);
597 remainders.push_back(dividend % divisor);
598 }
599 }
600 }
601 }
602
603 TestDivRem<int32>(dividends, divisors, quotients, remainders);
604 }
605
XLA_TEST_F(IntegerDivideOpTest,SignedOverflow)606 XLA_TEST_F(IntegerDivideOpTest, SignedOverflow) {
607 std::vector<int32> dividends = {5, INT32_MIN}, divisors = {0, -1},
608 quotients = {-1, INT32_MIN}, remainders = {5, 0};
609
610 TestDivRem<int32>(dividends, divisors, quotients, remainders);
611 }
612
XLA_TEST_F(IntegerDivideOpTest,DivU32s)613 XLA_TEST_F(IntegerDivideOpTest, DivU32s) {
614 // clang-format off
615 // Some interesting values to test.
616 std::vector<uint32> vals = {
617 0, 1, 2, 17, 101, 3333, 0x7FFFFFFF, 0xABCDEF12, 0xCAFEBEEF, 0x80000000,
618 0x80000001, UINT32_MAX - 2, UINT32_MAX - 1, UINT32_MAX};
619 // clang-format on
620
621 std::vector<uint32> dividends, divisors, quotients, remainders;
622 for (uint32 divisor : vals) {
623 if (divisor != 0) {
624 for (uint32 dividend : vals) {
625 dividends.push_back(dividend);
626 divisors.push_back(divisor);
627 quotients.push_back(dividend / divisor);
628 remainders.push_back(dividend % divisor);
629 }
630 }
631 }
632
633 TestDivRem<uint32>(dividends, divisors, quotients, remainders);
634 }
635
XLA_TEST_F(IntegerDivideOpTest,UnsignedOverflow)636 XLA_TEST_F(IntegerDivideOpTest, UnsignedOverflow) {
637 std::vector<int32> dividends = {5}, divisors = {0}, quotients = {-1},
638 remainders = {5};
639
640 TestDivRem<int32>(dividends, divisors, quotients, remainders);
641 }
642
XLA_TEST_F(ArrayElementwiseOpTest,DivTwoConstantC64s)643 XLA_TEST_F(ArrayElementwiseOpTest, DivTwoConstantC64s) {
644 XlaBuilder builder(TestName());
645 auto a = ConstantR1<complex64>(
646 &builder, {{-2.5f, 1.0f}, {-25.5f, 0.0f}, {2.0f, -1.0f}});
647 auto b = ConstantR1<complex64>(&builder,
648 {{10.0f, 0.0f}, {0.0f, 1.0f}, {2.0f, -1.0f}});
649 Div(a, b);
650
651 ComputeAndCompareR1<complex64>(
652 &builder, {{-0.25f, 0.1f}, {0.0f, 25.5f}, {1.0f, 0.0f}}, {}, error_spec_);
653 }
654
XLA_TEST_F(ArrayElementwiseOpTest,DivTwoConstantZeroElementC64s)655 XLA_TEST_F(ArrayElementwiseOpTest, DivTwoConstantZeroElementC64s) {
656 XlaBuilder builder(TestName());
657 auto a = ConstantR1<complex64>(&builder, {});
658 auto b = ConstantR1<complex64>(&builder, {});
659 Div(a, b);
660
661 ComputeAndCompareR1<complex64>(&builder, {}, {}, error_spec_);
662 }
663
XLA_TEST_F(ArrayElementwiseOpTest,RemF32s)664 XLA_TEST_F(ArrayElementwiseOpTest, RemF32s) {
665 XlaBuilder builder(TestName());
666 auto a = ConstantR1<float>(
667 &builder, {-2.5f, 25.5f, 2.25f, -10.0f, 6.0f, 3.0f, 3.0f, -1.0f, -8.0f});
668 auto b = ConstantR1<float>(
669 &builder, {10.0f, 5.1f, 1.0f, 10.0f, -6.0f, 2.0f, -2.0f, 7.0f, -4.0f});
670 Rem(a, b);
671
672 ComputeAndCompareR1<float>(
673 &builder, {-2.5f, 0.0f, 0.25f, 0.0f, -0.0f, 1.0f, 1.0f, -1.0f, -0.0f}, {},
674 error_spec_);
675 }
676
XLA_TEST_F(ArrayElementwiseOpTest,RemZeroElementF32s)677 XLA_TEST_F(ArrayElementwiseOpTest, RemZeroElementF32s) {
678 XlaBuilder builder(TestName());
679 auto a = ConstantR1<float>(&builder, {});
680 auto b = ConstantR1<float>(&builder, {});
681 Rem(a, b);
682
683 ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
684 }
685
XLA_TEST_F(ArrayElementwiseOpTest,RemF64s)686 XLA_TEST_F(ArrayElementwiseOpTest, RemF64s) {
687 XlaBuilder builder(TestName());
688 auto a = ConstantR1<double>(
689 &builder, {-2.5, 25.5, 2.25, -10.0, 6.0, 3.0, 3.0, -1.0, -8.0});
690 auto b = ConstantR1<double>(
691 &builder, {10.0, 5.1, 1.0, 10.0, -6.0, 2.0, -2.0, 7.0, -4.0});
692 Rem(a, b);
693
694 ComputeAndCompareR1<double>(
695 &builder, {-2.5, 0.0, 0.25, 0.0, -0.0, 1.0, 1.0, -1.0, -0.0}, {},
696 strict_error_spec_);
697 }
698
XLA_TEST_F(ArrayElementwiseOpTest,MulTwoConstantF32s)699 XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantF32s) {
700 XlaBuilder builder(TestName());
701 auto a = ConstantR1<float>(&builder, {-2.5f, 25.5f, 2.25f, -10.0f, 6.0f});
702 auto b = ConstantR1<float>(&builder, {10.0f, 5.0f, 1.0f, 10.0f, -6.0f});
703 Mul(a, b);
704
705 ComputeAndCompareR1<float>(&builder, {-25.0f, 127.5f, 2.25f, -100.0f, -36.0f},
706 {}, error_spec_);
707 }
708
XLA_TEST_F(ArrayElementwiseOpTest,MulTwoConstantZeroElementF32s)709 XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantZeroElementF32s) {
710 XlaBuilder builder(TestName());
711 auto a = ConstantR1<float>(&builder, {});
712 auto b = ConstantR1<float>(&builder, {});
713 Mul(a, b);
714
715 ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
716 }
717
XLA_TEST_F(ArrayElementwiseOpTest,MulTwoConstantS32s)718 XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantS32s) {
719 std::vector<int32> data = {0,
720 1,
721 -1,
722 1234,
723 0x1a243514,
724 std::numeric_limits<int32>::max(),
725 std::numeric_limits<int32>::min()};
726 // Form the test data set using all products of 'data' with itself.
727 std::vector<int32> a_data, b_data, expected;
728 for (int32_t a : data) {
729 for (int32_t b : data) {
730 a_data.push_back(a);
731 b_data.push_back(b);
732 expected.push_back(static_cast<uint32>(a) * static_cast<uint32>(b));
733 }
734 }
735
736 XlaBuilder builder(TestName());
737 auto a = ConstantR1<int32>(&builder, a_data);
738 auto b = ConstantR1<int32>(&builder, b_data);
739 Mul(a, b);
740
741 ComputeAndCompareR1<int32>(&builder, expected, {});
742 }
743
XLA_TEST_F(ArrayElementwiseOpTest,MulTwoConstantZeroElementS32s)744 XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantZeroElementS32s) {
745 XlaBuilder builder(TestName());
746 auto a = ConstantR1<int32>(&builder, {});
747 auto b = ConstantR1<int32>(&builder, {});
748 Mul(a, b);
749
750 ComputeAndCompareR1<int32>(&builder, {}, {});
751 }
752
XLA_TEST_F(ArrayElementwiseOpTest,MulTwoConstantU32s)753 XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantU32s) {
754 std::vector<uint32> data = {0, 1, 0xDEADBEEF, 1234,
755 0x1a243514, 0xFFFFFFFF, 0x80808080};
756
757 // Form the test data set using all products of 'data' with itself.
758 std::vector<uint32> a_data, b_data, expected;
759 for (uint32 a : data) {
760 for (uint32 b : data) {
761 a_data.push_back(a);
762 b_data.push_back(b);
763 expected.push_back(a * b);
764 }
765 }
766
767 XlaBuilder builder(TestName());
768 auto a = ConstantR1<uint32>(&builder, a_data);
769 auto b = ConstantR1<uint32>(&builder, b_data);
770 Mul(a, b);
771
772 ComputeAndCompareR1<uint32>(&builder, expected, {});
773 }
774
XLA_TEST_F(ArrayElementwiseOpTest,MulTwoConstantC64s)775 XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantC64s) {
776 XlaBuilder builder(TestName());
777 auto a = ConstantR1<complex64>(
778 &builder, {{-2.5f, 0.0f}, {0.0f, 25.5f}, {2.0f, -10.0f}});
779 auto b = ConstantR1<complex64>(&builder,
780 {{0.0f, 10.0f}, {5.0f, 1.0f}, {10.0f, -6.0f}});
781 Mul(a, b);
782
783 ComputeAndCompareR1<complex64>(
784 &builder, {{0.0f, -25.0f}, {-25.5f, 127.5f}, {-40.0f, -112.0}}, {},
785 error_spec_);
786 }
787
XLA_TEST_F(ArrayElementwiseOpTest,MulTwoConstantZeroElementC64s)788 XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantZeroElementC64s) {
789 XlaBuilder builder(TestName());
790 auto a = ConstantR1<complex64>(&builder, {});
791 auto b = ConstantR1<complex64>(&builder, {});
792 Mul(a, b);
793
794 ComputeAndCompareR1<complex64>(&builder, {}, {}, error_spec_);
795 }
796
XLA_TEST_F(ArrayElementwiseOpTest,AndPredR1)797 XLA_TEST_F(ArrayElementwiseOpTest, AndPredR1) {
798 XlaBuilder builder(TestName());
799 auto a = ConstantR1<bool>(&builder, {false, false, true, true});
800 auto b = ConstantR1<bool>(&builder, {false, true, false, true});
801 And(a, b);
802
803 ComputeAndCompareR1<bool>(&builder, {false, false, false, true}, {});
804 }
805
XLA_TEST_F(ArrayElementwiseOpTest,AndPredR2)806 XLA_TEST_F(ArrayElementwiseOpTest, AndPredR2) {
807 XlaBuilder builder(TestName());
808 auto a = ConstantR2<bool>(&builder, {{false, false}, {true, true}});
809 auto b = ConstantR2<bool>(&builder, {{false, true}, {false, true}});
810 And(a, b);
811
812 Array2D<bool> expected_array({{false, false}, {false, true}});
813 ComputeAndCompareR2<bool>(&builder, expected_array, {});
814 }
815
XLA_TEST_F(ArrayElementwiseOpTest,AndZeroElementPredR1)816 XLA_TEST_F(ArrayElementwiseOpTest, AndZeroElementPredR1) {
817 XlaBuilder builder(TestName());
818 auto a = ConstantR1<bool>(&builder, {});
819 auto b = ConstantR1<bool>(&builder, {});
820 And(a, b);
821
822 ComputeAndCompareR1<bool>(&builder, {}, {});
823 }
824
XLA_TEST_F(ArrayElementwiseOpTest,AndS32R1)825 XLA_TEST_F(ArrayElementwiseOpTest, AndS32R1) {
826 XlaBuilder builder(TestName());
827 auto a = ConstantR1<int32>(&builder, {0, -1, -8});
828 auto b = ConstantR1<int32>(&builder, {5, -7, 12});
829 And(a, b);
830
831 ComputeAndCompareR1<int32>(&builder, {0, -7, 8}, {});
832 }
833
XLA_TEST_F(ArrayElementwiseOpTest,AndS32R2)834 XLA_TEST_F(ArrayElementwiseOpTest, AndS32R2) {
835 XlaBuilder builder(TestName());
836 auto a = ConstantR2<int32>(&builder, {{0, -5}, {-1, 5}});
837 auto b = ConstantR2<int32>(&builder, {{1, -6}, {4, 5}});
838 And(a, b);
839
840 Array2D<int32> expected_array({{0, -6}, {4, 5}});
841 ComputeAndCompareR2<int32>(&builder, expected_array, {});
842 }
843
XLA_TEST_F(ArrayElementwiseOpTest,AndZeroElementS32R1)844 XLA_TEST_F(ArrayElementwiseOpTest, AndZeroElementS32R1) {
845 XlaBuilder builder(TestName());
846 auto a = ConstantR1<int32>(&builder, {});
847 auto b = ConstantR1<int32>(&builder, {});
848 And(a, b);
849
850 ComputeAndCompareR1<int32>(&builder, {}, {});
851 }
852
XLA_TEST_F(ArrayElementwiseOpTest,AndU32R1)853 XLA_TEST_F(ArrayElementwiseOpTest, AndU32R1) {
854 XlaBuilder builder(TestName());
855 auto a = ConstantR1<int32>(&builder, {0, 1, 8});
856 auto b = ConstantR1<int32>(&builder, {5, 7, 12});
857 And(a, b);
858
859 ComputeAndCompareR1<int32>(&builder, {0, 1, 8}, {});
860 }
861
XLA_TEST_F(ArrayElementwiseOpTest,AndU32R2)862 XLA_TEST_F(ArrayElementwiseOpTest, AndU32R2) {
863 XlaBuilder builder(TestName());
864 auto a = ConstantR2<uint32>(&builder, {{0, 1}, {3, 8}});
865 auto b = ConstantR2<uint32>(&builder, {{1, 0}, {7, 6}});
866 And(a, b);
867
868 Array2D<uint32> expected_array({{0, 0}, {3, 0}});
869 ComputeAndCompareR2<uint32>(&builder, expected_array, {});
870 }
871
XLA_TEST_F(ArrayElementwiseOpTest,AndZeroElementU32R1)872 XLA_TEST_F(ArrayElementwiseOpTest, AndZeroElementU32R1) {
873 XlaBuilder builder(TestName());
874 auto a = ConstantR1<uint32>(&builder, {});
875 auto b = ConstantR1<uint32>(&builder, {});
876 And(a, b);
877
878 ComputeAndCompareR1<uint32>(&builder, {}, {});
879 }
880
XLA_TEST_F(ArrayElementwiseOpTest,OrPredR1)881 XLA_TEST_F(ArrayElementwiseOpTest, OrPredR1) {
882 XlaBuilder builder(TestName());
883 auto a = ConstantR1<bool>(&builder, {false, false, true, true});
884 auto b = ConstantR1<bool>(&builder, {false, true, false, true});
885 Or(a, b);
886
887 ComputeAndCompareR1<bool>(&builder, {false, true, true, true}, {});
888 }
889
XLA_TEST_F(ArrayElementwiseOpTest,OrPredR2)890 XLA_TEST_F(ArrayElementwiseOpTest, OrPredR2) {
891 XlaBuilder builder(TestName());
892 auto a = ConstantR2<bool>(&builder, {{false, false}, {true, true}});
893 auto b = ConstantR2<bool>(&builder, {{false, true}, {false, true}});
894 Or(a, b);
895
896 Array2D<bool> expected_array({{false, true}, {true, true}});
897 ComputeAndCompareR2<bool>(&builder, expected_array, {});
898 }
899
XLA_TEST_F(ArrayElementwiseOpTest,OrZeroElementPredR1)900 XLA_TEST_F(ArrayElementwiseOpTest, OrZeroElementPredR1) {
901 XlaBuilder builder(TestName());
902 auto a = ConstantR1<bool>(&builder, {});
903 auto b = ConstantR1<bool>(&builder, {});
904 Or(a, b);
905
906 ComputeAndCompareR1<bool>(&builder, {}, {});
907 }
908
XLA_TEST_F(ArrayElementwiseOpTest,OrS32R1)909 XLA_TEST_F(ArrayElementwiseOpTest, OrS32R1) {
910 XlaBuilder builder(TestName());
911 auto a = ConstantR1<int32>(&builder, {0, -1, 8});
912 auto b = ConstantR1<int32>(&builder, {5, -7, 4});
913 Or(a, b);
914
915 ComputeAndCompareR1<int32>(&builder, {5, -1, 12}, {});
916 }
917
XLA_TEST_F(ArrayElementwiseOpTest,OrS32R2)918 XLA_TEST_F(ArrayElementwiseOpTest, OrS32R2) {
919 XlaBuilder builder(TestName());
920 auto a = ConstantR2<int32>(&builder, {{0, -1}, {8, 8}});
921 auto b = ConstantR2<int32>(&builder, {{5, -7}, {4, 1}});
922 Or(a, b);
923
924 Array2D<int32> expected_array({{5, -1}, {12, 9}});
925 ComputeAndCompareR2<int32>(&builder, expected_array, {});
926 }
927
XLA_TEST_F(ArrayElementwiseOpTest,OrZeroElementS32R1)928 XLA_TEST_F(ArrayElementwiseOpTest, OrZeroElementS32R1) {
929 XlaBuilder builder(TestName());
930 auto a = ConstantR1<int32>(&builder, {});
931 auto b = ConstantR1<int32>(&builder, {});
932 Or(a, b);
933
934 ComputeAndCompareR1<int32>(&builder, {}, {});
935 }
936
XLA_TEST_F(ArrayElementwiseOpTest,OrU32R1)937 XLA_TEST_F(ArrayElementwiseOpTest, OrU32R1) {
938 XlaBuilder builder(TestName());
939 auto a = ConstantR1<uint32>(&builder, {0, 1, 8});
940 auto b = ConstantR1<uint32>(&builder, {5, 7, 4});
941 Or(a, b);
942
943 ComputeAndCompareR1<uint32>(&builder, {5, 7, 12}, {});
944 }
945
XLA_TEST_F(ArrayElementwiseOpTest,OrU32R2)946 XLA_TEST_F(ArrayElementwiseOpTest, OrU32R2) {
947 XlaBuilder builder(TestName());
948 auto a = ConstantR2<uint32>(&builder, {{0, 1}, {8, 8}});
949 auto b = ConstantR2<uint32>(&builder, {{5, 7}, {4, 1}});
950 Or(a, b);
951
952 Array2D<uint32> expected_array({{5, 7}, {12, 9}});
953 ComputeAndCompareR2<uint32>(&builder, expected_array, {});
954 }
955
XLA_TEST_F(ArrayElementwiseOpTest,OrZeroElementU32R1)956 XLA_TEST_F(ArrayElementwiseOpTest, OrZeroElementU32R1) {
957 XlaBuilder builder(TestName());
958 auto a = ConstantR1<uint32>(&builder, {});
959 auto b = ConstantR1<uint32>(&builder, {});
960 Or(a, b);
961
962 ComputeAndCompareR1<uint32>(&builder, {}, {});
963 }
964
XLA_TEST_F(ArrayElementwiseOpTest,XorPredR1)965 XLA_TEST_F(ArrayElementwiseOpTest, XorPredR1) {
966 XlaBuilder builder(TestName());
967 auto a = ConstantR1<bool>(&builder, {false, false, true, true});
968 auto b = ConstantR1<bool>(&builder, {false, true, false, true});
969 Xor(a, b);
970
971 ComputeAndCompareR1<bool>(&builder, {false, true, true, false}, {});
972 }
973
XLA_TEST_F(ArrayElementwiseOpTest,XorPredR2)974 XLA_TEST_F(ArrayElementwiseOpTest, XorPredR2) {
975 XlaBuilder builder(TestName());
976 auto a = ConstantR2<bool>(&builder, {{false, false}, {true, true}});
977 auto b = ConstantR2<bool>(&builder, {{false, true}, {false, true}});
978 Xor(a, b);
979
980 Array2D<bool> expected_array({{false, true}, {true, false}});
981 ComputeAndCompareR2<bool>(&builder, expected_array, {});
982 }
983
XLA_TEST_F(ArrayElementwiseOpTest,XorZeroElementPredR1)984 XLA_TEST_F(ArrayElementwiseOpTest, XorZeroElementPredR1) {
985 XlaBuilder builder(TestName());
986 auto a = ConstantR1<bool>(&builder, {});
987 auto b = ConstantR1<bool>(&builder, {});
988 Xor(a, b);
989
990 ComputeAndCompareR1<bool>(&builder, {}, {});
991 }
992
XLA_TEST_F(ArrayElementwiseOpTest,XorS32R1)993 XLA_TEST_F(ArrayElementwiseOpTest, XorS32R1) {
994 XlaBuilder builder(TestName());
995 auto a = ConstantR1<int32>(&builder, {0, -1, 8});
996 auto b = ConstantR1<int32>(&builder, {5, -7, 4});
997 Xor(a, b);
998
999 ComputeAndCompareR1<int32>(&builder, {5, 6, 12}, {});
1000 }
1001
XLA_TEST_F(ArrayElementwiseOpTest,XorS32R2)1002 XLA_TEST_F(ArrayElementwiseOpTest, XorS32R2) {
1003 XlaBuilder builder(TestName());
1004 auto a = ConstantR2<int32>(&builder, {{0, -1}, {8, 8}});
1005 auto b = ConstantR2<int32>(&builder, {{5, -7}, {4, 1}});
1006 Xor(a, b);
1007
1008 Array2D<int32> expected_array({{5, 6}, {12, 9}});
1009 ComputeAndCompareR2<int32>(&builder, expected_array, {});
1010 }
1011
XLA_TEST_F(ArrayElementwiseOpTest,XorZeroElementS32R1)1012 XLA_TEST_F(ArrayElementwiseOpTest, XorZeroElementS32R1) {
1013 XlaBuilder builder(TestName());
1014 auto a = ConstantR1<int32>(&builder, {});
1015 auto b = ConstantR1<int32>(&builder, {});
1016 Xor(a, b);
1017
1018 ComputeAndCompareR1<int32>(&builder, {}, {});
1019 }
1020
XLA_TEST_F(ArrayElementwiseOpTest,XorU32R1)1021 XLA_TEST_F(ArrayElementwiseOpTest, XorU32R1) {
1022 XlaBuilder builder(TestName());
1023 auto a = ConstantR1<uint32>(&builder, {0, 1, 8});
1024 auto b = ConstantR1<uint32>(&builder, {5, 7, 4});
1025 Xor(a, b);
1026
1027 ComputeAndCompareR1<uint32>(&builder, {5, 6, 12}, {});
1028 }
1029
XLA_TEST_F(ArrayElementwiseOpTest,XorU32R2)1030 XLA_TEST_F(ArrayElementwiseOpTest, XorU32R2) {
1031 XlaBuilder builder(TestName());
1032 auto a = ConstantR2<uint32>(&builder, {{0, 1}, {8, 8}});
1033 auto b = ConstantR2<uint32>(&builder, {{5, 7}, {4, 1}});
1034 Xor(a, b);
1035
1036 Array2D<uint32> expected_array({{5, 6}, {12, 9}});
1037 ComputeAndCompareR2<uint32>(&builder, expected_array, {});
1038 }
1039
XLA_TEST_F(ArrayElementwiseOpTest,XorZeroElementU32R1)1040 XLA_TEST_F(ArrayElementwiseOpTest, XorZeroElementU32R1) {
1041 XlaBuilder builder(TestName());
1042 auto a = ConstantR1<uint32>(&builder, {});
1043 auto b = ConstantR1<uint32>(&builder, {});
1044 Xor(a, b);
1045
1046 ComputeAndCompareR1<uint32>(&builder, {}, {});
1047 }
XLA_TEST_F(ArrayElementwiseOpTest,NotPredR1)1048 XLA_TEST_F(ArrayElementwiseOpTest, NotPredR1) {
1049 XlaBuilder builder(TestName());
1050 auto a = ConstantR1<bool>(&builder, {false, true, true, false});
1051 Not(a);
1052
1053 ComputeAndCompareR1<bool>(&builder, {true, false, false, true}, {});
1054 }
1055
XLA_TEST_F(ArrayElementwiseOpTest,NotPredR2)1056 XLA_TEST_F(ArrayElementwiseOpTest, NotPredR2) {
1057 XlaBuilder builder(TestName());
1058 auto a = ConstantR2<bool>(&builder, {{false, true}, {true, false}});
1059 Not(a);
1060
1061 Array2D<bool> expected_array({{true, false}, {false, true}});
1062 ComputeAndCompareR2<bool>(&builder, expected_array, {});
1063 }
1064
XLA_TEST_F(ArrayElementwiseOpTest,NotZeroElementPredR1)1065 XLA_TEST_F(ArrayElementwiseOpTest, NotZeroElementPredR1) {
1066 XlaBuilder builder(TestName());
1067 auto a = ConstantR1<bool>(&builder, {});
1068 Not(a);
1069
1070 ComputeAndCompareR1<bool>(&builder, {}, {});
1071 }
1072
XLA_TEST_F(ArrayElementwiseOpTest,NotS32R1)1073 XLA_TEST_F(ArrayElementwiseOpTest, NotS32R1) {
1074 XlaBuilder builder(TestName());
1075 auto a = ConstantR1<int32>(&builder, {-1, 0, 1});
1076 Not(a);
1077
1078 ComputeAndCompareR1<int32>(&builder, {0, -1, -2}, {});
1079 }
1080
XLA_TEST_F(ArrayElementwiseOpTest,NotS32R2)1081 XLA_TEST_F(ArrayElementwiseOpTest, NotS32R2) {
1082 XlaBuilder builder(TestName());
1083 auto a = ConstantR2<int32>(&builder, {{-1, 0}, {1, 8}});
1084 Not(a);
1085
1086 Array2D<int32> expected_array({{0, -1}, {-2, -9}});
1087 ComputeAndCompareR2<int32>(&builder, expected_array, {});
1088 }
1089
XLA_TEST_F(ArrayElementwiseOpTest,NotZeroElementS32R1)1090 XLA_TEST_F(ArrayElementwiseOpTest, NotZeroElementS32R1) {
1091 XlaBuilder builder(TestName());
1092 auto a = ConstantR1<int32>(&builder, {});
1093 Not(a);
1094
1095 ComputeAndCompareR1<int32>(&builder, {}, {});
1096 }
1097
XLA_TEST_F(ArrayElementwiseOpTest,NotU32R1)1098 XLA_TEST_F(ArrayElementwiseOpTest, NotU32R1) {
1099 XlaBuilder builder(TestName());
1100 auto a = ConstantR1<uint32>(&builder, {0, 4294967295});
1101 Not(a);
1102
1103 ComputeAndCompareR1<uint32>(&builder, {4294967295, 0}, {});
1104 }
1105
XLA_TEST_F(ArrayElementwiseOpTest,NotU32R2)1106 XLA_TEST_F(ArrayElementwiseOpTest, NotU32R2) {
1107 XlaBuilder builder(TestName());
1108 auto a = ConstantR2<uint32>(&builder, {{0, 4294967295}, {1, 4294967294}});
1109 Not(a);
1110
1111 Array2D<uint32> expected_array({{4294967295, 0}, {4294967294, 1}});
1112 ComputeAndCompareR2<uint32>(&builder, expected_array, {});
1113 }
1114
XLA_TEST_F(ArrayElementwiseOpTest,NotZeroElementU32R1)1115 XLA_TEST_F(ArrayElementwiseOpTest, NotZeroElementU32R1) {
1116 XlaBuilder builder(TestName());
1117 auto a = ConstantR1<uint32>(&builder, {});
1118 Not(a);
1119
1120 ComputeAndCompareR1<uint32>(&builder, {}, {});
1121 }
1122
XLA_TEST_F(ArrayElementwiseOpTest,PopcntR1)1123 XLA_TEST_F(ArrayElementwiseOpTest, PopcntR1) {
1124 XlaBuilder builder(TestName());
1125 auto a = ConstantR1<int32>(&builder, {0, 1, -15, 341});
1126 PopulationCount(a);
1127 ComputeAndCompareR1<int32>(&builder, {0, 1, 29, 5}, {});
1128 }
1129
XLA_TEST_F(ArrayElementwiseOpTest,PopcntR2)1130 XLA_TEST_F(ArrayElementwiseOpTest, PopcntR2) {
1131 XlaBuilder builder(TestName());
1132 auto a = ConstantR2<int32>(&builder, {{0, 1}, {-15, 341}});
1133 PopulationCount(a);
1134 Array2D<int32> expected_array({{0, 1}, {29, 5}});
1135 ComputeAndCompareR2<int32>(&builder, expected_array, {});
1136 }
1137
XLA_TEST_F(ArrayElementwiseOpTest,PopcntS64)1138 XLA_TEST_F(ArrayElementwiseOpTest, PopcntS64) {
1139 XlaBuilder builder(TestName());
1140 auto a = ConstantR2<int64>(&builder, {{0, -1}, {INT64_MAX, INT64_MAX - 1}});
1141 PopulationCount(a);
1142 Array2D<int64> expected_array({{0, 64}, {63, 62}});
1143 ComputeAndCompareR2<int64>(&builder, expected_array, {});
1144 }
1145
XLA_TEST_F(ArrayElementwiseOpTest,ShiftLeftS32)1146 XLA_TEST_F(ArrayElementwiseOpTest, ShiftLeftS32) {
1147 XlaBuilder builder(TestName());
1148 auto a = ConstantR1<int32>(
1149 &builder, {static_cast<int32>(0x12345678), static_cast<int32>(0xF0001000),
1150 1, 3, 77, 1, -3, 77});
1151 auto b = ConstantR1<int32>(&builder, {4, 8, 2, 7, 15, 32, 100, -1});
1152 ShiftLeft(a, b);
1153
1154 ComputeAndCompareR1<int32>(&builder,
1155 {static_cast<int32>(0x23456780), 0x00100000, 0x4,
1156 0x180, 2523136, 0, 0, 0},
1157 {});
1158 }
1159
XLA_TEST_F(ArrayElementwiseOpTest,ShiftRightArithmeticS32)1160 XLA_TEST_F(ArrayElementwiseOpTest, ShiftRightArithmeticS32) {
1161 XlaBuilder builder(TestName());
1162 auto a = ConstantR1<int32>(
1163 &builder, {static_cast<int32>(0x92345678), static_cast<int32>(0x10001000),
1164 1, 3, 77, 1, -3, 77});
1165 auto b = ConstantR1<int32>(&builder, {4, 8, 2, 7, 2, 32, 100, -1});
1166 ShiftRightArithmetic(a, b);
1167
1168 ComputeAndCompareR1<int32>(
1169 &builder,
1170 {static_cast<int32>(0xF9234567), static_cast<int32>(0x00100010), 0, 0, 19,
1171 0, -1, 0},
1172 {});
1173 }
1174
XLA_TEST_F(ArrayElementwiseOpTest,ShiftRightLogicalS32)1175 XLA_TEST_F(ArrayElementwiseOpTest, ShiftRightLogicalS32) {
1176 XlaBuilder builder(TestName());
1177 auto a = ConstantR1<int32>(
1178 &builder, {static_cast<int32>(0x92345678), static_cast<int32>(0x10001000),
1179 1, 3, 77, 1, -3, 77});
1180 auto b = ConstantR1<int32>(&builder, {4, 8, 2, 7, 5, 32, 100, -1});
1181 ShiftRightLogical(a, b);
1182
1183 ComputeAndCompareR1<int32>(&builder,
1184 {0x09234567, 0x00100010, 0, 0, 2, 0, 0, 0}, {});
1185 }
1186
XLA_TEST_F(ArrayElementwiseOpTest,ShiftLeftU32)1187 XLA_TEST_F(ArrayElementwiseOpTest, ShiftLeftU32) {
1188 XlaBuilder builder(TestName());
1189 auto a = ConstantR1<uint32>(&builder,
1190 {0x12345678, 0xF0001000, 1, 3, 77, 1, ~3u, 77});
1191 auto b = ConstantR1<uint32>(&builder, {4, 8, 2, 7, 15, 32, 100, ~0u});
1192 ShiftLeft(a, b);
1193
1194 ComputeAndCompareR1<uint32>(
1195 &builder, {0x23456780, 0x00100000, 0x4, 0x180, 2523136, 0, 0, 0}, {});
1196 }
1197
XLA_TEST_F(ArrayElementwiseOpTest,ShiftRightArithmeticU32)1198 XLA_TEST_F(ArrayElementwiseOpTest, ShiftRightArithmeticU32) {
1199 XlaBuilder builder(TestName());
1200 auto a = ConstantR1<uint32>(&builder,
1201 {0x92345678, 0x10001000, 1, 3, 77, 1, ~3u, 77});
1202 auto b = ConstantR1<uint32>(&builder, {4, 8, 2, 7, 2, 32, 100, ~0u});
1203 ShiftRightArithmetic(a, b);
1204
1205 ComputeAndCompareR1<uint32>(
1206 &builder, {0xF9234567, 0x00100010, 0, 0, 19, 0, ~0u, 0}, {});
1207 }
1208
XLA_TEST_F(ArrayElementwiseOpTest,ShiftRightLogicalU32)1209 XLA_TEST_F(ArrayElementwiseOpTest, ShiftRightLogicalU32) {
1210 XlaBuilder builder(TestName());
1211 auto a = ConstantR1<uint32>(&builder,
1212 {0x92345678, 0x10001000, 1, 3, 77, 1, ~3u, 77});
1213 auto b = ConstantR1<uint32>(&builder, {4, 8, 2, 7, 5, 32, 100, ~0u});
1214 ShiftRightLogical(a, b);
1215
1216 ComputeAndCompareR1<uint32>(&builder,
1217 {0x09234567, 0x00100010, 0, 0, 2, 0, 0, 0}, {});
1218 }
1219
XLA_TEST_F(ArrayElementwiseOpTest,CompareEqF32s)1220 XLA_TEST_F(ArrayElementwiseOpTest, CompareEqF32s) {
1221 SetFastMathDisabled(true);
1222 XlaBuilder builder(TestName());
1223 auto lhs = ConstantR1<float>(&builder, {-2.5f, 25.5f, 2.25f, NAN, 6.0f});
1224 auto rhs = ConstantR1<float>(&builder, {10.0f, 5.0f, 2.25f, 10.0f, NAN});
1225 Eq(lhs, rhs);
1226
1227 ComputeAndCompareR1<bool>(&builder, {false, false, true, false, false}, {});
1228 }
1229
XLA_TEST_F(ArrayElementwiseOpTest,CompareEqF32sTO)1230 XLA_TEST_F(ArrayElementwiseOpTest, CompareEqF32sTO) {
1231 SetFastMathDisabled(true);
1232 XlaBuilder builder(TestName());
1233 auto lhs = ConstantR1<float>(&builder, {-2.5f, 25.5f, 2.25f, NAN, 6.0f});
1234 auto rhs = ConstantR1<float>(&builder, {10.0f, 5.0f, 2.25f, NAN, NAN});
1235 EqTotalOrder(lhs, rhs);
1236
1237 ComputeAndCompareR1<bool>(&builder, {false, false, true, true, false}, {});
1238 }
1239
XLA_TEST_F(ArrayElementwiseOpTest,CompareEqZeroElementF32s)1240 XLA_TEST_F(ArrayElementwiseOpTest, CompareEqZeroElementF32s) {
1241 XlaBuilder builder(TestName());
1242 auto lhs = ConstantR1<float>(&builder, {});
1243 auto rhs = ConstantR1<float>(&builder, {});
1244 Eq(lhs, rhs);
1245
1246 ComputeAndCompareR1<bool>(&builder, {}, {});
1247 }
1248
XLA_TEST_F(ArrayElementwiseOpTest,CompareGeF32s)1249 XLA_TEST_F(ArrayElementwiseOpTest, CompareGeF32s) {
1250 SetFastMathDisabled(true);
1251 XlaBuilder builder(TestName());
1252 auto lhs = ConstantR1<float>(&builder, {-2.5f, 25.5f, 2.25f, NAN, 6.0f});
1253 auto rhs = ConstantR1<float>(&builder, {10.0f, 5.0f, 1.0f, 10.0f, NAN});
1254 Ge(lhs, rhs);
1255
1256 ComputeAndCompareR1<bool>(&builder, {false, true, true, false, false}, {});
1257 }
1258
XLA_TEST_F(ArrayElementwiseOpTest,CompareGeF32sTO)1259 XLA_TEST_F(ArrayElementwiseOpTest, CompareGeF32sTO) {
1260 SetFastMathDisabled(true);
1261 XlaBuilder builder(TestName());
1262 // For portability, need to represent NAN using the following call.
1263 // The C++ standard does not specify if quiet_NaN() sets the sign bit of
1264 // its result. The call to std::fabs will ensure that it is not set.
1265 auto nan = std::fabs(std::numeric_limits<float>::quiet_NaN());
1266 auto lhs =
1267 ConstantR1<float>(&builder, {-2.5f, 25.5f, 2.25f, nan, 6.0f, 6.0f});
1268 auto rhs = ConstantR1<float>(&builder, {10.0f, 5.0f, 1.0f, 10.0f, nan, -nan});
1269 GeTotalOrder(lhs, rhs);
1270
1271 ComputeAndCompareR1<bool>(&builder, {false, true, true, true, false, true},
1272 {});
1273 }
1274
XLA_TEST_F(ArrayElementwiseOpTest,CompareGtF32s)1275 XLA_TEST_F(ArrayElementwiseOpTest, CompareGtF32s) {
1276 SetFastMathDisabled(true);
1277 XlaBuilder builder(TestName());
1278 auto lhs = ConstantR1<float>(&builder, {-2.5f, 25.5f, 2.25f, NAN, 6.0f});
1279 auto rhs = ConstantR1<float>(&builder, {10.0f, 5.0f, 1.0f, 10.0f, NAN});
1280 Gt(lhs, rhs);
1281
1282 ComputeAndCompareR1<bool>(&builder, {false, true, true, false, false}, {});
1283 }
1284
XLA_TEST_F(ArrayElementwiseOpTest,CompareLeF32s)1285 XLA_TEST_F(ArrayElementwiseOpTest, CompareLeF32s) {
1286 SetFastMathDisabled(true);
1287 XlaBuilder builder(TestName());
1288 auto lhs = ConstantR1<float>(&builder, {-2.5f, 5.0f, 2.25f, NAN, 6.0f});
1289 auto rhs = ConstantR1<float>(&builder, {10.0f, 5.0f, 1.0f, 10.0f, NAN});
1290 Le(lhs, rhs);
1291
1292 ComputeAndCompareR1<bool>(&builder, {true, true, false, false, false}, {});
1293 }
1294
XLA_TEST_F(ArrayElementwiseOpTest,CompareLtF32s)1295 XLA_TEST_F(ArrayElementwiseOpTest, CompareLtF32s) {
1296 SetFastMathDisabled(true);
1297 XlaBuilder builder(TestName());
1298 auto lhs = ConstantR1<float>(&builder, {-2.5f, 25.5f, 2.25f, NAN, 6.0f});
1299 auto rhs = ConstantR1<float>(&builder, {10.0f, 5.0f, 1.0f, 10.0f, NAN});
1300 Lt(lhs, rhs);
1301
1302 ComputeAndCompareR1<bool>(&builder, {true, false, false, false, false}, {});
1303 }
1304
XLA_TEST_F(ArrayElementwiseOpTest,CompareEqS32s)1305 XLA_TEST_F(ArrayElementwiseOpTest, CompareEqS32s) {
1306 const int32_t min = std::numeric_limits<int32>::min();
1307 const int32_t max = std::numeric_limits<int32>::max();
1308 XlaBuilder builder(TestName());
1309 auto lhs =
1310 ConstantR1<int32>(&builder, {min, min, min, 0, 0, 0, max, max, max});
1311 auto rhs = ConstantR1<int32>(&builder, {min, 0, max, -1, 0, 1, min, 0, max});
1312 Eq(lhs, rhs);
1313
1314 ComputeAndCompareR1<bool>(
1315 &builder, {true, false, false, false, true, false, false, false, true},
1316 {});
1317 }
1318
XLA_TEST_F(ArrayElementwiseOpTest,CompareEqZeroElementS32s)1319 XLA_TEST_F(ArrayElementwiseOpTest, CompareEqZeroElementS32s) {
1320 XlaBuilder builder(TestName());
1321 auto lhs = ConstantR1<int32>(&builder, {});
1322 auto rhs = ConstantR1<int32>(&builder, {});
1323 Eq(lhs, rhs);
1324
1325 ComputeAndCompareR1<bool>(&builder, {}, {});
1326 }
1327
XLA_TEST_F(ArrayElementwiseOpTest,CompareEqC64s)1328 XLA_TEST_F(ArrayElementwiseOpTest, CompareEqC64s) {
1329 SetFastMathDisabled(true);
1330 XlaBuilder builder(TestName());
1331 auto lhs = ConstantR1<complex64>(&builder, {{-2.5f, 10.0f},
1332 {1.0f, 25.5f},
1333 {2.25f, -3.0f},
1334 {NAN, 0.0f},
1335 {1.0f, 6.0f}});
1336 auto rhs = ConstantR1<complex64>(&builder, {{0.0f, 10.0f},
1337 {1.0f, 5.0f},
1338 {2.25f, -3.0f},
1339 {10.0f, 0.0f},
1340 {1.0f, NAN}});
1341 Eq(lhs, rhs);
1342
1343 ComputeAndCompareR1<bool>(&builder, {false, false, true, false, false}, {});
1344 }
1345
XLA_TEST_F(ArrayElementwiseOpTest,CompareEqZeroElementC64s)1346 XLA_TEST_F(ArrayElementwiseOpTest, CompareEqZeroElementC64s) {
1347 XlaBuilder builder(TestName());
1348 auto lhs = ConstantR1<complex64>(&builder, {});
1349 auto rhs = ConstantR1<complex64>(&builder, {});
1350 Eq(lhs, rhs);
1351
1352 ComputeAndCompareR1<bool>(&builder, {}, {});
1353 }
1354
XLA_TEST_F(ArrayElementwiseOpTest,CompareNeC64s)1355 XLA_TEST_F(ArrayElementwiseOpTest, CompareNeC64s) {
1356 // Disable fast-math because we're operating on NaNs.
1357 SetFastMathDisabled(true);
1358
1359 XlaBuilder builder(TestName());
1360 auto lhs = ConstantR1<complex64>(&builder, {{-2.5f, 10.0f},
1361 {1.0f, 25.5f},
1362 {2.25f, -3.0f},
1363 {NAN, 0.0f},
1364 {1.0f, 6.0f}});
1365 auto rhs = ConstantR1<complex64>(&builder, {{0.0f, 10.0f},
1366 {1.0f, 5.0f},
1367 {2.25f, -3.0f},
1368 {10.0f, 0.0f},
1369 {1.0f, NAN}});
1370 Ne(lhs, rhs);
1371
1372 ComputeAndCompareR1<bool>(&builder, {true, true, false, true, true}, {});
1373 }
1374
XLA_TEST_F(ArrayElementwiseOpTest,CompareNeF32s)1375 XLA_TEST_F(ArrayElementwiseOpTest, CompareNeF32s) {
1376 // Disable fast-math because we're operating on NaNs.
1377 SetFastMathDisabled(true);
1378
1379 XlaBuilder builder(TestName());
1380 auto lhs = ConstantR1<float>(&builder, {-2.5f, 25.5f, 2.25f, NAN, 6.0f});
1381 auto rhs = ConstantR1<float>(&builder, {10.0f, 25.5f, 1.0f, 10.0f, NAN});
1382 Ne(lhs, rhs);
1383
1384 ComputeAndCompareR1<bool>(&builder, {true, false, true, true, true}, {});
1385 }
1386
XLA_TEST_F(ArrayElementwiseOpTest,CompareNeS32s)1387 XLA_TEST_F(ArrayElementwiseOpTest, CompareNeS32s) {
1388 const int32_t min = std::numeric_limits<int32>::min();
1389 const int32_t max = std::numeric_limits<int32>::max();
1390 XlaBuilder builder(TestName());
1391 auto lhs =
1392 ConstantR1<int32>(&builder, {min, min, min, 0, 0, 0, max, max, max});
1393 auto rhs = ConstantR1<int32>(&builder, {min, 0, max, -1, 0, 1, min, 0, max});
1394 Ne(lhs, rhs);
1395
1396 ComputeAndCompareR1<bool>(
1397 &builder, {false, true, true, true, false, true, true, true, false}, {});
1398 }
1399
XLA_TEST_F(ArrayElementwiseOpTest,CompareGeS32s)1400 XLA_TEST_F(ArrayElementwiseOpTest, CompareGeS32s) {
1401 const int32_t min = std::numeric_limits<int32>::min();
1402 const int32_t max = std::numeric_limits<int32>::max();
1403 XlaBuilder builder(TestName());
1404 auto lhs =
1405 ConstantR1<int32>(&builder, {min, min, min, 0, 0, 0, max, max, max});
1406 auto rhs = ConstantR1<int32>(&builder, {min, 0, max, -1, 0, 1, min, 0, max});
1407 Ge(lhs, rhs);
1408
1409 ComputeAndCompareR1<bool>(
1410 &builder, {true, false, false, true, true, false, true, true, true}, {});
1411 }
1412
XLA_TEST_F(ArrayElementwiseOpTest,CompareGtS32s)1413 XLA_TEST_F(ArrayElementwiseOpTest, CompareGtS32s) {
1414 const int32_t min = std::numeric_limits<int32>::min();
1415 const int32_t max = std::numeric_limits<int32>::max();
1416 XlaBuilder builder(TestName());
1417 auto lhs =
1418 ConstantR1<int32>(&builder, {min, min, min, 0, 0, 0, max, max, max});
1419 auto rhs = ConstantR1<int32>(&builder, {min, 0, max, -1, 0, 1, min, 0, max});
1420 Gt(lhs, rhs);
1421
1422 ComputeAndCompareR1<bool>(
1423 &builder, {false, false, false, true, false, false, true, true, false},
1424 {});
1425 }
1426
XLA_TEST_F(ArrayElementwiseOpTest,CompareLeS32s)1427 XLA_TEST_F(ArrayElementwiseOpTest, CompareLeS32s) {
1428 const int32_t min = std::numeric_limits<int32>::min();
1429 const int32_t max = std::numeric_limits<int32>::max();
1430 XlaBuilder builder(TestName());
1431 auto lhs =
1432 ConstantR1<int32>(&builder, {min, min, min, 0, 0, 0, max, max, max});
1433 auto rhs = ConstantR1<int32>(&builder, {min, 0, max, -1, 0, 1, min, 0, max});
1434 Le(lhs, rhs);
1435
1436 ComputeAndCompareR1<bool>(
1437 &builder, {true, true, true, false, true, true, false, false, true}, {});
1438 }
1439
XLA_TEST_F(ArrayElementwiseOpTest,CompareLtS32s)1440 XLA_TEST_F(ArrayElementwiseOpTest, CompareLtS32s) {
1441 const int32_t min = std::numeric_limits<int32>::min();
1442 const int32_t max = std::numeric_limits<int32>::max();
1443 XlaBuilder builder(TestName());
1444 auto lhs =
1445 ConstantR1<int32>(&builder, {min, min, min, 0, 0, 0, max, max, max});
1446 auto rhs = ConstantR1<int32>(&builder, {min, 0, max, -1, 0, 1, min, 0, max});
1447 Lt(lhs, rhs);
1448
1449 ComputeAndCompareR1<bool>(
1450 &builder, {false, true, true, false, false, true, false, false, false},
1451 {});
1452 }
1453
XLA_TEST_F(ArrayElementwiseOpTest,CompareEqU32s)1454 XLA_TEST_F(ArrayElementwiseOpTest, CompareEqU32s) {
1455 const uint32 max = std::numeric_limits<uint32>::max();
1456 XlaBuilder builder(TestName());
1457 auto lhs = ConstantR1<uint32>(&builder, {0, 0, 0, 5, 5, 5, max, max, max});
1458 auto rhs = ConstantR1<uint32>(&builder, {0, 1, max, 4, 5, 6, 0, 1, max});
1459 Eq(lhs, rhs);
1460
1461 ComputeAndCompareR1<bool>(
1462 &builder, {true, false, false, false, true, false, false, false, true},
1463 {});
1464 }
1465
XLA_TEST_F(ArrayElementwiseOpTest,CompareNeU32s)1466 XLA_TEST_F(ArrayElementwiseOpTest, CompareNeU32s) {
1467 const uint32 max = std::numeric_limits<uint32>::max();
1468 XlaBuilder builder(TestName());
1469 auto lhs = ConstantR1<uint32>(&builder, {0, 0, 0, 5, 5, 5, max, max, max});
1470 auto rhs = ConstantR1<uint32>(&builder, {0, 1, max, 4, 5, 6, 0, 1, max});
1471 Ne(lhs, rhs);
1472
1473 ComputeAndCompareR1<bool>(
1474 &builder, {false, true, true, true, false, true, true, true, false}, {});
1475 }
1476
XLA_TEST_F(ArrayElementwiseOpTest,CompareGeU32s)1477 XLA_TEST_F(ArrayElementwiseOpTest, CompareGeU32s) {
1478 const uint32 max = std::numeric_limits<uint32>::max();
1479 XlaBuilder builder(TestName());
1480 auto lhs = ConstantR1<uint32>(&builder, {0, 0, 0, 5, 5, 5, max, max, max});
1481 auto rhs = ConstantR1<uint32>(&builder, {0, 1, max, 4, 5, 6, 0, 1, max});
1482 Ge(lhs, rhs);
1483
1484 ComputeAndCompareR1<bool>(
1485 &builder, {true, false, false, true, true, false, true, true, true}, {});
1486 }
1487
XLA_TEST_F(ArrayElementwiseOpTest,CompareGtU32s)1488 XLA_TEST_F(ArrayElementwiseOpTest, CompareGtU32s) {
1489 const uint32 max = std::numeric_limits<uint32>::max();
1490 XlaBuilder builder(TestName());
1491 auto lhs = ConstantR1<uint32>(&builder, {0, 0, 0, 5, 5, 5, max, max, max});
1492 auto rhs = ConstantR1<uint32>(&builder, {0, 1, max, 4, 5, 6, 0, 1, max});
1493 Gt(lhs, rhs);
1494
1495 ComputeAndCompareR1<bool>(
1496 &builder, {false, false, false, true, false, false, true, true, false},
1497 {});
1498 }
1499
XLA_TEST_F(ArrayElementwiseOpTest,CompareLeU32s)1500 XLA_TEST_F(ArrayElementwiseOpTest, CompareLeU32s) {
1501 const uint32 max = std::numeric_limits<uint32>::max();
1502 XlaBuilder builder(TestName());
1503 auto lhs = ConstantR1<uint32>(&builder, {0, 0, 0, 5, 5, 5, max, max, max});
1504 auto rhs = ConstantR1<uint32>(&builder, {0, 1, max, 4, 5, 6, 0, 1, max});
1505 Le(lhs, rhs);
1506
1507 ComputeAndCompareR1<bool>(
1508 &builder, {true, true, true, false, true, true, false, false, true}, {});
1509 }
1510
XLA_TEST_F(ArrayElementwiseOpTest,CompareLtU32s)1511 XLA_TEST_F(ArrayElementwiseOpTest, CompareLtU32s) {
1512 const uint32 max = std::numeric_limits<uint32>::max();
1513 XlaBuilder builder(TestName());
1514 auto lhs = ConstantR1<uint32>(&builder, {0, 0, 0, 5, 5, 5, max, max, max});
1515 auto rhs = ConstantR1<uint32>(&builder, {0, 1, max, 4, 5, 6, 0, 1, max});
1516 Lt(lhs, rhs);
1517
1518 ComputeAndCompareR1<bool>(
1519 &builder, {false, true, true, false, false, true, false, false, false},
1520 {});
1521 }
1522
XLA_TEST_F(ArrayElementwiseOpTest,PowF32s)1523 XLA_TEST_F(ArrayElementwiseOpTest, PowF32s) {
1524 SetFastMathDisabled(true);
1525 XlaBuilder builder(TestName());
1526 auto lhs = ConstantR1<float>(
1527 &builder, {0.0f, 4.0f, 2.0f, 2.0f, NAN, 6.0f, -2.0f, -2.0f});
1528 auto rhs = ConstantR1<float>(
1529 &builder, {0.0f, 2.0f, -2.0f, 3.0f, 10.0f, NAN, 3.0f, 4.0f});
1530 Pow(lhs, rhs);
1531
1532 ComputeAndCompareR1<float>(&builder,
1533 {1.0f, 16.0f, 0.25f, 8.0f, NAN, NAN, -8.0f, 16.0f},
1534 {}, error_spec_);
1535 }
1536
XLA_TEST_F(ArrayElementwiseOpTest,PowNonIntegerF32s)1537 XLA_TEST_F(ArrayElementwiseOpTest, PowNonIntegerF32s) {
1538 SetFastMathDisabled(true);
1539 XlaBuilder builder(TestName());
1540 auto lhs = ConstantR1<float>(&builder, {-2.0f, -0.6f, -0.6f, 0.0f});
1541 auto rhs = ConstantR1<float>(&builder, {0.5f, 0.6f, -0.6f, -0.6f});
1542 Pow(lhs, rhs);
1543
1544 ComputeAndCompareR1<float>(&builder, {NAN, NAN, NAN, INFINITY}, {},
1545 error_spec_);
1546 }
1547
XLA_TEST_F(ArrayElementwiseOpTest,PowC64s)1548 XLA_TEST_F(ArrayElementwiseOpTest, PowC64s) {
1549 SetFastMathDisabled(true);
1550 XlaBuilder builder(TestName());
1551 auto lhs =
1552 ConstantR1<complex64>(&builder, {-2.0f, -0.6f, -0.6f, 0.0f, 0.0f, 0.0f});
1553 auto rhs =
1554 ConstantR1<complex64>(&builder, {0.5f, 0.6f, -0.6f, 0.5f, 0.6f, 0.0f});
1555 Pow(lhs, rhs);
1556
1557 ComputeAndCompareR1<complex64>(&builder,
1558 {
1559 {0, 1.41421356},
1560 {-2.27443288e-01, 0.69999846},
1561 {-4.19847531e-01, -1.29215783},
1562 {0, 0},
1563 {0, 0},
1564 {1, 0},
1565 },
1566 {}, error_spec_);
1567 }
1568
XLA_TEST_F(ArrayElementwiseOpTest,PowZeroElementF32s)1569 XLA_TEST_F(ArrayElementwiseOpTest, PowZeroElementF32s) {
1570 XlaBuilder builder(TestName());
1571 auto lhs = ConstantR1<float>(&builder, {});
1572 auto rhs = ConstantR1<float>(&builder, {});
1573 Pow(lhs, rhs);
1574
1575 ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
1576 }
1577
1578 // Some Pow cases that can be implemented more efficiently.
XLA_TEST_F(ArrayElementwiseOpTest,PowSpecialF32)1579 XLA_TEST_F(ArrayElementwiseOpTest, PowSpecialF32) {
1580 XlaBuilder b(TestName());
1581
1582 std::vector<float> values = {1.0f, 2.0f, 3.2f, -4.0f};
1583 std::vector<float> exponents = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f};
1584
1585 Literal param_literal = LiteralUtil::CreateR1<float>(values);
1586 std::unique_ptr<GlobalData> param_data =
1587 client_->TransferToServer(param_literal).ConsumeValueOrDie();
1588
1589 auto sum = ConstantR0<float>(&b, 0.0f);
1590 auto param = Parameter(&b, 0, param_literal.shape(), "param");
1591 for (float exponent : exponents) {
1592 sum = Add(sum, Pow(param, ConstantR0<float>(&b, exponent)));
1593 }
1594
1595 std::vector<float> expected;
1596 for (auto value : values) {
1597 float sum = 0.0f;
1598 for (float exponent : exponents) {
1599 sum += std::pow(value, exponent);
1600 }
1601 expected.push_back(sum);
1602 }
1603
1604 ComputeAndCompareR1<float>(&b, expected, {param_data.get()}, error_spec_);
1605 }
1606
XLA_TEST_F(ArrayElementwiseOpTest,PowOfExpF32)1607 XLA_TEST_F(ArrayElementwiseOpTest, PowOfExpF32) {
1608 XlaBuilder b(TestName());
1609
1610 std::vector<float> values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.0f, 5.7f};
1611 std::vector<float> values1 = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f};
1612
1613 Literal literal0 = LiteralUtil::CreateR1<float>(values0);
1614 std::unique_ptr<GlobalData> data0 =
1615 client_->TransferToServer(literal0).ConsumeValueOrDie();
1616 Literal literal1 = LiteralUtil::CreateR1<float>(values1);
1617 std::unique_ptr<GlobalData> data1 =
1618 client_->TransferToServer(literal1).ConsumeValueOrDie();
1619 auto param0 = Parameter(&b, 0, literal0.shape(), "param0");
1620 auto param1 = Parameter(&b, 1, literal1.shape(), "param1");
1621 Pow(Exp(param0), param1);
1622
1623 std::vector<float> expected(values0.size());
1624 for (int64_t i = 0; i < values0.size(); ++i) {
1625 expected[i] = std::pow(std::exp(values0[i]), values1[i]);
1626 }
1627
1628 ComputeAndCompareR1<float>(&b, expected, {data0.get(), data1.get()},
1629 error_spec_);
1630 }
1631
XLA_TEST_F(ArrayElementwiseOpTest,LogOfPowerF32)1632 XLA_TEST_F(ArrayElementwiseOpTest, LogOfPowerF32) {
1633 XlaBuilder b(TestName());
1634
1635 std::vector<float> values0 = {1.0f, -10.0f, -2.0f, 2.0f,
1636 3.2f, 4.0f, 0.5f, 5.7f};
1637 std::vector<float> values1 = {0.0f, 10.0f, -4.0f, 1.0f,
1638 2.0f, 0.5f, -1.0f, -0.5f};
1639
1640 Literal literal0 = LiteralUtil::CreateR1<float>(values0);
1641 std::unique_ptr<GlobalData> data0 =
1642 client_->TransferToServer(literal0).ConsumeValueOrDie();
1643 Literal literal1 = LiteralUtil::CreateR1<float>(values1);
1644 std::unique_ptr<GlobalData> data1 =
1645 client_->TransferToServer(literal1).ConsumeValueOrDie();
1646 auto param0 = Parameter(&b, 0, literal0.shape(), "param0");
1647 auto param1 = Parameter(&b, 1, literal1.shape(), "param1");
1648 Log(Pow(param0, param1));
1649
1650 std::vector<float> expected(values0.size());
1651 for (int64_t i = 0; i < values0.size(); ++i) {
1652 expected[i] = std::log(std::pow(values0[i], values1[i]));
1653 }
1654
1655 ComputeAndCompareR1<float>(&b, expected, {data0.get(), data1.get()},
1656 error_spec_);
1657 }
1658
XLA_TEST_F(ArrayElementwiseOpTest,MulOfExpF32)1659 XLA_TEST_F(ArrayElementwiseOpTest, MulOfExpF32) {
1660 XlaBuilder b(TestName());
1661
1662 std::vector<float> values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.0f, 5.7f};
1663 std::vector<float> values1 = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f};
1664
1665 Literal literal0 = LiteralUtil::CreateR1<float>(values0);
1666 std::unique_ptr<GlobalData> data0 =
1667 client_->TransferToServer(literal0).ConsumeValueOrDie();
1668 Literal literal1 = LiteralUtil::CreateR1<float>(values1);
1669 std::unique_ptr<GlobalData> data1 =
1670 client_->TransferToServer(literal1).ConsumeValueOrDie();
1671 auto param0 = Parameter(&b, 0, literal0.shape(), "param0");
1672 auto param1 = Parameter(&b, 1, literal1.shape(), "param1");
1673 Mul(Exp(param0), Exp(param1));
1674
1675 std::vector<float> expected(values0.size());
1676 for (int64_t i = 0; i < values0.size(); ++i) {
1677 expected[i] = std::exp(values0[i]) * std::exp(values1[i]);
1678 }
1679
1680 ComputeAndCompareR1<float>(&b, expected, {data0.get(), data1.get()},
1681 error_spec_);
1682 }
1683
XLA_TEST_F(ArrayElementwiseOpTest,DivOfExpF32)1684 XLA_TEST_F(ArrayElementwiseOpTest, DivOfExpF32) {
1685 XlaBuilder b(TestName());
1686
1687 std::vector<float> values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.0f, 5.7f};
1688 std::vector<float> values1 = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f};
1689
1690 Literal literal0 = LiteralUtil::CreateR1<float>(values0);
1691 std::unique_ptr<GlobalData> data0 =
1692 client_->TransferToServer(literal0).ConsumeValueOrDie();
1693 Literal literal1 = LiteralUtil::CreateR1<float>(values1);
1694 std::unique_ptr<GlobalData> data1 =
1695 client_->TransferToServer(literal1).ConsumeValueOrDie();
1696 auto param0 = Parameter(&b, 0, literal0.shape(), "param0");
1697 auto param1 = Parameter(&b, 1, literal1.shape(), "param1");
1698 Div(param0, Exp(param1));
1699
1700 std::vector<float> expected(values0.size());
1701 for (int64_t i = 0; i < values0.size(); ++i) {
1702 expected[i] = values0[i] / std::exp(values1[i]);
1703 }
1704
1705 ComputeAndCompareR1<float>(&b, expected, {data0.get(), data1.get()},
1706 error_spec_);
1707 }
1708
XLA_TEST_F(ArrayElementwiseOpTest,Div3_lhs_F32)1709 XLA_TEST_F(ArrayElementwiseOpTest, Div3_lhs_F32) {
1710 XlaBuilder b(TestName());
1711
1712 std::vector<float> values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.45f, 5.7f};
1713 std::vector<float> values1 = {0.1f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f};
1714 std::vector<float> values2 = {0.1f, 1.1f, 6.9f, 12.5f, -15.0f, -0.5f};
1715
1716 Literal literal0 = LiteralUtil::CreateR1<float>(values0);
1717 std::unique_ptr<GlobalData> data0 =
1718 client_->TransferToServer(literal0).ConsumeValueOrDie();
1719
1720 Literal literal1 = LiteralUtil::CreateR1<float>(values1);
1721 std::unique_ptr<GlobalData> data1 =
1722 client_->TransferToServer(literal1).ConsumeValueOrDie();
1723
1724 Literal literal2 = LiteralUtil::CreateR1<float>(values2);
1725 std::unique_ptr<GlobalData> data2 =
1726 client_->TransferToServer(literal2).ConsumeValueOrDie();
1727 auto param0 = Parameter(&b, 0, literal0.shape(), "param0");
1728 auto param1 = Parameter(&b, 1, literal1.shape(), "param1");
1729 auto param2 = Parameter(&b, 2, literal2.shape(), "param2");
1730 Div(Div(param0, param1), param2);
1731
1732 std::vector<float> expected(values0.size());
1733 for (int64_t i = 0; i < values0.size(); ++i) {
1734 expected[i] = (values0[i] / values1[i]) / values2[i];
1735 }
1736
1737 ComputeAndCompareR1<float>(
1738 &b, expected, {data0.get(), data1.get(), data2.get()}, error_spec_);
1739 }
1740
XLA_TEST_F(ArrayElementwiseOpTest,Div3_rhs_F32)1741 XLA_TEST_F(ArrayElementwiseOpTest, Div3_rhs_F32) {
1742 XlaBuilder b(TestName());
1743
1744 std::vector<float> values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.45f, 5.7f};
1745 std::vector<float> values1 = {0.1f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f};
1746 std::vector<float> values2 = {0.1f, 1.1f, 6.9f, 12.5f, -15.0f, -0.5f};
1747
1748 Literal literal0 = LiteralUtil::CreateR1<float>(values0);
1749 std::unique_ptr<GlobalData> data0 =
1750 client_->TransferToServer(literal0).ConsumeValueOrDie();
1751
1752 Literal literal1 = LiteralUtil::CreateR1<float>(values1);
1753 std::unique_ptr<GlobalData> data1 =
1754 client_->TransferToServer(literal1).ConsumeValueOrDie();
1755
1756 Literal literal2 = LiteralUtil::CreateR1<float>(values2);
1757 std::unique_ptr<GlobalData> data2 =
1758 client_->TransferToServer(literal2).ConsumeValueOrDie();
1759
1760 auto param0 = Parameter(&b, 0, literal0.shape(), "param0");
1761 auto param1 = Parameter(&b, 1, literal1.shape(), "param1");
1762 auto param2 = Parameter(&b, 2, literal2.shape(), "param2");
1763 Div(param0, Div(param1, param2));
1764
1765 std::vector<float> expected(values0.size());
1766 for (int64_t i = 0; i < values0.size(); ++i) {
1767 expected[i] = values0[i] / (values1[i] / values2[i]);
1768 }
1769
1770 ComputeAndCompareR1<float>(
1771 &b, expected, {data0.get(), data1.get(), data2.get()}, error_spec_);
1772 }
1773
XLA_TEST_F(ArrayElementwiseOpTest,DivOfPowerF32)1774 XLA_TEST_F(ArrayElementwiseOpTest, DivOfPowerF32) {
1775 XlaBuilder b(TestName());
1776
1777 std::vector<float> values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.45f, 5.7f};
1778 std::vector<float> values1 = {0.1f, 1.0f, 2.0f, 0.5f, 1.0f, 0.5f};
1779 std::vector<float> values2 = {0.1f, 1.1f, 6.9f, 9.5f, -11.0f, -0.5f};
1780
1781 Literal literal0 = LiteralUtil::CreateR1<float>(values0);
1782 std::unique_ptr<GlobalData> data0 =
1783 client_->TransferToServer(literal0).ConsumeValueOrDie();
1784
1785 Literal literal1 = LiteralUtil::CreateR1<float>(values1);
1786 std::unique_ptr<GlobalData> data1 =
1787 client_->TransferToServer(literal1).ConsumeValueOrDie();
1788
1789 Literal literal2 = LiteralUtil::CreateR1<float>(values2);
1790 std::unique_ptr<GlobalData> data2 =
1791 client_->TransferToServer(literal2).ConsumeValueOrDie();
1792
1793 auto param0 = Parameter(&b, 0, literal0.shape(), "param0");
1794 auto param1 = Parameter(&b, 1, literal1.shape(), "param1");
1795 auto param2 = Parameter(&b, 2, literal2.shape(), "param2");
1796 Div(param0, Pow(param1, param2));
1797
1798 std::vector<float> expected(values0.size());
1799 for (int64_t i = 0; i < values0.size(); ++i) {
1800 expected[i] = values0[i] / std::pow(values1[i], values2[i]);
1801 }
1802
1803 ComputeAndCompareR1<float>(
1804 &b, expected, {data0.get(), data1.get(), data2.get()}, error_spec_);
1805 }
1806
XLA_TEST_F(ArrayElementwiseOpTest,Div4F32)1807 XLA_TEST_F(ArrayElementwiseOpTest, Div4F32) {
1808 XlaBuilder b(TestName());
1809
1810 std::vector<float> values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.45f, 5.7f};
1811 std::vector<float> values1 = {0.1f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f};
1812 std::vector<float> values2 = {0.1f, 1.1f, 6.9f, 12.5f, -15.0f, -0.5f};
1813 std::vector<float> values3 = {2.1f, 3.1f, 9.9f, -4.5f, -11.0f, -21.5f};
1814
1815 Literal literal0 = LiteralUtil::CreateR1<float>(values0);
1816 std::unique_ptr<GlobalData> data0 =
1817 client_->TransferToServer(literal0).ConsumeValueOrDie();
1818
1819 Literal literal1 = LiteralUtil::CreateR1<float>(values1);
1820 std::unique_ptr<GlobalData> data1 =
1821 client_->TransferToServer(literal1).ConsumeValueOrDie();
1822
1823 Literal literal2 = LiteralUtil::CreateR1<float>(values2);
1824 std::unique_ptr<GlobalData> data2 =
1825 client_->TransferToServer(literal2).ConsumeValueOrDie();
1826
1827 Literal literal3 = LiteralUtil::CreateR1<float>(values3);
1828 std::unique_ptr<GlobalData> data3 =
1829 client_->TransferToServer(literal3).ConsumeValueOrDie();
1830
1831 auto param0 = Parameter(&b, 0, literal0.shape(), "param0");
1832 auto param1 = Parameter(&b, 1, literal1.shape(), "param1");
1833 auto param2 = Parameter(&b, 2, literal2.shape(), "param2");
1834 auto param3 = Parameter(&b, 3, literal3.shape(), "param2");
1835 Div(Div(param0, param1), Div(param2, param3));
1836
1837 std::vector<float> expected(values0.size());
1838 for (int64_t i = 0; i < values0.size(); ++i) {
1839 expected[i] = (values0[i] / values1[i]) / (values2[i] / values3[i]);
1840 }
1841
1842 ComputeAndCompareR1<float>(
1843 &b, expected, {data0.get(), data1.get(), data2.get(), data3.get()},
1844 error_spec_);
1845 }
1846
TEST_P(ArrayElementwiseOpTestParamCount,SquareManyValues)1847 TEST_P(ArrayElementwiseOpTestParamCount, SquareManyValues) {
1848 const int count = GetParam();
1849 XlaBuilder builder(TestName());
1850 std::vector<float> values;
1851 values.reserve(count);
1852 for (int i = 0; i < count; ++i) {
1853 values.push_back(i / static_cast<float>(count));
1854 }
1855 auto x = ConstantR1<float>(&builder, values);
1856 Pow(x, ConstantR0<float>(&builder, 2.0f));
1857
1858 std::vector<float> expected;
1859 expected.reserve(values.size());
1860 for (float value : values) {
1861 expected.push_back(value * value);
1862 }
1863
1864 ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_);
1865 }
1866
XLA_TEST_F(ArrayElementwiseOpTest,SquareIn4D)1867 XLA_TEST_F(ArrayElementwiseOpTest, SquareIn4D) {
1868 XlaBuilder builder(TestName());
1869 Array4D<float> values(2, 2, 2, 2);
1870
1871 std::vector<float> values_vector;
1872 std::vector<float> expected_vector;
1873 for (int i = 0; i < values.num_elements(); ++i) {
1874 values_vector.push_back(static_cast<float>(i) / values.num_elements());
1875 expected_vector.push_back(values_vector.back() * values_vector.back());
1876 }
1877 values.SetValues(values_vector);
1878
1879 Array4D<float> expected(2, 2, 2, 2, expected_vector);
1880
1881 auto x = ConstantR4FromArray4D<float>(&builder, values);
1882 Pow(x, ConstantR0<float>(&builder, 2.0f));
1883
1884 ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
1885 }
1886
XLA_TEST_F(ArrayElementwiseOpTest,SquareIn4DZeroElements)1887 XLA_TEST_F(ArrayElementwiseOpTest, SquareIn4DZeroElements) {
1888 XlaBuilder builder(TestName());
1889 Array4D<float> values(2, 2, 0, 2);
1890 Array4D<float> expected(2, 2, 0, 2);
1891
1892 auto x = ConstantR4FromArray4D<float>(&builder, values);
1893 Pow(x, ConstantR0<float>(&builder, 2.0f));
1894
1895 ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
1896 }
1897
XLA_TEST_F(ArrayElementwiseOpTest,MinF32s)1898 XLA_TEST_F(ArrayElementwiseOpTest, MinF32s) {
1899 XlaBuilder builder(TestName());
1900 SetFastMathDisabled(true);
1901 auto lhs = ConstantR1<float>(&builder, {1.0f, 1.0f, 2.25f, NAN, 6.0f});
1902 auto rhs = ConstantR1<float>(&builder, {2.0f, -5.0f, 1.0f, 10.0f, NAN});
1903 Min(lhs, rhs);
1904
1905 ComputeAndCompareR1<float>(&builder, {1.0f, -5.0f, 1.0f, NAN, NAN}, {},
1906 error_spec_);
1907 }
1908
XLA_TEST_F(ArrayElementwiseOpTest,MinZeroElementF32s)1909 XLA_TEST_F(ArrayElementwiseOpTest, MinZeroElementF32s) {
1910 XlaBuilder builder(TestName());
1911 auto lhs = ConstantR1<float>(&builder, {});
1912 auto rhs = ConstantR1<float>(&builder, {});
1913 Min(lhs, rhs);
1914 ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
1915 }
1916
XLA_TEST_F(ArrayElementwiseOpTest,MinF64s)1917 XLA_TEST_F(ArrayElementwiseOpTest, MinF64s) {
1918 XlaBuilder builder(TestName());
1919 SetFastMathDisabled(true);
1920 auto lhs = ConstantR1<double>(&builder, {1.0, 1.0, 2.25, NAN, 6.0});
1921 auto rhs = ConstantR1<double>(&builder, {2.0, -5.0, 1.0, 10.0, NAN});
1922 Min(lhs, rhs);
1923
1924 ComputeAndCompareR1<double>(&builder, {1.0, -5.0, 1.0, NAN, NAN}, {},
1925 strict_error_spec_);
1926 }
1927
XLA_TEST_F(ArrayElementwiseOpTest,MaxF32s)1928 XLA_TEST_F(ArrayElementwiseOpTest, MaxF32s) {
1929 XlaBuilder builder(TestName());
1930 SetFastMathDisabled(true);
1931 auto lhs = ConstantR1<float>(&builder, {1.0f, 1.0f, 2.25f, NAN, 6.0f});
1932 auto rhs = ConstantR1<float>(&builder, {2.0f, -5.0f, 1.0f, 10.0f, NAN});
1933 Max(lhs, rhs);
1934
1935 ComputeAndCompareR1<float>(&builder, {2.0f, 1.0f, 2.25f, NAN, NAN}, {},
1936 error_spec_);
1937 }
1938
XLA_TEST_F(ArrayElementwiseOpTest,MaxZeroElementF32s)1939 XLA_TEST_F(ArrayElementwiseOpTest, MaxZeroElementF32s) {
1940 XlaBuilder builder(TestName());
1941 auto lhs = ConstantR1<float>(&builder, {});
1942 auto rhs = ConstantR1<float>(&builder, {});
1943 Max(lhs, rhs);
1944 ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
1945 }
1946
XLA_TEST_F(ArrayElementwiseOpTest,MaxF64s)1947 XLA_TEST_F(ArrayElementwiseOpTest, MaxF64s) {
1948 XlaBuilder builder(TestName());
1949 SetFastMathDisabled(true);
1950 auto lhs = ConstantR1<double>(&builder, {1.0, 1.0, 2.25, NAN, 6.0});
1951 auto rhs = ConstantR1<double>(&builder, {2.0, -5.0, 1.0, 10.0, NAN});
1952 Max(lhs, rhs);
1953
1954 ComputeAndCompareR1<double>(&builder, {2.0, 1.0, 2.25, NAN, NAN}, {},
1955 strict_error_spec_);
1956 }
1957
XLA_TEST_F(ArrayElementwiseOpTest,MaxS32s)1958 XLA_TEST_F(ArrayElementwiseOpTest, MaxS32s) {
1959 const int32_t min = std::numeric_limits<int32>::min();
1960 const int32_t max = std::numeric_limits<int32>::max();
1961 XlaBuilder builder(TestName());
1962 auto x = ConstantR1<int32>(
1963 &builder, {min, min, min, -1, -1, 0, 0, 0, 1, 1, max, max, max});
1964 auto y = ConstantR1<int32>(
1965 &builder, {min, max, 0, -10, 0, -1, 0, 1, 0, 10, 0, max, min});
1966 Max(x, y);
1967
1968 std::vector<int32> expected = {min, max, 0, -1, 0, 0, 0,
1969 1, 1, 10, max, max, max};
1970 ComputeAndCompareR1<int32>(&builder, expected, {});
1971 }
1972
XLA_TEST_F(ArrayElementwiseOpTest,MinS32s)1973 XLA_TEST_F(ArrayElementwiseOpTest, MinS32s) {
1974 const int32_t min = std::numeric_limits<int32>::min();
1975 const int32_t max = std::numeric_limits<int32>::max();
1976 XlaBuilder builder(TestName());
1977 auto x = ConstantR1<int32>(
1978 &builder, {min, min, min, -1, -1, 0, 0, 0, 1, 1, max, max, max});
1979 auto y = ConstantR1<int32>(
1980 &builder, {min, max, 0, -10, 0, -1, 0, 1, 0, 10, 0, max, min});
1981 Min(x, y);
1982
1983 std::vector<int32> expected = {min, min, min, -10, -1, -1, 0,
1984 0, 0, 1, 0, max, min};
1985 ComputeAndCompareR1<int32>(&builder, expected, {});
1986 }
1987
XLA_TEST_F(ArrayElementwiseOpTest,MaxU32s)1988 XLA_TEST_F(ArrayElementwiseOpTest, MaxU32s) {
1989 const uint32 max = std::numeric_limits<uint32>::max();
1990 XlaBuilder builder(TestName());
1991 auto x = ConstantR1<uint32>(&builder, {0, 0, 1, 1, 1, max, max, max});
1992 auto y = ConstantR1<uint32>(&builder, {0, 1, 0, 1, 10, 0, 234234, max});
1993 Max(x, y);
1994
1995 std::vector<uint32> expected = {0, 1, 1, 1, 10, max, max, max};
1996 ComputeAndCompareR1<uint32>(&builder, expected, {});
1997 }
1998
XLA_TEST_F(ArrayElementwiseOpTest,MinU32s)1999 XLA_TEST_F(ArrayElementwiseOpTest, MinU32s) {
2000 const uint32 max = std::numeric_limits<uint32>::max();
2001 XlaBuilder builder(TestName());
2002 auto x = ConstantR1<uint32>(&builder, {0, 0, 1, 1, 1, max, max, max});
2003 auto y = ConstantR1<uint32>(&builder, {0, 1, 0, 1, 10, 0, 234234, max});
2004 Min(x, y);
2005
2006 std::vector<uint32> expected = {0, 0, 0, 1, 1, 0, 234234, max};
2007 ComputeAndCompareR1<uint32>(&builder, expected, {});
2008 }
2009
XLA_TEST_F(ArrayElementwiseOpTest,MaxTenF32s)2010 XLA_TEST_F(ArrayElementwiseOpTest, MaxTenF32s) {
2011 XlaBuilder builder(TestName());
2012 auto x = ConstantR1<float>(
2013 &builder, {-0.0, 1.0, 2.0, -3.0, -4.0, 5.0, 6.0, -7.0, -8.0, 9.0});
2014 auto y = ConstantR1<float>(
2015 &builder, {-0.0, -1.0, -2.0, 3.0, 4.0, -5.0, -6.0, 7.0, 8.0, -9.0});
2016 Max(x, y);
2017
2018 std::vector<float> expected = {-0.0, 1.0, 2.0, 3.0, 4.0,
2019 5.0, 6.0, 7.0, 8.0, 9.0};
2020 ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_);
2021 }
2022
XLA_TEST_F(ArrayElementwiseOpTest,MaxR1S1AndR1S0F32s)2023 XLA_TEST_F(ArrayElementwiseOpTest, MaxR1S1AndR1S0F32s) {
2024 XlaBuilder builder(TestName());
2025 auto u = ConstantR1<float>(&builder, {3.5});
2026 auto v = ConstantR1<float>(&builder, {});
2027 Max(u, v);
2028
2029 ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
2030 }
2031
XLA_TEST_F(ArrayElementwiseOpTest,MaxR1S0AndR2S0x2F32s)2032 XLA_TEST_F(ArrayElementwiseOpTest, MaxR1S0AndR2S0x2F32s) {
2033 for (int broadcast_dim : {0, 1}) {
2034 XlaBuilder builder(TestName());
2035 auto u = ConstantR1<float>(&builder, {3.5});
2036 auto v = ConstantR2FromArray2D<float>(&builder, Array2D<float>(0, 2));
2037 Max(u, v, /*broadcast_dimensions=*/{broadcast_dim});
2038
2039 ComputeAndCompareR2<float>(&builder, Array2D<float>(0, 2), {}, error_spec_);
2040 }
2041 }
2042
XLA_TEST_F(ArrayElementwiseOpTest,Max1DAnd2DF32s)2043 XLA_TEST_F(ArrayElementwiseOpTest, Max1DAnd2DF32s) {
2044 XlaBuilder builder(TestName());
2045 auto v = ConstantR1<float>(&builder, {2.0f, 3.0f, 4.0f});
2046 auto m = ConstantR2<float>(&builder,
2047 {{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}});
2048 Max(v, m, /*broadcast_dimensions=*/{1});
2049
2050 Array2D<float> expected({{2.0f, 3.14f, 4.0f}, {2.25f, 3.0f, 4.0f}});
2051 ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
2052 }
2053
XLA_TEST_F(ArrayElementwiseOpTest,Max1DAnd2DZeroElementF32s)2054 XLA_TEST_F(ArrayElementwiseOpTest, Max1DAnd2DZeroElementF32s) {
2055 XlaBuilder builder(TestName());
2056 auto v = ConstantR1<float>(&builder, {});
2057 auto m = ConstantR2<float>(&builder, {{}, {}});
2058 Max(v, m, /*broadcast_dimensions=*/{1});
2059
2060 Array2D<float> expected({{}, {}});
2061 ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
2062 }
2063
XLA_TEST_F(ArrayElementwiseOpTest,Max3DAndScalarS32s)2064 XLA_TEST_F(ArrayElementwiseOpTest, Max3DAndScalarS32s) {
2065 XlaBuilder builder(TestName());
2066 auto scalar = ConstantR0<int32>(&builder, 2);
2067 Array3D<int32> a_3d({{{3, 9, -1}, {2, -10, 3}}, {{-2, 2, 8}, {12, 10, 4}}});
2068 auto array = ConstantR3FromArray3D<int32>(&builder, a_3d);
2069 Max(array, scalar, /*broadcast_dimensions=*/{});
2070
2071 Array3D<int32> expected({{{3, 9, 2}, {2, 2, 3}}, {{2, 2, 8}, {12, 10, 4}}});
2072 ComputeAndCompareR3<int32>(&builder, expected, {});
2073 }
2074
XLA_TEST_F(ArrayElementwiseOpTest,Max3DAndScalarZeroElementS32s)2075 XLA_TEST_F(ArrayElementwiseOpTest, Max3DAndScalarZeroElementS32s) {
2076 XlaBuilder builder(TestName());
2077 auto scalar = ConstantR0<int32>(&builder, 2);
2078 Array3D<int32> a_3d(2, 0, 3);
2079 auto array = ConstantR3FromArray3D<int32>(&builder, a_3d);
2080 Max(array, scalar, /*broadcast_dimensions=*/{});
2081
2082 Array3D<int32> expected(2, 0, 3);
2083 ComputeAndCompareR3<int32>(&builder, expected, {});
2084 }
2085
XLA_TEST_F(ArrayElementwiseOpTest,Min2DTo1DF32s)2086 XLA_TEST_F(ArrayElementwiseOpTest, Min2DTo1DF32s) {
2087 XlaBuilder builder(TestName());
2088 auto m = ConstantR2<float>(&builder,
2089 {{-10.4f, 64.0f, 6.0f}, {0.1f, 32.0f, 16.1f}});
2090 auto v = ConstantR1<float>(&builder, {-10.2f, 16.4f});
2091 Min(m, v, /*broadcast_dimensions=*/{0});
2092
2093 Array2D<float> expected({{-10.4f, -10.2f, -10.2f}, {0.1f, 16.4f, 16.1f}});
2094 ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
2095 }
2096
XLA_TEST_F(ArrayElementwiseOpTest,Min2DTo1DZeroElementF32s)2097 XLA_TEST_F(ArrayElementwiseOpTest, Min2DTo1DZeroElementF32s) {
2098 XlaBuilder builder(TestName());
2099 auto m = ConstantR2<float>(&builder, {{}, {}});
2100 auto v = ConstantR1<float>(&builder, {-10.2f, 16.4f});
2101 Min(m, v, /*broadcast_dimensions=*/{0});
2102
2103 Array2D<float> expected({{}, {}});
2104 ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
2105 }
2106
XLA_TEST_F(ArrayElementwiseOpTest,Min2DTo4DF32s)2107 XLA_TEST_F(ArrayElementwiseOpTest, Min2DTo4DF32s) {
2108 XlaBuilder builder(TestName());
2109 auto array2d =
2110 ConstantR2<float>(&builder, {{-12.2f, 64.3f, 6.1f}, {0.0f, 32.2f, 2.5f}});
2111 auto array4d = ConstantR4FromArray4D<float>(
2112 &builder, {{{{-12.1f, 32.3f, 6.2f}}, {{0.0f, 32.5f, 3.0f}}},
2113 {{{-2.5f, 64.29f, 6.5f}}, {{-0.01f, 32.25f, 2.6f}}}});
2114 Min(array2d, array4d, /*broadcast_dimensions=*/{1, 3});
2115
2116 Array4D<float> expected(
2117 {{{{-12.2f, 32.3f, 6.1f}}, {{0.0f, 32.2f, 2.5f}}},
2118 {{{-12.2f, 64.29f, 6.1f}}, {{-0.01f, 32.2f, 2.5f}}}});
2119 ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
2120 }
2121
XLA_TEST_F(ArrayElementwiseOpTest,Min2DTo4DZeroElementF32s)2122 XLA_TEST_F(ArrayElementwiseOpTest, Min2DTo4DZeroElementF32s) {
2123 XlaBuilder builder(TestName());
2124 auto array2d =
2125 ConstantR2<float>(&builder, {{-12.2f, 64.3f, 6.1f}, {0.0f, 32.2f, 2.5f}});
2126 Array4D<float> arg(2, 2, 0, 3);
2127 auto array4d = ConstantR4FromArray4D<float>(&builder, arg);
2128 Min(array2d, array4d, /*broadcast_dimensions=*/{1, 3});
2129
2130 Array4D<float> expected(2, 2, 0, 3);
2131 ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
2132 }
2133
XLA_TEST_F(ArrayElementwiseOpTest,MinTenS32s)2134 XLA_TEST_F(ArrayElementwiseOpTest, MinTenS32s) {
2135 XlaBuilder builder(TestName());
2136 auto x = ConstantR1<int32>(&builder, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9});
2137 auto y = ConstantR1<int32>(&builder, {9, 8, 7, 6, 5, 4, 3, 2, 1, 0});
2138 Min(x, y);
2139
2140 std::vector<int32> expected = {0, 1, 2, 3, 4, 4, 3, 2, 1, 0};
2141 ComputeAndCompareR1<int32>(&builder, expected, {});
2142 }
2143
XLA_TEST_F(ArrayElementwiseOpTest,MaxTenS32s)2144 XLA_TEST_F(ArrayElementwiseOpTest, MaxTenS32s) {
2145 XlaBuilder builder(TestName());
2146 auto x = ConstantR1<int32>(&builder, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9});
2147 auto y = ConstantR1<int32>(&builder, {9, 8, 7, 6, 5, 4, 3, 2, 1, 0});
2148 Max(x, y);
2149
2150 std::vector<int32> expected = {9, 8, 7, 6, 5, 5, 6, 7, 8, 9};
2151 ComputeAndCompareR1<int32>(&builder, expected, {});
2152 }
2153
XLA_TEST_F(ArrayElementwiseOpTest,RemTwoConstantS32s)2154 XLA_TEST_F(ArrayElementwiseOpTest, RemTwoConstantS32s) {
2155 XlaBuilder builder(TestName());
2156 auto a = ConstantR1<int32>(&builder, {-3, 26, 2, -1, 1});
2157 auto b = ConstantR1<int32>(&builder, {10, 5, 1, 10, -10});
2158 Rem(a, b);
2159
2160 ComputeAndCompareR1<int32>(&builder, {-3, 1, 0, -1, 1}, {});
2161 }
2162
XLA_TEST_F(ArrayElementwiseOpTest,NonNanClampF32)2163 XLA_TEST_F(ArrayElementwiseOpTest, NonNanClampF32) {
2164 XlaBuilder builder(TestName());
2165 auto minimum = ConstantR1<float>(&builder, {1.0f, -6.5f, 1.0f, 2.25f, 0.0f});
2166 auto argument =
2167 ConstantR1<float>(&builder, {2.0f, 10.0f, -5.0f, 1.0f, 10.0f});
2168 auto maximum = ConstantR1<float>(&builder, {3.0f, 0.5f, 25.5f, 5.0f, 123.0});
2169 Clamp(minimum, argument, maximum);
2170
2171 ComputeAndCompareR1<float>(&builder, {2.0f, 0.5f, 1.0f, 2.25f, 10.0f}, {},
2172 error_spec_);
2173 }
2174
XLA_TEST_F(ArrayElementwiseOpTest,ClampF32)2175 XLA_TEST_F(ArrayElementwiseOpTest, ClampF32) {
2176 SetFastMathDisabled(true);
2177 XlaBuilder builder(TestName());
2178 auto minimum = ConstantR1<float>(&builder, {1.0f, -6.5f, 1.0f, 2.25f, NAN});
2179 auto argument =
2180 ConstantR1<float>(&builder, {2.0f, 10.0f, -5.0f, 1.0f, 10.0f});
2181 auto maximum = ConstantR1<float>(&builder, {3.0f, 0.5f, 25.5f, NAN, 123.0f});
2182 Clamp(minimum, argument, maximum);
2183
2184 ComputeAndCompareR1<float>(&builder, {2.0f, 0.5f, 1.0f, NAN, NAN}, {},
2185 error_spec_);
2186 }
2187
XLA_TEST_F(ArrayElementwiseOpTest,ClampF32Scalar)2188 XLA_TEST_F(ArrayElementwiseOpTest, ClampF32Scalar) {
2189 XlaBuilder builder(TestName());
2190 auto minimum = ConstantR0<float>(&builder, 0.0f);
2191 auto argument = ConstantR1<float>(&builder, {2.0f, 10.0f, -5.0f, 1.0f, 4.0f});
2192 auto maximum = ConstantR0<float>(&builder, 5.0f);
2193 Clamp(minimum, argument, maximum);
2194
2195 ComputeAndCompareR1<float>(&builder, {2.0f, 5.0f, 0.0f, 1.0f, 4.0f}, {},
2196 error_spec_);
2197 }
2198
XLA_TEST_F(ArrayElementwiseOpTest,ClampF32ScalarVector)2199 XLA_TEST_F(ArrayElementwiseOpTest, ClampF32ScalarVector) {
2200 XlaBuilder builder(TestName());
2201 auto min_scalar = ConstantR0<float>(&builder, 0.0f);
2202 auto min_vector =
2203 ConstantR1<float>(&builder, {1.0f, -6.5f, 1.0f, 2.25f, 0.0f});
2204 auto arg_vector =
2205 ConstantR1<float>(&builder, {2.0f, 10.0f, -5.0f, 1.0f, 4.0f});
2206 auto max_scalar = ConstantR0<float>(&builder, 3.0f);
2207 auto max_vector =
2208 ConstantR1<float>(&builder, {3.0f, 0.5f, 25.5f, 5.0f, 123.0});
2209 // Perform clamp with broadcasted scalar and vector.
2210 Add(Add(Clamp(min_vector, arg_vector, max_scalar),
2211 Clamp(min_scalar, arg_vector, max_vector)),
2212 Add(Clamp(min_vector, arg_vector, max_vector),
2213 Clamp(min_scalar, arg_vector, max_scalar)));
2214
2215 ComputeAndCompareR1<float>(&builder, {8.0f, 7.0f, 2.0f, 6.5f, 14.0f}, {},
2216 error_spec_);
2217 }
2218
XLA_TEST_F(ArrayElementwiseOpTest,ClampS32Vector)2219 XLA_TEST_F(ArrayElementwiseOpTest, ClampS32Vector) {
2220 XlaBuilder builder(TestName());
2221 auto min_vector = ConstantR1<int32>(&builder, {1, -6, 1, 2, 0, -5});
2222 auto arg_vector = ConstantR1<int32>(&builder, {2, 10, -5, 1, 4, 10});
2223 auto max_vector = ConstantR1<int32>(&builder, {3, 0, 25, 5, 123, -1});
2224 Clamp(min_vector, arg_vector, max_vector);
2225
2226 ComputeAndCompareR1<int32>(&builder, {2, 0, 1, 2, 4, -1}, {});
2227 }
2228
XLA_TEST_F(ArrayElementwiseOpTest,ClampS32ScalarVector)2229 XLA_TEST_F(ArrayElementwiseOpTest, ClampS32ScalarVector) {
2230 XlaBuilder builder(TestName());
2231 auto min_scalar = ConstantR0<int32>(&builder, 0);
2232 auto min_vector = ConstantR1<int32>(&builder, {1, -6, 1, 2, 0});
2233 auto arg_vector = ConstantR1<int32>(&builder, {2, 10, -5, 1, 4});
2234 auto max_scalar = ConstantR0<int32>(&builder, 3);
2235 auto max_vector = ConstantR1<int32>(&builder, {3, 1, 25, 5, 123});
2236 // Perform clamp with broadcasted scalar and vector.
2237 Add(Add(Clamp(min_vector, arg_vector, max_scalar),
2238 Clamp(min_scalar, arg_vector, max_vector)),
2239 Add(Clamp(min_vector, arg_vector, max_vector),
2240 Clamp(min_scalar, arg_vector, max_scalar)));
2241
2242 ComputeAndCompareR1<int32>(&builder, {8, 8, 2, 6, 14}, {});
2243 }
2244
XLA_TEST_F(ArrayElementwiseOpTest,ClampU32Vector)2245 XLA_TEST_F(ArrayElementwiseOpTest, ClampU32Vector) {
2246 XlaBuilder builder(TestName());
2247 auto min_vector = ConstantR1<uint32>(&builder, {1, 2, 1, 2, 0, ~0u - 4});
2248 auto arg_vector = ConstantR1<uint32>(&builder, {2, 10, 5, 1, 4, 10});
2249 auto max_vector = ConstantR1<uint32>(&builder, {3, 5, 25, 5, 123, ~0u});
2250 Clamp(min_vector, arg_vector, max_vector);
2251
2252 ComputeAndCompareR1<uint32>(&builder, {2, 5, 5, 2, 4, ~0u - 4}, {});
2253 }
2254
XLA_TEST_F(ArrayElementwiseOpTest,ClampU32ScalarVector)2255 XLA_TEST_F(ArrayElementwiseOpTest, ClampU32ScalarVector) {
2256 XlaBuilder builder(TestName());
2257 auto min_scalar = ConstantR0<uint32>(&builder, 0);
2258 auto min_vector = ConstantR1<uint32>(&builder, {1, 0, 1, 2, 0});
2259 auto arg_vector = ConstantR1<uint32>(&builder, {2, 10, 0, 1, 4});
2260 auto max_scalar = ConstantR0<uint32>(&builder, 3);
2261 auto max_vector = ConstantR1<uint32>(&builder, {3, 1, 25, 5, 123});
2262 // Perform clamp with broadcasted scalar and vector.
2263 Add(Add(Clamp(min_vector, arg_vector, max_scalar),
2264 Clamp(min_scalar, arg_vector, max_vector)),
2265 Add(Clamp(min_vector, arg_vector, max_vector),
2266 Clamp(min_scalar, arg_vector, max_scalar)));
2267
2268 ComputeAndCompareR1<uint32>(&builder, {8, 8, 2, 6, 14}, {});
2269 }
2270
XLA_TEST_F(ArrayElementwiseOpTest,AddTwoParametersF32s)2271 XLA_TEST_F(ArrayElementwiseOpTest, AddTwoParametersF32s) {
2272 XlaBuilder builder(TestName());
2273
2274 Literal param0_literal =
2275 LiteralUtil::CreateR1<float>({1.1f, 2.2f, 3.3f, 5.5f});
2276 std::unique_ptr<GlobalData> param0_data =
2277 client_->TransferToServer(param0_literal).ConsumeValueOrDie();
2278
2279 Literal param1_literal =
2280 LiteralUtil::CreateR1<float>({7.2f, 2.3f, 3.4f, 5.6f});
2281 std::unique_ptr<GlobalData> param1_data =
2282 client_->TransferToServer(param1_literal).ConsumeValueOrDie();
2283
2284 auto p0 = Parameter(&builder, 0, param0_literal.shape(), "param0");
2285 auto p1 = Parameter(&builder, 1, param1_literal.shape(), "param1");
2286 Add(p0, p1);
2287
2288 ComputeAndCompareR1<float>(&builder, {8.3f, 4.5f, 6.7f, 11.1f},
2289 {param0_data.get(), param1_data.get()},
2290 error_spec_);
2291 }
2292
XLA_TEST_F(ArrayElementwiseOpTest,AddTwoParametersZeroElementF32s)2293 XLA_TEST_F(ArrayElementwiseOpTest, AddTwoParametersZeroElementF32s) {
2294 XlaBuilder builder(TestName());
2295
2296 Literal param0_literal =
2297 LiteralUtil::CreateR3FromArray3D<float>(Array3D<float>(0, 7, 0));
2298 std::unique_ptr<GlobalData> param0_data =
2299 client_->TransferToServer(param0_literal).ConsumeValueOrDie();
2300
2301 Literal param1_literal =
2302 LiteralUtil::CreateR3FromArray3D<float>(Array3D<float>(0, 7, 0));
2303 std::unique_ptr<GlobalData> param1_data =
2304 client_->TransferToServer(param1_literal).ConsumeValueOrDie();
2305
2306 auto p0 = Parameter(&builder, 0, param0_literal.shape(), "param0");
2307 auto p1 = Parameter(&builder, 1, param1_literal.shape(), "param1");
2308 Add(p0, p1);
2309
2310 Array3D<float> expected(0, 7, 0);
2311 ComputeAndCompareR3<float>(
2312 &builder, expected, {param0_data.get(), param1_data.get()}, error_spec_);
2313 }
2314
XLA_TEST_F(ArrayElementwiseOpTest,AddParameterToConstantF32s)2315 XLA_TEST_F(ArrayElementwiseOpTest, AddParameterToConstantF32s) {
2316 XlaBuilder builder(TestName());
2317
2318 Literal param0_literal =
2319 LiteralUtil::CreateR1<float>({1.1f, 2.2f, 3.3f, 5.5f});
2320 std::unique_ptr<GlobalData> param0_data =
2321 client_->TransferToServer(param0_literal).ConsumeValueOrDie();
2322
2323 auto a = ConstantR1<float>(&builder, {1.1f, 2.2f, 3.3f, 4.4f});
2324 auto p = Parameter(&builder, 0, param0_literal.shape(), "param0");
2325 Add(a, p);
2326
2327 ComputeAndCompareR1<float>(&builder, {2.2f, 4.4f, 6.6f, 9.9f},
2328 {param0_data.get()}, error_spec_);
2329 }
2330
XLA_TEST_F(ArrayElementwiseOpTest,CosF32s)2331 XLA_TEST_F(ArrayElementwiseOpTest, CosF32s) {
2332 XlaBuilder builder(TestName());
2333 auto a = ConstantR1<float>(&builder, {3.14159f, 0.0f, 1.570796f, -0.78539f});
2334 Cos(a);
2335
2336 ComputeAndCompareR1<float>(&builder, {-1.0f, 1.0f, 0.0f, 0.707107f}, {},
2337 error_spec_);
2338 }
2339
XLA_TEST_F(ArrayElementwiseOpTest,SinF32s)2340 XLA_TEST_F(ArrayElementwiseOpTest, SinF32s) {
2341 XlaBuilder builder(TestName());
2342 auto a = ConstantR1<float>(&builder, {3.14159f, 0.0f, 1.570796f, -0.78539f});
2343 Sin(a);
2344
2345 ComputeAndCompareR1<float>(&builder, {0.0f, 0.0f, 1.0f, -0.707107f}, {},
2346 error_spec_);
2347 }
2348
XLA_TEST_F(ArrayElementwiseOpTest,Atan2F32s)2349 XLA_TEST_F(ArrayElementwiseOpTest, Atan2F32s) {
2350 XlaBuilder builder(TestName());
2351 auto inf = std::numeric_limits<float>::infinity();
2352 std::vector<float> ys;
2353 std::vector<float> xs;
2354 for (auto y : {+0.0f, -0.0f, inf, -inf, 5.0f, -3.0f, 2.0f, -8.0f, 1.0f}) {
2355 for (auto x : {+0.0f, -0.0f, inf, -inf, 6.0f, -4.0f, 2.0f, 8.0f}) {
2356 ys.push_back(y);
2357 xs.push_back(x);
2358 }
2359 }
2360 auto y = ConstantR1<float>(&builder, ys);
2361 auto x = ConstantR1<float>(&builder, xs);
2362 Atan2(y, x);
2363
2364 ComputeAndCompare(&builder, {}, error_spec_);
2365 }
2366
XLA_TEST_F(ArrayElementwiseOpTest,Atan2C64s)2367 XLA_TEST_F(ArrayElementwiseOpTest, Atan2C64s) {
2368 XlaBuilder builder(TestName());
2369 auto inf = std::numeric_limits<float>::infinity();
2370 std::vector<std::complex<float>> ys;
2371 std::vector<std::complex<float>> xs;
2372 for (auto y : {+0.0f, -0.0f, inf, -inf, 5.0f, -3.0f, 2.0f, -8.0f, 1.0f}) {
2373 for (auto x : {+0.0f, -0.0f, inf, -inf, 6.0f, -4.0f, 2.0f, 8.0f}) {
2374 ys.push_back(y);
2375 xs.push_back(x);
2376 }
2377 }
2378 auto y = ConstantR1<std::complex<float>>(&builder, ys);
2379 auto x = ConstantR1<std::complex<float>>(&builder, xs);
2380 Atan2(y, x);
2381
2382 ComputeAndCompare(&builder, {}, error_spec_);
2383 }
2384
XLA_TEST_F(ArrayElementwiseOpTest,TanhF32s)2385 XLA_TEST_F(ArrayElementwiseOpTest, TanhF32s) {
2386 XlaBuilder builder(TestName());
2387 auto a = ConstantR1<float>(&builder, {-2.5f, 3.14f, 2.25f});
2388 Tanh(a);
2389
2390 ComputeAndCompareR1<float>(&builder, {-0.986614f, 0.996260f, 0.978026}, {},
2391 error_spec_);
2392 }
2393
XLA_TEST_F(ArrayElementwiseOpTest,TanhF32sVector)2394 XLA_TEST_F(ArrayElementwiseOpTest, TanhF32sVector) {
2395 // This is like the test ArrayElementwiseOpTest.TanhF32s above, except that
2396 // the input tensor is large enough to exercise the vectorized tanh
2397 // implementation on XLA CPU.
2398 XlaBuilder builder(TestName());
2399 auto input_literal = LiteralUtil::CreateR1<float>(
2400 {1.02, -0.32, 0.85, 0.90, 1.23, -0.91, -0.49, 0.80, -0.67, 0.16,
2401 -0.07, 0.39, -0.41, 0.04, 1.36, 1.25, 0.41, 0.65, -1.08, 0.32,
2402 -1.45, -0.77, -1.09, 0.91, -1.03, -0.30, -1.11, -1.17, 1.50, -0.85,
2403 0.04, 1.02, 0.34, -0.61, 0.41, 0.07, -0.02, 1.42, -0.62, 0.81,
2404 0.08, 0.81, -0.30, 1.17, -0.65, -0.44, 0.92, 1.26, -1.29, 1.35,
2405 0.08, -1.24, -0.92, 0.49, 1.17, -0.45, -1.31, -1.44, -0.13, -1.31,
2406 -0.79, 1.41, 1.21, 1.05});
2407 TF_ASSERT_OK_AND_ASSIGN(auto input_data,
2408 client_->TransferToServer(input_literal));
2409
2410 auto input = Parameter(&builder, 0, input_literal.shape(), "input");
2411 Tanh(input);
2412
2413 ComputeAndCompareR1<float>(
2414 &builder,
2415 {0.77009583, -0.30665702, 0.69070244, 0.71401149, 0.84400684,
2416 -0.71985596, -0.45764771, 0.66664988, -0.58278900, 0.16050975,
2417 -0.06770509, 0.36843640, -0.38476998, 0.04018109, 0.87562293,
2418 0.84788644, 0.38603750, 0.57294142, -0.79140943, 0.31032649,
2419 -0.89590985, -0.64770776, -0.79625875, 0.72234446, -0.77389336,
2420 -0.28871772, -0.80428445, -0.82541436, 0.90456349, -0.68856895,
2421 0.03877772, 0.76877952, 0.32561871, -0.54546672, 0.39072621,
2422 0.07273290, -0.01924866, 0.88924897, -0.55283129, 0.67183107,
2423 0.08006320, 0.66944766, -0.29068485, 0.82573754, -0.57170743,
2424 -0.41581789, 0.72739530, 0.85025692, -0.85931867, 0.87357593,
2425 0.07782833, -0.84597743, -0.72748238, 0.45396307, 0.82449573,
2426 -0.42462519, -0.86363792, -0.89368379, -0.12621804, -0.86445558,
2427 -0.65565848, 0.88789743, 0.83566397, 0.78287679},
2428 {input_data.get()},
2429 // The error spec is unusually high here to account for the fact that we
2430 // use a rational interpolant to approximate tanh.
2431 ErrorSpec(0.004, 0.004));
2432 }
2433
XLA_TEST_F(ArrayElementwiseOpTest,ExpF32sVector)2434 XLA_TEST_F(ArrayElementwiseOpTest, ExpF32sVector) {
2435 // The input tensor is large enough to exercise the vectorized exp
2436 // implementation on XLA CPU.
2437 XlaBuilder builder(TestName());
2438
2439 // Just to help make sense of the scales here -- exp(89) saturates float32 and
2440 // exp(-10) is smaller than our error spec.
2441 Literal input_literal = LiteralUtil::CreateR1<float>(
2442 {1.02, -0.32, 0.85, 0.9, 1.23, -0.91, -0.49, 0.8, -1.31,
2443 -1.44, -0.13, -1.31, -0.79, 1.41, 1.21, 1.05, -195.6, -194.5,
2444 -193.4, -192.3, -191.2, -190.1, -189.0, -187.9, -19.6, -18.5, -17.4,
2445 -16.3, -15.2, -14.1, -13.0, -11.9, -10.8, -9.7, -8.6, -7.5,
2446 -6.4, -5.3, -4.2, -3.1, -2.0, -0.9, 0.2, 1.3, 2.4,
2447 3.5, 4.6, 5.7, 6.8, 7.9, 9.0, 10.1, 11.2, 12.3,
2448 13.4, 14.5, 15.6, 16.7, 17.8, 18.9, 20.0, 21.1, 22.2,
2449 23.3, 24.4, 25.5, 26.6, 27.7, 28.8, 29.9, 31.0, 32.1,
2450 68.4, 69.5, 70.6, 71.7, 72.8, 73.9, 75.0, 76.1, 77.2,
2451 78.3, 79.4, 80.5, 81.6, 82.7, 83.8, 84.9, 85.2, 86.3,
2452 86.4, 86.5, 87.6, 87.7, 87.8, 87.9});
2453 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<GlobalData> input_data,
2454 client_->TransferToServer(input_literal));
2455
2456 auto input = Parameter(&builder, 0, input_literal.shape(), "input");
2457 Exp(input);
2458
2459 std::vector<float> expected_result;
2460 int64_t input_size = input_literal.shape().dimensions(0);
2461 expected_result.reserve(input_size);
2462 for (int64_t i = 0; i < input_size; i++) {
2463 expected_result.push_back(std::exp(input_literal.Get<float>({i})));
2464 }
2465
2466 ComputeAndCompareR1<float>(&builder, expected_result, {input_data.get()},
2467 error_spec_);
2468 }
2469
XLA_TEST_F(ArrayElementwiseOpTest,LogF32sVector)2470 XLA_TEST_F(ArrayElementwiseOpTest, LogF32sVector) {
2471 // The input tensor is large enough to exercise the vectorized exp
2472 // implementation on XLA CPU.
2473 XlaBuilder builder(TestName());
2474
2475 Literal input_literal = LiteralUtil::CreateR1<float>(
2476 {-1.29, -1.41, -1.25, -13.5, -11.7, -17.9, -198,
2477 -167, 1.29, 1.41, 1.25, 13.5, 11.7, 17.9,
2478 198, 167, 1.27e+03, 1.33e+03, 1.74e+03, 1.6e+04, 1.84e+04,
2479 1.74e+04, 1.89e+05, 1.9e+05, 1.93e+06, 1.98e+06, 1.65e+06, 1.97e+07,
2480 1.66e+07, 1e+07, 1.98e+08, 1.96e+08, 1.64e+09, 1.58e+09, 1.64e+09,
2481 1.44e+10, 1.5e+10, 1.99e+10, 1.17e+11, 1.08e+11, 1.08e+12, 1.38e+12,
2482 1.4e+12, 1.03e+13, 1.6e+13, 1.99e+13, 1.26e+14, 1.51e+14, 1.33e+15,
2483 1.41e+15, 1.63e+15, 1.39e+16, 1.21e+16, 1.27e+16, 1.28e+17, 1.62e+17,
2484 2e+18, 1.96e+18, 1.81e+18, 1.99e+19, 1.86e+19, 1.61e+19, 1.71e+20,
2485 1.47e+20, 1.83e+21, 1.33e+21, 1.3e+21, 1.35e+22, 1.84e+22, 1.02e+22,
2486 1.81e+23, 1.02e+23, 1.89e+24, 1.49e+24, 1.08e+24, 1.95e+25, 1.1e+25,
2487 1.62e+25, 1.2e+26, 1.41e+26, 1.93e+27, 1.66e+27, 1.62e+27, 1.05e+28,
2488 1.5e+28, 1.79e+28, 1.36e+29, 1.95e+29, 1.5e+30, 1.81e+30, 1.34e+30,
2489 1.7e+31, 1.44e+31, 1.1e+31, 1.4e+32, 1.67e+32, 1.96e+33, 1.11e+33,
2490 1.19e+33, 1.61e+34, 1.05e+34, 1.88e+34, 1.67e+35, 1.7e+35});
2491 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<GlobalData> input_data,
2492 client_->TransferToServer(input_literal));
2493
2494 auto input = Parameter(&builder, 0, input_literal.shape(), "input");
2495 Log(input);
2496
2497 std::vector<float> expected_result;
2498 int64_t input_size = input_literal.shape().dimensions(0);
2499 expected_result.reserve(input_size);
2500 for (int64_t i = 0; i < input_size; i++) {
2501 expected_result.push_back(std::log(input_literal.Get<float>({i})));
2502 }
2503
2504 ComputeAndCompareR1<float>(&builder, expected_result, {input_data.get()},
2505 error_spec_);
2506 }
2507
XLA_TEST_F(ArrayElementwiseOpTest,ClzU32s)2508 XLA_TEST_F(ArrayElementwiseOpTest, ClzU32s) {
2509 XlaBuilder builder(TestName());
2510 auto a = ConstantR1<uint32>(
2511 &builder, {0, 1, 0x10, 0x10000, 0x700000, 0x12345678, 0xF2345678});
2512 Clz(a);
2513
2514 ComputeAndCompareR1<uint32>(&builder, {32, 31, 27, 15, 9, 3, 0}, {});
2515 }
2516
XLA_TEST_F(ArrayElementwiseOpTest,ClzS64s)2517 XLA_TEST_F(ArrayElementwiseOpTest, ClzS64s) {
2518 XlaBuilder builder(TestName());
2519 auto a =
2520 ConstantR1<int64>(&builder, {0, 1, 0x80000000, 0x7FFFFFFFF2345678ul, -1});
2521 Clz(a);
2522
2523 ComputeAndCompareR1<int64>(&builder, {64, 63, 32, 1, 0}, {});
2524 }
2525
XLA_TEST_F(ArrayElementwiseOpTest,AddChainFoldLeft)2526 XLA_TEST_F(ArrayElementwiseOpTest, AddChainFoldLeft) {
2527 // a ------ (add) --------- (add)
2528 // / /
2529 // b -----/ /
2530 // c---------------------/
2531 XlaBuilder builder(TestName());
2532
2533 auto a = ConstantR1<float>(&builder, {1.1f, 2.2f, 3.3f, 4.4f});
2534 auto b = ConstantR1<float>(&builder, {2.1f, 3.2f, 4.3f, 5.4f});
2535 auto c = ConstantR1<float>(&builder, {-3.3f, -15.5f, -7.7f, -29.9f});
2536
2537 auto add = Add(a, b);
2538 Add(add, c);
2539
2540 ComputeAndCompareR1<float>(&builder, {-0.1f, -10.1f, -0.1f, -20.1f}, {},
2541 error_spec_);
2542 }
2543
XLA_TEST_F(ArrayElementwiseOpTest,AddChainFoldRight)2544 XLA_TEST_F(ArrayElementwiseOpTest, AddChainFoldRight) {
2545 // b ------ (add) --------- (add)
2546 // / /
2547 // c -----/ /
2548 // a---------------------/
2549 XlaBuilder builder(TestName());
2550
2551 auto a = ConstantR1<float>(&builder, {91.1f, 2.2f, 3.3f, 4.4f});
2552 auto b = ConstantR1<float>(&builder, {2.1f, 3.2f, 4.3f, 5.4f});
2553 auto c = ConstantR1<float>(&builder, {-3.3f, -15.5f, -7.7f, -29.9f});
2554
2555 auto add = Add(b, c);
2556 Add(a, add);
2557
2558 ComputeAndCompareR1<float>(&builder, {89.9f, -10.1f, -0.1f, -20.1f}, {},
2559 error_spec_);
2560 }
2561
XLA_TEST_F(ArrayElementwiseOpTest,AddWithNeg)2562 XLA_TEST_F(ArrayElementwiseOpTest, AddWithNeg) {
2563 // a ----- (neg) ----- (add)
2564 // /
2565 // b ----- (neg) ----/
2566 XlaBuilder builder(TestName());
2567
2568 auto a = ConstantR1<float>(&builder, {91.1f, 2.2f, 3.3f, 4.4f});
2569 auto b = ConstantR1<float>(&builder, {2.1f, 3.2f, 4.3f, 5.4f});
2570
2571 auto neg_a = Neg(a);
2572 auto neg_b = Neg(b);
2573 Add(neg_a, neg_b);
2574
2575 ComputeAndCompareR1<float>(&builder, {-93.2f, -5.4f, -7.6f, -9.8f}, {},
2576 error_spec_);
2577 }
2578
XLA_TEST_F(ArrayElementwiseOpTest,AddChainTwoSide)2579 XLA_TEST_F(ArrayElementwiseOpTest, AddChainTwoSide) {
2580 // a ------ (add) ------------\
2581 // / \
2582 // b -----/ (add)
2583 // /
2584 // c ------ (add) ------------/
2585 // /
2586 // d -----/
2587 XlaBuilder builder(TestName());
2588
2589 auto a = ConstantR1<float>(&builder, {91.1f, 2.2f, 3.3f, 4.4f});
2590 auto b = ConstantR1<float>(&builder, {2.1f, 3.2f, 4.3f, 5.4f});
2591 auto c = ConstantR1<float>(&builder, {-3.3f, -15.5f, -7.7f, -29.9f});
2592 auto d = ConstantR1<float>(&builder, {-19.0f, 10.0f, -40.0f, 20.2f});
2593
2594 auto add_ab = Add(a, b);
2595 auto add_cd = Add(c, d);
2596 Add(add_ab, add_cd);
2597
2598 ComputeAndCompareR1<float>(&builder, {70.9f, -0.1f, -40.1f, 0.1f}, {},
2599 error_spec_);
2600 }
2601
2602 XLA_TEST_F(ArrayElementwiseOpTest, 2DBinaryOpF32s) {
2603 XlaBuilder builder(TestName());
2604 auto a = ConstantR2<float>(&builder,
2605 {{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}});
2606 auto b = ConstantR2<float>(&builder,
2607 {{-1.5f, 8.14f, 42.0}, {-1.0f, -4.0f, 5.55f}});
2608 Add(a, b);
2609
2610 Array2D<float> expected_array(
2611 {{-4.0f, 11.28f, 43.0f}, {1.25f, -14.0f, 8.88f}});
2612 ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_);
2613 }
2614
XLA_TEST_F(ArrayElementwiseOpTest,ScalarPlus2DF32)2615 XLA_TEST_F(ArrayElementwiseOpTest, ScalarPlus2DF32) {
2616 // Add a scalar + matrix.
2617 XlaBuilder builder(TestName());
2618 auto a = ConstantR2<float>(&builder,
2619 {{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}});
2620 auto scalar = ConstantR0<float>(&builder, 3.0f);
2621 Add(scalar, a);
2622
2623 Array2D<float> expected_array({{0.5f, 6.14f, 4.0f}, {5.25f, -7.0f, 6.33f}});
2624 ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_);
2625 }
2626
2627 XLA_TEST_F(ArrayElementwiseOpTest, 2DPlusScalarF32) {
2628 // Add a matrix + scalar.
2629 XlaBuilder builder(TestName());
2630 auto a = ConstantR2<float>(&builder,
2631 {{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}});
2632 auto scalar = ConstantR0<float>(&builder, 3.0f);
2633 Add(a, scalar);
2634
2635 Array2D<float> expected_array({{0.5f, 6.14f, 4.0f}, {5.25f, -7.0f, 6.33f}});
2636 ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_);
2637 }
2638
XLA_TEST_F(ArrayElementwiseOpTest,Add1DTo2DF32)2639 XLA_TEST_F(ArrayElementwiseOpTest, Add1DTo2DF32) {
2640 // Test simple broadcasting of a R1F32 over R2F32. The vector's size matches
2641 // only dim 0 of the matrix.
2642 XlaBuilder builder(TestName());
2643 auto v = ConstantR1<float>(&builder, {20.0f, 40.0f, 60.0f});
2644 // clang-format off
2645 auto m = ConstantR2<float>(&builder, {
2646 {-2.5f, 3.14f, 1.0f},
2647 {2.25f, -10.0f, 3.33f}});
2648 // clang-format on
2649 Add(v, m, /*broadcast_dimensions=*/{1});
2650 Array2D<float> expected_array(
2651 {{17.5f, 43.14f, 61.0f}, {22.25f, 30.0f, 63.33f}});
2652 ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_);
2653 }
2654
XLA_TEST_F(ArrayElementwiseOpTest,Compare1DTo2DS32Eq)2655 XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Eq) {
2656 // Test broadcasting in Eq comparison.
2657 XlaBuilder builder(TestName());
2658 auto v = ConstantR1<int32>(&builder, {42, 73});
2659 auto m = ConstantR2<int32>(&builder, {{42, 73}, {42, 52}});
2660
2661 // This test exercises both possible broadcast dimensions for a vector/matrix
2662 // comparison.
2663 auto cmp_dim_0 = Eq(v, m, /*broadcast_dimensions=*/{1});
2664 auto cmp_dim_1 = Eq(v, m, /*broadcast_dimensions=*/{0});
2665 Tuple(&builder, {cmp_dim_0, cmp_dim_1});
2666
2667 auto expected = LiteralUtil::MakeTupleFromSlices(
2668 {LiteralUtil::CreateR2<bool>({{true, true}, {true, false}}),
2669 LiteralUtil::CreateR2<bool>({{true, false}, {false, false}})});
2670 ComputeAndCompareTuple(&builder, expected, {}, error_spec_);
2671 }
2672
XLA_TEST_F(ArrayElementwiseOpTest,Compare1DTo2DS32Ne)2673 XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Ne) {
2674 // Test broadcasting in Ne comparison.
2675 XlaBuilder builder(TestName());
2676 auto v = ConstantR1<int32>(&builder, {42, 73});
2677 auto m = ConstantR2<int32>(&builder, {{42, 73}, {42, 52}});
2678 Ne(v, m, /*broadcast_dimensions=*/{1});
2679
2680 const string expected = R"(pred[2,2] {
2681 { 0, 0 },
2682 { 0, 1 }
2683 })";
2684 EXPECT_EQ(expected, ExecuteToString(&builder, {}));
2685 }
2686
XLA_TEST_F(ArrayElementwiseOpTest,Compare1DTo2DS32Ge)2687 XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Ge) {
2688 // Test broadcasting in Ge comparison.
2689 XlaBuilder builder(TestName());
2690 auto v = ConstantR1<int32>(&builder, {1, 2, 3, 4});
2691 auto m = ConstantR2<int32>(&builder, {{1, 0, 5, 6}, {42, 52, 10, 4}});
2692 Ge(v, m, /*broadcast_dimensions=*/{1});
2693
2694 const string expected = R"(pred[2,4] {
2695 { 1, 1, 0, 0 },
2696 { 0, 0, 0, 1 }
2697 })";
2698 EXPECT_EQ(expected, ExecuteToString(&builder, {}));
2699 }
2700
XLA_TEST_F(ArrayElementwiseOpTest,Compare1DTo2DS32Gt)2701 XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Gt) {
2702 // Test broadcasting in Gt comparison.
2703 XlaBuilder builder(TestName());
2704 auto v = ConstantR1<int32>(&builder, {1, 2, 3, 4});
2705 auto m = ConstantR2<int32>(&builder, {{1, 0, 5, 6}, {42, 52, 10, 4}});
2706 Gt(v, m, /*broadcast_dimensions=*/{1});
2707
2708 const string expected = R"(pred[2,4] {
2709 { 0, 1, 0, 0 },
2710 { 0, 0, 0, 0 }
2711 })";
2712 EXPECT_EQ(expected, ExecuteToString(&builder, {}));
2713 }
2714
XLA_TEST_F(ArrayElementwiseOpTest,Compare1DTo2DS32Le)2715 XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Le) {
2716 // Test broadcasting in Le comparison.
2717 XlaBuilder builder(TestName());
2718 auto v = ConstantR1<int32>(&builder, {1, 2, 3, 4});
2719 auto m = ConstantR2<int32>(&builder, {{1, 0, 5, 6}, {42, 52, 10, 4}});
2720 Le(v, m, /*broadcast_dimensions=*/{1});
2721
2722 const string expected = R"(pred[2,4] {
2723 { 1, 0, 1, 1 },
2724 { 1, 1, 1, 1 }
2725 })";
2726 EXPECT_EQ(expected, ExecuteToString(&builder, {}));
2727 }
2728
XLA_TEST_F(ArrayElementwiseOpTest,Compare1DTo2DS32Lt)2729 XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Lt) {
2730 // Test broadcasting in Lt comparison.
2731 XlaBuilder builder(TestName());
2732 auto v = ConstantR1<int32>(&builder, {1, 2, 3, 4});
2733 auto m = ConstantR2<int32>(&builder, {{1, 0, 5, 6}, {42, 52, 10, 4}});
2734 Lt(v, m, /*broadcast_dimensions=*/{1});
2735
2736 const string expected = R"(pred[2,4] {
2737 { 0, 0, 1, 1 },
2738 { 1, 1, 1, 0 }
2739 })";
2740 EXPECT_EQ(expected, ExecuteToString(&builder, {}));
2741 }
2742
XLA_TEST_F(ArrayElementwiseOpTest,Mul2Dby1DF32)2743 XLA_TEST_F(ArrayElementwiseOpTest, Mul2Dby1DF32) {
2744 // Test simple broadcasting of a R1F32 over R2F32 when the order of binary op
2745 // arguments is reversed.
2746 XlaBuilder builder(TestName());
2747 auto m =
2748 ConstantR2<float>(&builder, {{1.5f, 2.5f, 3.5f}, {4.5f, 5.5f, 6.5f}});
2749 auto v = ConstantR1<float>(&builder, {2.0f, 4.0f, 6.0f});
2750 Mul(m, v, /*broadcast_dimensions=*/{1});
2751 Array2D<float> expected_array({{3.0f, 10.0f, 21.0f}, {9.0f, 22.0f, 39.0f}});
2752 ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_);
2753 }
2754
XLA_TEST_F(ArrayElementwiseOpTest,Add2DTo2DWithDegenerateDim1)2755 XLA_TEST_F(ArrayElementwiseOpTest, Add2DTo2DWithDegenerateDim1) {
2756 // Tests broadcasting for arrays with degenerate (size == 1) dimensions.
2757 XlaBuilder builder(TestName());
2758 // m's shape in XLA notation is {3, 2}
2759 // md's shape in XLA notation is {3, 1}
2760 // The result has shape {3, 2}, where md is broadcast over m
2761 auto m = ConstantR2<float>(&builder,
2762 {{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}});
2763 auto md = ConstantR2<float>(&builder, {{10.0f, 20.0f, 30.0f}});
2764 Add(m, md);
2765 Array2D<float> expected_array(
2766 {{7.5f, 23.14f, 31.0f}, {12.25f, 10.0f, 33.33f}});
2767 ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_);
2768 }
2769
XLA_TEST_F(ArrayElementwiseOpTest,Add2DTo2DWithDegenerateDim0)2770 XLA_TEST_F(ArrayElementwiseOpTest, Add2DTo2DWithDegenerateDim0) {
2771 // Tests broadcasting for arrays with degenerate (size == 1) dimensions.
2772 XlaBuilder builder(TestName());
2773 // m's shape in XLA notation is {3, 2}
2774 // md's shape in XLA notation is {1, 2}
2775 // The result has shape {3, 2}, where md is broadcast over m
2776 auto m = ConstantR2<float>(&builder,
2777 {{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}});
2778 auto md = ConstantR2<float>(&builder, {{10.0f}, {20.0f}});
2779 Add(m, md);
2780 Array2D<float> expected_array(
2781 {{7.5f, 13.14f, 11.0f}, {22.25f, 10.0f, 23.33f}});
2782 ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_);
2783 }
2784
XLA_TEST_F(ArrayElementwiseOpTest,Add2DsWithDegenerateDimsOuterProduct)2785 XLA_TEST_F(ArrayElementwiseOpTest, Add2DsWithDegenerateDimsOuterProduct) {
2786 // Tests broadcasting for two degenerate arrays. This kind of broadcasting
2787 // effectively creates an "outer product" operation.
2788 // This is taken from the Numpy docs example at:
2789 // http://docs.scipy.org/doc/numpy-1.10.1/user/basics.broadcasting.html
2790 XlaBuilder builder(TestName());
2791 // a's shape in XLA notation is {1, 4}
2792 // b's shape in XLA notation is {3, 1}
2793 // The result has shape {3, 4}.
2794 auto a = ConstantR2<float>(&builder, {{0.0f}, {10.0f}, {20.0f}, {30.0f}});
2795 auto b = ConstantR2<float>(&builder, {{1.0f, 2.0f, 3.0f}});
2796 Add(a, b);
2797 Array2D<float> expected_array({{1.0f, 2.0f, 3.0f},
2798 {11.0f, 12.0f, 13.0f},
2799 {21.0f, 22.0f, 23.0f},
2800 {31.0f, 32.0f, 33.0f}});
2801 ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_);
2802 }
2803
XLA_TEST_F(ArrayElementwiseOpTest,Add1DTo2DF32TwoWaysOver1)2804 XLA_TEST_F(ArrayElementwiseOpTest, Add1DTo2DF32TwoWaysOver1) {
2805 // Add together a (2,2) array and a (2) array, using dimension 0 for
2806 // broadcasting (though there are two ways to broadcast these shapes).
2807 XlaBuilder builder(TestName());
2808 auto v = ConstantR1<float>(&builder, {20.0f, 40.0f});
2809 auto m = ConstantR2<float>(&builder, {{10.0f, 50.0f}, {77.0f, 88.0f}});
2810 Add(v, m, /*broadcast_dimensions=*/{1});
2811 Array2D<float> expected_array({{30.0f, 90.0f}, {97.0f, 128.0f}});
2812 ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_);
2813 }
2814
XLA_TEST_F(ArrayElementwiseOpTest,Add1DTo2DF32TwoWaysOver0)2815 XLA_TEST_F(ArrayElementwiseOpTest, Add1DTo2DF32TwoWaysOver0) {
2816 // Add together a (2,2) array and a (2) array, using dimension 1 for
2817 // broadcasting (though there are two ways to broadcast these shapes).
2818 XlaBuilder builder(TestName());
2819 auto v = ConstantR1<float>(&builder, {20.0f, 40.0f});
2820 auto m = ConstantR2<float>(&builder, {{10.0f, 50.0f}, {77.0f, 88.0f}});
2821 Add(v, m, /*broadcast_dimensions=*/{0});
2822 Array2D<float> expected_array({{30.0f, 70.0f}, {117.0f, 128.0f}});
2823 ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_);
2824 }
2825
2826 XLA_TEST_F(ArrayElementwiseOpTest, 3DBinaryOpF32s) {
2827 // Binary add of two R3s together
2828 XlaBuilder builder(TestName());
2829 Array3D<float> a_3d({{{1.0f, 2.0f}, {3.0f, 4.0f}, {5.0f, 6.0f}},
2830 {{7.0f, 8.0f}, {9.0f, 10.0f}, {11.0f, 12.0f}}});
2831 auto a = ConstantR3FromArray3D<float>(&builder, a_3d);
2832
2833 Array3D<float> b_3d({{{2.0f, 4.0f}, {6.0f, 8.0f}, {10.0f, 12.0f}},
2834 {{14.0f, 16.0f}, {18.0f, 20.0f}, {22.0f, 24.0f}}});
2835 auto b = ConstantR3FromArray3D<float>(&builder, b_3d);
2836 Add(a, b);
2837
2838 Array3D<float> expected_3d(
2839 {{{3.0f, 6.0f}, {9.0f, 12.0f}, {15.0f, 18.0f}},
2840 {{21.0f, 24.0f}, {27.0f, 30.0f}, {33.0f, 36.0f}}});
2841 ComputeAndCompareR3<float>(&builder, expected_3d, {}, error_spec_);
2842 }
2843
XLA_TEST_F(ArrayElementwiseOpTest,Add1DTo3DTwoWaysOver2)2844 XLA_TEST_F(ArrayElementwiseOpTest, Add1DTo3DTwoWaysOver2) {
2845 // Add together a (2, 3, 2) array with a (2) array, using dimension 0 for
2846 // broadcasting (though there are two ways to broadcast these shapes).
2847 XlaBuilder builder(TestName());
2848 // clang-format off
2849 Array3D<float> a_3d({
2850 {{1.0f, 2.0f},
2851 {3.0f, 4.0f},
2852 {5.0f, 6.0f}},
2853 {{7.0f, 8.0f},
2854 {9.0f, 10.0f},
2855 {11.0f, 12.0f}},
2856 });
2857 // clang-format on
2858 auto a = ConstantR3FromArray3D<float>(&builder, a_3d);
2859 auto v = ConstantR1<float>(&builder, {10.0f, 20.0f});
2860 Add(a, v, /*broadcast_dimensions=*/{2});
2861
2862 Array3D<float> expected_3d(
2863 {{{11.0f, 22.0f}, {13.0f, 24.0f}, {15.0f, 26.0f}},
2864 {{17.0f, 28.0f}, {19.0f, 30.0f}, {21.0f, 32.0f}}});
2865 ComputeAndCompareR3<float>(&builder, expected_3d, {}, error_spec_);
2866 }
2867
XLA_TEST_F(ArrayElementwiseOpTest,Add1DTo3DTwoWaysOver0)2868 XLA_TEST_F(ArrayElementwiseOpTest, Add1DTo3DTwoWaysOver0) {
2869 // Add together a (2, 3, 2) array with a (2) array, using dimension 2 for
2870 // broadcasting (though there are two ways to broadcast these shapes).
2871 XlaBuilder builder(TestName());
2872 // clang-format off
2873 Array3D<float> a_3d({
2874 {{1.0f, 2.0f},
2875 {3.0f, 4.0f},
2876 {5.0f, 6.0f}},
2877 {{7.0f, 8.0f},
2878 {9.0f, 10.0f},
2879 {11.0f, 12.0f}},
2880 });
2881 // clang-format on
2882 auto a = ConstantR3FromArray3D<float>(&builder, a_3d);
2883 auto v = ConstantR1<float>(&builder, {10.0f, 20.0f});
2884 Add(a, v, /*broadcast_dimensions=*/{0});
2885
2886 // clang-format off
2887 Array3D<float> expected_3d({
2888 {{11.0f, 12.0f},
2889 {13.0f, 14.0f},
2890 {15.0f, 16.0f}},
2891 {{27.0f, 28.0f},
2892 {29.0f, 30.0f},
2893 {31.0f, 32.0f}},
2894 });
2895 // clang-format on
2896 ComputeAndCompareR3<float>(&builder, expected_3d, {}, error_spec_);
2897 }
2898
XLA_TEST_F(ArrayElementwiseOpTest,Add2DTo3D)2899 XLA_TEST_F(ArrayElementwiseOpTest, Add2DTo3D) {
2900 // Add together a (2, 3, 2) array with a (3, 2) array, using dimensions {1,2}
2901 // for broadcasting.
2902 XlaBuilder builder(TestName());
2903 // clang-format off
2904 Array3D<float> a_3d({
2905 {{1.0f, 2.0f},
2906 {3.0f, 4.0f},
2907 {5.0f, 6.0f}},
2908 {{7.0f, 8.0f},
2909 {9.0f, 10.0f},
2910 {11.0f, 12.0f}},
2911 });
2912 auto a = ConstantR3FromArray3D<float>(&builder, a_3d);
2913 auto m = ConstantR2<float>(&builder, {
2914 {10.0f, 20.0f, 30.0f},
2915 {40.0f, 50.0f, 60.0f},
2916 });
2917 Add(a, m, /*broadcast_dimensions=*/{0, 1});
2918
2919 Array3D<float> expected_3d({
2920 {{11.0f, 12.0f},
2921 {23.0f, 24.0f},
2922 {35.0f, 36.0f}},
2923 {{47.0f, 48.0f},
2924 {59.0f, 60.0f},
2925 {71.0f, 72.0f}},
2926 });
2927 // clang-format on
2928 ComputeAndCompareR3<float>(&builder, expected_3d, {}, error_spec_);
2929 }
2930
XLA_TEST_F(ArrayElementwiseOpTest,CompareGtR3F32sWithDegenerateDim2)2931 XLA_TEST_F(ArrayElementwiseOpTest, CompareGtR3F32sWithDegenerateDim2) {
2932 // Comparison between two 3D arrays of compatible shapes:
2933 // (2, 3, 2) and (2, 3, 1): expected to produce a (2, 3, 2) shape of PREDs.
2934 XlaBuilder builder(TestName());
2935 Array3D<float> a_3d({{{1.0f, 2.0f}, {3.0f, 4.0f}, {5.0f, 6.0f}},
2936 {{7.0f, 8.0f}, {9.0f, 10.0f}, {11.0f, 12.0f}}});
2937 auto a = ConstantR3FromArray3D<float>(&builder, a_3d);
2938
2939 Array3D<float> b_3d({{{7.0f, 1.0f}, {3.0f, 10.0f}, {15.0f, 6.0f}}});
2940 auto b = ConstantR3FromArray3D<float>(&builder, b_3d);
2941
2942 Gt(a, b);
2943
2944 Array3D<int> expected_3d(
2945 {{{0, 1}, {0, 0}, {0, 0}}, {{0, 1}, {1, 0}, {0, 1}}});
2946 const string expected = R"(pred[2,3,2] {
2947 {
2948 { 0, 1 },
2949 { 0, 0 },
2950 { 0, 0 }
2951 },
2952 {
2953 { 0, 1 },
2954 { 1, 0 },
2955 { 0, 1 }
2956 }
2957 })";
2958 EXPECT_EQ(expected, ExecuteToString(&builder, {}));
2959 }
2960
2961 XLA_TEST_F(ArrayElementwiseOpTest, 4DBinaryOpF32s) {
2962 XlaBuilder builder(TestName());
2963
2964 std::unique_ptr<Array4D<float>> operand_a_4d(new Array4D<float>(2, 3, 4, 5));
2965 std::unique_ptr<Array4D<float>> operand_b_4d(new Array4D<float>(2, 3, 4, 5));
2966 std::unique_ptr<Array4D<float>> expected_4d(new Array4D<float>(2, 3, 4, 5));
2967 float value = 0.0;
2968 for (int64_t p = 0; p < 2; ++p) {
2969 for (int64_t z = 0; z < 3; ++z) {
2970 for (int64_t y = 0; y < 4; ++y) {
2971 for (int64_t x = 0; x < 5; ++x) {
2972 (*operand_a_4d)(p, z, y, x) = value;
2973 (*operand_b_4d)(p, z, y, x) = 2.0 * value;
2974 (*expected_4d)(p, z, y, x) = 3.0 * value;
2975 value += 0.1;
2976 }
2977 }
2978 }
2979 }
2980
2981 auto a = ConstantR4FromArray4D<float>(&builder, *operand_a_4d);
2982 auto b = ConstantR4FromArray4D<float>(&builder, *operand_b_4d);
2983 Add(a, b);
2984
2985 ComputeAndCompareR4<float>(&builder, *expected_4d, {}, error_spec_);
2986 }
2987
XLA_TEST_F(ArrayElementwiseOpTest,R4PlusR1InDim1)2988 XLA_TEST_F(ArrayElementwiseOpTest, R4PlusR1InDim1) {
2989 XlaBuilder builder(TestName());
2990
2991 std::unique_ptr<Array4D<float>> operand_a_4d(new Array4D<float>(2, 3, 4, 5));
2992 std::unique_ptr<Array4D<float>> expected_4d(new Array4D<float>(2, 3, 4, 5));
2993 std::vector<float> operand_b_1d(3);
2994 std::iota(operand_b_1d.begin(), operand_b_1d.end(), 1.0);
2995
2996 float value = 0.0;
2997 for (int64_t p = 0; p < 2; ++p) {
2998 for (int64_t z = 0; z < 3; ++z) {
2999 for (int64_t y = 0; y < 4; ++y) {
3000 for (int64_t x = 0; x < 5; ++x) {
3001 (*operand_a_4d)(p, z, y, x) = value;
3002 (*expected_4d)(p, z, y, x) = value + operand_b_1d[z];
3003 value += 0.1;
3004 }
3005 }
3006 }
3007 }
3008
3009 auto a = ConstantR4FromArray4D<float>(&builder, *operand_a_4d);
3010 auto b = ConstantR1<float>(&builder, operand_b_1d);
3011 Add(a, b, {1});
3012
3013 ComputeAndCompareR4<float>(&builder, *expected_4d, {}, error_spec_);
3014 }
3015
XLA_TEST_F(ArrayElementwiseOpTest,R4_16x16x2x2_Plus_R1_16)3016 XLA_TEST_F(ArrayElementwiseOpTest, R4_16x16x2x2_Plus_R1_16) {
3017 constexpr int d0 = 16;
3018 constexpr int d1 = 16;
3019 constexpr int d2 = 2;
3020 constexpr int d3 = 2;
3021 Array4D<float> r4(d0, d1, d2, d3);
3022 r4.Fill(1.0);
3023 std::vector<float> r1(d1);
3024 std::iota(r1.begin(), r1.end(), 1.0);
3025
3026 XlaBuilder builder(TestName());
3027 Literal a_literal = LiteralUtil::CreateR4FromArray4DWithLayout(
3028 r4, LayoutUtil::MakeLayout({0, 1, 2, 3}));
3029 auto a = ConstantLiteral(&builder, a_literal);
3030 auto b = ConstantR1<float>(&builder, r1);
3031 Add(a, b, {1});
3032
3033 for (int i0 = 0; i0 < d0; ++i0) {
3034 for (int i1 = 0; i1 < d1; ++i1) {
3035 for (int i2 = 0; i2 < d2; ++i2) {
3036 for (int i3 = 0; i3 < d3; ++i3) {
3037 r4(i0, i1, i2, i3) += r1[i1];
3038 }
3039 }
3040 }
3041 }
3042 ComputeAndCompareR4<float>(&builder, r4, {}, error_spec_);
3043 }
3044
3045 // Show that we can't add two opaques.
XLA_TEST_F(ArrayElementwiseOpTest,CannotAddOpaques)3046 XLA_TEST_F(ArrayElementwiseOpTest, CannotAddOpaques) {
3047 XlaBuilder builder(TestName());
3048 auto shape = ShapeUtil::MakeOpaqueShape();
3049 auto x = Parameter(&builder, 0, shape, "x");
3050 Add(x, x);
3051 auto computation_status = builder.Build();
3052 ASSERT_FALSE(computation_status.ok());
3053 EXPECT_THAT(computation_status.status().ToString(),
3054 ::testing::ContainsRegex(
3055 "Expected array argument for lhs of binary operation"));
3056 }
3057
XLA_TEST_F(ArrayElementwiseOpTest,IdentityBroadcastOfSameRankIsAllowed)3058 XLA_TEST_F(ArrayElementwiseOpTest, IdentityBroadcastOfSameRankIsAllowed) {
3059 XlaBuilder builder(TestName());
3060 auto a = ConstantR2<float>(&builder,
3061 {{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}});
3062 auto b = ConstantR2<float>(&builder,
3063 {{-1.5f, 8.14f, 42.0}, {-1.0f, -4.0f, 5.55f}});
3064 Add(a, b, /*broadcast_dimensions=*/{0, 1});
3065
3066 Array2D<float> expected_array(
3067 {{-4.0f, 11.28f, 43.0f}, {1.25f, -14.0f, 8.88f}});
3068 ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_);
3069 }
3070
XLA_TEST_F(ArrayElementwiseOpTest,NonIdentityBroadcastOfSameRankIsDisallowed)3071 XLA_TEST_F(ArrayElementwiseOpTest, NonIdentityBroadcastOfSameRankIsDisallowed) {
3072 XlaBuilder builder(TestName());
3073 auto a = ConstantR2<float>(&builder,
3074 {{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}});
3075 auto b = ConstantR2<float>(&builder,
3076 {{-1.5f, 8.14f, 42.0}, {-1.0f, -4.0f, 5.55f}});
3077 Add(a, b, /*broadcast_dimensions=*/{1, 0});
3078
3079 auto computation_status = builder.Build();
3080 ASSERT_FALSE(computation_status.ok());
3081 EXPECT_THAT(computation_status.status().error_message(),
3082 ::testing::ContainsRegex("must.*be the identity"));
3083 }
3084
3085 // Regression test for b/31927799. "slice - y" is fused and requires implicit
3086 // broadcast.
XLA_TEST_F(ArrayElementwiseOpTest,ImplicitBroadcastInFusedExpressions)3087 XLA_TEST_F(ArrayElementwiseOpTest, ImplicitBroadcastInFusedExpressions) {
3088 XlaBuilder builder(TestName());
3089 auto x_literal = LiteralUtil::CreateR1<float>({1, 2, 3});
3090 auto y_literal = LiteralUtil::CreateR1<float>({4, 5});
3091 auto x_data = client_->TransferToServer(x_literal).ConsumeValueOrDie();
3092 auto y_data = client_->TransferToServer(y_literal).ConsumeValueOrDie();
3093
3094 auto x = Parameter(&builder, 0, x_literal.shape(), "x");
3095 auto y = Parameter(&builder, 1, y_literal.shape(), "y");
3096 auto slice = Slice(x, {1}, {2}, {1});
3097 Sub(slice, y);
3098
3099 ComputeAndCompareR1<float>(&builder, {-2, -3}, {x_data.get(), y_data.get()},
3100 error_spec_);
3101 }
3102
3103 INSTANTIATE_TEST_CASE_P(ArrayElementwiseOpTestParamCount,
3104 ArrayElementwiseOpTestParamCount,
3105 ::testing::Values(127, 128, 129, 17 * 4096));
3106
3107 } // namespace
3108 } // namespace xla
3109