• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <cmath>
17 #include <limits>
18 #include <memory>
19 #include <numeric>
20 #include <vector>
21 
22 #include "absl/base/casts.h"
23 #include "absl/types/span.h"
24 #include "tensorflow/compiler/xla/array2d.h"
25 #include "tensorflow/compiler/xla/array3d.h"
26 #include "tensorflow/compiler/xla/array4d.h"
27 #include "tensorflow/compiler/xla/client/global_data.h"
28 #include "tensorflow/compiler/xla/client/local_client.h"
29 #include "tensorflow/compiler/xla/client/xla_builder.h"
30 #include "tensorflow/compiler/xla/layout_util.h"
31 #include "tensorflow/compiler/xla/literal.h"
32 #include "tensorflow/compiler/xla/statusor.h"
33 #include "tensorflow/compiler/xla/test.h"
34 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
35 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
36 #include "tensorflow/compiler/xla/tests/test_macros.h"
37 #include "tensorflow/compiler/xla/types.h"
38 #include "tensorflow/core/platform/types.h"
39 
40 namespace xla {
41 namespace {
42 
43 class ArrayElementwiseOpTest : public ClientLibraryTestBase {
44  public:
45   ErrorSpec error_spec_{0.0001, 0.0001};
46   ErrorSpec strict_error_spec_{3.6e-15, 3.6e-15};
47 };
48 
49 class ArrayElementwiseOpTestParamCount
50     : public ArrayElementwiseOpTest,
51       public ::testing::WithParamInterface<int> {};
52 
XLA_TEST_F(ArrayElementwiseOpTest,NegConstantZeroElementF32)53 XLA_TEST_F(ArrayElementwiseOpTest, NegConstantZeroElementF32) {
54   XlaBuilder builder(TestName());
55   auto a = ConstantR1<float>(&builder, {});
56   Neg(a);
57 
58   ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
59 }
60 
XLA_TEST_F(ArrayElementwiseOpTest,NegConstantF32)61 XLA_TEST_F(ArrayElementwiseOpTest, NegConstantF32) {
62   XlaBuilder builder(TestName());
63   auto a = ConstantR1<float>(&builder, {-2.5f, 3.14f, 2.25f, -10.0f, 6.0f});
64   Neg(a);
65 
66   ComputeAndCompareR1<float>(&builder, {2.5f, -3.14f, -2.25f, 10.0f, -6.0f}, {},
67                              error_spec_);
68 }
69 
XLA_TEST_F(ArrayElementwiseOpTest,NegConstantF64)70 XLA_TEST_F(ArrayElementwiseOpTest, NegConstantF64) {
71   XlaBuilder builder(TestName());
72   auto a = ConstantR1<double>(&builder, {-2.5, 3.14, 2.25, -10.0, 6.0});
73   Neg(a);
74 
75   ComputeAndCompare(&builder, {}, strict_error_spec_);
76 }
77 
XLA_TEST_F(ArrayElementwiseOpTest,NegConstantS32)78 XLA_TEST_F(ArrayElementwiseOpTest, NegConstantS32) {
79   XlaBuilder builder(TestName());
80   auto a = ConstantR1<int32>(&builder,
81                              {-1, 0, 1, 324, std::numeric_limits<int32>::min(),
82                               std::numeric_limits<int32>::max()});
83   Neg(a);
84 
85   // -min == min for int32 due to an overflow. In C++ it is undefined behavior
86   // to do this calculation. For XLA we have not specified that, so it
87   // ought to work.
88   ComputeAndCompareR1<int32>(&builder,
89                              {1, 0, -1, -324, std::numeric_limits<int32>::min(),
90                               -std::numeric_limits<int32>::max()},
91                              {});
92 }
93 
XLA_TEST_F(ArrayElementwiseOpTest,NegConstantZeroElementC64)94 XLA_TEST_F(ArrayElementwiseOpTest, NegConstantZeroElementC64) {
95   XlaBuilder builder(TestName());
96   auto a = ConstantR1<complex64>(&builder, {});
97   Neg(a);
98 
99   ComputeAndCompareR1<complex64>(&builder, {}, {}, error_spec_);
100 }
101 
XLA_TEST_F(ArrayElementwiseOpTest,NegConstantC64)102 XLA_TEST_F(ArrayElementwiseOpTest, NegConstantC64) {
103   XlaBuilder builder(TestName());
104   auto a = ConstantR1<complex64>(
105       &builder, {{-2.5f, 1.0f}, {0.0f, 3.14f}, {2.25f, -1.0f}, {-10.0f, 0.0f}});
106   Neg(a);
107 
108   ComputeAndCompareR1<complex64>(
109       &builder, {{2.5f, -1.0f}, {0.0f, -3.14f}, {-2.25f, 1.0f}, {10.0f, 0.0f}},
110       {}, error_spec_);
111 }
112 
XLA_TEST_F(ArrayElementwiseOpTest,NegConstantS64)113 XLA_TEST_F(ArrayElementwiseOpTest, NegConstantS64) {
114   XlaBuilder builder(TestName());
115   auto a =
116       ConstantR1<int64>(&builder, {
117                                       -1,
118                                       1,
119                                       0,
120                                       0x12345678,
121                                       static_cast<int64>(0xffffffff12345678l),
122                                       static_cast<int64>(0x8000000000000000LL),
123                                       static_cast<int64>(0x8000000000000001LL),
124                                   });
125   Neg(a);
126   LOG(INFO) << -static_cast<int64>(0x7FFFFFFFFFFFFFFFLL);
127 
128   ComputeAndCompareR1<int64>(&builder,
129                              {
130                                  1,
131                                  -1,
132                                  0,
133                                  -0x12345678,
134                                  0xedcba988,
135                                  static_cast<int64>(0x8000000000000000LL),
136                                  -static_cast<int64>(0x8000000000000001LL),
137                              },
138                              {});
139 }
140 
XLA_TEST_F(ArrayElementwiseOpTest,IsFiniteZeroElementF32s)141 XLA_TEST_F(ArrayElementwiseOpTest, IsFiniteZeroElementF32s) {
142   XlaBuilder builder(TestName());
143   auto a = ConstantR1<float>(&builder, {});
144   IsFinite(a);
145 
146   ComputeAndCompareR1<bool>(&builder, {}, {});
147 }
148 
XLA_TEST_F(ArrayElementwiseOpTest,IntPow)149 XLA_TEST_F(ArrayElementwiseOpTest, IntPow) {
150   XlaBuilder builder(TestName());
151   XlaOp lhs =
152       ConstantR1<int32>(&builder, {0, 1, 2, 3, 4, 5, -1, -2, 3, 5, 3, 1});
153   XlaOp rhs =
154       ConstantR1<int32>(&builder, {0, 3, 3, 3, 3, 3, 2, 3, 2, 10, -100, -2});
155   Pow(lhs, rhs);
156 
157   std::vector<int32> expected = {1, 1, 8, 27, 64, 125, 1, -8, 9, 9765625, 0, 1};
158 
159   ComputeAndCompareR1<int32>(&builder, expected, {});
160 }
161 
XLA_TEST_F(ArrayElementwiseOpTest,IntPowLarge)162 XLA_TEST_F(ArrayElementwiseOpTest, IntPowLarge) {
163   XlaBuilder builder(TestName());
164   XlaOp lhs = ConstantR1<int64>(&builder, {2});
165   XlaOp rhs = ConstantR1<int64>(&builder, {62});
166   Pow(lhs, rhs);
167 
168   std::vector<int64> expected = {4611686018427387904};
169 
170   ComputeAndCompareR1<int64>(&builder, expected, {});
171 }
172 
173 // A non-canonical quiet NaN value.
174 static const float kNonCanonicalNaN = absl::bit_cast<float>(0x7FD01234);
175 
XLA_TEST_F(ArrayElementwiseOpTest,IsFiniteScalarF32)176 XLA_TEST_F(ArrayElementwiseOpTest, IsFiniteScalarF32) {
177   XlaBuilder builder(TestName());
178   IsFinite(ConstantR0<float>(&builder, NAN));
179   ComputeAndCompareR0<bool>(&builder, false, {});
180 
181   EXPECT_TRUE(std::isnan(kNonCanonicalNaN));
182   IsFinite(ConstantR0<float>(&builder, kNonCanonicalNaN));
183   ComputeAndCompareR0<bool>(&builder, false, {});
184 
185   const float inf = std::numeric_limits<float>::infinity();
186   IsFinite(ConstantR0<float>(&builder, inf));
187   ComputeAndCompareR0<bool>(&builder, false, {});
188 
189   IsFinite(ConstantR0<float>(&builder, -inf));
190   ComputeAndCompareR0<bool>(&builder, false, {});
191 
192   IsFinite(ConstantR0<float>(&builder, 0.0f));
193   ComputeAndCompareR0<bool>(&builder, true, {});
194 }
195 
XLA_TEST_F(ArrayElementwiseOpTest,IsFiniteR1F32s)196 XLA_TEST_F(ArrayElementwiseOpTest, IsFiniteR1F32s) {
197   XlaBuilder builder(TestName());
198   const float inf = std::numeric_limits<float>::infinity();
199   EXPECT_TRUE(std::isnan(kNonCanonicalNaN));
200   auto a = ConstantR1<float>(&builder,
201                              {{NAN, 7.0f, kNonCanonicalNaN, -1.0f, inf, -inf}});
202   IsFinite(a);
203 
204   ComputeAndCompareR1<bool>(&builder, {false, true, false, true, false, false},
205                             {});
206 }
207 
XLA_TEST_F(ArrayElementwiseOpTest,AddTwoConstantF32s)208 XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantF32s) {
209   XlaBuilder builder(TestName());
210   auto a = ConstantR1<float>(&builder, {-2.5f, 3.14f, 2.25f, -10.0f, 6.0f});
211   auto b = ConstantR1<float>(&builder, {100.0f, 3.13f, 2.75f, 10.5f, -999.0f});
212   Add(a, b);
213 
214   ComputeAndCompareR1<float>(&builder, {97.5f, 6.27f, 5.0f, 0.5f, -993.0f}, {},
215                              error_spec_);
216 }
217 
XLA_TEST_F(ArrayElementwiseOpTest,AddTwoConstantZeroElementF32s)218 XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantZeroElementF32s) {
219   XlaBuilder builder(TestName());
220   auto a = ConstantR1<float>(&builder, {});
221   auto b = ConstantR1<float>(&builder, {});
222   Add(a, b);
223 
224   ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
225 }
226 
XLA_TEST_F(ArrayElementwiseOpTest,AddTwoConstantC64s)227 XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantC64s) {
228   XlaBuilder builder(TestName());
229   auto a = ConstantR1<complex64>(
230       &builder, {{-2.5f, 0.0f}, {0.0f, 3.14f}, {2.25f, 0.0f}, {1.0f, -10.0f}});
231   auto b = ConstantR1<complex64>(
232       &builder, {{100.0f, 0.0f}, {3.13f, 0.0f}, {2.75f, 1.0f}, {-2.0f, 10.5f}});
233   Add(a, b);
234 
235   ComputeAndCompareR1<complex64>(
236       &builder, {97.5f, {3.13f, 3.14f}, {5.0f, 1.0f}, {-1.0f, 0.5f}}, {},
237       error_spec_);
238 }
239 
XLA_TEST_F(ArrayElementwiseOpTest,AddTwoConstantZeroElementC64s)240 XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantZeroElementC64s) {
241   XlaBuilder builder(TestName());
242   auto a = ConstantR1<complex64>(&builder, {});
243   auto b = ConstantR1<complex64>(&builder, {});
244   Add(a, b);
245 
246   ComputeAndCompareR1<complex64>(&builder, {}, {}, error_spec_);
247 }
248 
XLA_TEST_F(ArrayElementwiseOpTest,AddTwoConstantU64s)249 XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantU64s) {
250   XlaBuilder b(TestName());
251 
252   std::vector<uint64> lhs{0xFFFFFFFF,
253                           static_cast<uint64>(-1),
254                           0,
255                           0,
256                           0x7FFFFFFFFFFFFFFFLL,
257                           0x7FFFFFFFFFFFFFFLL,
258                           0x8000000000000000ULL,
259                           0x8000000000000000ULL,
260                           1};
261   Literal lhs_literal = LiteralUtil::CreateR1<uint64>({lhs});
262   auto lhs_param = Parameter(&b, 0, lhs_literal.shape(), "lhs_param");
263   std::unique_ptr<GlobalData> lhs_data =
264       client_->TransferToServer(lhs_literal).ConsumeValueOrDie();
265 
266   std::vector<uint64> rhs{1,
267                           0x7FFFFFFFFFFFFFFLL,
268                           0x7FFFFFFFFFFFFFFFLL,
269                           0x8000000000000000ULL,
270                           0,
271                           static_cast<uint64>(-1),
272                           0,
273                           1,
274                           0x8000000000000000ULL};
275   Literal rhs_literal = LiteralUtil::CreateR1<uint64>({rhs});
276   auto rhs_param = Parameter(&b, 1, rhs_literal.shape(), "rhs_param");
277   std::unique_ptr<GlobalData> rhs_data =
278       client_->TransferToServer(rhs_literal).ConsumeValueOrDie();
279 
280   Add(lhs_param, rhs_param);
281 
282   std::vector<uint64> expected(lhs.size());
283   for (int64_t i = 0; i < lhs.size(); ++i) {
284     expected[i] = lhs[i] + rhs[i];
285   }
286 
287   ComputeAndCompareR1<uint64>(&b, expected, {lhs_data.get(), rhs_data.get()});
288 }
289 
XLA_TEST_F(ArrayElementwiseOpTest,SubTwoConstantS64s)290 XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantS64s) {
291   XlaBuilder b(TestName());
292 
293   std::vector<int64> lhs{static_cast<int64>(0x8000000000000000LL),
294                          static_cast<int64>(0x8000000000000000LL),
295                          -1,
296                          0x7FFFFFFFFFFFFFFLL,
297                          0x7FFFFFFFFFFFFFFFLL,
298                          1,
299                          0,
300                          -1};
301   Literal lhs_literal = LiteralUtil::CreateR1<int64>({lhs});
302   auto lhs_param = Parameter(&b, 0, lhs_literal.shape(), "lhs_param");
303   std::unique_ptr<GlobalData> lhs_data =
304       client_->TransferToServer(lhs_literal).ConsumeValueOrDie();
305 
306   std::vector<int64> rhs{-1,
307                          0,
308                          static_cast<int64>(0x8000000000000000LL),
309                          1,
310                          0,
311                          0x7FFFFFFFFFFFFFFLL,
312                          0x7FFFFFFFFFFFFFFFLL,
313                          0x7FFFFFFFFFFFFFFFLL};
314   Literal rhs_literal = LiteralUtil::CreateR1<int64>({rhs});
315   auto rhs_param = Parameter(&b, 1, rhs_literal.shape(), "rhs_param");
316   std::unique_ptr<GlobalData> rhs_data =
317       client_->TransferToServer(rhs_literal).ConsumeValueOrDie();
318 
319   Sub(lhs_param, rhs_param);
320 
321   std::vector<int64> expected(lhs.size());
322   for (int64_t i = 0; i < lhs.size(); ++i) {
323     expected[i] = lhs[i] - rhs[i];
324   }
325 
326   ComputeAndCompareR1<int64>(&b, expected, {lhs_data.get(), rhs_data.get()});
327 }
328 
XLA_TEST_F(ArrayElementwiseOpTest,CmpTwoConstantU64s)329 XLA_TEST_F(ArrayElementwiseOpTest, CmpTwoConstantU64s) {
330   XlaBuilder b(TestName());
331 
332   std::vector<uint64> lhs{static_cast<uint64>(0x8000000000000000ULL)};
333   Literal lhs_literal = LiteralUtil::CreateR1<uint64>({lhs});
334   auto lhs_param = Parameter(&b, 0, lhs_literal.shape(), "lhs_param");
335 
336   std::vector<uint64> rhs{static_cast<uint64>(0x7FFFFFFFFFFFFFFFULL)};
337   Literal rhs_literal = LiteralUtil::CreateR1<uint64>({rhs});
338   auto rhs_param = Parameter(&b, 1, rhs_literal.shape(), "rhs_param");
339 
340   Lt(lhs_param, rhs_param);
341 
342   ComputeAndCompare(&b, {std::move(lhs_literal), std::move(rhs_literal)});
343 }
344 
TEST_P(ArrayElementwiseOpTestParamCount,AddManyValues)345 TEST_P(ArrayElementwiseOpTestParamCount, AddManyValues) {
346   const int count = GetParam();
347   XlaBuilder builder(TestName());
348   std::vector<float> a_values;
349   std::vector<float> b_values;
350   for (int i = 0; i < count; ++i) {
351     a_values.push_back(i / static_cast<float>(count));
352     b_values.push_back(2 * i / static_cast<float>(count + 2));
353   }
354 
355   Literal a_literal = LiteralUtil::CreateR1<float>({a_values});
356   std::unique_ptr<GlobalData> a_data =
357       client_->TransferToServer(a_literal).ConsumeValueOrDie();
358   auto a_constant = ConstantR1<float>(&builder, a_values);
359   auto a_param = Parameter(&builder, 0, a_literal.shape(), "a_param");
360 
361   Literal b_literal = LiteralUtil::CreateR1<float>({b_values});
362   std::unique_ptr<GlobalData> b_data =
363       client_->TransferToServer(b_literal).ConsumeValueOrDie();
364   auto b_param = Parameter(&builder, 1, a_literal.shape(), "b_param");
365   auto b_constant = ConstantR1<float>(&builder, b_values);
366 
367   auto sum1 = Add(a_constant, b_param);
368   auto sum2 = Add(a_constant, b_constant);
369   auto sum3 = Add(a_param, b_param);
370   auto sum4 = Add(a_param, b_constant);
371 
372   auto sum = Add(sum1, sum2);
373   sum = Add(sum, sum3);
374   sum = Add(sum, sum4);
375 
376   std::vector<float> expected;
377   for (int64_t i = 0; i < count; ++i) {
378     expected.push_back(4 * (a_values[i] + b_values[i]));
379   }
380 
381   ComputeAndCompareR1<float>(&builder, expected, {a_data.get(), b_data.get()},
382                              error_spec_);
383 }
384 
XLA_TEST_F(ArrayElementwiseOpTest,DeeplyNestedAddWithSlices)385 XLA_TEST_F(ArrayElementwiseOpTest, DeeplyNestedAddWithSlices) {
386   XlaBuilder builder(TestName());
387   std::vector<float> values(30, 0.0);
388   auto a_literal = LiteralUtil::CreateR1<float>(values);
389   auto a = Parameter(&builder, 0, a_literal.shape(), "x");
390   auto b_literal = LiteralUtil::CreateR1<float>(values);
391   auto b = Parameter(&builder, 1, b_literal.shape(), "x");
392 
393   // Construct a sequence of diamond-shaped gadgets like this:
394   //
395   //      add
396   //    /    \
397   //  slice  slice
398   //     \   /
399   //      add
400   //
401   // Each 'left' slice removes the last element, each 'right' slice removes the
402   // first element. In this way, we index into the add with different
403   // multi-dimensional index arrays, which defeats the caching we use to avoid
404   // exponential compile time.
405   std::function<XlaOp(int64_t)> generate_recursive =
406       [&](int64_t slice_size) -> XlaOp {
407     if (slice_size == values.size()) {
408       return Add(a, b);
409     }
410     XlaOp param = generate_recursive(slice_size + 1);
411     auto slice1 = Slice(param, {0}, {slice_size}, {1});
412     auto slice2 = Slice(param, {1}, {slice_size + 1}, {1});
413     return Add(slice1, slice2);
414   };
415   generate_recursive(1);
416   auto a_data = client_->TransferToServer(a_literal).ConsumeValueOrDie();
417   auto b_data = client_->TransferToServer(b_literal).ConsumeValueOrDie();
418   ComputeAndCompareR1<float>(&builder, {0.0}, {a_data.get(), b_data.get()});
419 }
420 
XLA_TEST_F(ArrayElementwiseOpTest,SubTwoConstantF32s)421 XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantF32s) {
422   XlaBuilder builder(TestName());
423   auto a = ConstantR1<float>(&builder, {-2.5f, 3.14f, 2.25f, -10.0f, 6.0f});
424   auto b = ConstantR1<float>(&builder, {100.0f, 3.13f, 2.75f, 10.5f, -999.0f});
425   Sub(a, b);
426 
427   ComputeAndCompareR1<float>(&builder, {-102.5f, 0.01f, -0.5f, -20.5f, 1005.0f},
428                              {}, error_spec_);
429 }
430 
XLA_TEST_F(ArrayElementwiseOpTest,SubTwoConstantZeroElementF32s)431 XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantZeroElementF32s) {
432   XlaBuilder builder(TestName());
433   auto a = ConstantR1<float>(&builder, {});
434   auto b = ConstantR1<float>(&builder, {});
435   Sub(a, b);
436 
437   ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
438 }
439 
XLA_TEST_F(ArrayElementwiseOpTest,SubTwoConstantS32s)440 XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantS32s) {
441   XlaBuilder builder(TestName());
442   auto a = ConstantR1<int32>(&builder, {-1, 0, 2, 1000000000});
443   auto b = ConstantR1<int32>(&builder, {-1, 2, 1, -1});
444   Sub(a, b);
445 
446   ComputeAndCompareR1<int32>(&builder, {0, -2, 1, 1000000001}, {});
447 }
448 
XLA_TEST_F(ArrayElementwiseOpTest,SubTwoConstantZeroElementS32s)449 XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantZeroElementS32s) {
450   XlaBuilder builder(TestName());
451   auto a = ConstantR1<int32>(&builder, {});
452   auto b = ConstantR1<int32>(&builder, {});
453   Sub(a, b);
454 
455   ComputeAndCompareR1<int32>(&builder, {}, {});
456 }
457 
XLA_TEST_F(ArrayElementwiseOpTest,SubTwoConstantC64s)458 XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantC64s) {
459   XlaBuilder builder(TestName());
460   auto a = ConstantR1<complex64>(&builder,
461                                  {{-2.5f, 0.0f}, {0.0f, 3.14f}, {3.0f, 2.25f}});
462   auto b = ConstantR1<complex64>(
463       &builder, {{0.0f, 10.0f}, {3.13f, 0.0f}, {2.75f, -0.25f}});
464   Sub(a, b);
465 
466   ComputeAndCompareR1<complex64>(
467       &builder, {{-2.5f, -10.0f}, {-3.13f, 3.14f}, {0.25f, 2.5f}}, {},
468       error_spec_);
469 }
470 
XLA_TEST_F(ArrayElementwiseOpTest,SubTwoConstantZeroElementC64s)471 XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantZeroElementC64s) {
472   XlaBuilder builder(TestName());
473   auto a = ConstantR1<complex64>(&builder, {});
474   auto b = ConstantR1<complex64>(&builder, {});
475   Sub(a, b);
476 
477   ComputeAndCompareR1<complex64>(&builder, {}, {}, error_spec_);
478 }
479 
XLA_TEST_F(ArrayElementwiseOpTest,SubTwoConstantF64s)480 XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantF64s) {
481   XlaBuilder builder(TestName());
482   auto a = ConstantR1<double>(&builder, {-2.5, 3.14, 2.25, -10.0, 6.0});
483   auto b = ConstantR1<double>(&builder, {100.0, 3.13, 2.75, 10.5, -999.0});
484   Sub(a, b);
485 
486   ComputeAndCompare(&builder, {}, strict_error_spec_);
487 }
488 
XLA_TEST_F(ArrayElementwiseOpTest,DivTwoConstantF32s)489 XLA_TEST_F(ArrayElementwiseOpTest, DivTwoConstantF32s) {
490   XlaBuilder builder(TestName());
491   auto a = ConstantR1<float>(&builder, {-2.5f, 25.5f, 2.25f, -10.0f, 6.0f});
492   auto b = ConstantR1<float>(&builder, {10.0f, 5.1f, 1.0f, 10.0f, -6.0f});
493   Div(a, b);
494 
495   ComputeAndCompareR1<float>(&builder, {-0.25f, 5.0f, 2.25f, -1.0f, -1.0f}, {},
496                              error_spec_);
497 }
498 
XLA_TEST_F(ArrayElementwiseOpTest,DivTwoConstantZeroElementF32s)499 XLA_TEST_F(ArrayElementwiseOpTest, DivTwoConstantZeroElementF32s) {
500   XlaBuilder builder(TestName());
501   auto a = ConstantR1<float>(&builder, {});
502   auto b = ConstantR1<float>(&builder, {});
503   Div(a, b);
504 
505   ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
506 }
507 
XLA_TEST_F(ArrayElementwiseOpTest,DivTwoConstantF64s)508 XLA_TEST_F(ArrayElementwiseOpTest, DivTwoConstantF64s) {
509   XlaBuilder builder(TestName());
510   auto a = ConstantR1<double>(
511       &builder, {-2.5, 25.5, 2.25, -10.0, 6.0, 1.0, 2.0, 3.2, -4.0, 0.45, 5.7,
512                  0.1, 1.0, 2.0, 0.5, -1.0, -0.5, 1.0});
513   auto b = ConstantR1<double>(
514       &builder, {10.0, 5.1, 1.0, 10.0, -6.0, 0.1, 1.0, 2.0, 0.5, -1.0, -0.5,
515                  2.1, 3.1, 9.9, -4.5, -11.0, -21.5, M_PI});
516   Div(a, b);
517 
518   ComputeAndCompare(&builder, {}, strict_error_spec_);
519 }
520 
521 class IntegerDivideOpTest : public ArrayElementwiseOpTest {
522  protected:
523   template <typename T>
TestDivRem(absl::Span<const T> dividends,absl::Span<const T> divisors,absl::Span<const T> quotients,absl::Span<const T> remainders)524   void TestDivRem(absl::Span<const T> dividends, absl::Span<const T> divisors,
525                   absl::Span<const T> quotients,
526                   absl::Span<const T> remainders) {
527     {
528       XlaBuilder builder(TestName());
529       XlaOp dividend;
530       XlaOp divisor;
531       auto dividend_data =
532           CreateR1Parameter<T>(dividends, 0, "dividend", &builder, &dividend);
533       auto divisor_data =
534           CreateR1Parameter<T>(divisors, 1, "divisor", &builder, &divisor);
535       Div(dividend, divisor);
536 
537       ComputeAndCompareR1<T>(&builder, quotients,
538                              {dividend_data.get(), divisor_data.get()});
539     }
540 
541     // Test with a compile-time constant divisor.
542     {
543       XlaBuilder builder(TestName());
544       XlaOp dividend;
545       auto dividend_data =
546           CreateR1Parameter<T>(dividends, 0, "dividend", &builder, &dividend);
547       Div(dividend, ConstantR1<T>(&builder, divisors));
548 
549       ComputeAndCompareR1<T>(&builder, quotients, {dividend_data.get()});
550     }
551 
552     {
553       XlaBuilder builder(TestName());
554       XlaOp dividend;
555       XlaOp divisor;
556       auto dividend_data =
557           CreateR1Parameter<T>(dividends, 0, "dividend", &builder, &dividend);
558       auto divisor_data =
559           CreateR1Parameter<T>(divisors, 1, "divisor", &builder, &divisor);
560       Rem(dividend, divisor);
561 
562       ComputeAndCompareR1<T>(&builder, remainders,
563                              {dividend_data.get(), divisor_data.get()});
564     }
565 
566     // Test with a compile-time constant divisor.
567     {
568       XlaBuilder builder(TestName());
569       XlaOp dividend;
570       auto dividend_data =
571           CreateR1Parameter<T>(dividends, 0, "dividend", &builder, &dividend);
572       Rem(dividend, ConstantR1<T>(&builder, divisors));
573 
574       ComputeAndCompareR1<T>(&builder, remainders, {dividend_data.get()});
575     }
576   }
577 };
578 
XLA_TEST_F(IntegerDivideOpTest,DivS32s)579 XLA_TEST_F(IntegerDivideOpTest, DivS32s) {
580   // clang-format off
581   // Some interesting values to test.
582   std::vector<int32> vals = {
583     INT32_MIN, INT32_MIN + 1, INT32_MIN + 2, -0x40000000, -0x3fffffff,
584     -271181, -1309, -17, -10, -5, -3, -2, -1, 0, 1, 2, 3, 5, 10, 17, 26, 101,
585     7919, 0x40000000, INT32_MAX - 2, INT32_MAX - 1, INT32_MAX};
586   // clang-format on
587 
588   std::vector<int32> dividends, divisors, quotients, remainders;
589   for (int32_t divisor : vals) {
590     if (divisor != 0) {
591       for (int32_t dividend : vals) {
592         // Avoid integer overflow.
593         if (dividend != INT32_MIN || divisor != -1) {
594           dividends.push_back(dividend);
595           divisors.push_back(divisor);
596           quotients.push_back(dividend / divisor);
597           remainders.push_back(dividend % divisor);
598         }
599       }
600     }
601   }
602 
603   TestDivRem<int32>(dividends, divisors, quotients, remainders);
604 }
605 
XLA_TEST_F(IntegerDivideOpTest,SignedOverflow)606 XLA_TEST_F(IntegerDivideOpTest, SignedOverflow) {
607   std::vector<int32> dividends = {5, INT32_MIN}, divisors = {0, -1},
608                      quotients = {-1, INT32_MIN}, remainders = {5, 0};
609 
610   TestDivRem<int32>(dividends, divisors, quotients, remainders);
611 }
612 
XLA_TEST_F(IntegerDivideOpTest,DivU32s)613 XLA_TEST_F(IntegerDivideOpTest, DivU32s) {
614   // clang-format off
615   // Some interesting values to test.
616   std::vector<uint32> vals = {
617     0, 1, 2, 17, 101, 3333, 0x7FFFFFFF, 0xABCDEF12, 0xCAFEBEEF, 0x80000000,
618     0x80000001, UINT32_MAX - 2, UINT32_MAX - 1, UINT32_MAX};
619   // clang-format on
620 
621   std::vector<uint32> dividends, divisors, quotients, remainders;
622   for (uint32 divisor : vals) {
623     if (divisor != 0) {
624       for (uint32 dividend : vals) {
625         dividends.push_back(dividend);
626         divisors.push_back(divisor);
627         quotients.push_back(dividend / divisor);
628         remainders.push_back(dividend % divisor);
629       }
630     }
631   }
632 
633   TestDivRem<uint32>(dividends, divisors, quotients, remainders);
634 }
635 
XLA_TEST_F(IntegerDivideOpTest,UnsignedOverflow)636 XLA_TEST_F(IntegerDivideOpTest, UnsignedOverflow) {
637   std::vector<int32> dividends = {5}, divisors = {0}, quotients = {-1},
638                      remainders = {5};
639 
640   TestDivRem<int32>(dividends, divisors, quotients, remainders);
641 }
642 
XLA_TEST_F(ArrayElementwiseOpTest,DivTwoConstantC64s)643 XLA_TEST_F(ArrayElementwiseOpTest, DivTwoConstantC64s) {
644   XlaBuilder builder(TestName());
645   auto a = ConstantR1<complex64>(
646       &builder, {{-2.5f, 1.0f}, {-25.5f, 0.0f}, {2.0f, -1.0f}});
647   auto b = ConstantR1<complex64>(&builder,
648                                  {{10.0f, 0.0f}, {0.0f, 1.0f}, {2.0f, -1.0f}});
649   Div(a, b);
650 
651   ComputeAndCompareR1<complex64>(
652       &builder, {{-0.25f, 0.1f}, {0.0f, 25.5f}, {1.0f, 0.0f}}, {}, error_spec_);
653 }
654 
XLA_TEST_F(ArrayElementwiseOpTest,DivTwoConstantZeroElementC64s)655 XLA_TEST_F(ArrayElementwiseOpTest, DivTwoConstantZeroElementC64s) {
656   XlaBuilder builder(TestName());
657   auto a = ConstantR1<complex64>(&builder, {});
658   auto b = ConstantR1<complex64>(&builder, {});
659   Div(a, b);
660 
661   ComputeAndCompareR1<complex64>(&builder, {}, {}, error_spec_);
662 }
663 
XLA_TEST_F(ArrayElementwiseOpTest,RemF32s)664 XLA_TEST_F(ArrayElementwiseOpTest, RemF32s) {
665   XlaBuilder builder(TestName());
666   auto a = ConstantR1<float>(
667       &builder, {-2.5f, 25.5f, 2.25f, -10.0f, 6.0f, 3.0f, 3.0f, -1.0f, -8.0f});
668   auto b = ConstantR1<float>(
669       &builder, {10.0f, 5.1f, 1.0f, 10.0f, -6.0f, 2.0f, -2.0f, 7.0f, -4.0f});
670   Rem(a, b);
671 
672   ComputeAndCompareR1<float>(
673       &builder, {-2.5f, 0.0f, 0.25f, 0.0f, -0.0f, 1.0f, 1.0f, -1.0f, -0.0f}, {},
674       error_spec_);
675 }
676 
XLA_TEST_F(ArrayElementwiseOpTest,RemZeroElementF32s)677 XLA_TEST_F(ArrayElementwiseOpTest, RemZeroElementF32s) {
678   XlaBuilder builder(TestName());
679   auto a = ConstantR1<float>(&builder, {});
680   auto b = ConstantR1<float>(&builder, {});
681   Rem(a, b);
682 
683   ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
684 }
685 
XLA_TEST_F(ArrayElementwiseOpTest,RemF64s)686 XLA_TEST_F(ArrayElementwiseOpTest, RemF64s) {
687   XlaBuilder builder(TestName());
688   auto a = ConstantR1<double>(
689       &builder, {-2.5, 25.5, 2.25, -10.0, 6.0, 3.0, 3.0, -1.0, -8.0});
690   auto b = ConstantR1<double>(
691       &builder, {10.0, 5.1, 1.0, 10.0, -6.0, 2.0, -2.0, 7.0, -4.0});
692   Rem(a, b);
693 
694   ComputeAndCompareR1<double>(
695       &builder, {-2.5, 0.0, 0.25, 0.0, -0.0, 1.0, 1.0, -1.0, -0.0}, {},
696       strict_error_spec_);
697 }
698 
XLA_TEST_F(ArrayElementwiseOpTest,MulTwoConstantF32s)699 XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantF32s) {
700   XlaBuilder builder(TestName());
701   auto a = ConstantR1<float>(&builder, {-2.5f, 25.5f, 2.25f, -10.0f, 6.0f});
702   auto b = ConstantR1<float>(&builder, {10.0f, 5.0f, 1.0f, 10.0f, -6.0f});
703   Mul(a, b);
704 
705   ComputeAndCompareR1<float>(&builder, {-25.0f, 127.5f, 2.25f, -100.0f, -36.0f},
706                              {}, error_spec_);
707 }
708 
XLA_TEST_F(ArrayElementwiseOpTest,MulTwoConstantZeroElementF32s)709 XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantZeroElementF32s) {
710   XlaBuilder builder(TestName());
711   auto a = ConstantR1<float>(&builder, {});
712   auto b = ConstantR1<float>(&builder, {});
713   Mul(a, b);
714 
715   ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
716 }
717 
XLA_TEST_F(ArrayElementwiseOpTest,MulTwoConstantS32s)718 XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantS32s) {
719   std::vector<int32> data = {0,
720                              1,
721                              -1,
722                              1234,
723                              0x1a243514,
724                              std::numeric_limits<int32>::max(),
725                              std::numeric_limits<int32>::min()};
726   // Form the test data set using all products of 'data' with itself.
727   std::vector<int32> a_data, b_data, expected;
728   for (int32_t a : data) {
729     for (int32_t b : data) {
730       a_data.push_back(a);
731       b_data.push_back(b);
732       expected.push_back(static_cast<uint32>(a) * static_cast<uint32>(b));
733     }
734   }
735 
736   XlaBuilder builder(TestName());
737   auto a = ConstantR1<int32>(&builder, a_data);
738   auto b = ConstantR1<int32>(&builder, b_data);
739   Mul(a, b);
740 
741   ComputeAndCompareR1<int32>(&builder, expected, {});
742 }
743 
XLA_TEST_F(ArrayElementwiseOpTest,MulTwoConstantZeroElementS32s)744 XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantZeroElementS32s) {
745   XlaBuilder builder(TestName());
746   auto a = ConstantR1<int32>(&builder, {});
747   auto b = ConstantR1<int32>(&builder, {});
748   Mul(a, b);
749 
750   ComputeAndCompareR1<int32>(&builder, {}, {});
751 }
752 
XLA_TEST_F(ArrayElementwiseOpTest,MulTwoConstantU32s)753 XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantU32s) {
754   std::vector<uint32> data = {0,          1,          0xDEADBEEF, 1234,
755                               0x1a243514, 0xFFFFFFFF, 0x80808080};
756 
757   // Form the test data set using all products of 'data' with itself.
758   std::vector<uint32> a_data, b_data, expected;
759   for (uint32 a : data) {
760     for (uint32 b : data) {
761       a_data.push_back(a);
762       b_data.push_back(b);
763       expected.push_back(a * b);
764     }
765   }
766 
767   XlaBuilder builder(TestName());
768   auto a = ConstantR1<uint32>(&builder, a_data);
769   auto b = ConstantR1<uint32>(&builder, b_data);
770   Mul(a, b);
771 
772   ComputeAndCompareR1<uint32>(&builder, expected, {});
773 }
774 
XLA_TEST_F(ArrayElementwiseOpTest,MulTwoConstantC64s)775 XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantC64s) {
776   XlaBuilder builder(TestName());
777   auto a = ConstantR1<complex64>(
778       &builder, {{-2.5f, 0.0f}, {0.0f, 25.5f}, {2.0f, -10.0f}});
779   auto b = ConstantR1<complex64>(&builder,
780                                  {{0.0f, 10.0f}, {5.0f, 1.0f}, {10.0f, -6.0f}});
781   Mul(a, b);
782 
783   ComputeAndCompareR1<complex64>(
784       &builder, {{0.0f, -25.0f}, {-25.5f, 127.5f}, {-40.0f, -112.0}}, {},
785       error_spec_);
786 }
787 
XLA_TEST_F(ArrayElementwiseOpTest,MulTwoConstantZeroElementC64s)788 XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantZeroElementC64s) {
789   XlaBuilder builder(TestName());
790   auto a = ConstantR1<complex64>(&builder, {});
791   auto b = ConstantR1<complex64>(&builder, {});
792   Mul(a, b);
793 
794   ComputeAndCompareR1<complex64>(&builder, {}, {}, error_spec_);
795 }
796 
XLA_TEST_F(ArrayElementwiseOpTest,AndPredR1)797 XLA_TEST_F(ArrayElementwiseOpTest, AndPredR1) {
798   XlaBuilder builder(TestName());
799   auto a = ConstantR1<bool>(&builder, {false, false, true, true});
800   auto b = ConstantR1<bool>(&builder, {false, true, false, true});
801   And(a, b);
802 
803   ComputeAndCompareR1<bool>(&builder, {false, false, false, true}, {});
804 }
805 
XLA_TEST_F(ArrayElementwiseOpTest,AndPredR2)806 XLA_TEST_F(ArrayElementwiseOpTest, AndPredR2) {
807   XlaBuilder builder(TestName());
808   auto a = ConstantR2<bool>(&builder, {{false, false}, {true, true}});
809   auto b = ConstantR2<bool>(&builder, {{false, true}, {false, true}});
810   And(a, b);
811 
812   Array2D<bool> expected_array({{false, false}, {false, true}});
813   ComputeAndCompareR2<bool>(&builder, expected_array, {});
814 }
815 
XLA_TEST_F(ArrayElementwiseOpTest,AndZeroElementPredR1)816 XLA_TEST_F(ArrayElementwiseOpTest, AndZeroElementPredR1) {
817   XlaBuilder builder(TestName());
818   auto a = ConstantR1<bool>(&builder, {});
819   auto b = ConstantR1<bool>(&builder, {});
820   And(a, b);
821 
822   ComputeAndCompareR1<bool>(&builder, {}, {});
823 }
824 
XLA_TEST_F(ArrayElementwiseOpTest,AndS32R1)825 XLA_TEST_F(ArrayElementwiseOpTest, AndS32R1) {
826   XlaBuilder builder(TestName());
827   auto a = ConstantR1<int32>(&builder, {0, -1, -8});
828   auto b = ConstantR1<int32>(&builder, {5, -7, 12});
829   And(a, b);
830 
831   ComputeAndCompareR1<int32>(&builder, {0, -7, 8}, {});
832 }
833 
XLA_TEST_F(ArrayElementwiseOpTest,AndS32R2)834 XLA_TEST_F(ArrayElementwiseOpTest, AndS32R2) {
835   XlaBuilder builder(TestName());
836   auto a = ConstantR2<int32>(&builder, {{0, -5}, {-1, 5}});
837   auto b = ConstantR2<int32>(&builder, {{1, -6}, {4, 5}});
838   And(a, b);
839 
840   Array2D<int32> expected_array({{0, -6}, {4, 5}});
841   ComputeAndCompareR2<int32>(&builder, expected_array, {});
842 }
843 
XLA_TEST_F(ArrayElementwiseOpTest,AndZeroElementS32R1)844 XLA_TEST_F(ArrayElementwiseOpTest, AndZeroElementS32R1) {
845   XlaBuilder builder(TestName());
846   auto a = ConstantR1<int32>(&builder, {});
847   auto b = ConstantR1<int32>(&builder, {});
848   And(a, b);
849 
850   ComputeAndCompareR1<int32>(&builder, {}, {});
851 }
852 
XLA_TEST_F(ArrayElementwiseOpTest,AndU32R1)853 XLA_TEST_F(ArrayElementwiseOpTest, AndU32R1) {
854   XlaBuilder builder(TestName());
855   auto a = ConstantR1<int32>(&builder, {0, 1, 8});
856   auto b = ConstantR1<int32>(&builder, {5, 7, 12});
857   And(a, b);
858 
859   ComputeAndCompareR1<int32>(&builder, {0, 1, 8}, {});
860 }
861 
XLA_TEST_F(ArrayElementwiseOpTest,AndU32R2)862 XLA_TEST_F(ArrayElementwiseOpTest, AndU32R2) {
863   XlaBuilder builder(TestName());
864   auto a = ConstantR2<uint32>(&builder, {{0, 1}, {3, 8}});
865   auto b = ConstantR2<uint32>(&builder, {{1, 0}, {7, 6}});
866   And(a, b);
867 
868   Array2D<uint32> expected_array({{0, 0}, {3, 0}});
869   ComputeAndCompareR2<uint32>(&builder, expected_array, {});
870 }
871 
XLA_TEST_F(ArrayElementwiseOpTest,AndZeroElementU32R1)872 XLA_TEST_F(ArrayElementwiseOpTest, AndZeroElementU32R1) {
873   XlaBuilder builder(TestName());
874   auto a = ConstantR1<uint32>(&builder, {});
875   auto b = ConstantR1<uint32>(&builder, {});
876   And(a, b);
877 
878   ComputeAndCompareR1<uint32>(&builder, {}, {});
879 }
880 
XLA_TEST_F(ArrayElementwiseOpTest,OrPredR1)881 XLA_TEST_F(ArrayElementwiseOpTest, OrPredR1) {
882   XlaBuilder builder(TestName());
883   auto a = ConstantR1<bool>(&builder, {false, false, true, true});
884   auto b = ConstantR1<bool>(&builder, {false, true, false, true});
885   Or(a, b);
886 
887   ComputeAndCompareR1<bool>(&builder, {false, true, true, true}, {});
888 }
889 
XLA_TEST_F(ArrayElementwiseOpTest,OrPredR2)890 XLA_TEST_F(ArrayElementwiseOpTest, OrPredR2) {
891   XlaBuilder builder(TestName());
892   auto a = ConstantR2<bool>(&builder, {{false, false}, {true, true}});
893   auto b = ConstantR2<bool>(&builder, {{false, true}, {false, true}});
894   Or(a, b);
895 
896   Array2D<bool> expected_array({{false, true}, {true, true}});
897   ComputeAndCompareR2<bool>(&builder, expected_array, {});
898 }
899 
XLA_TEST_F(ArrayElementwiseOpTest,OrZeroElementPredR1)900 XLA_TEST_F(ArrayElementwiseOpTest, OrZeroElementPredR1) {
901   XlaBuilder builder(TestName());
902   auto a = ConstantR1<bool>(&builder, {});
903   auto b = ConstantR1<bool>(&builder, {});
904   Or(a, b);
905 
906   ComputeAndCompareR1<bool>(&builder, {}, {});
907 }
908 
XLA_TEST_F(ArrayElementwiseOpTest,OrS32R1)909 XLA_TEST_F(ArrayElementwiseOpTest, OrS32R1) {
910   XlaBuilder builder(TestName());
911   auto a = ConstantR1<int32>(&builder, {0, -1, 8});
912   auto b = ConstantR1<int32>(&builder, {5, -7, 4});
913   Or(a, b);
914 
915   ComputeAndCompareR1<int32>(&builder, {5, -1, 12}, {});
916 }
917 
XLA_TEST_F(ArrayElementwiseOpTest,OrS32R2)918 XLA_TEST_F(ArrayElementwiseOpTest, OrS32R2) {
919   XlaBuilder builder(TestName());
920   auto a = ConstantR2<int32>(&builder, {{0, -1}, {8, 8}});
921   auto b = ConstantR2<int32>(&builder, {{5, -7}, {4, 1}});
922   Or(a, b);
923 
924   Array2D<int32> expected_array({{5, -1}, {12, 9}});
925   ComputeAndCompareR2<int32>(&builder, expected_array, {});
926 }
927 
XLA_TEST_F(ArrayElementwiseOpTest,OrZeroElementS32R1)928 XLA_TEST_F(ArrayElementwiseOpTest, OrZeroElementS32R1) {
929   XlaBuilder builder(TestName());
930   auto a = ConstantR1<int32>(&builder, {});
931   auto b = ConstantR1<int32>(&builder, {});
932   Or(a, b);
933 
934   ComputeAndCompareR1<int32>(&builder, {}, {});
935 }
936 
XLA_TEST_F(ArrayElementwiseOpTest,OrU32R1)937 XLA_TEST_F(ArrayElementwiseOpTest, OrU32R1) {
938   XlaBuilder builder(TestName());
939   auto a = ConstantR1<uint32>(&builder, {0, 1, 8});
940   auto b = ConstantR1<uint32>(&builder, {5, 7, 4});
941   Or(a, b);
942 
943   ComputeAndCompareR1<uint32>(&builder, {5, 7, 12}, {});
944 }
945 
XLA_TEST_F(ArrayElementwiseOpTest,OrU32R2)946 XLA_TEST_F(ArrayElementwiseOpTest, OrU32R2) {
947   XlaBuilder builder(TestName());
948   auto a = ConstantR2<uint32>(&builder, {{0, 1}, {8, 8}});
949   auto b = ConstantR2<uint32>(&builder, {{5, 7}, {4, 1}});
950   Or(a, b);
951 
952   Array2D<uint32> expected_array({{5, 7}, {12, 9}});
953   ComputeAndCompareR2<uint32>(&builder, expected_array, {});
954 }
955 
XLA_TEST_F(ArrayElementwiseOpTest,OrZeroElementU32R1)956 XLA_TEST_F(ArrayElementwiseOpTest, OrZeroElementU32R1) {
957   XlaBuilder builder(TestName());
958   auto a = ConstantR1<uint32>(&builder, {});
959   auto b = ConstantR1<uint32>(&builder, {});
960   Or(a, b);
961 
962   ComputeAndCompareR1<uint32>(&builder, {}, {});
963 }
964 
XLA_TEST_F(ArrayElementwiseOpTest,XorPredR1)965 XLA_TEST_F(ArrayElementwiseOpTest, XorPredR1) {
966   XlaBuilder builder(TestName());
967   auto a = ConstantR1<bool>(&builder, {false, false, true, true});
968   auto b = ConstantR1<bool>(&builder, {false, true, false, true});
969   Xor(a, b);
970 
971   ComputeAndCompareR1<bool>(&builder, {false, true, true, false}, {});
972 }
973 
XLA_TEST_F(ArrayElementwiseOpTest,XorPredR2)974 XLA_TEST_F(ArrayElementwiseOpTest, XorPredR2) {
975   XlaBuilder builder(TestName());
976   auto a = ConstantR2<bool>(&builder, {{false, false}, {true, true}});
977   auto b = ConstantR2<bool>(&builder, {{false, true}, {false, true}});
978   Xor(a, b);
979 
980   Array2D<bool> expected_array({{false, true}, {true, false}});
981   ComputeAndCompareR2<bool>(&builder, expected_array, {});
982 }
983 
XLA_TEST_F(ArrayElementwiseOpTest,XorZeroElementPredR1)984 XLA_TEST_F(ArrayElementwiseOpTest, XorZeroElementPredR1) {
985   XlaBuilder builder(TestName());
986   auto a = ConstantR1<bool>(&builder, {});
987   auto b = ConstantR1<bool>(&builder, {});
988   Xor(a, b);
989 
990   ComputeAndCompareR1<bool>(&builder, {}, {});
991 }
992 
XLA_TEST_F(ArrayElementwiseOpTest,XorS32R1)993 XLA_TEST_F(ArrayElementwiseOpTest, XorS32R1) {
994   XlaBuilder builder(TestName());
995   auto a = ConstantR1<int32>(&builder, {0, -1, 8});
996   auto b = ConstantR1<int32>(&builder, {5, -7, 4});
997   Xor(a, b);
998 
999   ComputeAndCompareR1<int32>(&builder, {5, 6, 12}, {});
1000 }
1001 
XLA_TEST_F(ArrayElementwiseOpTest,XorS32R2)1002 XLA_TEST_F(ArrayElementwiseOpTest, XorS32R2) {
1003   XlaBuilder builder(TestName());
1004   auto a = ConstantR2<int32>(&builder, {{0, -1}, {8, 8}});
1005   auto b = ConstantR2<int32>(&builder, {{5, -7}, {4, 1}});
1006   Xor(a, b);
1007 
1008   Array2D<int32> expected_array({{5, 6}, {12, 9}});
1009   ComputeAndCompareR2<int32>(&builder, expected_array, {});
1010 }
1011 
XLA_TEST_F(ArrayElementwiseOpTest,XorZeroElementS32R1)1012 XLA_TEST_F(ArrayElementwiseOpTest, XorZeroElementS32R1) {
1013   XlaBuilder builder(TestName());
1014   auto a = ConstantR1<int32>(&builder, {});
1015   auto b = ConstantR1<int32>(&builder, {});
1016   Xor(a, b);
1017 
1018   ComputeAndCompareR1<int32>(&builder, {}, {});
1019 }
1020 
XLA_TEST_F(ArrayElementwiseOpTest,XorU32R1)1021 XLA_TEST_F(ArrayElementwiseOpTest, XorU32R1) {
1022   XlaBuilder builder(TestName());
1023   auto a = ConstantR1<uint32>(&builder, {0, 1, 8});
1024   auto b = ConstantR1<uint32>(&builder, {5, 7, 4});
1025   Xor(a, b);
1026 
1027   ComputeAndCompareR1<uint32>(&builder, {5, 6, 12}, {});
1028 }
1029 
XLA_TEST_F(ArrayElementwiseOpTest,XorU32R2)1030 XLA_TEST_F(ArrayElementwiseOpTest, XorU32R2) {
1031   XlaBuilder builder(TestName());
1032   auto a = ConstantR2<uint32>(&builder, {{0, 1}, {8, 8}});
1033   auto b = ConstantR2<uint32>(&builder, {{5, 7}, {4, 1}});
1034   Xor(a, b);
1035 
1036   Array2D<uint32> expected_array({{5, 6}, {12, 9}});
1037   ComputeAndCompareR2<uint32>(&builder, expected_array, {});
1038 }
1039 
XLA_TEST_F(ArrayElementwiseOpTest,XorZeroElementU32R1)1040 XLA_TEST_F(ArrayElementwiseOpTest, XorZeroElementU32R1) {
1041   XlaBuilder builder(TestName());
1042   auto a = ConstantR1<uint32>(&builder, {});
1043   auto b = ConstantR1<uint32>(&builder, {});
1044   Xor(a, b);
1045 
1046   ComputeAndCompareR1<uint32>(&builder, {}, {});
1047 }
XLA_TEST_F(ArrayElementwiseOpTest,NotPredR1)1048 XLA_TEST_F(ArrayElementwiseOpTest, NotPredR1) {
1049   XlaBuilder builder(TestName());
1050   auto a = ConstantR1<bool>(&builder, {false, true, true, false});
1051   Not(a);
1052 
1053   ComputeAndCompareR1<bool>(&builder, {true, false, false, true}, {});
1054 }
1055 
XLA_TEST_F(ArrayElementwiseOpTest,NotPredR2)1056 XLA_TEST_F(ArrayElementwiseOpTest, NotPredR2) {
1057   XlaBuilder builder(TestName());
1058   auto a = ConstantR2<bool>(&builder, {{false, true}, {true, false}});
1059   Not(a);
1060 
1061   Array2D<bool> expected_array({{true, false}, {false, true}});
1062   ComputeAndCompareR2<bool>(&builder, expected_array, {});
1063 }
1064 
XLA_TEST_F(ArrayElementwiseOpTest,NotZeroElementPredR1)1065 XLA_TEST_F(ArrayElementwiseOpTest, NotZeroElementPredR1) {
1066   XlaBuilder builder(TestName());
1067   auto a = ConstantR1<bool>(&builder, {});
1068   Not(a);
1069 
1070   ComputeAndCompareR1<bool>(&builder, {}, {});
1071 }
1072 
XLA_TEST_F(ArrayElementwiseOpTest,NotS32R1)1073 XLA_TEST_F(ArrayElementwiseOpTest, NotS32R1) {
1074   XlaBuilder builder(TestName());
1075   auto a = ConstantR1<int32>(&builder, {-1, 0, 1});
1076   Not(a);
1077 
1078   ComputeAndCompareR1<int32>(&builder, {0, -1, -2}, {});
1079 }
1080 
XLA_TEST_F(ArrayElementwiseOpTest,NotS32R2)1081 XLA_TEST_F(ArrayElementwiseOpTest, NotS32R2) {
1082   XlaBuilder builder(TestName());
1083   auto a = ConstantR2<int32>(&builder, {{-1, 0}, {1, 8}});
1084   Not(a);
1085 
1086   Array2D<int32> expected_array({{0, -1}, {-2, -9}});
1087   ComputeAndCompareR2<int32>(&builder, expected_array, {});
1088 }
1089 
XLA_TEST_F(ArrayElementwiseOpTest,NotZeroElementS32R1)1090 XLA_TEST_F(ArrayElementwiseOpTest, NotZeroElementS32R1) {
1091   XlaBuilder builder(TestName());
1092   auto a = ConstantR1<int32>(&builder, {});
1093   Not(a);
1094 
1095   ComputeAndCompareR1<int32>(&builder, {}, {});
1096 }
1097 
XLA_TEST_F(ArrayElementwiseOpTest,NotU32R1)1098 XLA_TEST_F(ArrayElementwiseOpTest, NotU32R1) {
1099   XlaBuilder builder(TestName());
1100   auto a = ConstantR1<uint32>(&builder, {0, 4294967295});
1101   Not(a);
1102 
1103   ComputeAndCompareR1<uint32>(&builder, {4294967295, 0}, {});
1104 }
1105 
XLA_TEST_F(ArrayElementwiseOpTest,NotU32R2)1106 XLA_TEST_F(ArrayElementwiseOpTest, NotU32R2) {
1107   XlaBuilder builder(TestName());
1108   auto a = ConstantR2<uint32>(&builder, {{0, 4294967295}, {1, 4294967294}});
1109   Not(a);
1110 
1111   Array2D<uint32> expected_array({{4294967295, 0}, {4294967294, 1}});
1112   ComputeAndCompareR2<uint32>(&builder, expected_array, {});
1113 }
1114 
XLA_TEST_F(ArrayElementwiseOpTest,NotZeroElementU32R1)1115 XLA_TEST_F(ArrayElementwiseOpTest, NotZeroElementU32R1) {
1116   XlaBuilder builder(TestName());
1117   auto a = ConstantR1<uint32>(&builder, {});
1118   Not(a);
1119 
1120   ComputeAndCompareR1<uint32>(&builder, {}, {});
1121 }
1122 
XLA_TEST_F(ArrayElementwiseOpTest,PopcntR1)1123 XLA_TEST_F(ArrayElementwiseOpTest, PopcntR1) {
1124   XlaBuilder builder(TestName());
1125   auto a = ConstantR1<int32>(&builder, {0, 1, -15, 341});
1126   PopulationCount(a);
1127   ComputeAndCompareR1<int32>(&builder, {0, 1, 29, 5}, {});
1128 }
1129 
XLA_TEST_F(ArrayElementwiseOpTest,PopcntR2)1130 XLA_TEST_F(ArrayElementwiseOpTest, PopcntR2) {
1131   XlaBuilder builder(TestName());
1132   auto a = ConstantR2<int32>(&builder, {{0, 1}, {-15, 341}});
1133   PopulationCount(a);
1134   Array2D<int32> expected_array({{0, 1}, {29, 5}});
1135   ComputeAndCompareR2<int32>(&builder, expected_array, {});
1136 }
1137 
XLA_TEST_F(ArrayElementwiseOpTest,PopcntS64)1138 XLA_TEST_F(ArrayElementwiseOpTest, PopcntS64) {
1139   XlaBuilder builder(TestName());
1140   auto a = ConstantR2<int64>(&builder, {{0, -1}, {INT64_MAX, INT64_MAX - 1}});
1141   PopulationCount(a);
1142   Array2D<int64> expected_array({{0, 64}, {63, 62}});
1143   ComputeAndCompareR2<int64>(&builder, expected_array, {});
1144 }
1145 
XLA_TEST_F(ArrayElementwiseOpTest,ShiftLeftS32)1146 XLA_TEST_F(ArrayElementwiseOpTest, ShiftLeftS32) {
1147   XlaBuilder builder(TestName());
1148   auto a = ConstantR1<int32>(
1149       &builder, {static_cast<int32>(0x12345678), static_cast<int32>(0xF0001000),
1150                  1, 3, 77, 1, -3, 77});
1151   auto b = ConstantR1<int32>(&builder, {4, 8, 2, 7, 15, 32, 100, -1});
1152   ShiftLeft(a, b);
1153 
1154   ComputeAndCompareR1<int32>(&builder,
1155                              {static_cast<int32>(0x23456780), 0x00100000, 0x4,
1156                               0x180, 2523136, 0, 0, 0},
1157                              {});
1158 }
1159 
XLA_TEST_F(ArrayElementwiseOpTest,ShiftRightArithmeticS32)1160 XLA_TEST_F(ArrayElementwiseOpTest, ShiftRightArithmeticS32) {
1161   XlaBuilder builder(TestName());
1162   auto a = ConstantR1<int32>(
1163       &builder, {static_cast<int32>(0x92345678), static_cast<int32>(0x10001000),
1164                  1, 3, 77, 1, -3, 77});
1165   auto b = ConstantR1<int32>(&builder, {4, 8, 2, 7, 2, 32, 100, -1});
1166   ShiftRightArithmetic(a, b);
1167 
1168   ComputeAndCompareR1<int32>(
1169       &builder,
1170       {static_cast<int32>(0xF9234567), static_cast<int32>(0x00100010), 0, 0, 19,
1171        0, -1, 0},
1172       {});
1173 }
1174 
XLA_TEST_F(ArrayElementwiseOpTest,ShiftRightLogicalS32)1175 XLA_TEST_F(ArrayElementwiseOpTest, ShiftRightLogicalS32) {
1176   XlaBuilder builder(TestName());
1177   auto a = ConstantR1<int32>(
1178       &builder, {static_cast<int32>(0x92345678), static_cast<int32>(0x10001000),
1179                  1, 3, 77, 1, -3, 77});
1180   auto b = ConstantR1<int32>(&builder, {4, 8, 2, 7, 5, 32, 100, -1});
1181   ShiftRightLogical(a, b);
1182 
1183   ComputeAndCompareR1<int32>(&builder,
1184                              {0x09234567, 0x00100010, 0, 0, 2, 0, 0, 0}, {});
1185 }
1186 
XLA_TEST_F(ArrayElementwiseOpTest,ShiftLeftU32)1187 XLA_TEST_F(ArrayElementwiseOpTest, ShiftLeftU32) {
1188   XlaBuilder builder(TestName());
1189   auto a = ConstantR1<uint32>(&builder,
1190                               {0x12345678, 0xF0001000, 1, 3, 77, 1, ~3u, 77});
1191   auto b = ConstantR1<uint32>(&builder, {4, 8, 2, 7, 15, 32, 100, ~0u});
1192   ShiftLeft(a, b);
1193 
1194   ComputeAndCompareR1<uint32>(
1195       &builder, {0x23456780, 0x00100000, 0x4, 0x180, 2523136, 0, 0, 0}, {});
1196 }
1197 
XLA_TEST_F(ArrayElementwiseOpTest,ShiftRightArithmeticU32)1198 XLA_TEST_F(ArrayElementwiseOpTest, ShiftRightArithmeticU32) {
1199   XlaBuilder builder(TestName());
1200   auto a = ConstantR1<uint32>(&builder,
1201                               {0x92345678, 0x10001000, 1, 3, 77, 1, ~3u, 77});
1202   auto b = ConstantR1<uint32>(&builder, {4, 8, 2, 7, 2, 32, 100, ~0u});
1203   ShiftRightArithmetic(a, b);
1204 
1205   ComputeAndCompareR1<uint32>(
1206       &builder, {0xF9234567, 0x00100010, 0, 0, 19, 0, ~0u, 0}, {});
1207 }
1208 
XLA_TEST_F(ArrayElementwiseOpTest,ShiftRightLogicalU32)1209 XLA_TEST_F(ArrayElementwiseOpTest, ShiftRightLogicalU32) {
1210   XlaBuilder builder(TestName());
1211   auto a = ConstantR1<uint32>(&builder,
1212                               {0x92345678, 0x10001000, 1, 3, 77, 1, ~3u, 77});
1213   auto b = ConstantR1<uint32>(&builder, {4, 8, 2, 7, 5, 32, 100, ~0u});
1214   ShiftRightLogical(a, b);
1215 
1216   ComputeAndCompareR1<uint32>(&builder,
1217                               {0x09234567, 0x00100010, 0, 0, 2, 0, 0, 0}, {});
1218 }
1219 
XLA_TEST_F(ArrayElementwiseOpTest,CompareEqF32s)1220 XLA_TEST_F(ArrayElementwiseOpTest, CompareEqF32s) {
1221   SetFastMathDisabled(true);
1222   XlaBuilder builder(TestName());
1223   auto lhs = ConstantR1<float>(&builder, {-2.5f, 25.5f, 2.25f, NAN, 6.0f});
1224   auto rhs = ConstantR1<float>(&builder, {10.0f, 5.0f, 2.25f, 10.0f, NAN});
1225   Eq(lhs, rhs);
1226 
1227   ComputeAndCompareR1<bool>(&builder, {false, false, true, false, false}, {});
1228 }
1229 
XLA_TEST_F(ArrayElementwiseOpTest,CompareEqF32sTO)1230 XLA_TEST_F(ArrayElementwiseOpTest, CompareEqF32sTO) {
1231   SetFastMathDisabled(true);
1232   XlaBuilder builder(TestName());
1233   auto lhs = ConstantR1<float>(&builder, {-2.5f, 25.5f, 2.25f, NAN, 6.0f});
1234   auto rhs = ConstantR1<float>(&builder, {10.0f, 5.0f, 2.25f, NAN, NAN});
1235   EqTotalOrder(lhs, rhs);
1236 
1237   ComputeAndCompareR1<bool>(&builder, {false, false, true, true, false}, {});
1238 }
1239 
XLA_TEST_F(ArrayElementwiseOpTest,CompareEqZeroElementF32s)1240 XLA_TEST_F(ArrayElementwiseOpTest, CompareEqZeroElementF32s) {
1241   XlaBuilder builder(TestName());
1242   auto lhs = ConstantR1<float>(&builder, {});
1243   auto rhs = ConstantR1<float>(&builder, {});
1244   Eq(lhs, rhs);
1245 
1246   ComputeAndCompareR1<bool>(&builder, {}, {});
1247 }
1248 
XLA_TEST_F(ArrayElementwiseOpTest,CompareGeF32s)1249 XLA_TEST_F(ArrayElementwiseOpTest, CompareGeF32s) {
1250   SetFastMathDisabled(true);
1251   XlaBuilder builder(TestName());
1252   auto lhs = ConstantR1<float>(&builder, {-2.5f, 25.5f, 2.25f, NAN, 6.0f});
1253   auto rhs = ConstantR1<float>(&builder, {10.0f, 5.0f, 1.0f, 10.0f, NAN});
1254   Ge(lhs, rhs);
1255 
1256   ComputeAndCompareR1<bool>(&builder, {false, true, true, false, false}, {});
1257 }
1258 
XLA_TEST_F(ArrayElementwiseOpTest,CompareGeF32sTO)1259 XLA_TEST_F(ArrayElementwiseOpTest, CompareGeF32sTO) {
1260   SetFastMathDisabled(true);
1261   XlaBuilder builder(TestName());
1262   // For portability, need to represent NAN using the following call.
1263   // The C++ standard does not specify if quiet_NaN() sets the sign bit of
1264   // its result. The call to std::fabs will ensure that it is not set.
1265   auto nan = std::fabs(std::numeric_limits<float>::quiet_NaN());
1266   auto lhs =
1267       ConstantR1<float>(&builder, {-2.5f, 25.5f, 2.25f, nan, 6.0f, 6.0f});
1268   auto rhs = ConstantR1<float>(&builder, {10.0f, 5.0f, 1.0f, 10.0f, nan, -nan});
1269   GeTotalOrder(lhs, rhs);
1270 
1271   ComputeAndCompareR1<bool>(&builder, {false, true, true, true, false, true},
1272                             {});
1273 }
1274 
XLA_TEST_F(ArrayElementwiseOpTest,CompareGtF32s)1275 XLA_TEST_F(ArrayElementwiseOpTest, CompareGtF32s) {
1276   SetFastMathDisabled(true);
1277   XlaBuilder builder(TestName());
1278   auto lhs = ConstantR1<float>(&builder, {-2.5f, 25.5f, 2.25f, NAN, 6.0f});
1279   auto rhs = ConstantR1<float>(&builder, {10.0f, 5.0f, 1.0f, 10.0f, NAN});
1280   Gt(lhs, rhs);
1281 
1282   ComputeAndCompareR1<bool>(&builder, {false, true, true, false, false}, {});
1283 }
1284 
XLA_TEST_F(ArrayElementwiseOpTest,CompareLeF32s)1285 XLA_TEST_F(ArrayElementwiseOpTest, CompareLeF32s) {
1286   SetFastMathDisabled(true);
1287   XlaBuilder builder(TestName());
1288   auto lhs = ConstantR1<float>(&builder, {-2.5f, 5.0f, 2.25f, NAN, 6.0f});
1289   auto rhs = ConstantR1<float>(&builder, {10.0f, 5.0f, 1.0f, 10.0f, NAN});
1290   Le(lhs, rhs);
1291 
1292   ComputeAndCompareR1<bool>(&builder, {true, true, false, false, false}, {});
1293 }
1294 
XLA_TEST_F(ArrayElementwiseOpTest,CompareLtF32s)1295 XLA_TEST_F(ArrayElementwiseOpTest, CompareLtF32s) {
1296   SetFastMathDisabled(true);
1297   XlaBuilder builder(TestName());
1298   auto lhs = ConstantR1<float>(&builder, {-2.5f, 25.5f, 2.25f, NAN, 6.0f});
1299   auto rhs = ConstantR1<float>(&builder, {10.0f, 5.0f, 1.0f, 10.0f, NAN});
1300   Lt(lhs, rhs);
1301 
1302   ComputeAndCompareR1<bool>(&builder, {true, false, false, false, false}, {});
1303 }
1304 
XLA_TEST_F(ArrayElementwiseOpTest,CompareEqS32s)1305 XLA_TEST_F(ArrayElementwiseOpTest, CompareEqS32s) {
1306   const int32_t min = std::numeric_limits<int32>::min();
1307   const int32_t max = std::numeric_limits<int32>::max();
1308   XlaBuilder builder(TestName());
1309   auto lhs =
1310       ConstantR1<int32>(&builder, {min, min, min, 0, 0, 0, max, max, max});
1311   auto rhs = ConstantR1<int32>(&builder, {min, 0, max, -1, 0, 1, min, 0, max});
1312   Eq(lhs, rhs);
1313 
1314   ComputeAndCompareR1<bool>(
1315       &builder, {true, false, false, false, true, false, false, false, true},
1316       {});
1317 }
1318 
XLA_TEST_F(ArrayElementwiseOpTest,CompareEqZeroElementS32s)1319 XLA_TEST_F(ArrayElementwiseOpTest, CompareEqZeroElementS32s) {
1320   XlaBuilder builder(TestName());
1321   auto lhs = ConstantR1<int32>(&builder, {});
1322   auto rhs = ConstantR1<int32>(&builder, {});
1323   Eq(lhs, rhs);
1324 
1325   ComputeAndCompareR1<bool>(&builder, {}, {});
1326 }
1327 
XLA_TEST_F(ArrayElementwiseOpTest,CompareEqC64s)1328 XLA_TEST_F(ArrayElementwiseOpTest, CompareEqC64s) {
1329   SetFastMathDisabled(true);
1330   XlaBuilder builder(TestName());
1331   auto lhs = ConstantR1<complex64>(&builder, {{-2.5f, 10.0f},
1332                                               {1.0f, 25.5f},
1333                                               {2.25f, -3.0f},
1334                                               {NAN, 0.0f},
1335                                               {1.0f, 6.0f}});
1336   auto rhs = ConstantR1<complex64>(&builder, {{0.0f, 10.0f},
1337                                               {1.0f, 5.0f},
1338                                               {2.25f, -3.0f},
1339                                               {10.0f, 0.0f},
1340                                               {1.0f, NAN}});
1341   Eq(lhs, rhs);
1342 
1343   ComputeAndCompareR1<bool>(&builder, {false, false, true, false, false}, {});
1344 }
1345 
XLA_TEST_F(ArrayElementwiseOpTest,CompareEqZeroElementC64s)1346 XLA_TEST_F(ArrayElementwiseOpTest, CompareEqZeroElementC64s) {
1347   XlaBuilder builder(TestName());
1348   auto lhs = ConstantR1<complex64>(&builder, {});
1349   auto rhs = ConstantR1<complex64>(&builder, {});
1350   Eq(lhs, rhs);
1351 
1352   ComputeAndCompareR1<bool>(&builder, {}, {});
1353 }
1354 
XLA_TEST_F(ArrayElementwiseOpTest,CompareNeC64s)1355 XLA_TEST_F(ArrayElementwiseOpTest, CompareNeC64s) {
1356   // Disable fast-math because we're operating on NaNs.
1357   SetFastMathDisabled(true);
1358 
1359   XlaBuilder builder(TestName());
1360   auto lhs = ConstantR1<complex64>(&builder, {{-2.5f, 10.0f},
1361                                               {1.0f, 25.5f},
1362                                               {2.25f, -3.0f},
1363                                               {NAN, 0.0f},
1364                                               {1.0f, 6.0f}});
1365   auto rhs = ConstantR1<complex64>(&builder, {{0.0f, 10.0f},
1366                                               {1.0f, 5.0f},
1367                                               {2.25f, -3.0f},
1368                                               {10.0f, 0.0f},
1369                                               {1.0f, NAN}});
1370   Ne(lhs, rhs);
1371 
1372   ComputeAndCompareR1<bool>(&builder, {true, true, false, true, true}, {});
1373 }
1374 
XLA_TEST_F(ArrayElementwiseOpTest,CompareNeF32s)1375 XLA_TEST_F(ArrayElementwiseOpTest, CompareNeF32s) {
1376   // Disable fast-math because we're operating on NaNs.
1377   SetFastMathDisabled(true);
1378 
1379   XlaBuilder builder(TestName());
1380   auto lhs = ConstantR1<float>(&builder, {-2.5f, 25.5f, 2.25f, NAN, 6.0f});
1381   auto rhs = ConstantR1<float>(&builder, {10.0f, 25.5f, 1.0f, 10.0f, NAN});
1382   Ne(lhs, rhs);
1383 
1384   ComputeAndCompareR1<bool>(&builder, {true, false, true, true, true}, {});
1385 }
1386 
XLA_TEST_F(ArrayElementwiseOpTest,CompareNeS32s)1387 XLA_TEST_F(ArrayElementwiseOpTest, CompareNeS32s) {
1388   const int32_t min = std::numeric_limits<int32>::min();
1389   const int32_t max = std::numeric_limits<int32>::max();
1390   XlaBuilder builder(TestName());
1391   auto lhs =
1392       ConstantR1<int32>(&builder, {min, min, min, 0, 0, 0, max, max, max});
1393   auto rhs = ConstantR1<int32>(&builder, {min, 0, max, -1, 0, 1, min, 0, max});
1394   Ne(lhs, rhs);
1395 
1396   ComputeAndCompareR1<bool>(
1397       &builder, {false, true, true, true, false, true, true, true, false}, {});
1398 }
1399 
XLA_TEST_F(ArrayElementwiseOpTest,CompareGeS32s)1400 XLA_TEST_F(ArrayElementwiseOpTest, CompareGeS32s) {
1401   const int32_t min = std::numeric_limits<int32>::min();
1402   const int32_t max = std::numeric_limits<int32>::max();
1403   XlaBuilder builder(TestName());
1404   auto lhs =
1405       ConstantR1<int32>(&builder, {min, min, min, 0, 0, 0, max, max, max});
1406   auto rhs = ConstantR1<int32>(&builder, {min, 0, max, -1, 0, 1, min, 0, max});
1407   Ge(lhs, rhs);
1408 
1409   ComputeAndCompareR1<bool>(
1410       &builder, {true, false, false, true, true, false, true, true, true}, {});
1411 }
1412 
XLA_TEST_F(ArrayElementwiseOpTest,CompareGtS32s)1413 XLA_TEST_F(ArrayElementwiseOpTest, CompareGtS32s) {
1414   const int32_t min = std::numeric_limits<int32>::min();
1415   const int32_t max = std::numeric_limits<int32>::max();
1416   XlaBuilder builder(TestName());
1417   auto lhs =
1418       ConstantR1<int32>(&builder, {min, min, min, 0, 0, 0, max, max, max});
1419   auto rhs = ConstantR1<int32>(&builder, {min, 0, max, -1, 0, 1, min, 0, max});
1420   Gt(lhs, rhs);
1421 
1422   ComputeAndCompareR1<bool>(
1423       &builder, {false, false, false, true, false, false, true, true, false},
1424       {});
1425 }
1426 
XLA_TEST_F(ArrayElementwiseOpTest,CompareLeS32s)1427 XLA_TEST_F(ArrayElementwiseOpTest, CompareLeS32s) {
1428   const int32_t min = std::numeric_limits<int32>::min();
1429   const int32_t max = std::numeric_limits<int32>::max();
1430   XlaBuilder builder(TestName());
1431   auto lhs =
1432       ConstantR1<int32>(&builder, {min, min, min, 0, 0, 0, max, max, max});
1433   auto rhs = ConstantR1<int32>(&builder, {min, 0, max, -1, 0, 1, min, 0, max});
1434   Le(lhs, rhs);
1435 
1436   ComputeAndCompareR1<bool>(
1437       &builder, {true, true, true, false, true, true, false, false, true}, {});
1438 }
1439 
XLA_TEST_F(ArrayElementwiseOpTest,CompareLtS32s)1440 XLA_TEST_F(ArrayElementwiseOpTest, CompareLtS32s) {
1441   const int32_t min = std::numeric_limits<int32>::min();
1442   const int32_t max = std::numeric_limits<int32>::max();
1443   XlaBuilder builder(TestName());
1444   auto lhs =
1445       ConstantR1<int32>(&builder, {min, min, min, 0, 0, 0, max, max, max});
1446   auto rhs = ConstantR1<int32>(&builder, {min, 0, max, -1, 0, 1, min, 0, max});
1447   Lt(lhs, rhs);
1448 
1449   ComputeAndCompareR1<bool>(
1450       &builder, {false, true, true, false, false, true, false, false, false},
1451       {});
1452 }
1453 
XLA_TEST_F(ArrayElementwiseOpTest,CompareEqU32s)1454 XLA_TEST_F(ArrayElementwiseOpTest, CompareEqU32s) {
1455   const uint32 max = std::numeric_limits<uint32>::max();
1456   XlaBuilder builder(TestName());
1457   auto lhs = ConstantR1<uint32>(&builder, {0, 0, 0, 5, 5, 5, max, max, max});
1458   auto rhs = ConstantR1<uint32>(&builder, {0, 1, max, 4, 5, 6, 0, 1, max});
1459   Eq(lhs, rhs);
1460 
1461   ComputeAndCompareR1<bool>(
1462       &builder, {true, false, false, false, true, false, false, false, true},
1463       {});
1464 }
1465 
XLA_TEST_F(ArrayElementwiseOpTest,CompareNeU32s)1466 XLA_TEST_F(ArrayElementwiseOpTest, CompareNeU32s) {
1467   const uint32 max = std::numeric_limits<uint32>::max();
1468   XlaBuilder builder(TestName());
1469   auto lhs = ConstantR1<uint32>(&builder, {0, 0, 0, 5, 5, 5, max, max, max});
1470   auto rhs = ConstantR1<uint32>(&builder, {0, 1, max, 4, 5, 6, 0, 1, max});
1471   Ne(lhs, rhs);
1472 
1473   ComputeAndCompareR1<bool>(
1474       &builder, {false, true, true, true, false, true, true, true, false}, {});
1475 }
1476 
XLA_TEST_F(ArrayElementwiseOpTest,CompareGeU32s)1477 XLA_TEST_F(ArrayElementwiseOpTest, CompareGeU32s) {
1478   const uint32 max = std::numeric_limits<uint32>::max();
1479   XlaBuilder builder(TestName());
1480   auto lhs = ConstantR1<uint32>(&builder, {0, 0, 0, 5, 5, 5, max, max, max});
1481   auto rhs = ConstantR1<uint32>(&builder, {0, 1, max, 4, 5, 6, 0, 1, max});
1482   Ge(lhs, rhs);
1483 
1484   ComputeAndCompareR1<bool>(
1485       &builder, {true, false, false, true, true, false, true, true, true}, {});
1486 }
1487 
XLA_TEST_F(ArrayElementwiseOpTest,CompareGtU32s)1488 XLA_TEST_F(ArrayElementwiseOpTest, CompareGtU32s) {
1489   const uint32 max = std::numeric_limits<uint32>::max();
1490   XlaBuilder builder(TestName());
1491   auto lhs = ConstantR1<uint32>(&builder, {0, 0, 0, 5, 5, 5, max, max, max});
1492   auto rhs = ConstantR1<uint32>(&builder, {0, 1, max, 4, 5, 6, 0, 1, max});
1493   Gt(lhs, rhs);
1494 
1495   ComputeAndCompareR1<bool>(
1496       &builder, {false, false, false, true, false, false, true, true, false},
1497       {});
1498 }
1499 
XLA_TEST_F(ArrayElementwiseOpTest,CompareLeU32s)1500 XLA_TEST_F(ArrayElementwiseOpTest, CompareLeU32s) {
1501   const uint32 max = std::numeric_limits<uint32>::max();
1502   XlaBuilder builder(TestName());
1503   auto lhs = ConstantR1<uint32>(&builder, {0, 0, 0, 5, 5, 5, max, max, max});
1504   auto rhs = ConstantR1<uint32>(&builder, {0, 1, max, 4, 5, 6, 0, 1, max});
1505   Le(lhs, rhs);
1506 
1507   ComputeAndCompareR1<bool>(
1508       &builder, {true, true, true, false, true, true, false, false, true}, {});
1509 }
1510 
XLA_TEST_F(ArrayElementwiseOpTest,CompareLtU32s)1511 XLA_TEST_F(ArrayElementwiseOpTest, CompareLtU32s) {
1512   const uint32 max = std::numeric_limits<uint32>::max();
1513   XlaBuilder builder(TestName());
1514   auto lhs = ConstantR1<uint32>(&builder, {0, 0, 0, 5, 5, 5, max, max, max});
1515   auto rhs = ConstantR1<uint32>(&builder, {0, 1, max, 4, 5, 6, 0, 1, max});
1516   Lt(lhs, rhs);
1517 
1518   ComputeAndCompareR1<bool>(
1519       &builder, {false, true, true, false, false, true, false, false, false},
1520       {});
1521 }
1522 
XLA_TEST_F(ArrayElementwiseOpTest,PowF32s)1523 XLA_TEST_F(ArrayElementwiseOpTest, PowF32s) {
1524   SetFastMathDisabled(true);
1525   XlaBuilder builder(TestName());
1526   auto lhs = ConstantR1<float>(
1527       &builder, {0.0f, 4.0f, 2.0f, 2.0f, NAN, 6.0f, -2.0f, -2.0f});
1528   auto rhs = ConstantR1<float>(
1529       &builder, {0.0f, 2.0f, -2.0f, 3.0f, 10.0f, NAN, 3.0f, 4.0f});
1530   Pow(lhs, rhs);
1531 
1532   ComputeAndCompareR1<float>(&builder,
1533                              {1.0f, 16.0f, 0.25f, 8.0f, NAN, NAN, -8.0f, 16.0f},
1534                              {}, error_spec_);
1535 }
1536 
XLA_TEST_F(ArrayElementwiseOpTest,PowNonIntegerF32s)1537 XLA_TEST_F(ArrayElementwiseOpTest, PowNonIntegerF32s) {
1538   SetFastMathDisabled(true);
1539   XlaBuilder builder(TestName());
1540   auto lhs = ConstantR1<float>(&builder, {-2.0f, -0.6f, -0.6f, 0.0f});
1541   auto rhs = ConstantR1<float>(&builder, {0.5f, 0.6f, -0.6f, -0.6f});
1542   Pow(lhs, rhs);
1543 
1544   ComputeAndCompareR1<float>(&builder, {NAN, NAN, NAN, INFINITY}, {},
1545                              error_spec_);
1546 }
1547 
XLA_TEST_F(ArrayElementwiseOpTest,PowC64s)1548 XLA_TEST_F(ArrayElementwiseOpTest, PowC64s) {
1549   SetFastMathDisabled(true);
1550   XlaBuilder builder(TestName());
1551   auto lhs =
1552       ConstantR1<complex64>(&builder, {-2.0f, -0.6f, -0.6f, 0.0f, 0.0f, 0.0f});
1553   auto rhs =
1554       ConstantR1<complex64>(&builder, {0.5f, 0.6f, -0.6f, 0.5f, 0.6f, 0.0f});
1555   Pow(lhs, rhs);
1556 
1557   ComputeAndCompareR1<complex64>(&builder,
1558                                  {
1559                                      {0, 1.41421356},
1560                                      {-2.27443288e-01, 0.69999846},
1561                                      {-4.19847531e-01, -1.29215783},
1562                                      {0, 0},
1563                                      {0, 0},
1564                                      {1, 0},
1565                                  },
1566                                  {}, error_spec_);
1567 }
1568 
XLA_TEST_F(ArrayElementwiseOpTest,PowZeroElementF32s)1569 XLA_TEST_F(ArrayElementwiseOpTest, PowZeroElementF32s) {
1570   XlaBuilder builder(TestName());
1571   auto lhs = ConstantR1<float>(&builder, {});
1572   auto rhs = ConstantR1<float>(&builder, {});
1573   Pow(lhs, rhs);
1574 
1575   ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
1576 }
1577 
1578 // Some Pow cases that can be implemented more efficiently.
XLA_TEST_F(ArrayElementwiseOpTest,PowSpecialF32)1579 XLA_TEST_F(ArrayElementwiseOpTest, PowSpecialF32) {
1580   XlaBuilder b(TestName());
1581 
1582   std::vector<float> values = {1.0f, 2.0f, 3.2f, -4.0f};
1583   std::vector<float> exponents = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f};
1584 
1585   Literal param_literal = LiteralUtil::CreateR1<float>(values);
1586   std::unique_ptr<GlobalData> param_data =
1587       client_->TransferToServer(param_literal).ConsumeValueOrDie();
1588 
1589   auto sum = ConstantR0<float>(&b, 0.0f);
1590   auto param = Parameter(&b, 0, param_literal.shape(), "param");
1591   for (float exponent : exponents) {
1592     sum = Add(sum, Pow(param, ConstantR0<float>(&b, exponent)));
1593   }
1594 
1595   std::vector<float> expected;
1596   for (auto value : values) {
1597     float sum = 0.0f;
1598     for (float exponent : exponents) {
1599       sum += std::pow(value, exponent);
1600     }
1601     expected.push_back(sum);
1602   }
1603 
1604   ComputeAndCompareR1<float>(&b, expected, {param_data.get()}, error_spec_);
1605 }
1606 
XLA_TEST_F(ArrayElementwiseOpTest,PowOfExpF32)1607 XLA_TEST_F(ArrayElementwiseOpTest, PowOfExpF32) {
1608   XlaBuilder b(TestName());
1609 
1610   std::vector<float> values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.0f, 5.7f};
1611   std::vector<float> values1 = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f};
1612 
1613   Literal literal0 = LiteralUtil::CreateR1<float>(values0);
1614   std::unique_ptr<GlobalData> data0 =
1615       client_->TransferToServer(literal0).ConsumeValueOrDie();
1616   Literal literal1 = LiteralUtil::CreateR1<float>(values1);
1617   std::unique_ptr<GlobalData> data1 =
1618       client_->TransferToServer(literal1).ConsumeValueOrDie();
1619   auto param0 = Parameter(&b, 0, literal0.shape(), "param0");
1620   auto param1 = Parameter(&b, 1, literal1.shape(), "param1");
1621   Pow(Exp(param0), param1);
1622 
1623   std::vector<float> expected(values0.size());
1624   for (int64_t i = 0; i < values0.size(); ++i) {
1625     expected[i] = std::pow(std::exp(values0[i]), values1[i]);
1626   }
1627 
1628   ComputeAndCompareR1<float>(&b, expected, {data0.get(), data1.get()},
1629                              error_spec_);
1630 }
1631 
XLA_TEST_F(ArrayElementwiseOpTest,LogOfPowerF32)1632 XLA_TEST_F(ArrayElementwiseOpTest, LogOfPowerF32) {
1633   XlaBuilder b(TestName());
1634 
1635   std::vector<float> values0 = {1.0f, -10.0f, -2.0f, 2.0f,
1636                                 3.2f, 4.0f,   0.5f,  5.7f};
1637   std::vector<float> values1 = {0.0f, 10.0f, -4.0f, 1.0f,
1638                                 2.0f, 0.5f,  -1.0f, -0.5f};
1639 
1640   Literal literal0 = LiteralUtil::CreateR1<float>(values0);
1641   std::unique_ptr<GlobalData> data0 =
1642       client_->TransferToServer(literal0).ConsumeValueOrDie();
1643   Literal literal1 = LiteralUtil::CreateR1<float>(values1);
1644   std::unique_ptr<GlobalData> data1 =
1645       client_->TransferToServer(literal1).ConsumeValueOrDie();
1646   auto param0 = Parameter(&b, 0, literal0.shape(), "param0");
1647   auto param1 = Parameter(&b, 1, literal1.shape(), "param1");
1648   Log(Pow(param0, param1));
1649 
1650   std::vector<float> expected(values0.size());
1651   for (int64_t i = 0; i < values0.size(); ++i) {
1652     expected[i] = std::log(std::pow(values0[i], values1[i]));
1653   }
1654 
1655   ComputeAndCompareR1<float>(&b, expected, {data0.get(), data1.get()},
1656                              error_spec_);
1657 }
1658 
XLA_TEST_F(ArrayElementwiseOpTest,MulOfExpF32)1659 XLA_TEST_F(ArrayElementwiseOpTest, MulOfExpF32) {
1660   XlaBuilder b(TestName());
1661 
1662   std::vector<float> values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.0f, 5.7f};
1663   std::vector<float> values1 = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f};
1664 
1665   Literal literal0 = LiteralUtil::CreateR1<float>(values0);
1666   std::unique_ptr<GlobalData> data0 =
1667       client_->TransferToServer(literal0).ConsumeValueOrDie();
1668   Literal literal1 = LiteralUtil::CreateR1<float>(values1);
1669   std::unique_ptr<GlobalData> data1 =
1670       client_->TransferToServer(literal1).ConsumeValueOrDie();
1671   auto param0 = Parameter(&b, 0, literal0.shape(), "param0");
1672   auto param1 = Parameter(&b, 1, literal1.shape(), "param1");
1673   Mul(Exp(param0), Exp(param1));
1674 
1675   std::vector<float> expected(values0.size());
1676   for (int64_t i = 0; i < values0.size(); ++i) {
1677     expected[i] = std::exp(values0[i]) * std::exp(values1[i]);
1678   }
1679 
1680   ComputeAndCompareR1<float>(&b, expected, {data0.get(), data1.get()},
1681                              error_spec_);
1682 }
1683 
XLA_TEST_F(ArrayElementwiseOpTest,DivOfExpF32)1684 XLA_TEST_F(ArrayElementwiseOpTest, DivOfExpF32) {
1685   XlaBuilder b(TestName());
1686 
1687   std::vector<float> values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.0f, 5.7f};
1688   std::vector<float> values1 = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f};
1689 
1690   Literal literal0 = LiteralUtil::CreateR1<float>(values0);
1691   std::unique_ptr<GlobalData> data0 =
1692       client_->TransferToServer(literal0).ConsumeValueOrDie();
1693   Literal literal1 = LiteralUtil::CreateR1<float>(values1);
1694   std::unique_ptr<GlobalData> data1 =
1695       client_->TransferToServer(literal1).ConsumeValueOrDie();
1696   auto param0 = Parameter(&b, 0, literal0.shape(), "param0");
1697   auto param1 = Parameter(&b, 1, literal1.shape(), "param1");
1698   Div(param0, Exp(param1));
1699 
1700   std::vector<float> expected(values0.size());
1701   for (int64_t i = 0; i < values0.size(); ++i) {
1702     expected[i] = values0[i] / std::exp(values1[i]);
1703   }
1704 
1705   ComputeAndCompareR1<float>(&b, expected, {data0.get(), data1.get()},
1706                              error_spec_);
1707 }
1708 
XLA_TEST_F(ArrayElementwiseOpTest,Div3_lhs_F32)1709 XLA_TEST_F(ArrayElementwiseOpTest, Div3_lhs_F32) {
1710   XlaBuilder b(TestName());
1711 
1712   std::vector<float> values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.45f, 5.7f};
1713   std::vector<float> values1 = {0.1f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f};
1714   std::vector<float> values2 = {0.1f, 1.1f, 6.9f, 12.5f, -15.0f, -0.5f};
1715 
1716   Literal literal0 = LiteralUtil::CreateR1<float>(values0);
1717   std::unique_ptr<GlobalData> data0 =
1718       client_->TransferToServer(literal0).ConsumeValueOrDie();
1719 
1720   Literal literal1 = LiteralUtil::CreateR1<float>(values1);
1721   std::unique_ptr<GlobalData> data1 =
1722       client_->TransferToServer(literal1).ConsumeValueOrDie();
1723 
1724   Literal literal2 = LiteralUtil::CreateR1<float>(values2);
1725   std::unique_ptr<GlobalData> data2 =
1726       client_->TransferToServer(literal2).ConsumeValueOrDie();
1727   auto param0 = Parameter(&b, 0, literal0.shape(), "param0");
1728   auto param1 = Parameter(&b, 1, literal1.shape(), "param1");
1729   auto param2 = Parameter(&b, 2, literal2.shape(), "param2");
1730   Div(Div(param0, param1), param2);
1731 
1732   std::vector<float> expected(values0.size());
1733   for (int64_t i = 0; i < values0.size(); ++i) {
1734     expected[i] = (values0[i] / values1[i]) / values2[i];
1735   }
1736 
1737   ComputeAndCompareR1<float>(
1738       &b, expected, {data0.get(), data1.get(), data2.get()}, error_spec_);
1739 }
1740 
XLA_TEST_F(ArrayElementwiseOpTest,Div3_rhs_F32)1741 XLA_TEST_F(ArrayElementwiseOpTest, Div3_rhs_F32) {
1742   XlaBuilder b(TestName());
1743 
1744   std::vector<float> values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.45f, 5.7f};
1745   std::vector<float> values1 = {0.1f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f};
1746   std::vector<float> values2 = {0.1f, 1.1f, 6.9f, 12.5f, -15.0f, -0.5f};
1747 
1748   Literal literal0 = LiteralUtil::CreateR1<float>(values0);
1749   std::unique_ptr<GlobalData> data0 =
1750       client_->TransferToServer(literal0).ConsumeValueOrDie();
1751 
1752   Literal literal1 = LiteralUtil::CreateR1<float>(values1);
1753   std::unique_ptr<GlobalData> data1 =
1754       client_->TransferToServer(literal1).ConsumeValueOrDie();
1755 
1756   Literal literal2 = LiteralUtil::CreateR1<float>(values2);
1757   std::unique_ptr<GlobalData> data2 =
1758       client_->TransferToServer(literal2).ConsumeValueOrDie();
1759 
1760   auto param0 = Parameter(&b, 0, literal0.shape(), "param0");
1761   auto param1 = Parameter(&b, 1, literal1.shape(), "param1");
1762   auto param2 = Parameter(&b, 2, literal2.shape(), "param2");
1763   Div(param0, Div(param1, param2));
1764 
1765   std::vector<float> expected(values0.size());
1766   for (int64_t i = 0; i < values0.size(); ++i) {
1767     expected[i] = values0[i] / (values1[i] / values2[i]);
1768   }
1769 
1770   ComputeAndCompareR1<float>(
1771       &b, expected, {data0.get(), data1.get(), data2.get()}, error_spec_);
1772 }
1773 
XLA_TEST_F(ArrayElementwiseOpTest,DivOfPowerF32)1774 XLA_TEST_F(ArrayElementwiseOpTest, DivOfPowerF32) {
1775   XlaBuilder b(TestName());
1776 
1777   std::vector<float> values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.45f, 5.7f};
1778   std::vector<float> values1 = {0.1f, 1.0f, 2.0f, 0.5f, 1.0f, 0.5f};
1779   std::vector<float> values2 = {0.1f, 1.1f, 6.9f, 9.5f, -11.0f, -0.5f};
1780 
1781   Literal literal0 = LiteralUtil::CreateR1<float>(values0);
1782   std::unique_ptr<GlobalData> data0 =
1783       client_->TransferToServer(literal0).ConsumeValueOrDie();
1784 
1785   Literal literal1 = LiteralUtil::CreateR1<float>(values1);
1786   std::unique_ptr<GlobalData> data1 =
1787       client_->TransferToServer(literal1).ConsumeValueOrDie();
1788 
1789   Literal literal2 = LiteralUtil::CreateR1<float>(values2);
1790   std::unique_ptr<GlobalData> data2 =
1791       client_->TransferToServer(literal2).ConsumeValueOrDie();
1792 
1793   auto param0 = Parameter(&b, 0, literal0.shape(), "param0");
1794   auto param1 = Parameter(&b, 1, literal1.shape(), "param1");
1795   auto param2 = Parameter(&b, 2, literal2.shape(), "param2");
1796   Div(param0, Pow(param1, param2));
1797 
1798   std::vector<float> expected(values0.size());
1799   for (int64_t i = 0; i < values0.size(); ++i) {
1800     expected[i] = values0[i] / std::pow(values1[i], values2[i]);
1801   }
1802 
1803   ComputeAndCompareR1<float>(
1804       &b, expected, {data0.get(), data1.get(), data2.get()}, error_spec_);
1805 }
1806 
XLA_TEST_F(ArrayElementwiseOpTest,Div4F32)1807 XLA_TEST_F(ArrayElementwiseOpTest, Div4F32) {
1808   XlaBuilder b(TestName());
1809 
1810   std::vector<float> values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.45f, 5.7f};
1811   std::vector<float> values1 = {0.1f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f};
1812   std::vector<float> values2 = {0.1f, 1.1f, 6.9f, 12.5f, -15.0f, -0.5f};
1813   std::vector<float> values3 = {2.1f, 3.1f, 9.9f, -4.5f, -11.0f, -21.5f};
1814 
1815   Literal literal0 = LiteralUtil::CreateR1<float>(values0);
1816   std::unique_ptr<GlobalData> data0 =
1817       client_->TransferToServer(literal0).ConsumeValueOrDie();
1818 
1819   Literal literal1 = LiteralUtil::CreateR1<float>(values1);
1820   std::unique_ptr<GlobalData> data1 =
1821       client_->TransferToServer(literal1).ConsumeValueOrDie();
1822 
1823   Literal literal2 = LiteralUtil::CreateR1<float>(values2);
1824   std::unique_ptr<GlobalData> data2 =
1825       client_->TransferToServer(literal2).ConsumeValueOrDie();
1826 
1827   Literal literal3 = LiteralUtil::CreateR1<float>(values3);
1828   std::unique_ptr<GlobalData> data3 =
1829       client_->TransferToServer(literal3).ConsumeValueOrDie();
1830 
1831   auto param0 = Parameter(&b, 0, literal0.shape(), "param0");
1832   auto param1 = Parameter(&b, 1, literal1.shape(), "param1");
1833   auto param2 = Parameter(&b, 2, literal2.shape(), "param2");
1834   auto param3 = Parameter(&b, 3, literal3.shape(), "param2");
1835   Div(Div(param0, param1), Div(param2, param3));
1836 
1837   std::vector<float> expected(values0.size());
1838   for (int64_t i = 0; i < values0.size(); ++i) {
1839     expected[i] = (values0[i] / values1[i]) / (values2[i] / values3[i]);
1840   }
1841 
1842   ComputeAndCompareR1<float>(
1843       &b, expected, {data0.get(), data1.get(), data2.get(), data3.get()},
1844       error_spec_);
1845 }
1846 
TEST_P(ArrayElementwiseOpTestParamCount,SquareManyValues)1847 TEST_P(ArrayElementwiseOpTestParamCount, SquareManyValues) {
1848   const int count = GetParam();
1849   XlaBuilder builder(TestName());
1850   std::vector<float> values;
1851   values.reserve(count);
1852   for (int i = 0; i < count; ++i) {
1853     values.push_back(i / static_cast<float>(count));
1854   }
1855   auto x = ConstantR1<float>(&builder, values);
1856   Pow(x, ConstantR0<float>(&builder, 2.0f));
1857 
1858   std::vector<float> expected;
1859   expected.reserve(values.size());
1860   for (float value : values) {
1861     expected.push_back(value * value);
1862   }
1863 
1864   ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_);
1865 }
1866 
XLA_TEST_F(ArrayElementwiseOpTest,SquareIn4D)1867 XLA_TEST_F(ArrayElementwiseOpTest, SquareIn4D) {
1868   XlaBuilder builder(TestName());
1869   Array4D<float> values(2, 2, 2, 2);
1870 
1871   std::vector<float> values_vector;
1872   std::vector<float> expected_vector;
1873   for (int i = 0; i < values.num_elements(); ++i) {
1874     values_vector.push_back(static_cast<float>(i) / values.num_elements());
1875     expected_vector.push_back(values_vector.back() * values_vector.back());
1876   }
1877   values.SetValues(values_vector);
1878 
1879   Array4D<float> expected(2, 2, 2, 2, expected_vector);
1880 
1881   auto x = ConstantR4FromArray4D<float>(&builder, values);
1882   Pow(x, ConstantR0<float>(&builder, 2.0f));
1883 
1884   ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
1885 }
1886 
XLA_TEST_F(ArrayElementwiseOpTest,SquareIn4DZeroElements)1887 XLA_TEST_F(ArrayElementwiseOpTest, SquareIn4DZeroElements) {
1888   XlaBuilder builder(TestName());
1889   Array4D<float> values(2, 2, 0, 2);
1890   Array4D<float> expected(2, 2, 0, 2);
1891 
1892   auto x = ConstantR4FromArray4D<float>(&builder, values);
1893   Pow(x, ConstantR0<float>(&builder, 2.0f));
1894 
1895   ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
1896 }
1897 
XLA_TEST_F(ArrayElementwiseOpTest,MinF32s)1898 XLA_TEST_F(ArrayElementwiseOpTest, MinF32s) {
1899   XlaBuilder builder(TestName());
1900   SetFastMathDisabled(true);
1901   auto lhs = ConstantR1<float>(&builder, {1.0f, 1.0f, 2.25f, NAN, 6.0f});
1902   auto rhs = ConstantR1<float>(&builder, {2.0f, -5.0f, 1.0f, 10.0f, NAN});
1903   Min(lhs, rhs);
1904 
1905   ComputeAndCompareR1<float>(&builder, {1.0f, -5.0f, 1.0f, NAN, NAN}, {},
1906                              error_spec_);
1907 }
1908 
XLA_TEST_F(ArrayElementwiseOpTest,MinZeroElementF32s)1909 XLA_TEST_F(ArrayElementwiseOpTest, MinZeroElementF32s) {
1910   XlaBuilder builder(TestName());
1911   auto lhs = ConstantR1<float>(&builder, {});
1912   auto rhs = ConstantR1<float>(&builder, {});
1913   Min(lhs, rhs);
1914   ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
1915 }
1916 
XLA_TEST_F(ArrayElementwiseOpTest,MinF64s)1917 XLA_TEST_F(ArrayElementwiseOpTest, MinF64s) {
1918   XlaBuilder builder(TestName());
1919   SetFastMathDisabled(true);
1920   auto lhs = ConstantR1<double>(&builder, {1.0, 1.0, 2.25, NAN, 6.0});
1921   auto rhs = ConstantR1<double>(&builder, {2.0, -5.0, 1.0, 10.0, NAN});
1922   Min(lhs, rhs);
1923 
1924   ComputeAndCompareR1<double>(&builder, {1.0, -5.0, 1.0, NAN, NAN}, {},
1925                               strict_error_spec_);
1926 }
1927 
XLA_TEST_F(ArrayElementwiseOpTest,MaxF32s)1928 XLA_TEST_F(ArrayElementwiseOpTest, MaxF32s) {
1929   XlaBuilder builder(TestName());
1930   SetFastMathDisabled(true);
1931   auto lhs = ConstantR1<float>(&builder, {1.0f, 1.0f, 2.25f, NAN, 6.0f});
1932   auto rhs = ConstantR1<float>(&builder, {2.0f, -5.0f, 1.0f, 10.0f, NAN});
1933   Max(lhs, rhs);
1934 
1935   ComputeAndCompareR1<float>(&builder, {2.0f, 1.0f, 2.25f, NAN, NAN}, {},
1936                              error_spec_);
1937 }
1938 
XLA_TEST_F(ArrayElementwiseOpTest,MaxZeroElementF32s)1939 XLA_TEST_F(ArrayElementwiseOpTest, MaxZeroElementF32s) {
1940   XlaBuilder builder(TestName());
1941   auto lhs = ConstantR1<float>(&builder, {});
1942   auto rhs = ConstantR1<float>(&builder, {});
1943   Max(lhs, rhs);
1944   ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
1945 }
1946 
XLA_TEST_F(ArrayElementwiseOpTest,MaxF64s)1947 XLA_TEST_F(ArrayElementwiseOpTest, MaxF64s) {
1948   XlaBuilder builder(TestName());
1949   SetFastMathDisabled(true);
1950   auto lhs = ConstantR1<double>(&builder, {1.0, 1.0, 2.25, NAN, 6.0});
1951   auto rhs = ConstantR1<double>(&builder, {2.0, -5.0, 1.0, 10.0, NAN});
1952   Max(lhs, rhs);
1953 
1954   ComputeAndCompareR1<double>(&builder, {2.0, 1.0, 2.25, NAN, NAN}, {},
1955                               strict_error_spec_);
1956 }
1957 
XLA_TEST_F(ArrayElementwiseOpTest,MaxS32s)1958 XLA_TEST_F(ArrayElementwiseOpTest, MaxS32s) {
1959   const int32_t min = std::numeric_limits<int32>::min();
1960   const int32_t max = std::numeric_limits<int32>::max();
1961   XlaBuilder builder(TestName());
1962   auto x = ConstantR1<int32>(
1963       &builder, {min, min, min, -1, -1, 0, 0, 0, 1, 1, max, max, max});
1964   auto y = ConstantR1<int32>(
1965       &builder, {min, max, 0, -10, 0, -1, 0, 1, 0, 10, 0, max, min});
1966   Max(x, y);
1967 
1968   std::vector<int32> expected = {min, max, 0,  -1,  0,   0,  0,
1969                                  1,   1,   10, max, max, max};
1970   ComputeAndCompareR1<int32>(&builder, expected, {});
1971 }
1972 
XLA_TEST_F(ArrayElementwiseOpTest,MinS32s)1973 XLA_TEST_F(ArrayElementwiseOpTest, MinS32s) {
1974   const int32_t min = std::numeric_limits<int32>::min();
1975   const int32_t max = std::numeric_limits<int32>::max();
1976   XlaBuilder builder(TestName());
1977   auto x = ConstantR1<int32>(
1978       &builder, {min, min, min, -1, -1, 0, 0, 0, 1, 1, max, max, max});
1979   auto y = ConstantR1<int32>(
1980       &builder, {min, max, 0, -10, 0, -1, 0, 1, 0, 10, 0, max, min});
1981   Min(x, y);
1982 
1983   std::vector<int32> expected = {min, min, min, -10, -1,  -1, 0,
1984                                  0,   0,   1,   0,   max, min};
1985   ComputeAndCompareR1<int32>(&builder, expected, {});
1986 }
1987 
XLA_TEST_F(ArrayElementwiseOpTest,MaxU32s)1988 XLA_TEST_F(ArrayElementwiseOpTest, MaxU32s) {
1989   const uint32 max = std::numeric_limits<uint32>::max();
1990   XlaBuilder builder(TestName());
1991   auto x = ConstantR1<uint32>(&builder, {0, 0, 1, 1, 1, max, max, max});
1992   auto y = ConstantR1<uint32>(&builder, {0, 1, 0, 1, 10, 0, 234234, max});
1993   Max(x, y);
1994 
1995   std::vector<uint32> expected = {0, 1, 1, 1, 10, max, max, max};
1996   ComputeAndCompareR1<uint32>(&builder, expected, {});
1997 }
1998 
XLA_TEST_F(ArrayElementwiseOpTest,MinU32s)1999 XLA_TEST_F(ArrayElementwiseOpTest, MinU32s) {
2000   const uint32 max = std::numeric_limits<uint32>::max();
2001   XlaBuilder builder(TestName());
2002   auto x = ConstantR1<uint32>(&builder, {0, 0, 1, 1, 1, max, max, max});
2003   auto y = ConstantR1<uint32>(&builder, {0, 1, 0, 1, 10, 0, 234234, max});
2004   Min(x, y);
2005 
2006   std::vector<uint32> expected = {0, 0, 0, 1, 1, 0, 234234, max};
2007   ComputeAndCompareR1<uint32>(&builder, expected, {});
2008 }
2009 
XLA_TEST_F(ArrayElementwiseOpTest,MaxTenF32s)2010 XLA_TEST_F(ArrayElementwiseOpTest, MaxTenF32s) {
2011   XlaBuilder builder(TestName());
2012   auto x = ConstantR1<float>(
2013       &builder, {-0.0, 1.0, 2.0, -3.0, -4.0, 5.0, 6.0, -7.0, -8.0, 9.0});
2014   auto y = ConstantR1<float>(
2015       &builder, {-0.0, -1.0, -2.0, 3.0, 4.0, -5.0, -6.0, 7.0, 8.0, -9.0});
2016   Max(x, y);
2017 
2018   std::vector<float> expected = {-0.0, 1.0, 2.0, 3.0, 4.0,
2019                                  5.0,  6.0, 7.0, 8.0, 9.0};
2020   ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_);
2021 }
2022 
XLA_TEST_F(ArrayElementwiseOpTest,MaxR1S1AndR1S0F32s)2023 XLA_TEST_F(ArrayElementwiseOpTest, MaxR1S1AndR1S0F32s) {
2024   XlaBuilder builder(TestName());
2025   auto u = ConstantR1<float>(&builder, {3.5});
2026   auto v = ConstantR1<float>(&builder, {});
2027   Max(u, v);
2028 
2029   ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
2030 }
2031 
XLA_TEST_F(ArrayElementwiseOpTest,MaxR1S0AndR2S0x2F32s)2032 XLA_TEST_F(ArrayElementwiseOpTest, MaxR1S0AndR2S0x2F32s) {
2033   for (int broadcast_dim : {0, 1}) {
2034     XlaBuilder builder(TestName());
2035     auto u = ConstantR1<float>(&builder, {3.5});
2036     auto v = ConstantR2FromArray2D<float>(&builder, Array2D<float>(0, 2));
2037     Max(u, v, /*broadcast_dimensions=*/{broadcast_dim});
2038 
2039     ComputeAndCompareR2<float>(&builder, Array2D<float>(0, 2), {}, error_spec_);
2040   }
2041 }
2042 
XLA_TEST_F(ArrayElementwiseOpTest,Max1DAnd2DF32s)2043 XLA_TEST_F(ArrayElementwiseOpTest, Max1DAnd2DF32s) {
2044   XlaBuilder builder(TestName());
2045   auto v = ConstantR1<float>(&builder, {2.0f, 3.0f, 4.0f});
2046   auto m = ConstantR2<float>(&builder,
2047                              {{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}});
2048   Max(v, m, /*broadcast_dimensions=*/{1});
2049 
2050   Array2D<float> expected({{2.0f, 3.14f, 4.0f}, {2.25f, 3.0f, 4.0f}});
2051   ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
2052 }
2053 
XLA_TEST_F(ArrayElementwiseOpTest,Max1DAnd2DZeroElementF32s)2054 XLA_TEST_F(ArrayElementwiseOpTest, Max1DAnd2DZeroElementF32s) {
2055   XlaBuilder builder(TestName());
2056   auto v = ConstantR1<float>(&builder, {});
2057   auto m = ConstantR2<float>(&builder, {{}, {}});
2058   Max(v, m, /*broadcast_dimensions=*/{1});
2059 
2060   Array2D<float> expected({{}, {}});
2061   ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
2062 }
2063 
XLA_TEST_F(ArrayElementwiseOpTest,Max3DAndScalarS32s)2064 XLA_TEST_F(ArrayElementwiseOpTest, Max3DAndScalarS32s) {
2065   XlaBuilder builder(TestName());
2066   auto scalar = ConstantR0<int32>(&builder, 2);
2067   Array3D<int32> a_3d({{{3, 9, -1}, {2, -10, 3}}, {{-2, 2, 8}, {12, 10, 4}}});
2068   auto array = ConstantR3FromArray3D<int32>(&builder, a_3d);
2069   Max(array, scalar, /*broadcast_dimensions=*/{});
2070 
2071   Array3D<int32> expected({{{3, 9, 2}, {2, 2, 3}}, {{2, 2, 8}, {12, 10, 4}}});
2072   ComputeAndCompareR3<int32>(&builder, expected, {});
2073 }
2074 
XLA_TEST_F(ArrayElementwiseOpTest,Max3DAndScalarZeroElementS32s)2075 XLA_TEST_F(ArrayElementwiseOpTest, Max3DAndScalarZeroElementS32s) {
2076   XlaBuilder builder(TestName());
2077   auto scalar = ConstantR0<int32>(&builder, 2);
2078   Array3D<int32> a_3d(2, 0, 3);
2079   auto array = ConstantR3FromArray3D<int32>(&builder, a_3d);
2080   Max(array, scalar, /*broadcast_dimensions=*/{});
2081 
2082   Array3D<int32> expected(2, 0, 3);
2083   ComputeAndCompareR3<int32>(&builder, expected, {});
2084 }
2085 
XLA_TEST_F(ArrayElementwiseOpTest,Min2DTo1DF32s)2086 XLA_TEST_F(ArrayElementwiseOpTest, Min2DTo1DF32s) {
2087   XlaBuilder builder(TestName());
2088   auto m = ConstantR2<float>(&builder,
2089                              {{-10.4f, 64.0f, 6.0f}, {0.1f, 32.0f, 16.1f}});
2090   auto v = ConstantR1<float>(&builder, {-10.2f, 16.4f});
2091   Min(m, v, /*broadcast_dimensions=*/{0});
2092 
2093   Array2D<float> expected({{-10.4f, -10.2f, -10.2f}, {0.1f, 16.4f, 16.1f}});
2094   ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
2095 }
2096 
XLA_TEST_F(ArrayElementwiseOpTest,Min2DTo1DZeroElementF32s)2097 XLA_TEST_F(ArrayElementwiseOpTest, Min2DTo1DZeroElementF32s) {
2098   XlaBuilder builder(TestName());
2099   auto m = ConstantR2<float>(&builder, {{}, {}});
2100   auto v = ConstantR1<float>(&builder, {-10.2f, 16.4f});
2101   Min(m, v, /*broadcast_dimensions=*/{0});
2102 
2103   Array2D<float> expected({{}, {}});
2104   ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
2105 }
2106 
XLA_TEST_F(ArrayElementwiseOpTest,Min2DTo4DF32s)2107 XLA_TEST_F(ArrayElementwiseOpTest, Min2DTo4DF32s) {
2108   XlaBuilder builder(TestName());
2109   auto array2d =
2110       ConstantR2<float>(&builder, {{-12.2f, 64.3f, 6.1f}, {0.0f, 32.2f, 2.5f}});
2111   auto array4d = ConstantR4FromArray4D<float>(
2112       &builder, {{{{-12.1f, 32.3f, 6.2f}}, {{0.0f, 32.5f, 3.0f}}},
2113                  {{{-2.5f, 64.29f, 6.5f}}, {{-0.01f, 32.25f, 2.6f}}}});
2114   Min(array2d, array4d, /*broadcast_dimensions=*/{1, 3});
2115 
2116   Array4D<float> expected(
2117       {{{{-12.2f, 32.3f, 6.1f}}, {{0.0f, 32.2f, 2.5f}}},
2118        {{{-12.2f, 64.29f, 6.1f}}, {{-0.01f, 32.2f, 2.5f}}}});
2119   ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
2120 }
2121 
XLA_TEST_F(ArrayElementwiseOpTest,Min2DTo4DZeroElementF32s)2122 XLA_TEST_F(ArrayElementwiseOpTest, Min2DTo4DZeroElementF32s) {
2123   XlaBuilder builder(TestName());
2124   auto array2d =
2125       ConstantR2<float>(&builder, {{-12.2f, 64.3f, 6.1f}, {0.0f, 32.2f, 2.5f}});
2126   Array4D<float> arg(2, 2, 0, 3);
2127   auto array4d = ConstantR4FromArray4D<float>(&builder, arg);
2128   Min(array2d, array4d, /*broadcast_dimensions=*/{1, 3});
2129 
2130   Array4D<float> expected(2, 2, 0, 3);
2131   ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
2132 }
2133 
XLA_TEST_F(ArrayElementwiseOpTest,MinTenS32s)2134 XLA_TEST_F(ArrayElementwiseOpTest, MinTenS32s) {
2135   XlaBuilder builder(TestName());
2136   auto x = ConstantR1<int32>(&builder, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9});
2137   auto y = ConstantR1<int32>(&builder, {9, 8, 7, 6, 5, 4, 3, 2, 1, 0});
2138   Min(x, y);
2139 
2140   std::vector<int32> expected = {0, 1, 2, 3, 4, 4, 3, 2, 1, 0};
2141   ComputeAndCompareR1<int32>(&builder, expected, {});
2142 }
2143 
XLA_TEST_F(ArrayElementwiseOpTest,MaxTenS32s)2144 XLA_TEST_F(ArrayElementwiseOpTest, MaxTenS32s) {
2145   XlaBuilder builder(TestName());
2146   auto x = ConstantR1<int32>(&builder, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9});
2147   auto y = ConstantR1<int32>(&builder, {9, 8, 7, 6, 5, 4, 3, 2, 1, 0});
2148   Max(x, y);
2149 
2150   std::vector<int32> expected = {9, 8, 7, 6, 5, 5, 6, 7, 8, 9};
2151   ComputeAndCompareR1<int32>(&builder, expected, {});
2152 }
2153 
XLA_TEST_F(ArrayElementwiseOpTest,RemTwoConstantS32s)2154 XLA_TEST_F(ArrayElementwiseOpTest, RemTwoConstantS32s) {
2155   XlaBuilder builder(TestName());
2156   auto a = ConstantR1<int32>(&builder, {-3, 26, 2, -1, 1});
2157   auto b = ConstantR1<int32>(&builder, {10, 5, 1, 10, -10});
2158   Rem(a, b);
2159 
2160   ComputeAndCompareR1<int32>(&builder, {-3, 1, 0, -1, 1}, {});
2161 }
2162 
XLA_TEST_F(ArrayElementwiseOpTest,NonNanClampF32)2163 XLA_TEST_F(ArrayElementwiseOpTest, NonNanClampF32) {
2164   XlaBuilder builder(TestName());
2165   auto minimum = ConstantR1<float>(&builder, {1.0f, -6.5f, 1.0f, 2.25f, 0.0f});
2166   auto argument =
2167       ConstantR1<float>(&builder, {2.0f, 10.0f, -5.0f, 1.0f, 10.0f});
2168   auto maximum = ConstantR1<float>(&builder, {3.0f, 0.5f, 25.5f, 5.0f, 123.0});
2169   Clamp(minimum, argument, maximum);
2170 
2171   ComputeAndCompareR1<float>(&builder, {2.0f, 0.5f, 1.0f, 2.25f, 10.0f}, {},
2172                              error_spec_);
2173 }
2174 
XLA_TEST_F(ArrayElementwiseOpTest,ClampF32)2175 XLA_TEST_F(ArrayElementwiseOpTest, ClampF32) {
2176   SetFastMathDisabled(true);
2177   XlaBuilder builder(TestName());
2178   auto minimum = ConstantR1<float>(&builder, {1.0f, -6.5f, 1.0f, 2.25f, NAN});
2179   auto argument =
2180       ConstantR1<float>(&builder, {2.0f, 10.0f, -5.0f, 1.0f, 10.0f});
2181   auto maximum = ConstantR1<float>(&builder, {3.0f, 0.5f, 25.5f, NAN, 123.0f});
2182   Clamp(minimum, argument, maximum);
2183 
2184   ComputeAndCompareR1<float>(&builder, {2.0f, 0.5f, 1.0f, NAN, NAN}, {},
2185                              error_spec_);
2186 }
2187 
XLA_TEST_F(ArrayElementwiseOpTest,ClampF32Scalar)2188 XLA_TEST_F(ArrayElementwiseOpTest, ClampF32Scalar) {
2189   XlaBuilder builder(TestName());
2190   auto minimum = ConstantR0<float>(&builder, 0.0f);
2191   auto argument = ConstantR1<float>(&builder, {2.0f, 10.0f, -5.0f, 1.0f, 4.0f});
2192   auto maximum = ConstantR0<float>(&builder, 5.0f);
2193   Clamp(minimum, argument, maximum);
2194 
2195   ComputeAndCompareR1<float>(&builder, {2.0f, 5.0f, 0.0f, 1.0f, 4.0f}, {},
2196                              error_spec_);
2197 }
2198 
XLA_TEST_F(ArrayElementwiseOpTest,ClampF32ScalarVector)2199 XLA_TEST_F(ArrayElementwiseOpTest, ClampF32ScalarVector) {
2200   XlaBuilder builder(TestName());
2201   auto min_scalar = ConstantR0<float>(&builder, 0.0f);
2202   auto min_vector =
2203       ConstantR1<float>(&builder, {1.0f, -6.5f, 1.0f, 2.25f, 0.0f});
2204   auto arg_vector =
2205       ConstantR1<float>(&builder, {2.0f, 10.0f, -5.0f, 1.0f, 4.0f});
2206   auto max_scalar = ConstantR0<float>(&builder, 3.0f);
2207   auto max_vector =
2208       ConstantR1<float>(&builder, {3.0f, 0.5f, 25.5f, 5.0f, 123.0});
2209   // Perform clamp with broadcasted scalar and vector.
2210   Add(Add(Clamp(min_vector, arg_vector, max_scalar),
2211           Clamp(min_scalar, arg_vector, max_vector)),
2212       Add(Clamp(min_vector, arg_vector, max_vector),
2213           Clamp(min_scalar, arg_vector, max_scalar)));
2214 
2215   ComputeAndCompareR1<float>(&builder, {8.0f, 7.0f, 2.0f, 6.5f, 14.0f}, {},
2216                              error_spec_);
2217 }
2218 
XLA_TEST_F(ArrayElementwiseOpTest,ClampS32Vector)2219 XLA_TEST_F(ArrayElementwiseOpTest, ClampS32Vector) {
2220   XlaBuilder builder(TestName());
2221   auto min_vector = ConstantR1<int32>(&builder, {1, -6, 1, 2, 0, -5});
2222   auto arg_vector = ConstantR1<int32>(&builder, {2, 10, -5, 1, 4, 10});
2223   auto max_vector = ConstantR1<int32>(&builder, {3, 0, 25, 5, 123, -1});
2224   Clamp(min_vector, arg_vector, max_vector);
2225 
2226   ComputeAndCompareR1<int32>(&builder, {2, 0, 1, 2, 4, -1}, {});
2227 }
2228 
XLA_TEST_F(ArrayElementwiseOpTest,ClampS32ScalarVector)2229 XLA_TEST_F(ArrayElementwiseOpTest, ClampS32ScalarVector) {
2230   XlaBuilder builder(TestName());
2231   auto min_scalar = ConstantR0<int32>(&builder, 0);
2232   auto min_vector = ConstantR1<int32>(&builder, {1, -6, 1, 2, 0});
2233   auto arg_vector = ConstantR1<int32>(&builder, {2, 10, -5, 1, 4});
2234   auto max_scalar = ConstantR0<int32>(&builder, 3);
2235   auto max_vector = ConstantR1<int32>(&builder, {3, 1, 25, 5, 123});
2236   // Perform clamp with broadcasted scalar and vector.
2237   Add(Add(Clamp(min_vector, arg_vector, max_scalar),
2238           Clamp(min_scalar, arg_vector, max_vector)),
2239       Add(Clamp(min_vector, arg_vector, max_vector),
2240           Clamp(min_scalar, arg_vector, max_scalar)));
2241 
2242   ComputeAndCompareR1<int32>(&builder, {8, 8, 2, 6, 14}, {});
2243 }
2244 
XLA_TEST_F(ArrayElementwiseOpTest,ClampU32Vector)2245 XLA_TEST_F(ArrayElementwiseOpTest, ClampU32Vector) {
2246   XlaBuilder builder(TestName());
2247   auto min_vector = ConstantR1<uint32>(&builder, {1, 2, 1, 2, 0, ~0u - 4});
2248   auto arg_vector = ConstantR1<uint32>(&builder, {2, 10, 5, 1, 4, 10});
2249   auto max_vector = ConstantR1<uint32>(&builder, {3, 5, 25, 5, 123, ~0u});
2250   Clamp(min_vector, arg_vector, max_vector);
2251 
2252   ComputeAndCompareR1<uint32>(&builder, {2, 5, 5, 2, 4, ~0u - 4}, {});
2253 }
2254 
XLA_TEST_F(ArrayElementwiseOpTest,ClampU32ScalarVector)2255 XLA_TEST_F(ArrayElementwiseOpTest, ClampU32ScalarVector) {
2256   XlaBuilder builder(TestName());
2257   auto min_scalar = ConstantR0<uint32>(&builder, 0);
2258   auto min_vector = ConstantR1<uint32>(&builder, {1, 0, 1, 2, 0});
2259   auto arg_vector = ConstantR1<uint32>(&builder, {2, 10, 0, 1, 4});
2260   auto max_scalar = ConstantR0<uint32>(&builder, 3);
2261   auto max_vector = ConstantR1<uint32>(&builder, {3, 1, 25, 5, 123});
2262   // Perform clamp with broadcasted scalar and vector.
2263   Add(Add(Clamp(min_vector, arg_vector, max_scalar),
2264           Clamp(min_scalar, arg_vector, max_vector)),
2265       Add(Clamp(min_vector, arg_vector, max_vector),
2266           Clamp(min_scalar, arg_vector, max_scalar)));
2267 
2268   ComputeAndCompareR1<uint32>(&builder, {8, 8, 2, 6, 14}, {});
2269 }
2270 
XLA_TEST_F(ArrayElementwiseOpTest,AddTwoParametersF32s)2271 XLA_TEST_F(ArrayElementwiseOpTest, AddTwoParametersF32s) {
2272   XlaBuilder builder(TestName());
2273 
2274   Literal param0_literal =
2275       LiteralUtil::CreateR1<float>({1.1f, 2.2f, 3.3f, 5.5f});
2276   std::unique_ptr<GlobalData> param0_data =
2277       client_->TransferToServer(param0_literal).ConsumeValueOrDie();
2278 
2279   Literal param1_literal =
2280       LiteralUtil::CreateR1<float>({7.2f, 2.3f, 3.4f, 5.6f});
2281   std::unique_ptr<GlobalData> param1_data =
2282       client_->TransferToServer(param1_literal).ConsumeValueOrDie();
2283 
2284   auto p0 = Parameter(&builder, 0, param0_literal.shape(), "param0");
2285   auto p1 = Parameter(&builder, 1, param1_literal.shape(), "param1");
2286   Add(p0, p1);
2287 
2288   ComputeAndCompareR1<float>(&builder, {8.3f, 4.5f, 6.7f, 11.1f},
2289                              {param0_data.get(), param1_data.get()},
2290                              error_spec_);
2291 }
2292 
XLA_TEST_F(ArrayElementwiseOpTest,AddTwoParametersZeroElementF32s)2293 XLA_TEST_F(ArrayElementwiseOpTest, AddTwoParametersZeroElementF32s) {
2294   XlaBuilder builder(TestName());
2295 
2296   Literal param0_literal =
2297       LiteralUtil::CreateR3FromArray3D<float>(Array3D<float>(0, 7, 0));
2298   std::unique_ptr<GlobalData> param0_data =
2299       client_->TransferToServer(param0_literal).ConsumeValueOrDie();
2300 
2301   Literal param1_literal =
2302       LiteralUtil::CreateR3FromArray3D<float>(Array3D<float>(0, 7, 0));
2303   std::unique_ptr<GlobalData> param1_data =
2304       client_->TransferToServer(param1_literal).ConsumeValueOrDie();
2305 
2306   auto p0 = Parameter(&builder, 0, param0_literal.shape(), "param0");
2307   auto p1 = Parameter(&builder, 1, param1_literal.shape(), "param1");
2308   Add(p0, p1);
2309 
2310   Array3D<float> expected(0, 7, 0);
2311   ComputeAndCompareR3<float>(
2312       &builder, expected, {param0_data.get(), param1_data.get()}, error_spec_);
2313 }
2314 
XLA_TEST_F(ArrayElementwiseOpTest,AddParameterToConstantF32s)2315 XLA_TEST_F(ArrayElementwiseOpTest, AddParameterToConstantF32s) {
2316   XlaBuilder builder(TestName());
2317 
2318   Literal param0_literal =
2319       LiteralUtil::CreateR1<float>({1.1f, 2.2f, 3.3f, 5.5f});
2320   std::unique_ptr<GlobalData> param0_data =
2321       client_->TransferToServer(param0_literal).ConsumeValueOrDie();
2322 
2323   auto a = ConstantR1<float>(&builder, {1.1f, 2.2f, 3.3f, 4.4f});
2324   auto p = Parameter(&builder, 0, param0_literal.shape(), "param0");
2325   Add(a, p);
2326 
2327   ComputeAndCompareR1<float>(&builder, {2.2f, 4.4f, 6.6f, 9.9f},
2328                              {param0_data.get()}, error_spec_);
2329 }
2330 
XLA_TEST_F(ArrayElementwiseOpTest,CosF32s)2331 XLA_TEST_F(ArrayElementwiseOpTest, CosF32s) {
2332   XlaBuilder builder(TestName());
2333   auto a = ConstantR1<float>(&builder, {3.14159f, 0.0f, 1.570796f, -0.78539f});
2334   Cos(a);
2335 
2336   ComputeAndCompareR1<float>(&builder, {-1.0f, 1.0f, 0.0f, 0.707107f}, {},
2337                              error_spec_);
2338 }
2339 
XLA_TEST_F(ArrayElementwiseOpTest,SinF32s)2340 XLA_TEST_F(ArrayElementwiseOpTest, SinF32s) {
2341   XlaBuilder builder(TestName());
2342   auto a = ConstantR1<float>(&builder, {3.14159f, 0.0f, 1.570796f, -0.78539f});
2343   Sin(a);
2344 
2345   ComputeAndCompareR1<float>(&builder, {0.0f, 0.0f, 1.0f, -0.707107f}, {},
2346                              error_spec_);
2347 }
2348 
XLA_TEST_F(ArrayElementwiseOpTest,Atan2F32s)2349 XLA_TEST_F(ArrayElementwiseOpTest, Atan2F32s) {
2350   XlaBuilder builder(TestName());
2351   auto inf = std::numeric_limits<float>::infinity();
2352   std::vector<float> ys;
2353   std::vector<float> xs;
2354   for (auto y : {+0.0f, -0.0f, inf, -inf, 5.0f, -3.0f, 2.0f, -8.0f, 1.0f}) {
2355     for (auto x : {+0.0f, -0.0f, inf, -inf, 6.0f, -4.0f, 2.0f, 8.0f}) {
2356       ys.push_back(y);
2357       xs.push_back(x);
2358     }
2359   }
2360   auto y = ConstantR1<float>(&builder, ys);
2361   auto x = ConstantR1<float>(&builder, xs);
2362   Atan2(y, x);
2363 
2364   ComputeAndCompare(&builder, {}, error_spec_);
2365 }
2366 
XLA_TEST_F(ArrayElementwiseOpTest,Atan2C64s)2367 XLA_TEST_F(ArrayElementwiseOpTest, Atan2C64s) {
2368   XlaBuilder builder(TestName());
2369   auto inf = std::numeric_limits<float>::infinity();
2370   std::vector<std::complex<float>> ys;
2371   std::vector<std::complex<float>> xs;
2372   for (auto y : {+0.0f, -0.0f, inf, -inf, 5.0f, -3.0f, 2.0f, -8.0f, 1.0f}) {
2373     for (auto x : {+0.0f, -0.0f, inf, -inf, 6.0f, -4.0f, 2.0f, 8.0f}) {
2374       ys.push_back(y);
2375       xs.push_back(x);
2376     }
2377   }
2378   auto y = ConstantR1<std::complex<float>>(&builder, ys);
2379   auto x = ConstantR1<std::complex<float>>(&builder, xs);
2380   Atan2(y, x);
2381 
2382   ComputeAndCompare(&builder, {}, error_spec_);
2383 }
2384 
XLA_TEST_F(ArrayElementwiseOpTest,TanhF32s)2385 XLA_TEST_F(ArrayElementwiseOpTest, TanhF32s) {
2386   XlaBuilder builder(TestName());
2387   auto a = ConstantR1<float>(&builder, {-2.5f, 3.14f, 2.25f});
2388   Tanh(a);
2389 
2390   ComputeAndCompareR1<float>(&builder, {-0.986614f, 0.996260f, 0.978026}, {},
2391                              error_spec_);
2392 }
2393 
XLA_TEST_F(ArrayElementwiseOpTest,TanhF32sVector)2394 XLA_TEST_F(ArrayElementwiseOpTest, TanhF32sVector) {
2395   // This is like the test ArrayElementwiseOpTest.TanhF32s above, except that
2396   // the input tensor is large enough to exercise the vectorized tanh
2397   // implementation on XLA CPU.
2398   XlaBuilder builder(TestName());
2399   auto input_literal = LiteralUtil::CreateR1<float>(
2400       {1.02,  -0.32, 0.85,  0.90,  1.23,  -0.91, -0.49, 0.80,  -0.67, 0.16,
2401        -0.07, 0.39,  -0.41, 0.04,  1.36,  1.25,  0.41,  0.65,  -1.08, 0.32,
2402        -1.45, -0.77, -1.09, 0.91,  -1.03, -0.30, -1.11, -1.17, 1.50,  -0.85,
2403        0.04,  1.02,  0.34,  -0.61, 0.41,  0.07,  -0.02, 1.42,  -0.62, 0.81,
2404        0.08,  0.81,  -0.30, 1.17,  -0.65, -0.44, 0.92,  1.26,  -1.29, 1.35,
2405        0.08,  -1.24, -0.92, 0.49,  1.17,  -0.45, -1.31, -1.44, -0.13, -1.31,
2406        -0.79, 1.41,  1.21,  1.05});
2407   TF_ASSERT_OK_AND_ASSIGN(auto input_data,
2408                           client_->TransferToServer(input_literal));
2409 
2410   auto input = Parameter(&builder, 0, input_literal.shape(), "input");
2411   Tanh(input);
2412 
2413   ComputeAndCompareR1<float>(
2414       &builder,
2415       {0.77009583,  -0.30665702, 0.69070244,  0.71401149,  0.84400684,
2416        -0.71985596, -0.45764771, 0.66664988,  -0.58278900, 0.16050975,
2417        -0.06770509, 0.36843640,  -0.38476998, 0.04018109,  0.87562293,
2418        0.84788644,  0.38603750,  0.57294142,  -0.79140943, 0.31032649,
2419        -0.89590985, -0.64770776, -0.79625875, 0.72234446,  -0.77389336,
2420        -0.28871772, -0.80428445, -0.82541436, 0.90456349,  -0.68856895,
2421        0.03877772,  0.76877952,  0.32561871,  -0.54546672, 0.39072621,
2422        0.07273290,  -0.01924866, 0.88924897,  -0.55283129, 0.67183107,
2423        0.08006320,  0.66944766,  -0.29068485, 0.82573754,  -0.57170743,
2424        -0.41581789, 0.72739530,  0.85025692,  -0.85931867, 0.87357593,
2425        0.07782833,  -0.84597743, -0.72748238, 0.45396307,  0.82449573,
2426        -0.42462519, -0.86363792, -0.89368379, -0.12621804, -0.86445558,
2427        -0.65565848, 0.88789743,  0.83566397,  0.78287679},
2428       {input_data.get()},
2429       // The error spec is unusually high here to account for the fact that we
2430       // use a rational interpolant to approximate tanh.
2431       ErrorSpec(0.004, 0.004));
2432 }
2433 
XLA_TEST_F(ArrayElementwiseOpTest,ExpF32sVector)2434 XLA_TEST_F(ArrayElementwiseOpTest, ExpF32sVector) {
2435   // The input tensor is large enough to exercise the vectorized exp
2436   // implementation on XLA CPU.
2437   XlaBuilder builder(TestName());
2438 
2439   // Just to help make sense of the scales here -- exp(89) saturates float32 and
2440   // exp(-10) is smaller than our error spec.
2441   Literal input_literal = LiteralUtil::CreateR1<float>(
2442       {1.02,   -0.32,  0.85,   0.9,    1.23,   -0.91,  -0.49, 0.8,    -1.31,
2443        -1.44,  -0.13,  -1.31,  -0.79,  1.41,   1.21,   1.05,  -195.6, -194.5,
2444        -193.4, -192.3, -191.2, -190.1, -189.0, -187.9, -19.6, -18.5,  -17.4,
2445        -16.3,  -15.2,  -14.1,  -13.0,  -11.9,  -10.8,  -9.7,  -8.6,   -7.5,
2446        -6.4,   -5.3,   -4.2,   -3.1,   -2.0,   -0.9,   0.2,   1.3,    2.4,
2447        3.5,    4.6,    5.7,    6.8,    7.9,    9.0,    10.1,  11.2,   12.3,
2448        13.4,   14.5,   15.6,   16.7,   17.8,   18.9,   20.0,  21.1,   22.2,
2449        23.3,   24.4,   25.5,   26.6,   27.7,   28.8,   29.9,  31.0,   32.1,
2450        68.4,   69.5,   70.6,   71.7,   72.8,   73.9,   75.0,  76.1,   77.2,
2451        78.3,   79.4,   80.5,   81.6,   82.7,   83.8,   84.9,  85.2,   86.3,
2452        86.4,   86.5,   87.6,   87.7,   87.8,   87.9});
2453   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<GlobalData> input_data,
2454                           client_->TransferToServer(input_literal));
2455 
2456   auto input = Parameter(&builder, 0, input_literal.shape(), "input");
2457   Exp(input);
2458 
2459   std::vector<float> expected_result;
2460   int64_t input_size = input_literal.shape().dimensions(0);
2461   expected_result.reserve(input_size);
2462   for (int64_t i = 0; i < input_size; i++) {
2463     expected_result.push_back(std::exp(input_literal.Get<float>({i})));
2464   }
2465 
2466   ComputeAndCompareR1<float>(&builder, expected_result, {input_data.get()},
2467                              error_spec_);
2468 }
2469 
XLA_TEST_F(ArrayElementwiseOpTest,LogF32sVector)2470 XLA_TEST_F(ArrayElementwiseOpTest, LogF32sVector) {
2471   // The input tensor is large enough to exercise the vectorized exp
2472   // implementation on XLA CPU.
2473   XlaBuilder builder(TestName());
2474 
2475   Literal input_literal = LiteralUtil::CreateR1<float>(
2476       {-1.29,    -1.41,    -1.25,    -13.5,    -11.7,    -17.9,    -198,
2477        -167,     1.29,     1.41,     1.25,     13.5,     11.7,     17.9,
2478        198,      167,      1.27e+03, 1.33e+03, 1.74e+03, 1.6e+04,  1.84e+04,
2479        1.74e+04, 1.89e+05, 1.9e+05,  1.93e+06, 1.98e+06, 1.65e+06, 1.97e+07,
2480        1.66e+07, 1e+07,    1.98e+08, 1.96e+08, 1.64e+09, 1.58e+09, 1.64e+09,
2481        1.44e+10, 1.5e+10,  1.99e+10, 1.17e+11, 1.08e+11, 1.08e+12, 1.38e+12,
2482        1.4e+12,  1.03e+13, 1.6e+13,  1.99e+13, 1.26e+14, 1.51e+14, 1.33e+15,
2483        1.41e+15, 1.63e+15, 1.39e+16, 1.21e+16, 1.27e+16, 1.28e+17, 1.62e+17,
2484        2e+18,    1.96e+18, 1.81e+18, 1.99e+19, 1.86e+19, 1.61e+19, 1.71e+20,
2485        1.47e+20, 1.83e+21, 1.33e+21, 1.3e+21,  1.35e+22, 1.84e+22, 1.02e+22,
2486        1.81e+23, 1.02e+23, 1.89e+24, 1.49e+24, 1.08e+24, 1.95e+25, 1.1e+25,
2487        1.62e+25, 1.2e+26,  1.41e+26, 1.93e+27, 1.66e+27, 1.62e+27, 1.05e+28,
2488        1.5e+28,  1.79e+28, 1.36e+29, 1.95e+29, 1.5e+30,  1.81e+30, 1.34e+30,
2489        1.7e+31,  1.44e+31, 1.1e+31,  1.4e+32,  1.67e+32, 1.96e+33, 1.11e+33,
2490        1.19e+33, 1.61e+34, 1.05e+34, 1.88e+34, 1.67e+35, 1.7e+35});
2491   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<GlobalData> input_data,
2492                           client_->TransferToServer(input_literal));
2493 
2494   auto input = Parameter(&builder, 0, input_literal.shape(), "input");
2495   Log(input);
2496 
2497   std::vector<float> expected_result;
2498   int64_t input_size = input_literal.shape().dimensions(0);
2499   expected_result.reserve(input_size);
2500   for (int64_t i = 0; i < input_size; i++) {
2501     expected_result.push_back(std::log(input_literal.Get<float>({i})));
2502   }
2503 
2504   ComputeAndCompareR1<float>(&builder, expected_result, {input_data.get()},
2505                              error_spec_);
2506 }
2507 
XLA_TEST_F(ArrayElementwiseOpTest,ClzU32s)2508 XLA_TEST_F(ArrayElementwiseOpTest, ClzU32s) {
2509   XlaBuilder builder(TestName());
2510   auto a = ConstantR1<uint32>(
2511       &builder, {0, 1, 0x10, 0x10000, 0x700000, 0x12345678, 0xF2345678});
2512   Clz(a);
2513 
2514   ComputeAndCompareR1<uint32>(&builder, {32, 31, 27, 15, 9, 3, 0}, {});
2515 }
2516 
XLA_TEST_F(ArrayElementwiseOpTest,ClzS64s)2517 XLA_TEST_F(ArrayElementwiseOpTest, ClzS64s) {
2518   XlaBuilder builder(TestName());
2519   auto a =
2520       ConstantR1<int64>(&builder, {0, 1, 0x80000000, 0x7FFFFFFFF2345678ul, -1});
2521   Clz(a);
2522 
2523   ComputeAndCompareR1<int64>(&builder, {64, 63, 32, 1, 0}, {});
2524 }
2525 
XLA_TEST_F(ArrayElementwiseOpTest,AddChainFoldLeft)2526 XLA_TEST_F(ArrayElementwiseOpTest, AddChainFoldLeft) {
2527   // a ------ (add) --------- (add)
2528   //         /               /
2529   // b -----/               /
2530   // c---------------------/
2531   XlaBuilder builder(TestName());
2532 
2533   auto a = ConstantR1<float>(&builder, {1.1f, 2.2f, 3.3f, 4.4f});
2534   auto b = ConstantR1<float>(&builder, {2.1f, 3.2f, 4.3f, 5.4f});
2535   auto c = ConstantR1<float>(&builder, {-3.3f, -15.5f, -7.7f, -29.9f});
2536 
2537   auto add = Add(a, b);
2538   Add(add, c);
2539 
2540   ComputeAndCompareR1<float>(&builder, {-0.1f, -10.1f, -0.1f, -20.1f}, {},
2541                              error_spec_);
2542 }
2543 
XLA_TEST_F(ArrayElementwiseOpTest,AddChainFoldRight)2544 XLA_TEST_F(ArrayElementwiseOpTest, AddChainFoldRight) {
2545   // b ------ (add) --------- (add)
2546   //         /               /
2547   // c -----/               /
2548   // a---------------------/
2549   XlaBuilder builder(TestName());
2550 
2551   auto a = ConstantR1<float>(&builder, {91.1f, 2.2f, 3.3f, 4.4f});
2552   auto b = ConstantR1<float>(&builder, {2.1f, 3.2f, 4.3f, 5.4f});
2553   auto c = ConstantR1<float>(&builder, {-3.3f, -15.5f, -7.7f, -29.9f});
2554 
2555   auto add = Add(b, c);
2556   Add(a, add);
2557 
2558   ComputeAndCompareR1<float>(&builder, {89.9f, -10.1f, -0.1f, -20.1f}, {},
2559                              error_spec_);
2560 }
2561 
XLA_TEST_F(ArrayElementwiseOpTest,AddWithNeg)2562 XLA_TEST_F(ArrayElementwiseOpTest, AddWithNeg) {
2563   // a ----- (neg) ----- (add)
2564   //                    /
2565   // b ----- (neg) ----/
2566   XlaBuilder builder(TestName());
2567 
2568   auto a = ConstantR1<float>(&builder, {91.1f, 2.2f, 3.3f, 4.4f});
2569   auto b = ConstantR1<float>(&builder, {2.1f, 3.2f, 4.3f, 5.4f});
2570 
2571   auto neg_a = Neg(a);
2572   auto neg_b = Neg(b);
2573   Add(neg_a, neg_b);
2574 
2575   ComputeAndCompareR1<float>(&builder, {-93.2f, -5.4f, -7.6f, -9.8f}, {},
2576                              error_spec_);
2577 }
2578 
XLA_TEST_F(ArrayElementwiseOpTest,AddChainTwoSide)2579 XLA_TEST_F(ArrayElementwiseOpTest, AddChainTwoSide) {
2580   // a ------ (add) ------------\
2581   //         /                   \
2582   // b -----/                    (add)
2583   //                             /
2584   // c ------ (add) ------------/
2585   //         /
2586   // d -----/
2587   XlaBuilder builder(TestName());
2588 
2589   auto a = ConstantR1<float>(&builder, {91.1f, 2.2f, 3.3f, 4.4f});
2590   auto b = ConstantR1<float>(&builder, {2.1f, 3.2f, 4.3f, 5.4f});
2591   auto c = ConstantR1<float>(&builder, {-3.3f, -15.5f, -7.7f, -29.9f});
2592   auto d = ConstantR1<float>(&builder, {-19.0f, 10.0f, -40.0f, 20.2f});
2593 
2594   auto add_ab = Add(a, b);
2595   auto add_cd = Add(c, d);
2596   Add(add_ab, add_cd);
2597 
2598   ComputeAndCompareR1<float>(&builder, {70.9f, -0.1f, -40.1f, 0.1f}, {},
2599                              error_spec_);
2600 }
2601 
2602 XLA_TEST_F(ArrayElementwiseOpTest, 2DBinaryOpF32s) {
2603   XlaBuilder builder(TestName());
2604   auto a = ConstantR2<float>(&builder,
2605                              {{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}});
2606   auto b = ConstantR2<float>(&builder,
2607                              {{-1.5f, 8.14f, 42.0}, {-1.0f, -4.0f, 5.55f}});
2608   Add(a, b);
2609 
2610   Array2D<float> expected_array(
2611       {{-4.0f, 11.28f, 43.0f}, {1.25f, -14.0f, 8.88f}});
2612   ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_);
2613 }
2614 
XLA_TEST_F(ArrayElementwiseOpTest,ScalarPlus2DF32)2615 XLA_TEST_F(ArrayElementwiseOpTest, ScalarPlus2DF32) {
2616   // Add a scalar + matrix.
2617   XlaBuilder builder(TestName());
2618   auto a = ConstantR2<float>(&builder,
2619                              {{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}});
2620   auto scalar = ConstantR0<float>(&builder, 3.0f);
2621   Add(scalar, a);
2622 
2623   Array2D<float> expected_array({{0.5f, 6.14f, 4.0f}, {5.25f, -7.0f, 6.33f}});
2624   ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_);
2625 }
2626 
2627 XLA_TEST_F(ArrayElementwiseOpTest, 2DPlusScalarF32) {
2628   // Add a matrix + scalar.
2629   XlaBuilder builder(TestName());
2630   auto a = ConstantR2<float>(&builder,
2631                              {{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}});
2632   auto scalar = ConstantR0<float>(&builder, 3.0f);
2633   Add(a, scalar);
2634 
2635   Array2D<float> expected_array({{0.5f, 6.14f, 4.0f}, {5.25f, -7.0f, 6.33f}});
2636   ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_);
2637 }
2638 
XLA_TEST_F(ArrayElementwiseOpTest,Add1DTo2DF32)2639 XLA_TEST_F(ArrayElementwiseOpTest, Add1DTo2DF32) {
2640   // Test simple broadcasting of a R1F32 over R2F32. The vector's size matches
2641   // only dim 0 of the matrix.
2642   XlaBuilder builder(TestName());
2643   auto v = ConstantR1<float>(&builder, {20.0f, 40.0f, 60.0f});
2644   // clang-format off
2645   auto m = ConstantR2<float>(&builder, {
2646     {-2.5f, 3.14f, 1.0f},
2647     {2.25f, -10.0f, 3.33f}});
2648   // clang-format on
2649   Add(v, m, /*broadcast_dimensions=*/{1});
2650   Array2D<float> expected_array(
2651       {{17.5f, 43.14f, 61.0f}, {22.25f, 30.0f, 63.33f}});
2652   ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_);
2653 }
2654 
XLA_TEST_F(ArrayElementwiseOpTest,Compare1DTo2DS32Eq)2655 XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Eq) {
2656   // Test broadcasting in Eq comparison.
2657   XlaBuilder builder(TestName());
2658   auto v = ConstantR1<int32>(&builder, {42, 73});
2659   auto m = ConstantR2<int32>(&builder, {{42, 73}, {42, 52}});
2660 
2661   // This test exercises both possible broadcast dimensions for a vector/matrix
2662   // comparison.
2663   auto cmp_dim_0 = Eq(v, m, /*broadcast_dimensions=*/{1});
2664   auto cmp_dim_1 = Eq(v, m, /*broadcast_dimensions=*/{0});
2665   Tuple(&builder, {cmp_dim_0, cmp_dim_1});
2666 
2667   auto expected = LiteralUtil::MakeTupleFromSlices(
2668       {LiteralUtil::CreateR2<bool>({{true, true}, {true, false}}),
2669        LiteralUtil::CreateR2<bool>({{true, false}, {false, false}})});
2670   ComputeAndCompareTuple(&builder, expected, {}, error_spec_);
2671 }
2672 
XLA_TEST_F(ArrayElementwiseOpTest,Compare1DTo2DS32Ne)2673 XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Ne) {
2674   // Test broadcasting in Ne comparison.
2675   XlaBuilder builder(TestName());
2676   auto v = ConstantR1<int32>(&builder, {42, 73});
2677   auto m = ConstantR2<int32>(&builder, {{42, 73}, {42, 52}});
2678   Ne(v, m, /*broadcast_dimensions=*/{1});
2679 
2680   const string expected = R"(pred[2,2] {
2681   { 0, 0 },
2682   { 0, 1 }
2683 })";
2684   EXPECT_EQ(expected, ExecuteToString(&builder, {}));
2685 }
2686 
XLA_TEST_F(ArrayElementwiseOpTest,Compare1DTo2DS32Ge)2687 XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Ge) {
2688   // Test broadcasting in Ge comparison.
2689   XlaBuilder builder(TestName());
2690   auto v = ConstantR1<int32>(&builder, {1, 2, 3, 4});
2691   auto m = ConstantR2<int32>(&builder, {{1, 0, 5, 6}, {42, 52, 10, 4}});
2692   Ge(v, m, /*broadcast_dimensions=*/{1});
2693 
2694   const string expected = R"(pred[2,4] {
2695   { 1, 1, 0, 0 },
2696   { 0, 0, 0, 1 }
2697 })";
2698   EXPECT_EQ(expected, ExecuteToString(&builder, {}));
2699 }
2700 
XLA_TEST_F(ArrayElementwiseOpTest,Compare1DTo2DS32Gt)2701 XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Gt) {
2702   // Test broadcasting in Gt comparison.
2703   XlaBuilder builder(TestName());
2704   auto v = ConstantR1<int32>(&builder, {1, 2, 3, 4});
2705   auto m = ConstantR2<int32>(&builder, {{1, 0, 5, 6}, {42, 52, 10, 4}});
2706   Gt(v, m, /*broadcast_dimensions=*/{1});
2707 
2708   const string expected = R"(pred[2,4] {
2709   { 0, 1, 0, 0 },
2710   { 0, 0, 0, 0 }
2711 })";
2712   EXPECT_EQ(expected, ExecuteToString(&builder, {}));
2713 }
2714 
XLA_TEST_F(ArrayElementwiseOpTest,Compare1DTo2DS32Le)2715 XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Le) {
2716   // Test broadcasting in Le comparison.
2717   XlaBuilder builder(TestName());
2718   auto v = ConstantR1<int32>(&builder, {1, 2, 3, 4});
2719   auto m = ConstantR2<int32>(&builder, {{1, 0, 5, 6}, {42, 52, 10, 4}});
2720   Le(v, m, /*broadcast_dimensions=*/{1});
2721 
2722   const string expected = R"(pred[2,4] {
2723   { 1, 0, 1, 1 },
2724   { 1, 1, 1, 1 }
2725 })";
2726   EXPECT_EQ(expected, ExecuteToString(&builder, {}));
2727 }
2728 
XLA_TEST_F(ArrayElementwiseOpTest,Compare1DTo2DS32Lt)2729 XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Lt) {
2730   // Test broadcasting in Lt comparison.
2731   XlaBuilder builder(TestName());
2732   auto v = ConstantR1<int32>(&builder, {1, 2, 3, 4});
2733   auto m = ConstantR2<int32>(&builder, {{1, 0, 5, 6}, {42, 52, 10, 4}});
2734   Lt(v, m, /*broadcast_dimensions=*/{1});
2735 
2736   const string expected = R"(pred[2,4] {
2737   { 0, 0, 1, 1 },
2738   { 1, 1, 1, 0 }
2739 })";
2740   EXPECT_EQ(expected, ExecuteToString(&builder, {}));
2741 }
2742 
XLA_TEST_F(ArrayElementwiseOpTest,Mul2Dby1DF32)2743 XLA_TEST_F(ArrayElementwiseOpTest, Mul2Dby1DF32) {
2744   // Test simple broadcasting of a R1F32 over R2F32 when the order of binary op
2745   // arguments is reversed.
2746   XlaBuilder builder(TestName());
2747   auto m =
2748       ConstantR2<float>(&builder, {{1.5f, 2.5f, 3.5f}, {4.5f, 5.5f, 6.5f}});
2749   auto v = ConstantR1<float>(&builder, {2.0f, 4.0f, 6.0f});
2750   Mul(m, v, /*broadcast_dimensions=*/{1});
2751   Array2D<float> expected_array({{3.0f, 10.0f, 21.0f}, {9.0f, 22.0f, 39.0f}});
2752   ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_);
2753 }
2754 
XLA_TEST_F(ArrayElementwiseOpTest,Add2DTo2DWithDegenerateDim1)2755 XLA_TEST_F(ArrayElementwiseOpTest, Add2DTo2DWithDegenerateDim1) {
2756   // Tests broadcasting for arrays with degenerate (size == 1) dimensions.
2757   XlaBuilder builder(TestName());
2758   // m's shape in XLA notation is {3, 2}
2759   // md's shape in XLA notation is {3, 1}
2760   // The result has shape {3, 2}, where md is broadcast over m
2761   auto m = ConstantR2<float>(&builder,
2762                              {{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}});
2763   auto md = ConstantR2<float>(&builder, {{10.0f, 20.0f, 30.0f}});
2764   Add(m, md);
2765   Array2D<float> expected_array(
2766       {{7.5f, 23.14f, 31.0f}, {12.25f, 10.0f, 33.33f}});
2767   ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_);
2768 }
2769 
XLA_TEST_F(ArrayElementwiseOpTest,Add2DTo2DWithDegenerateDim0)2770 XLA_TEST_F(ArrayElementwiseOpTest, Add2DTo2DWithDegenerateDim0) {
2771   // Tests broadcasting for arrays with degenerate (size == 1) dimensions.
2772   XlaBuilder builder(TestName());
2773   // m's shape in XLA notation is {3, 2}
2774   // md's shape in XLA notation is {1, 2}
2775   // The result has shape {3, 2}, where md is broadcast over m
2776   auto m = ConstantR2<float>(&builder,
2777                              {{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}});
2778   auto md = ConstantR2<float>(&builder, {{10.0f}, {20.0f}});
2779   Add(m, md);
2780   Array2D<float> expected_array(
2781       {{7.5f, 13.14f, 11.0f}, {22.25f, 10.0f, 23.33f}});
2782   ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_);
2783 }
2784 
XLA_TEST_F(ArrayElementwiseOpTest,Add2DsWithDegenerateDimsOuterProduct)2785 XLA_TEST_F(ArrayElementwiseOpTest, Add2DsWithDegenerateDimsOuterProduct) {
2786   // Tests broadcasting for two degenerate arrays. This kind of broadcasting
2787   // effectively creates an "outer product" operation.
2788   // This is taken from the Numpy docs example at:
2789   // http://docs.scipy.org/doc/numpy-1.10.1/user/basics.broadcasting.html
2790   XlaBuilder builder(TestName());
2791   // a's shape in XLA notation is {1, 4}
2792   // b's shape in XLA notation is {3, 1}
2793   // The result has shape {3, 4}.
2794   auto a = ConstantR2<float>(&builder, {{0.0f}, {10.0f}, {20.0f}, {30.0f}});
2795   auto b = ConstantR2<float>(&builder, {{1.0f, 2.0f, 3.0f}});
2796   Add(a, b);
2797   Array2D<float> expected_array({{1.0f, 2.0f, 3.0f},
2798                                  {11.0f, 12.0f, 13.0f},
2799                                  {21.0f, 22.0f, 23.0f},
2800                                  {31.0f, 32.0f, 33.0f}});
2801   ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_);
2802 }
2803 
XLA_TEST_F(ArrayElementwiseOpTest,Add1DTo2DF32TwoWaysOver1)2804 XLA_TEST_F(ArrayElementwiseOpTest, Add1DTo2DF32TwoWaysOver1) {
2805   // Add together a (2,2) array and a (2) array, using dimension 0 for
2806   // broadcasting (though there are two ways to broadcast these shapes).
2807   XlaBuilder builder(TestName());
2808   auto v = ConstantR1<float>(&builder, {20.0f, 40.0f});
2809   auto m = ConstantR2<float>(&builder, {{10.0f, 50.0f}, {77.0f, 88.0f}});
2810   Add(v, m, /*broadcast_dimensions=*/{1});
2811   Array2D<float> expected_array({{30.0f, 90.0f}, {97.0f, 128.0f}});
2812   ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_);
2813 }
2814 
XLA_TEST_F(ArrayElementwiseOpTest,Add1DTo2DF32TwoWaysOver0)2815 XLA_TEST_F(ArrayElementwiseOpTest, Add1DTo2DF32TwoWaysOver0) {
2816   // Add together a (2,2) array and a (2) array, using dimension 1 for
2817   // broadcasting (though there are two ways to broadcast these shapes).
2818   XlaBuilder builder(TestName());
2819   auto v = ConstantR1<float>(&builder, {20.0f, 40.0f});
2820   auto m = ConstantR2<float>(&builder, {{10.0f, 50.0f}, {77.0f, 88.0f}});
2821   Add(v, m, /*broadcast_dimensions=*/{0});
2822   Array2D<float> expected_array({{30.0f, 70.0f}, {117.0f, 128.0f}});
2823   ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_);
2824 }
2825 
2826 XLA_TEST_F(ArrayElementwiseOpTest, 3DBinaryOpF32s) {
2827   // Binary add of two R3s together
2828   XlaBuilder builder(TestName());
2829   Array3D<float> a_3d({{{1.0f, 2.0f}, {3.0f, 4.0f}, {5.0f, 6.0f}},
2830                        {{7.0f, 8.0f}, {9.0f, 10.0f}, {11.0f, 12.0f}}});
2831   auto a = ConstantR3FromArray3D<float>(&builder, a_3d);
2832 
2833   Array3D<float> b_3d({{{2.0f, 4.0f}, {6.0f, 8.0f}, {10.0f, 12.0f}},
2834                        {{14.0f, 16.0f}, {18.0f, 20.0f}, {22.0f, 24.0f}}});
2835   auto b = ConstantR3FromArray3D<float>(&builder, b_3d);
2836   Add(a, b);
2837 
2838   Array3D<float> expected_3d(
2839       {{{3.0f, 6.0f}, {9.0f, 12.0f}, {15.0f, 18.0f}},
2840        {{21.0f, 24.0f}, {27.0f, 30.0f}, {33.0f, 36.0f}}});
2841   ComputeAndCompareR3<float>(&builder, expected_3d, {}, error_spec_);
2842 }
2843 
XLA_TEST_F(ArrayElementwiseOpTest,Add1DTo3DTwoWaysOver2)2844 XLA_TEST_F(ArrayElementwiseOpTest, Add1DTo3DTwoWaysOver2) {
2845   // Add together a (2, 3, 2) array with a (2) array, using dimension 0 for
2846   // broadcasting (though there are two ways to broadcast these shapes).
2847   XlaBuilder builder(TestName());
2848   // clang-format off
2849   Array3D<float> a_3d({
2850     {{1.0f, 2.0f},
2851      {3.0f, 4.0f},
2852      {5.0f, 6.0f}},
2853     {{7.0f, 8.0f},
2854      {9.0f, 10.0f},
2855      {11.0f, 12.0f}},
2856   });
2857   // clang-format on
2858   auto a = ConstantR3FromArray3D<float>(&builder, a_3d);
2859   auto v = ConstantR1<float>(&builder, {10.0f, 20.0f});
2860   Add(a, v, /*broadcast_dimensions=*/{2});
2861 
2862   Array3D<float> expected_3d(
2863       {{{11.0f, 22.0f}, {13.0f, 24.0f}, {15.0f, 26.0f}},
2864        {{17.0f, 28.0f}, {19.0f, 30.0f}, {21.0f, 32.0f}}});
2865   ComputeAndCompareR3<float>(&builder, expected_3d, {}, error_spec_);
2866 }
2867 
XLA_TEST_F(ArrayElementwiseOpTest,Add1DTo3DTwoWaysOver0)2868 XLA_TEST_F(ArrayElementwiseOpTest, Add1DTo3DTwoWaysOver0) {
2869   // Add together a (2, 3, 2) array with a (2) array, using dimension 2 for
2870   // broadcasting (though there are two ways to broadcast these shapes).
2871   XlaBuilder builder(TestName());
2872   // clang-format off
2873   Array3D<float> a_3d({
2874     {{1.0f, 2.0f},
2875      {3.0f, 4.0f},
2876      {5.0f, 6.0f}},
2877     {{7.0f, 8.0f},
2878      {9.0f, 10.0f},
2879      {11.0f, 12.0f}},
2880   });
2881   // clang-format on
2882   auto a = ConstantR3FromArray3D<float>(&builder, a_3d);
2883   auto v = ConstantR1<float>(&builder, {10.0f, 20.0f});
2884   Add(a, v, /*broadcast_dimensions=*/{0});
2885 
2886   // clang-format off
2887   Array3D<float> expected_3d({
2888     {{11.0f, 12.0f},
2889      {13.0f, 14.0f},
2890      {15.0f, 16.0f}},
2891     {{27.0f, 28.0f},
2892      {29.0f, 30.0f},
2893      {31.0f, 32.0f}},
2894   });
2895   // clang-format on
2896   ComputeAndCompareR3<float>(&builder, expected_3d, {}, error_spec_);
2897 }
2898 
XLA_TEST_F(ArrayElementwiseOpTest,Add2DTo3D)2899 XLA_TEST_F(ArrayElementwiseOpTest, Add2DTo3D) {
2900   // Add together a (2, 3, 2) array with a (3, 2) array, using dimensions {1,2}
2901   // for broadcasting.
2902   XlaBuilder builder(TestName());
2903   // clang-format off
2904   Array3D<float> a_3d({
2905     {{1.0f, 2.0f},
2906      {3.0f, 4.0f},
2907      {5.0f, 6.0f}},
2908     {{7.0f, 8.0f},
2909      {9.0f, 10.0f},
2910      {11.0f, 12.0f}},
2911   });
2912   auto a = ConstantR3FromArray3D<float>(&builder, a_3d);
2913   auto m = ConstantR2<float>(&builder, {
2914     {10.0f, 20.0f, 30.0f},
2915     {40.0f, 50.0f, 60.0f},
2916   });
2917   Add(a, m, /*broadcast_dimensions=*/{0, 1});
2918 
2919   Array3D<float> expected_3d({
2920     {{11.0f, 12.0f},
2921      {23.0f, 24.0f},
2922      {35.0f, 36.0f}},
2923     {{47.0f, 48.0f},
2924      {59.0f, 60.0f},
2925      {71.0f, 72.0f}},
2926   });
2927   // clang-format on
2928   ComputeAndCompareR3<float>(&builder, expected_3d, {}, error_spec_);
2929 }
2930 
XLA_TEST_F(ArrayElementwiseOpTest,CompareGtR3F32sWithDegenerateDim2)2931 XLA_TEST_F(ArrayElementwiseOpTest, CompareGtR3F32sWithDegenerateDim2) {
2932   // Comparison between two 3D arrays of compatible shapes:
2933   // (2, 3, 2) and (2, 3, 1): expected to produce a (2, 3, 2) shape of PREDs.
2934   XlaBuilder builder(TestName());
2935   Array3D<float> a_3d({{{1.0f, 2.0f}, {3.0f, 4.0f}, {5.0f, 6.0f}},
2936                        {{7.0f, 8.0f}, {9.0f, 10.0f}, {11.0f, 12.0f}}});
2937   auto a = ConstantR3FromArray3D<float>(&builder, a_3d);
2938 
2939   Array3D<float> b_3d({{{7.0f, 1.0f}, {3.0f, 10.0f}, {15.0f, 6.0f}}});
2940   auto b = ConstantR3FromArray3D<float>(&builder, b_3d);
2941 
2942   Gt(a, b);
2943 
2944   Array3D<int> expected_3d(
2945       {{{0, 1}, {0, 0}, {0, 0}}, {{0, 1}, {1, 0}, {0, 1}}});
2946   const string expected = R"(pred[2,3,2] {
2947 {
2948   { 0, 1 },
2949   { 0, 0 },
2950   { 0, 0 }
2951 },
2952 {
2953   { 0, 1 },
2954   { 1, 0 },
2955   { 0, 1 }
2956 }
2957 })";
2958   EXPECT_EQ(expected, ExecuteToString(&builder, {}));
2959 }
2960 
2961 XLA_TEST_F(ArrayElementwiseOpTest, 4DBinaryOpF32s) {
2962   XlaBuilder builder(TestName());
2963 
2964   std::unique_ptr<Array4D<float>> operand_a_4d(new Array4D<float>(2, 3, 4, 5));
2965   std::unique_ptr<Array4D<float>> operand_b_4d(new Array4D<float>(2, 3, 4, 5));
2966   std::unique_ptr<Array4D<float>> expected_4d(new Array4D<float>(2, 3, 4, 5));
2967   float value = 0.0;
2968   for (int64_t p = 0; p < 2; ++p) {
2969     for (int64_t z = 0; z < 3; ++z) {
2970       for (int64_t y = 0; y < 4; ++y) {
2971         for (int64_t x = 0; x < 5; ++x) {
2972           (*operand_a_4d)(p, z, y, x) = value;
2973           (*operand_b_4d)(p, z, y, x) = 2.0 * value;
2974           (*expected_4d)(p, z, y, x) = 3.0 * value;
2975           value += 0.1;
2976         }
2977       }
2978     }
2979   }
2980 
2981   auto a = ConstantR4FromArray4D<float>(&builder, *operand_a_4d);
2982   auto b = ConstantR4FromArray4D<float>(&builder, *operand_b_4d);
2983   Add(a, b);
2984 
2985   ComputeAndCompareR4<float>(&builder, *expected_4d, {}, error_spec_);
2986 }
2987 
XLA_TEST_F(ArrayElementwiseOpTest,R4PlusR1InDim1)2988 XLA_TEST_F(ArrayElementwiseOpTest, R4PlusR1InDim1) {
2989   XlaBuilder builder(TestName());
2990 
2991   std::unique_ptr<Array4D<float>> operand_a_4d(new Array4D<float>(2, 3, 4, 5));
2992   std::unique_ptr<Array4D<float>> expected_4d(new Array4D<float>(2, 3, 4, 5));
2993   std::vector<float> operand_b_1d(3);
2994   std::iota(operand_b_1d.begin(), operand_b_1d.end(), 1.0);
2995 
2996   float value = 0.0;
2997   for (int64_t p = 0; p < 2; ++p) {
2998     for (int64_t z = 0; z < 3; ++z) {
2999       for (int64_t y = 0; y < 4; ++y) {
3000         for (int64_t x = 0; x < 5; ++x) {
3001           (*operand_a_4d)(p, z, y, x) = value;
3002           (*expected_4d)(p, z, y, x) = value + operand_b_1d[z];
3003           value += 0.1;
3004         }
3005       }
3006     }
3007   }
3008 
3009   auto a = ConstantR4FromArray4D<float>(&builder, *operand_a_4d);
3010   auto b = ConstantR1<float>(&builder, operand_b_1d);
3011   Add(a, b, {1});
3012 
3013   ComputeAndCompareR4<float>(&builder, *expected_4d, {}, error_spec_);
3014 }
3015 
XLA_TEST_F(ArrayElementwiseOpTest,R4_16x16x2x2_Plus_R1_16)3016 XLA_TEST_F(ArrayElementwiseOpTest, R4_16x16x2x2_Plus_R1_16) {
3017   constexpr int d0 = 16;
3018   constexpr int d1 = 16;
3019   constexpr int d2 = 2;
3020   constexpr int d3 = 2;
3021   Array4D<float> r4(d0, d1, d2, d3);
3022   r4.Fill(1.0);
3023   std::vector<float> r1(d1);
3024   std::iota(r1.begin(), r1.end(), 1.0);
3025 
3026   XlaBuilder builder(TestName());
3027   Literal a_literal = LiteralUtil::CreateR4FromArray4DWithLayout(
3028       r4, LayoutUtil::MakeLayout({0, 1, 2, 3}));
3029   auto a = ConstantLiteral(&builder, a_literal);
3030   auto b = ConstantR1<float>(&builder, r1);
3031   Add(a, b, {1});
3032 
3033   for (int i0 = 0; i0 < d0; ++i0) {
3034     for (int i1 = 0; i1 < d1; ++i1) {
3035       for (int i2 = 0; i2 < d2; ++i2) {
3036         for (int i3 = 0; i3 < d3; ++i3) {
3037           r4(i0, i1, i2, i3) += r1[i1];
3038         }
3039       }
3040     }
3041   }
3042   ComputeAndCompareR4<float>(&builder, r4, {}, error_spec_);
3043 }
3044 
3045 // Show that we can't add two opaques.
XLA_TEST_F(ArrayElementwiseOpTest,CannotAddOpaques)3046 XLA_TEST_F(ArrayElementwiseOpTest, CannotAddOpaques) {
3047   XlaBuilder builder(TestName());
3048   auto shape = ShapeUtil::MakeOpaqueShape();
3049   auto x = Parameter(&builder, 0, shape, "x");
3050   Add(x, x);
3051   auto computation_status = builder.Build();
3052   ASSERT_FALSE(computation_status.ok());
3053   EXPECT_THAT(computation_status.status().ToString(),
3054               ::testing::ContainsRegex(
3055                   "Expected array argument for lhs of binary operation"));
3056 }
3057 
XLA_TEST_F(ArrayElementwiseOpTest,IdentityBroadcastOfSameRankIsAllowed)3058 XLA_TEST_F(ArrayElementwiseOpTest, IdentityBroadcastOfSameRankIsAllowed) {
3059   XlaBuilder builder(TestName());
3060   auto a = ConstantR2<float>(&builder,
3061                              {{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}});
3062   auto b = ConstantR2<float>(&builder,
3063                              {{-1.5f, 8.14f, 42.0}, {-1.0f, -4.0f, 5.55f}});
3064   Add(a, b, /*broadcast_dimensions=*/{0, 1});
3065 
3066   Array2D<float> expected_array(
3067       {{-4.0f, 11.28f, 43.0f}, {1.25f, -14.0f, 8.88f}});
3068   ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_);
3069 }
3070 
XLA_TEST_F(ArrayElementwiseOpTest,NonIdentityBroadcastOfSameRankIsDisallowed)3071 XLA_TEST_F(ArrayElementwiseOpTest, NonIdentityBroadcastOfSameRankIsDisallowed) {
3072   XlaBuilder builder(TestName());
3073   auto a = ConstantR2<float>(&builder,
3074                              {{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}});
3075   auto b = ConstantR2<float>(&builder,
3076                              {{-1.5f, 8.14f, 42.0}, {-1.0f, -4.0f, 5.55f}});
3077   Add(a, b, /*broadcast_dimensions=*/{1, 0});
3078 
3079   auto computation_status = builder.Build();
3080   ASSERT_FALSE(computation_status.ok());
3081   EXPECT_THAT(computation_status.status().error_message(),
3082               ::testing::ContainsRegex("must.*be the identity"));
3083 }
3084 
3085 // Regression test for b/31927799. "slice - y" is fused and requires implicit
3086 // broadcast.
XLA_TEST_F(ArrayElementwiseOpTest,ImplicitBroadcastInFusedExpressions)3087 XLA_TEST_F(ArrayElementwiseOpTest, ImplicitBroadcastInFusedExpressions) {
3088   XlaBuilder builder(TestName());
3089   auto x_literal = LiteralUtil::CreateR1<float>({1, 2, 3});
3090   auto y_literal = LiteralUtil::CreateR1<float>({4, 5});
3091   auto x_data = client_->TransferToServer(x_literal).ConsumeValueOrDie();
3092   auto y_data = client_->TransferToServer(y_literal).ConsumeValueOrDie();
3093 
3094   auto x = Parameter(&builder, 0, x_literal.shape(), "x");
3095   auto y = Parameter(&builder, 1, y_literal.shape(), "y");
3096   auto slice = Slice(x, {1}, {2}, {1});
3097   Sub(slice, y);
3098 
3099   ComputeAndCompareR1<float>(&builder, {-2, -3}, {x_data.get(), y_data.get()},
3100                              error_spec_);
3101 }
3102 
3103 INSTANTIATE_TEST_CASE_P(ArrayElementwiseOpTestParamCount,
3104                         ArrayElementwiseOpTestParamCount,
3105                         ::testing::Values(127, 128, 129, 17 * 4096));
3106 
3107 }  // namespace
3108 }  // namespace xla
3109