• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <cmath>
17 #include <limits>
18 #include <memory>
19 
20 #include "absl/strings/str_cat.h"
21 #include "absl/types/span.h"
22 #include "tensorflow/compiler/xla/client/global_data.h"
23 #include "tensorflow/compiler/xla/client/local_client.h"
24 #include "tensorflow/compiler/xla/client/xla_builder.h"
25 #include "tensorflow/compiler/xla/client/xla_computation.h"
26 #include "tensorflow/compiler/xla/literal.h"
27 #include "tensorflow/compiler/xla/literal_util.h"
28 #include "tensorflow/compiler/xla/status_macros.h"
29 #include "tensorflow/compiler/xla/statusor.h"
30 #include "tensorflow/compiler/xla/test_helpers.h"
31 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
32 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
33 #include "tensorflow/compiler/xla/tests/test_macros.h"
34 #include "tensorflow/compiler/xla/xla_data.pb.h"
35 #include "tensorflow/core/platform/test.h"
36 
37 namespace xla {
38 namespace {
39 
40 class ScalarComputationsTest : public ClientLibraryTestBase {
41  public:
42   ErrorSpec error_spec_{0.0001};
43 
44  protected:
45   // A template for building and running a binary comparison test.
46   template <typename NativeT>
TestCompare(NativeT lhs,NativeT rhs,bool expected,const std::function<XlaOp (const XlaOp &,const XlaOp &,absl::Span<const int64_t>)> & op)47   void TestCompare(NativeT lhs, NativeT rhs, bool expected,
48                    const std::function<XlaOp(const XlaOp&, const XlaOp&,
49                                              absl::Span<const int64_t>)>& op) {
50     XlaBuilder builder(TestName());
51     XlaOp lhs_op = ConstantR0<NativeT>(&builder, lhs);
52     XlaOp rhs_op = ConstantR0<NativeT>(&builder, rhs);
53     op(lhs_op, rhs_op, {});
54     ComputeAndCompareR0<bool>(&builder, expected, {});
55   }
56 
57   template <typename NativeT>
TestMinMax(NativeT lhs,NativeT rhs,NativeT expected,const std::function<XlaOp (const XlaOp &,const XlaOp &,absl::Span<const int64_t>)> & op)58   void TestMinMax(NativeT lhs, NativeT rhs, NativeT expected,
59                   const std::function<XlaOp(const XlaOp&, const XlaOp&,
60                                             absl::Span<const int64_t>)>& op) {
61     XlaBuilder builder(TestName());
62     XlaOp lhs_op = ConstantR0<NativeT>(&builder, lhs);
63     XlaOp rhs_op = ConstantR0<NativeT>(&builder, rhs);
64     op(lhs_op, rhs_op, {});
65     ComputeAndCompareR0<NativeT>(&builder, expected, {});
66   }
67 };
68 
XLA_TEST_F(ScalarComputationsTest,ReturnScalarF32)69 XLA_TEST_F(ScalarComputationsTest, ReturnScalarF32) {
70   XlaBuilder builder(TestName());
71   ConstantR0<float>(&builder, 2.1f);
72 
73   ComputeAndCompareR0<float>(&builder, 2.1f, {}, error_spec_);
74 }
75 
XLA_TEST_F(ScalarComputationsTest,NegateScalarF32)76 XLA_TEST_F(ScalarComputationsTest, NegateScalarF32) {
77   XlaBuilder builder(TestName());
78   Neg(ConstantR0<float>(&builder, 2.1f));
79 
80   ComputeAndCompareR0<float>(&builder, -2.1f, {}, error_spec_);
81 }
82 
XLA_TEST_F(ScalarComputationsTest,NegateScalarS32)83 XLA_TEST_F(ScalarComputationsTest, NegateScalarS32) {
84   XlaBuilder builder(TestName());
85   Neg(ConstantR0<int32_t>(&builder, 2));
86 
87   ComputeAndCompareR0<int32_t>(&builder, -2, {});
88 }
89 
XLA_TEST_F(ScalarComputationsTest,AddTwoScalarsF32)90 XLA_TEST_F(ScalarComputationsTest, AddTwoScalarsF32) {
91   XlaBuilder builder(TestName());
92   Add(ConstantR0<float>(&builder, 2.1f), ConstantR0<float>(&builder, 5.5f));
93 
94   ComputeAndCompareR0<float>(&builder, 7.6f, {}, error_spec_);
95 }
96 
XLA_TEST_F(ScalarComputationsTest,AddTwoScalarsS32)97 XLA_TEST_F(ScalarComputationsTest, AddTwoScalarsS32) {
98   XlaBuilder builder(TestName());
99   Add(ConstantR0<int32_t>(&builder, 2), ConstantR0<int32_t>(&builder, 5));
100 
101   ComputeAndCompareR0<int32_t>(&builder, 7, {});
102 }
103 
XLA_TEST_F(ScalarComputationsTest,AddTwoScalarsU32)104 XLA_TEST_F(ScalarComputationsTest, AddTwoScalarsU32) {
105   XlaBuilder builder(TestName());
106   Add(ConstantR0<uint32_t>(&builder, 35), ConstantR0<uint32_t>(&builder, 57));
107 
108   ComputeAndCompareR0<uint32_t>(&builder, 92, {});
109 }
110 
XLA_TEST_F(ScalarComputationsTest,AddTwoScalarsU8)111 XLA_TEST_F(ScalarComputationsTest, AddTwoScalarsU8) {
112   XlaBuilder builder(TestName());
113   Add(ConstantR0<uint8_t>(&builder, 35), ConstantR0<uint8_t>(&builder, 57));
114 
115   ComputeAndCompareR0<uint8_t>(&builder, 92, {});
116 }
117 
XLA_TEST_F(ScalarComputationsTest,AddTwoScalarsU64)118 XLA_TEST_F(ScalarComputationsTest, AddTwoScalarsU64) {
119   XlaBuilder builder(TestName());
120   const uint64_t a = static_cast<uint64_t>(1) << 63;
121   const uint64_t b = a + 1;
122   Add(ConstantR0<uint64_t>(&builder, a), ConstantR0<uint64_t>(&builder, b));
123 
124   ComputeAndCompareR0<uint64_t>(&builder, a + b, {});
125 }
126 
XLA_TEST_F(ScalarComputationsTest,AddTwoScalarsS64)127 XLA_TEST_F(ScalarComputationsTest, AddTwoScalarsS64) {
128   XlaBuilder builder(TestName());
129   const int64_t a = static_cast<int64_t>(1) << 62;
130   const int64_t b = a - 1;
131   Add(ConstantR0<int64_t>(&builder, a), ConstantR0<int64_t>(&builder, b));
132 
133   ComputeAndCompareR0<int64_t>(&builder, a + b, {});
134 }
135 
XLA_TEST_F(ScalarComputationsTest,AddTwoScalarsF64)136 XLA_TEST_F(ScalarComputationsTest, AddTwoScalarsF64) {
137   XlaBuilder builder(TestName());
138   Add(ConstantR0<double>(&builder, 0.25), ConstantR0<double>(&builder, 3.5));
139 
140   ComputeAndCompareR0<double>(&builder, 3.75, {});
141 }
142 
XLA_TEST_F(ScalarComputationsTest,SubtractTwoScalarsF32)143 XLA_TEST_F(ScalarComputationsTest, SubtractTwoScalarsF32) {
144   XlaBuilder builder(TestName());
145   Sub(ConstantR0<float>(&builder, 2.1f), ConstantR0<float>(&builder, 5.5f));
146 
147   ComputeAndCompareR0<float>(&builder, -3.4f, {}, error_spec_);
148 }
149 
XLA_TEST_F(ScalarComputationsTest,SubtractTwoScalarsS32)150 XLA_TEST_F(ScalarComputationsTest, SubtractTwoScalarsS32) {
151   XlaBuilder builder(TestName());
152   Sub(ConstantR0<int32_t>(&builder, 2), ConstantR0<int32_t>(&builder, 5));
153 
154   ComputeAndCompareR0<int32_t>(&builder, -3, {});
155 }
156 
XLA_TEST_F(ScalarComputationsTest,CastS64ToF32)157 XLA_TEST_F(ScalarComputationsTest, CastS64ToF32) {
158   XlaBuilder builder(TestName());
159   auto a = Parameter(&builder, 0, ShapeUtil::MakeShape(S64, {}), "a");
160   ConvertElementType(a, F32);
161 
162   int64_t value = 3LL << 35;
163   Literal a_literal = LiteralUtil::CreateR0<int64_t>(value);
164   std::unique_ptr<GlobalData> a_data =
165       client_->TransferToServer(a_literal).value();
166   ComputeAndCompareR0<float>(&builder, static_cast<float>(value),
167                              {a_data.get()});
168 }
169 
XLA_TEST_F(ScalarComputationsTest,MulThreeScalarsF32)170 XLA_TEST_F(ScalarComputationsTest, MulThreeScalarsF32) {
171   XlaBuilder builder(TestName());
172   Mul(Mul(ConstantR0<float>(&builder, 2.1f), ConstantR0<float>(&builder, 5.5f)),
173       ConstantR0<float>(&builder, 0.5f));
174 
175   ComputeAndCompareR0<float>(&builder, 5.775f, {}, error_spec_);
176 }
177 
XLA_TEST_F(ScalarComputationsTest,MulThreeScalarsF64)178 XLA_TEST_F(ScalarComputationsTest, MulThreeScalarsF64) {
179   XlaBuilder builder(TestName());
180   Mul(Mul(ConstantR0<double>(&builder, 3.1415926535897932),
181           ConstantR0<double>(&builder, 2.7182818284590452)),
182       ConstantR0<double>(&builder, 0.5772156649015328));
183 
184   ComputeAndCompareR0<double>(&builder, 4.929268367422896, {},
185                               ErrorSpec{3.6e-15});
186 }
187 
XLA_TEST_F(ScalarComputationsTest,MulTwoScalarsS32)188 XLA_TEST_F(ScalarComputationsTest, MulTwoScalarsS32) {
189   std::vector<int32_t> data = {0,
190                                1,
191                                -1,
192                                1234,
193                                0x1a243514,
194                                std::numeric_limits<int32_t>::max(),
195                                std::numeric_limits<int32_t>::min()};
196 
197   for (int32_t x : data) {
198     for (int32_t y : data) {
199       XlaBuilder builder(TestName());
200       Mul(ConstantR0<int32_t>(&builder, x), ConstantR0<int32_t>(&builder, y));
201 
202       // Signed integer overflow is undefined behavior in C++. Convert the input
203       // integers to unsigned, perform the multiplication unsigned, and convert
204       // back.
205       int32_t expected = static_cast<uint32_t>(x) * static_cast<uint32_t>(y);
206 
207       ComputeAndCompareR0<int32_t>(&builder, expected, {});
208     }
209   }
210 }
211 
XLA_TEST_F(ScalarComputationsTest,MulTwoScalarsU32)212 XLA_TEST_F(ScalarComputationsTest, MulTwoScalarsU32) {
213   std::vector<uint32_t> data = {0,          1,          0xDEADBEEF, 1234,
214                                 0x1a243514, 0xFFFFFFFF, 0x80808080};
215 
216   for (uint32_t x : data) {
217     for (uint32_t y : data) {
218       XlaBuilder builder(TestName());
219       Mul(ConstantR0<uint32_t>(&builder, x), ConstantR0<uint32_t>(&builder, y));
220 
221       uint32_t expected = x * y;
222       ComputeAndCompareR0<uint32_t>(&builder, expected, {});
223     }
224   }
225 }
226 
XLA_TEST_F(ScalarComputationsTest,MulThreeScalarsS32)227 XLA_TEST_F(ScalarComputationsTest, MulThreeScalarsS32) {
228   XlaBuilder builder(TestName());
229   Mul(Mul(ConstantR0<int32_t>(&builder, 2), ConstantR0<int32_t>(&builder, 5)),
230       ConstantR0<int32_t>(&builder, 1));
231 
232   ComputeAndCompareR0<int32_t>(&builder, 10, {});
233 }
234 
XLA_TEST_F(ScalarComputationsTest,MulThreeScalarsF32Params)235 XLA_TEST_F(ScalarComputationsTest, MulThreeScalarsF32Params) {
236   XlaBuilder builder(TestName());
237   Literal a_literal = LiteralUtil::CreateR0<float>(2.1f);
238   Literal b_literal = LiteralUtil::CreateR0<float>(5.5f);
239   Literal c_literal = LiteralUtil::CreateR0<float>(0.5f);
240 
241   std::unique_ptr<GlobalData> a_data =
242       client_->TransferToServer(a_literal).value();
243   std::unique_ptr<GlobalData> b_data =
244       client_->TransferToServer(b_literal).value();
245   std::unique_ptr<GlobalData> c_data =
246       client_->TransferToServer(c_literal).value();
247 
248   XlaOp a = Parameter(&builder, 0, a_literal.shape(), "a");
249   XlaOp b = Parameter(&builder, 1, b_literal.shape(), "b");
250   XlaOp c = Parameter(&builder, 2, c_literal.shape(), "c");
251   Mul(Mul(a, b), c);
252 
253   ComputeAndCompareR0<float>(&builder, 5.775f,
254                              {a_data.get(), b_data.get(), c_data.get()},
255                              error_spec_);
256 }
257 
XLA_TEST_F(ScalarComputationsTest,DivideTwoScalarsF32)258 XLA_TEST_F(ScalarComputationsTest, DivideTwoScalarsF32) {
259   XlaBuilder builder(TestName());
260   Div(ConstantR0<float>(&builder, 5.0f), ConstantR0<float>(&builder, 2.5f));
261 
262   ComputeAndCompareR0<float>(&builder, 2.0f, {}, error_spec_);
263 }
264 
XLA_TEST_F(ScalarComputationsTest,RemTwoScalarsF32)265 XLA_TEST_F(ScalarComputationsTest, RemTwoScalarsF32) {
266   XlaBuilder builder(TestName());
267   Rem(ConstantR0<float>(&builder, 2.5f), ConstantR0<float>(&builder, 5.0f));
268 
269   ComputeAndCompareR0<float>(&builder, 2.5f, {}, error_spec_);
270 }
271 
272 struct DivS32Params {
273   int32_t dividend;
274   int32_t divisor;
275   int32_t quotient;
276   int32_t remainder;
277 };
278 
PrintTo(const DivS32Params & p,std::ostream * os)279 void PrintTo(const DivS32Params& p, std::ostream* os) {
280   *os << "{" << p.dividend << ", " << p.divisor << ", " << p.quotient << ", "
281       << p.remainder << "}";
282 }
283 
284 class DivS32Test : public ClientLibraryTestBase,
285                    public ::testing::WithParamInterface<DivS32Params> {};
286 
XLA_TEST_P(DivS32Test,DivideTwoScalarsS32)287 XLA_TEST_P(DivS32Test, DivideTwoScalarsS32) {
288   DivS32Params p = GetParam();
289   XlaBuilder builder(TestName());
290   Div(ConstantR0<int32_t>(&builder, p.dividend),
291       ConstantR0<int32_t>(&builder, p.divisor));
292 
293   ComputeAndCompareR0<int32_t>(&builder, p.quotient, {});
294 }
295 
XLA_TEST_P(DivS32Test,RemainderTwoScalarsS32)296 XLA_TEST_P(DivS32Test, RemainderTwoScalarsS32) {
297   DivS32Params p = GetParam();
298   XlaBuilder builder(TestName());
299   Rem(ConstantR0<int32_t>(&builder, p.dividend),
300       ConstantR0<int32_t>(&builder, p.divisor));
301 
302   ComputeAndCompareR0<int32_t>(&builder, p.remainder, {});
303 }
304 
XLA_TEST_P(DivS32Test,DivideTwoScalarsNonConstS32)305 XLA_TEST_P(DivS32Test, DivideTwoScalarsNonConstS32) {
306   DivS32Params p = GetParam();
307   XlaBuilder builder(TestName());
308   XlaOp dividend;
309   XlaOp divisor;
310   auto dividendd = CreateR0Parameter<int32_t>(p.dividend, 0, "dividend",
311                                               &builder, &dividend);
312   auto divisord =
313       CreateR0Parameter<int32_t>(p.divisor, 1, "divisor", &builder, &divisor);
314   Div(dividend, divisor);
315 
316   ComputeAndCompareR0<int32_t>(&builder, p.quotient,
317                                {dividendd.get(), divisord.get()});
318 }
319 
XLA_TEST_P(DivS32Test,RemainderTwoScalarsNonConstDivisorS32)320 XLA_TEST_P(DivS32Test, RemainderTwoScalarsNonConstDivisorS32) {
321   DivS32Params p = GetParam();
322   XlaBuilder builder(TestName());
323   XlaOp dividend;
324   XlaOp divisor;
325   auto dividendd = CreateR0Parameter<int32_t>(p.dividend, 0, "dividend",
326                                               &builder, &dividend);
327   auto divisord =
328       CreateR0Parameter<int32_t>(p.divisor, 1, "divisor", &builder, &divisor);
329   Rem(dividend, divisor);
330 
331   ComputeAndCompareR0<int32_t>(&builder, p.remainder,
332                                {dividendd.get(), divisord.get()});
333 }
334 
335 INSTANTIATE_TEST_CASE_P(
336     DivS32Test_Instantiation, DivS32Test,
337     ::testing::Values(
338         // Positive divisors.
339         DivS32Params{5, 2, 2, 1},      //
340         DivS32Params{-5, 2, -2, -1},   //
341         DivS32Params{17, 3, 5, 2},     //
342         DivS32Params{-17, 3, -5, -2},  //
343         // Negative divisors.
344         DivS32Params{5, -2, -2, 1},    //
345         DivS32Params{-5, -2, 2, -1},   //
346         DivS32Params{17, -3, -5, 2},   //
347         DivS32Params{-17, -3, 5, -2},  //
348         // Large positive divisors.
349         DivS32Params{INT32_MIN, 7919, -271181, -1309},             //
350         DivS32Params{INT32_MIN, INT32_MAX, -1, -1},                //
351         DivS32Params{INT32_MIN + 1, INT32_MAX, -1, 0},             //
352         DivS32Params{INT32_MIN + 2, INT32_MAX, 0, INT32_MIN + 2},  //
353         DivS32Params{INT32_MIN, 0x40000000, -2, 0},                //
354         DivS32Params{INT32_MIN + 1, 0x40000000, -1, -0x3fffffff},  //
355         // Large negative divisors.
356         DivS32Params{INT32_MIN, INT32_MIN, 1, 0},                  //
357         DivS32Params{INT32_MIN, INT32_MIN + 1, 1, -1},             //
358         DivS32Params{INT32_MIN + 1, INT32_MIN, 0, INT32_MIN + 1},  //
359         DivS32Params{INT32_MAX, INT32_MIN, 0, INT32_MAX},          //
360         DivS32Params{INT32_MAX, INT32_MIN + 1, -1, 0},             //
361         DivS32Params{INT32_MIN, -0x40000000, 2, 0},                //
362         DivS32Params{INT32_MIN + 1, -0x40000000, 1, -0x3fffffff}));
363 
XLA_TEST_F(ScalarComputationsTest,DivU32s)364 XLA_TEST_F(ScalarComputationsTest, DivU32s) {
365   // clang-format off
366   // Some interesting values to test.
367   std::vector<uint32_t> vals = {
368     0, 1, 2, 17, 101, 3333, 0x7FFFFFFF, 0x80000000, UINT32_MAX - 1, UINT32_MAX};
369   // clang-format on
370 
371   XlaComputation div_computation;
372   {
373     XlaBuilder builder(TestName());
374 
375     XlaOp dividend =
376         Parameter(&builder, 0, ShapeUtil::MakeShape(U32, {}), "dividend");
377     XlaOp divisor =
378         Parameter(&builder, 1, ShapeUtil::MakeShape(U32, {}), "divisor");
379     Div(dividend, divisor);
380     TF_ASSERT_OK_AND_ASSIGN(div_computation, builder.Build());
381   }
382 
383   for (uint32_t divisor : vals) {
384     if (divisor != 0) {
385       for (uint32_t dividend : vals) {
386         auto dividend_literal = LiteralUtil::CreateR0<uint32_t>(dividend);
387         auto divisor_literal = LiteralUtil::CreateR0<uint32_t>(divisor);
388         TF_ASSERT_OK_AND_ASSIGN(auto dividend_data,
389                                 client_->TransferToServer(dividend_literal));
390         TF_ASSERT_OK_AND_ASSIGN(auto divisor_data,
391                                 client_->TransferToServer(divisor_literal));
392         auto actual_literal =
393             client_
394                 ->ExecuteAndTransfer(div_computation,
395                                      {dividend_data.get(), divisor_data.get()},
396                                      &execution_options_)
397                 .value();
398         auto expected_literal =
399             LiteralUtil::CreateR0<uint32_t>(dividend / divisor);
400         EXPECT_TRUE(LiteralTestUtil::Equal(expected_literal, actual_literal));
401       }
402     }
403   }
404 }
405 
XLA_TEST_F(ScalarComputationsTest,RemU32s)406 XLA_TEST_F(ScalarComputationsTest, RemU32s) {
407   // clang-format off
408   // Some interesting values to test.
409   std::vector<uint32_t> vals = {
410     0, 1, 2, 17, 101, 3333, 0x7FFFFFFF, 0x80000000, UINT32_MAX - 1, UINT32_MAX};
411   // clang-format on
412 
413   XlaComputation rem_computation;
414   {
415     XlaBuilder builder(TestName());
416 
417     XlaOp dividend =
418         Parameter(&builder, 0, ShapeUtil::MakeShape(U32, {}), "dividend");
419     XlaOp divisor =
420         Parameter(&builder, 1, ShapeUtil::MakeShape(U32, {}), "divisor");
421     Rem(dividend, divisor);
422     TF_ASSERT_OK_AND_ASSIGN(rem_computation, builder.Build());
423   }
424 
425   for (uint32_t divisor : vals) {
426     if (divisor != 0) {
427       for (uint32_t dividend : vals) {
428         auto dividend_literal = LiteralUtil::CreateR0<uint32_t>(dividend);
429         auto divisor_literal = LiteralUtil::CreateR0<uint32_t>(divisor);
430         TF_ASSERT_OK_AND_ASSIGN(auto dividend_data,
431                                 client_->TransferToServer(dividend_literal));
432         TF_ASSERT_OK_AND_ASSIGN(auto divisor_data,
433                                 client_->TransferToServer(divisor_literal));
434         auto actual_literal =
435             client_
436                 ->ExecuteAndTransfer(rem_computation,
437                                      {dividend_data.get(), divisor_data.get()},
438                                      &execution_options_)
439                 .value();
440         auto expected_literal =
441             LiteralUtil::CreateR0<uint32_t>(dividend % divisor);
442         EXPECT_TRUE(LiteralTestUtil::Equal(expected_literal, actual_literal));
443       }
444     }
445   }
446 }
447 
XLA_TEST_F(ScalarComputationsTest,RemainderTwoScalarsNonConstDividendS32)448 XLA_TEST_F(ScalarComputationsTest, RemainderTwoScalarsNonConstDividendS32) {
449   XlaBuilder builder(TestName());
450   auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(S32, {}), "x");
451   Rem(x, ConstantR0<int32_t>(&builder, 80000));
452 
453   Literal literal = LiteralUtil::CreateR0<int32_t>(87919);
454   TF_ASSERT_OK_AND_ASSIGN(auto input_data, client_->TransferToServer(literal));
455   ComputeAndCompareR0<int32_t>(&builder, 7919, {input_data.get()});
456 }
457 
XLA_TEST_F(ScalarComputationsTest,DivideTwoScalarsU32)458 XLA_TEST_F(ScalarComputationsTest, DivideTwoScalarsU32) {
459   XlaBuilder builder(TestName());
460   // This verifies 0xFFFFFFFE / 2 = 0x7FFFFFFF. If XLA incorrectly treated U32
461   // as S32, it would output -2 / 2 = -1 (0xFFFFFFFF).
462   Div(ConstantR0<uint32_t>(&builder, 0xFFFFFFFE),
463       ConstantR0<uint32_t>(&builder, 2));
464 
465   ComputeAndCompareR0<uint32_t>(&builder, 0x7FFFFFFF, {});
466 }
467 
XLA_TEST_F(ScalarComputationsTest,RemTwoScalarsU32)468 XLA_TEST_F(ScalarComputationsTest, RemTwoScalarsU32) {
469   XlaBuilder builder(TestName());
470   Rem(ConstantR0<uint32_t>(&builder, 11), ConstantR0<uint32_t>(&builder, 3));
471 
472   ComputeAndCompareR0<uint32_t>(&builder, 2, {});
473 }
474 
XLA_TEST_F(ScalarComputationsTest,AndBool)475 XLA_TEST_F(ScalarComputationsTest, AndBool) {
476   for (bool x : {false, true}) {
477     for (bool y : {false, true}) {
478       XlaBuilder builder(TestName());
479       And(ConstantR0<bool>(&builder, x), ConstantR0<bool>(&builder, y));
480 
481       ComputeAndCompareR0<bool>(&builder, x && y, {});
482     }
483   }
484 }
485 
XLA_TEST_F(ScalarComputationsTest,AndS32)486 XLA_TEST_F(ScalarComputationsTest, AndS32) {
487   for (int32_t x : {0, 8}) {
488     for (int32_t y : {1, -16}) {
489       XlaBuilder builder(TestName());
490       And(ConstantR0<int32_t>(&builder, x), ConstantR0<int32_t>(&builder, y));
491 
492       ComputeAndCompareR0<int32_t>(&builder, x & y, {});
493     }
494   }
495 }
496 
XLA_TEST_F(ScalarComputationsTest,AndU32)497 XLA_TEST_F(ScalarComputationsTest, AndU32) {
498   for (uint32_t x : {0, 8}) {
499     for (uint32_t y : {1, 16}) {
500       XlaBuilder builder(TestName());
501       And(ConstantR0<uint32_t>(&builder, x), ConstantR0<uint32_t>(&builder, y));
502 
503       ComputeAndCompareR0<uint32_t>(&builder, x & y, {});
504     }
505   }
506 }
507 
XLA_TEST_F(ScalarComputationsTest,OrBool)508 XLA_TEST_F(ScalarComputationsTest, OrBool) {
509   for (bool x : {false, true}) {
510     for (bool y : {false, true}) {
511       XlaBuilder builder(TestName());
512       Or(ConstantR0<bool>(&builder, x), ConstantR0<bool>(&builder, y));
513 
514       ComputeAndCompareR0<bool>(&builder, x || y, {});
515     }
516   }
517 }
518 
XLA_TEST_F(ScalarComputationsTest,OrS32)519 XLA_TEST_F(ScalarComputationsTest, OrS32) {
520   for (int32_t x : {0, 8}) {
521     for (int32_t y : {1, -16}) {
522       XlaBuilder builder(TestName());
523       Or(ConstantR0<int32_t>(&builder, x), ConstantR0<int32_t>(&builder, y));
524 
525       ComputeAndCompareR0<int32_t>(&builder, x | y, {});
526     }
527   }
528 }
529 
XLA_TEST_F(ScalarComputationsTest,OrU32)530 XLA_TEST_F(ScalarComputationsTest, OrU32) {
531   for (uint32_t x : {0, 8}) {
532     for (uint32_t y : {1, 16}) {
533       XlaBuilder builder(TestName());
534       Or(ConstantR0<uint32_t>(&builder, x), ConstantR0<uint32_t>(&builder, y));
535 
536       ComputeAndCompareR0<uint32_t>(&builder, x | y, {});
537     }
538   }
539 }
540 
XLA_TEST_F(ScalarComputationsTest,NotBool)541 XLA_TEST_F(ScalarComputationsTest, NotBool) {
542   for (bool x : {false, true}) {
543     XlaBuilder builder(TestName());
544     Not(ConstantR0<bool>(&builder, x));
545 
546     ComputeAndCompareR0<bool>(&builder, !x, {});
547   }
548 }
549 
XLA_TEST_F(ScalarComputationsTest,NotS32)550 XLA_TEST_F(ScalarComputationsTest, NotS32) {
551   for (int32_t x : {-1, 0, 1}) {
552     XlaBuilder builder(TestName());
553     Not(ConstantR0<int32_t>(&builder, x));
554 
555     ComputeAndCompareR0<int32_t>(&builder, ~x, {});
556   }
557 }
558 
XLA_TEST_F(ScalarComputationsTest,NotU32)559 XLA_TEST_F(ScalarComputationsTest, NotU32) {
560   for (uint32_t x : {0, 1, 2}) {
561     XlaBuilder builder(TestName());
562     Not(ConstantR0<uint32_t>(&builder, x));
563 
564     ComputeAndCompareR0<uint32_t>(&builder, ~x, {});
565   }
566 }
567 
XLA_TEST_F(ScalarComputationsTest,SelectScalarTrue)568 XLA_TEST_F(ScalarComputationsTest, SelectScalarTrue) {
569   XlaBuilder builder(TestName());
570   Select(ConstantR0<bool>(&builder, true),     // The predicate.
571          ConstantR0<float>(&builder, 123.0f),  // The value on true.
572          ConstantR0<float>(&builder, 42.0f));  // The value on false.
573 
574   ComputeAndCompareR0<float>(&builder, 123.0f, {}, error_spec_);
575 }
576 
XLA_TEST_F(ScalarComputationsTest,SelectScalarFalse)577 XLA_TEST_F(ScalarComputationsTest, SelectScalarFalse) {
578   XlaBuilder builder(TestName());
579   Select(ConstantR0<bool>(&builder, false),    // The predicate.
580          ConstantR0<float>(&builder, 123.0f),  // The value on true.
581          ConstantR0<float>(&builder, 42.0f));  // The value on false.
582 
583   ComputeAndCompareR0<float>(&builder, 42.0f, {}, error_spec_);
584 }
585 
586 // This test is an explicit version of what is happening in the following
587 // templatized comparison tests.
XLA_TEST_F(ScalarComputationsTest,CompareGtScalar)588 XLA_TEST_F(ScalarComputationsTest, CompareGtScalar) {
589   XlaBuilder builder(TestName());
590   Gt(ConstantR0<float>(&builder, 2.0f), ConstantR0<float>(&builder, 1.0f));
591 
592   ComputeAndCompareR0<bool>(&builder, true, {});
593 }
594 
595 // S32 comparisons.
XLA_TEST_F(ScalarComputationsTest,CompareEqS32Greater)596 XLA_TEST_F(ScalarComputationsTest, CompareEqS32Greater) {
597   TestCompare<int32_t>(2, 1, false, &Eq);
598 }
XLA_TEST_F(ScalarComputationsTest,CompareEqS32Equal)599 XLA_TEST_F(ScalarComputationsTest, CompareEqS32Equal) {
600   TestCompare<int32_t>(3, 3, true, &Eq);
601 }
602 
XLA_TEST_F(ScalarComputationsTest,CompareNeS32)603 XLA_TEST_F(ScalarComputationsTest, CompareNeS32) {
604   TestCompare<int32_t>(2, 1, true, &Ne);
605 }
606 
XLA_TEST_F(ScalarComputationsTest,CompareGeS32)607 XLA_TEST_F(ScalarComputationsTest, CompareGeS32) {
608   TestCompare<int32_t>(2, 1, true, &Ge);
609 }
610 
XLA_TEST_F(ScalarComputationsTest,CompareGtS32)611 XLA_TEST_F(ScalarComputationsTest, CompareGtS32) {
612   TestCompare<int32_t>(1, 5, false, &Gt);
613 }
614 
XLA_TEST_F(ScalarComputationsTest,CompareLeS32)615 XLA_TEST_F(ScalarComputationsTest, CompareLeS32) {
616   TestCompare<int32_t>(2, 1, false, &Le);
617 }
618 
XLA_TEST_F(ScalarComputationsTest,CompareLtS32)619 XLA_TEST_F(ScalarComputationsTest, CompareLtS32) {
620   TestCompare<int32_t>(9, 7, false, &Lt);
621   TestCompare<int32_t>(std::numeric_limits<int32_t>::min(),
622                        std::numeric_limits<int32_t>::max(), true, &Lt);
623 }
624 
625 // U32 comparisons.
XLA_TEST_F(ScalarComputationsTest,CompareEqU32False)626 XLA_TEST_F(ScalarComputationsTest, CompareEqU32False) {
627   TestCompare<uint32_t>(2, 1, false, &Eq);
628 }
629 
XLA_TEST_F(ScalarComputationsTest,CompareNeU32)630 XLA_TEST_F(ScalarComputationsTest, CompareNeU32) {
631   TestCompare<uint32_t>(2, 1, true, &Ne);
632 }
633 
XLA_TEST_F(ScalarComputationsTest,CompareGeU32Greater)634 XLA_TEST_F(ScalarComputationsTest, CompareGeU32Greater) {
635   TestCompare<uint32_t>(2, 1, true, &Ge);
636 }
637 
XLA_TEST_F(ScalarComputationsTest,CompareGeU32Equal)638 XLA_TEST_F(ScalarComputationsTest, CompareGeU32Equal) {
639   TestCompare<uint32_t>(3, 3, true, &Ge);
640 }
641 
XLA_TEST_F(ScalarComputationsTest,CompareGtU32)642 XLA_TEST_F(ScalarComputationsTest, CompareGtU32) {
643   TestCompare<uint32_t>(1, 5, false, &Gt);
644   TestCompare<uint32_t>(5, 5, false, &Gt);
645   TestCompare<uint32_t>(5, 1, true, &Gt);
646 }
647 
XLA_TEST_F(ScalarComputationsTest,CompareLeU32)648 XLA_TEST_F(ScalarComputationsTest, CompareLeU32) {
649   TestCompare<uint32_t>(2, 1, false, &Le);
650 }
651 
XLA_TEST_F(ScalarComputationsTest,CompareLtU32)652 XLA_TEST_F(ScalarComputationsTest, CompareLtU32) {
653   TestCompare<uint32_t>(9, 7, false, &Lt);
654   TestCompare<uint32_t>(0, std::numeric_limits<uint32_t>::max(), true, &Lt);
655 }
656 
657 // F32 comparisons.
XLA_TEST_F(ScalarComputationsTest,CompareEqF32False)658 XLA_TEST_F(ScalarComputationsTest, CompareEqF32False) {
659   TestCompare<float>(2.0, 1.3, false, &Eq);
660 }
661 
XLA_TEST_F(ScalarComputationsTest,CompareNeF32)662 XLA_TEST_F(ScalarComputationsTest, CompareNeF32) {
663   TestCompare<float>(2.0, 1.3, true, &Ne);
664 }
665 
XLA_TEST_F(ScalarComputationsTest,CompareGeF32Greater)666 XLA_TEST_F(ScalarComputationsTest, CompareGeF32Greater) {
667   TestCompare<float>(2.0, 1.9, true, &Ge);
668 }
XLA_TEST_F(ScalarComputationsTest,CompareGeF32Equal)669 XLA_TEST_F(ScalarComputationsTest, CompareGeF32Equal) {
670   TestCompare<float>(3.5, 3.5, true, &Ge);
671 }
672 
XLA_TEST_F(ScalarComputationsTest,CompareGtF32)673 XLA_TEST_F(ScalarComputationsTest, CompareGtF32) {
674   TestCompare<float>(1.0, 5.2, false, &Gt);
675 }
676 
XLA_TEST_F(ScalarComputationsTest,CompareLeF32)677 XLA_TEST_F(ScalarComputationsTest, CompareLeF32) {
678   TestCompare<float>(2.0, 1.2, false, &Le);
679 }
680 
XLA_TEST_F(ScalarComputationsTest,CompareLtF32)681 XLA_TEST_F(ScalarComputationsTest, CompareLtF32) {
682   TestCompare<float>(9.0, 7.2, false, &Lt);
683 }
684 
685 // F32 comparisons with exceptional values.  The test names encode the
686 // left/right operands at the end, and use Minf and Mzero for -inf and -0.0.
XLA_TEST_F(ScalarComputationsTest,CompareLtF32MinfMzero)687 XLA_TEST_F(ScalarComputationsTest, CompareLtF32MinfMzero) {
688   TestCompare<float>(-INFINITY, -0.0, true, &Lt);
689 }
XLA_TEST_F(ScalarComputationsTest,CompareLtF32MzeroZero)690 XLA_TEST_F(ScalarComputationsTest, CompareLtF32MzeroZero) {
691   // Comparisons of 0.0 to -0.0 consider them equal in IEEE 754.
692   TestCompare<float>(-0.0, 0.0, false, &Lt);
693 }
XLA_TEST_F(ScalarComputationsTest,CompareLtF32ZeroInf)694 XLA_TEST_F(ScalarComputationsTest, CompareLtF32ZeroInf) {
695   TestCompare<float>(0.0, INFINITY, true, &Lt);
696 }
697 
XLA_TEST_F(ScalarComputationsTest,CompareGeF32MinfMzero)698 XLA_TEST_F(ScalarComputationsTest, CompareGeF32MinfMzero) {
699   TestCompare<float>(-INFINITY, -0.0, false, &Ge);
700 }
XLA_TEST_F(ScalarComputationsTest,CompareGeF32MzeroZero)701 XLA_TEST_F(ScalarComputationsTest, CompareGeF32MzeroZero) {
702   // Comparisons of 0.0 to -0.0 consider them equal in IEEE 754.
703   TestCompare<float>(-0.0, 0.0, true, &Ge);
704 }
XLA_TEST_F(ScalarComputationsTest,CompareGeF32ZeroInf)705 XLA_TEST_F(ScalarComputationsTest, CompareGeF32ZeroInf) {
706   TestCompare<float>(0.0, INFINITY, false, &Ge);
707 }
708 
XLA_TEST_F(ScalarComputationsTest,ExpScalar)709 XLA_TEST_F(ScalarComputationsTest, ExpScalar) {
710   XlaBuilder builder(TestName());
711   Exp(ConstantR0<float>(&builder, 2.0f));
712 
713   ComputeAndCompareR0<float>(&builder, 7.3890562, {}, error_spec_);
714 }
715 
XLA_TEST_F(ScalarComputationsTest,LogScalar)716 XLA_TEST_F(ScalarComputationsTest, LogScalar) {
717   XlaBuilder builder("log");
718   Log(ConstantR0<float>(&builder, 2.0f));
719 
720   ComputeAndCompareR0<float>(&builder, 0.6931471, {}, error_spec_);
721 }
722 
XLA_TEST_F(ScalarComputationsTest,TanhScalar)723 XLA_TEST_F(ScalarComputationsTest, TanhScalar) {
724   XlaBuilder builder(TestName());
725   Tanh(ConstantR0<float>(&builder, 2.0f));
726 
727   ComputeAndCompareR0<float>(&builder, 0.96402758, {}, error_spec_);
728 }
729 
XLA_TEST_F(ScalarComputationsTest,TanhDoubleScalar)730 XLA_TEST_F(ScalarComputationsTest, TanhDoubleScalar) {
731   XlaBuilder builder(TestName());
732   Tanh(ConstantR0<double>(&builder, 2.0));
733 
734   ComputeAndCompareR0<double>(&builder, 0.96402758, {}, error_spec_);
735 }
736 
XLA_TEST_F(ScalarComputationsTest,PowScalar)737 XLA_TEST_F(ScalarComputationsTest, PowScalar) {
738   XlaBuilder builder(TestName());
739   Pow(ConstantR0<float>(&builder, 2.0f), ConstantR0<float>(&builder, 3.0f));
740 
741   ComputeAndCompareR0<float>(&builder, 8.0, {}, error_spec_);
742 }
743 
XLA_TEST_F(ScalarComputationsTest,CbrtScalar)744 XLA_TEST_F(ScalarComputationsTest, CbrtScalar) {
745   XlaBuilder builder(TestName());
746   Cbrt(ConstantR0<float>(&builder, 2.0f));
747 
748   ComputeAndCompare(&builder, {}, error_spec_);
749 }
750 
XLA_TEST_F(ScalarComputationsTest,ClampScalarHighS32)751 XLA_TEST_F(ScalarComputationsTest, ClampScalarHighS32) {
752   XlaBuilder builder(TestName());
753   Clamp(ConstantR0<int32_t>(&builder, -1),  // The lower bound.
754         ConstantR0<int32_t>(&builder, 5),   // The operand to be clamped.
755         ConstantR0<int32_t>(&builder, 3));  // The upper bound.
756 
757   ComputeAndCompareR0<int32_t>(&builder, 3, {});
758 }
759 
XLA_TEST_F(ScalarComputationsTest,ClampScalarMiddleS32)760 XLA_TEST_F(ScalarComputationsTest, ClampScalarMiddleS32) {
761   XlaBuilder builder(TestName());
762   Clamp(ConstantR0<int32_t>(&builder, -1),  // The lower bound.
763         ConstantR0<int32_t>(&builder, 2),   // The operand to be clamped.
764         ConstantR0<int32_t>(&builder, 3));  // The upper bound.
765 
766   ComputeAndCompareR0<int32_t>(&builder, 2, {});
767 }
768 
XLA_TEST_F(ScalarComputationsTest,ClampScalarLowS32)769 XLA_TEST_F(ScalarComputationsTest, ClampScalarLowS32) {
770   XlaBuilder builder(TestName());
771   Clamp(ConstantR0<int32_t>(&builder, -1),  // The lower bound.
772         ConstantR0<int32_t>(&builder, -5),  // The operand to be clamped.
773         ConstantR0<int32_t>(&builder, 3));  // The upper bound.
774 
775   ComputeAndCompareR0<int32_t>(&builder, -1, {});
776 }
777 
XLA_TEST_F(ScalarComputationsTest,ClampScalarHighU32)778 XLA_TEST_F(ScalarComputationsTest, ClampScalarHighU32) {
779   XlaBuilder builder(TestName());
780   Clamp(ConstantR0<uint32_t>(&builder, 1),   // The lower bound.
781         ConstantR0<uint32_t>(&builder, 5),   // The operand to be clamped.
782         ConstantR0<uint32_t>(&builder, 3));  // The upper bound.
783 
784   ComputeAndCompareR0<uint32_t>(&builder, 3, {});
785 }
786 
XLA_TEST_F(ScalarComputationsTest,ClampScalarMiddleU32)787 XLA_TEST_F(ScalarComputationsTest, ClampScalarMiddleU32) {
788   XlaBuilder builder(TestName());
789   Clamp(ConstantR0<uint32_t>(&builder, 1),   // The lower bound.
790         ConstantR0<uint32_t>(&builder, 2),   // The operand to be clamped.
791         ConstantR0<uint32_t>(&builder, 3));  // The upper bound.
792 
793   ComputeAndCompareR0<uint32_t>(&builder, 2, {});
794 }
795 
XLA_TEST_F(ScalarComputationsTest,ClampScalarLowU32)796 XLA_TEST_F(ScalarComputationsTest, ClampScalarLowU32) {
797   XlaBuilder builder(TestName());
798   Clamp(ConstantR0<uint32_t>(&builder, 1),   // The lower bound.
799         ConstantR0<uint32_t>(&builder, 0),   // The operand to be clamped.
800         ConstantR0<uint32_t>(&builder, 3));  // The upper bound.
801 
802   ComputeAndCompareR0<uint32_t>(&builder, 1, {});
803 }
804 
XLA_TEST_F(ScalarComputationsTest,ClampScalarHighF32)805 XLA_TEST_F(ScalarComputationsTest, ClampScalarHighF32) {
806   XlaBuilder builder(TestName());
807   Clamp(ConstantR0<float>(&builder, 2.0f),   // The lower bound.
808         ConstantR0<float>(&builder, 5.0f),   // The operand to be clamped.
809         ConstantR0<float>(&builder, 3.0f));  // The upper bound.
810 
811   ComputeAndCompareR0<float>(&builder, 3.0, {}, error_spec_);
812 }
813 
XLA_TEST_F(ScalarComputationsTest,ClampScalarMiddleF32)814 XLA_TEST_F(ScalarComputationsTest, ClampScalarMiddleF32) {
815   XlaBuilder builder(TestName());
816   Clamp(ConstantR0<float>(&builder, 2.0f),   // The lower bound.
817         ConstantR0<float>(&builder, 2.5f),   // The operand to be clamped.
818         ConstantR0<float>(&builder, 3.0f));  // The upper bound.
819 
820   ComputeAndCompareR0<float>(&builder, 2.5, {}, error_spec_);
821 }
822 
XLA_TEST_F(ScalarComputationsTest,ClampScalarLowF32)823 XLA_TEST_F(ScalarComputationsTest, ClampScalarLowF32) {
824   XlaBuilder builder(TestName());
825   Clamp(ConstantR0<float>(&builder, 2.0f),   // The lower bound.
826         ConstantR0<float>(&builder, -5.0f),  // The operand to be clamped.
827         ConstantR0<float>(&builder, 3.0f));  // The upper bound.
828 
829   ComputeAndCompareR0<float>(&builder, 2.0, {}, error_spec_);
830 }
831 
XLA_TEST_F(ScalarComputationsTest,MinS32Above)832 XLA_TEST_F(ScalarComputationsTest, MinS32Above) {
833   TestMinMax<int32_t>(10, 3, 3, &Min);
834 }
835 
XLA_TEST_F(ScalarComputationsTest,MinS32Below)836 XLA_TEST_F(ScalarComputationsTest, MinS32Below) {
837   TestMinMax<int32_t>(-100, 3, -100, &Min);
838 }
839 
XLA_TEST_F(ScalarComputationsTest,MaxS32Above)840 XLA_TEST_F(ScalarComputationsTest, MaxS32Above) {
841   TestMinMax<int32_t>(10, 3, 10, &Max);
842 }
843 
XLA_TEST_F(ScalarComputationsTest,MaxS32Below)844 XLA_TEST_F(ScalarComputationsTest, MaxS32Below) {
845   TestMinMax<int32_t>(-100, 3, 3, &Max);
846 }
847 
XLA_TEST_F(ScalarComputationsTest,MinU32Above)848 XLA_TEST_F(ScalarComputationsTest, MinU32Above) {
849   const uint32_t large = std::numeric_limits<int32_t>::max();
850   TestMinMax<uint32_t>(large, 3, 3, &Min);
851 }
852 
XLA_TEST_F(ScalarComputationsTest,MinU32Below)853 XLA_TEST_F(ScalarComputationsTest, MinU32Below) {
854   TestMinMax<uint32_t>(0, 5, 0, &Min);
855 }
856 
XLA_TEST_F(ScalarComputationsTest,MaxU32Above)857 XLA_TEST_F(ScalarComputationsTest, MaxU32Above) {
858   const uint32_t large = std::numeric_limits<int32_t>::max();
859   TestMinMax<uint32_t>(large, 3, large, &Max);
860 }
861 
XLA_TEST_F(ScalarComputationsTest,MaxU32Below)862 XLA_TEST_F(ScalarComputationsTest, MaxU32Below) {
863   TestMinMax<uint32_t>(0, 5, 5, &Max);
864 }
865 
XLA_TEST_F(ScalarComputationsTest,MinF32Above)866 XLA_TEST_F(ScalarComputationsTest, MinF32Above) {
867   TestMinMax<float>(10.1f, 3.1f, 3.1f, &Min);
868 }
869 
XLA_TEST_F(ScalarComputationsTest,MinF32Below)870 XLA_TEST_F(ScalarComputationsTest, MinF32Below) {
871   TestMinMax<float>(-100.1f, 3.1f, -100.1f, &Min);
872 }
873 
XLA_TEST_F(ScalarComputationsTest,MinPropagatesNan)874 XLA_TEST_F(ScalarComputationsTest, MinPropagatesNan) {
875   SetFastMathDisabled(true);
876   TestMinMax<float>(NAN, 3.1f, NAN, &Min);
877   TestMinMax<float>(-3.1f, NAN, NAN, &Min);
878 }
879 
XLA_TEST_F(ScalarComputationsTest,MaxF32Above)880 XLA_TEST_F(ScalarComputationsTest, MaxF32Above) {
881   TestMinMax<float>(10.1f, 3.1f, 10.1f, &Max);
882 }
883 
XLA_TEST_F(ScalarComputationsTest,MaxF32Below)884 XLA_TEST_F(ScalarComputationsTest, MaxF32Below) {
885   TestMinMax<float>(-100.1f, 3.1f, 3.1f, &Max);
886 }
887 
XLA_TEST_F(ScalarComputationsTest,MaxPropagatesNan)888 XLA_TEST_F(ScalarComputationsTest, MaxPropagatesNan) {
889   SetFastMathDisabled(true);
890   TestMinMax<float>(NAN, 3.1f, NAN, &Max);
891   TestMinMax<float>(-3.1f, NAN, NAN, &Max);
892 }
893 
XLA_TEST_F(ScalarComputationsTest,ComplicatedArithmeticExpressionF32)894 XLA_TEST_F(ScalarComputationsTest, ComplicatedArithmeticExpressionF32) {
895   // Compute the expression (1 * (3 - 1) * (7 + 0) - 4) / 20.
896   XlaBuilder b(TestName());
897   Div(Sub(Mul(ConstantR0<float>(&b, 1),
898               Mul(Sub(ConstantR0<float>(&b, 3), ConstantR0<float>(&b, 1)),
899                   Add(ConstantR0<float>(&b, 7), ConstantR0<float>(&b, 0)))),
900           ConstantR0<float>(&b, 4)),
901       ConstantR0<float>(&b, 20));
902 
903   ComputeAndCompareR0<float>(&b, 0.5, {}, error_spec_);
904 }
905 
XLA_TEST_F(ScalarComputationsTest,ComplicatedArithmeticExpressionS32)906 XLA_TEST_F(ScalarComputationsTest, ComplicatedArithmeticExpressionS32) {
907   // Compute the expression 1 * (3 - 1) * (7 + 0) - 4.
908   XlaBuilder b(TestName());
909   Sub(Mul(ConstantR0<int32_t>(&b, 1),
910           Mul(Sub(ConstantR0<int32_t>(&b, 3), ConstantR0<int32_t>(&b, 1)),
911               Add(ConstantR0<int32_t>(&b, 7), ConstantR0<int32_t>(&b, 0)))),
912       ConstantR0<int32_t>(&b, 4));
913 
914   ComputeAndCompareR0<int32_t>(&b, 10, {});
915 }
916 
XLA_TEST_F(ScalarComputationsTest,RoundScalar)917 XLA_TEST_F(ScalarComputationsTest, RoundScalar) {
918   XlaBuilder builder(TestName());
919   Round(ConstantR0<float>(&builder, 1.4f));
920 
921   ComputeAndCompareR0<float>(&builder, 1.0f, {}, error_spec_);
922 }
923 
924 }  // namespace
925 }  // namespace xla
926