• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <cmath>
17 #include <limits>
18 #include <memory>
19 
20 #include "absl/strings/str_cat.h"
21 #include "absl/types/span.h"
22 #include "tensorflow/compiler/xla/client/global_data.h"
23 #include "tensorflow/compiler/xla/client/local_client.h"
24 #include "tensorflow/compiler/xla/client/xla_builder.h"
25 #include "tensorflow/compiler/xla/client/xla_computation.h"
26 #include "tensorflow/compiler/xla/literal.h"
27 #include "tensorflow/compiler/xla/literal_util.h"
28 #include "tensorflow/compiler/xla/status_macros.h"
29 #include "tensorflow/compiler/xla/statusor.h"
30 #include "tensorflow/compiler/xla/test_helpers.h"
31 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
32 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
33 #include "tensorflow/compiler/xla/tests/test_macros.h"
34 #include "tensorflow/compiler/xla/xla_data.pb.h"
35 #include "tensorflow/core/platform/test.h"
36 #include "tensorflow/core/platform/types.h"
37 
38 namespace xla {
39 namespace {
40 
41 class ScalarComputationsTest : public ClientLibraryTestBase {
42  public:
43   ErrorSpec error_spec_{0.0001};
44 
45  protected:
46   // A template for building and running a binary comparison test.
47   template <typename NativeT>
TestCompare(NativeT lhs,NativeT rhs,bool expected,const std::function<XlaOp (const XlaOp &,const XlaOp &,absl::Span<const int64>)> & op)48   void TestCompare(NativeT lhs, NativeT rhs, bool expected,
49                    const std::function<XlaOp(const XlaOp&, const XlaOp&,
50                                              absl::Span<const int64>)>& op) {
51     XlaBuilder builder(TestName());
52     XlaOp lhs_op = ConstantR0<NativeT>(&builder, lhs);
53     XlaOp rhs_op = ConstantR0<NativeT>(&builder, rhs);
54     op(lhs_op, rhs_op, {});
55     ComputeAndCompareR0<bool>(&builder, expected, {});
56   }
57 
58   template <typename NativeT>
TestMinMax(NativeT lhs,NativeT rhs,NativeT expected,const std::function<XlaOp (const XlaOp &,const XlaOp &,absl::Span<const int64>)> & op)59   void TestMinMax(NativeT lhs, NativeT rhs, NativeT expected,
60                   const std::function<XlaOp(const XlaOp&, const XlaOp&,
61                                             absl::Span<const int64>)>& op) {
62     XlaBuilder builder(TestName());
63     XlaOp lhs_op = ConstantR0<NativeT>(&builder, lhs);
64     XlaOp rhs_op = ConstantR0<NativeT>(&builder, rhs);
65     op(lhs_op, rhs_op, {});
66     ComputeAndCompareR0<NativeT>(&builder, expected, {});
67   }
68 };
69 
XLA_TEST_F(ScalarComputationsTest,ReturnScalarF32)70 XLA_TEST_F(ScalarComputationsTest, ReturnScalarF32) {
71   XlaBuilder builder(TestName());
72   ConstantR0<float>(&builder, 2.1f);
73 
74   ComputeAndCompareR0<float>(&builder, 2.1f, {}, error_spec_);
75 }
76 
XLA_TEST_F(ScalarComputationsTest,NegateScalarF32)77 XLA_TEST_F(ScalarComputationsTest, NegateScalarF32) {
78   XlaBuilder builder(TestName());
79   Neg(ConstantR0<float>(&builder, 2.1f));
80 
81   ComputeAndCompareR0<float>(&builder, -2.1f, {}, error_spec_);
82 }
83 
XLA_TEST_F(ScalarComputationsTest,NegateScalarS32)84 XLA_TEST_F(ScalarComputationsTest, NegateScalarS32) {
85   XlaBuilder builder(TestName());
86   Neg(ConstantR0<int32>(&builder, 2));
87 
88   ComputeAndCompareR0<int32>(&builder, -2, {});
89 }
90 
XLA_TEST_F(ScalarComputationsTest,AddTwoScalarsF32)91 XLA_TEST_F(ScalarComputationsTest, AddTwoScalarsF32) {
92   XlaBuilder builder(TestName());
93   Add(ConstantR0<float>(&builder, 2.1f), ConstantR0<float>(&builder, 5.5f));
94 
95   ComputeAndCompareR0<float>(&builder, 7.6f, {}, error_spec_);
96 }
97 
XLA_TEST_F(ScalarComputationsTest,AddTwoScalarsS32)98 XLA_TEST_F(ScalarComputationsTest, AddTwoScalarsS32) {
99   XlaBuilder builder(TestName());
100   Add(ConstantR0<int32>(&builder, 2), ConstantR0<int32>(&builder, 5));
101 
102   ComputeAndCompareR0<int32>(&builder, 7, {});
103 }
104 
XLA_TEST_F(ScalarComputationsTest,AddTwoScalarsU32)105 XLA_TEST_F(ScalarComputationsTest, AddTwoScalarsU32) {
106   XlaBuilder builder(TestName());
107   Add(ConstantR0<uint32>(&builder, 35), ConstantR0<uint32>(&builder, 57));
108 
109   ComputeAndCompareR0<uint32>(&builder, 92, {});
110 }
111 
XLA_TEST_F(ScalarComputationsTest,AddTwoScalarsU8)112 XLA_TEST_F(ScalarComputationsTest, AddTwoScalarsU8) {
113   XlaBuilder builder(TestName());
114   Add(ConstantR0<uint8>(&builder, 35), ConstantR0<uint8>(&builder, 57));
115 
116   ComputeAndCompareR0<uint8>(&builder, 92, {});
117 }
118 
XLA_TEST_F(ScalarComputationsTest,AddTwoScalarsU64)119 XLA_TEST_F(ScalarComputationsTest, AddTwoScalarsU64) {
120   XlaBuilder builder(TestName());
121   const uint64 a = static_cast<uint64>(1) << 63;
122   const uint64 b = a + 1;
123   Add(ConstantR0<uint64>(&builder, a), ConstantR0<uint64>(&builder, b));
124 
125   ComputeAndCompareR0<uint64>(&builder, a + b, {});
126 }
127 
XLA_TEST_F(ScalarComputationsTest,AddTwoScalarsS64)128 XLA_TEST_F(ScalarComputationsTest, AddTwoScalarsS64) {
129   XlaBuilder builder(TestName());
130   const int64_t a = static_cast<int64>(1) << 62;
131   const int64_t b = a - 1;
132   Add(ConstantR0<int64>(&builder, a), ConstantR0<int64>(&builder, b));
133 
134   ComputeAndCompareR0<int64>(&builder, a + b, {});
135 }
136 
XLA_TEST_F(ScalarComputationsTest,AddTwoScalarsF64)137 XLA_TEST_F(ScalarComputationsTest, AddTwoScalarsF64) {
138   XlaBuilder builder(TestName());
139   Add(ConstantR0<double>(&builder, 0.25), ConstantR0<double>(&builder, 3.5));
140 
141   ComputeAndCompareR0<double>(&builder, 3.75, {});
142 }
143 
XLA_TEST_F(ScalarComputationsTest,SubtractTwoScalarsF32)144 XLA_TEST_F(ScalarComputationsTest, SubtractTwoScalarsF32) {
145   XlaBuilder builder(TestName());
146   Sub(ConstantR0<float>(&builder, 2.1f), ConstantR0<float>(&builder, 5.5f));
147 
148   ComputeAndCompareR0<float>(&builder, -3.4f, {}, error_spec_);
149 }
150 
XLA_TEST_F(ScalarComputationsTest,SubtractTwoScalarsS32)151 XLA_TEST_F(ScalarComputationsTest, SubtractTwoScalarsS32) {
152   XlaBuilder builder(TestName());
153   Sub(ConstantR0<int32>(&builder, 2), ConstantR0<int32>(&builder, 5));
154 
155   ComputeAndCompareR0<int32>(&builder, -3, {});
156 }
157 
XLA_TEST_F(ScalarComputationsTest,CastS64ToF32)158 XLA_TEST_F(ScalarComputationsTest, CastS64ToF32) {
159   XlaBuilder builder(TestName());
160   auto a = Parameter(&builder, 0, ShapeUtil::MakeShape(S64, {}), "a");
161   ConvertElementType(a, F32);
162 
163   int64_t value = 3LL << 35;
164   Literal a_literal = LiteralUtil::CreateR0<int64>(value);
165   std::unique_ptr<GlobalData> a_data =
166       client_->TransferToServer(a_literal).ConsumeValueOrDie();
167   ComputeAndCompareR0<float>(&builder, static_cast<float>(value),
168                              {a_data.get()});
169 }
170 
XLA_TEST_F(ScalarComputationsTest,MulThreeScalarsF32)171 XLA_TEST_F(ScalarComputationsTest, MulThreeScalarsF32) {
172   XlaBuilder builder(TestName());
173   Mul(Mul(ConstantR0<float>(&builder, 2.1f), ConstantR0<float>(&builder, 5.5f)),
174       ConstantR0<float>(&builder, 0.5f));
175 
176   ComputeAndCompareR0<float>(&builder, 5.775f, {}, error_spec_);
177 }
178 
XLA_TEST_F(ScalarComputationsTest,MulThreeScalarsF64)179 XLA_TEST_F(ScalarComputationsTest, MulThreeScalarsF64) {
180   XlaBuilder builder(TestName());
181   Mul(Mul(ConstantR0<double>(&builder, 3.1415926535897932),
182           ConstantR0<double>(&builder, 2.7182818284590452)),
183       ConstantR0<double>(&builder, 0.5772156649015328));
184 
185   ComputeAndCompareR0<double>(&builder, 4.929268367422896, {},
186                               ErrorSpec{3.6e-15});
187 }
188 
XLA_TEST_F(ScalarComputationsTest,MulTwoScalarsS32)189 XLA_TEST_F(ScalarComputationsTest, MulTwoScalarsS32) {
190   std::vector<int32> data = {0,
191                              1,
192                              -1,
193                              1234,
194                              0x1a243514,
195                              std::numeric_limits<int32>::max(),
196                              std::numeric_limits<int32>::min()};
197 
198   for (int32_t x : data) {
199     for (int32_t y : data) {
200       XlaBuilder builder(TestName());
201       Mul(ConstantR0<int32>(&builder, x), ConstantR0<int32>(&builder, y));
202 
203       // Signed integer overflow is undefined behavior in C++. Convert the input
204       // integers to unsigned, perform the multiplication unsigned, and convert
205       // back.
206       int32_t expected = static_cast<uint32>(x) * static_cast<uint32>(y);
207 
208       ComputeAndCompareR0<int32>(&builder, expected, {});
209     }
210   }
211 }
212 
XLA_TEST_F(ScalarComputationsTest,MulTwoScalarsU32)213 XLA_TEST_F(ScalarComputationsTest, MulTwoScalarsU32) {
214   std::vector<uint32> data = {0,          1,          0xDEADBEEF, 1234,
215                               0x1a243514, 0xFFFFFFFF, 0x80808080};
216 
217   for (uint32 x : data) {
218     for (uint32 y : data) {
219       XlaBuilder builder(TestName());
220       Mul(ConstantR0<uint32>(&builder, x), ConstantR0<uint32>(&builder, y));
221 
222       uint32 expected = x * y;
223       ComputeAndCompareR0<uint32>(&builder, expected, {});
224     }
225   }
226 }
227 
XLA_TEST_F(ScalarComputationsTest,MulThreeScalarsS32)228 XLA_TEST_F(ScalarComputationsTest, MulThreeScalarsS32) {
229   XlaBuilder builder(TestName());
230   Mul(Mul(ConstantR0<int32>(&builder, 2), ConstantR0<int32>(&builder, 5)),
231       ConstantR0<int32>(&builder, 1));
232 
233   ComputeAndCompareR0<int32>(&builder, 10, {});
234 }
235 
XLA_TEST_F(ScalarComputationsTest,MulThreeScalarsF32Params)236 XLA_TEST_F(ScalarComputationsTest, MulThreeScalarsF32Params) {
237   XlaBuilder builder(TestName());
238   Literal a_literal = LiteralUtil::CreateR0<float>(2.1f);
239   Literal b_literal = LiteralUtil::CreateR0<float>(5.5f);
240   Literal c_literal = LiteralUtil::CreateR0<float>(0.5f);
241 
242   std::unique_ptr<GlobalData> a_data =
243       client_->TransferToServer(a_literal).ConsumeValueOrDie();
244   std::unique_ptr<GlobalData> b_data =
245       client_->TransferToServer(b_literal).ConsumeValueOrDie();
246   std::unique_ptr<GlobalData> c_data =
247       client_->TransferToServer(c_literal).ConsumeValueOrDie();
248 
249   XlaOp a = Parameter(&builder, 0, a_literal.shape(), "a");
250   XlaOp b = Parameter(&builder, 1, b_literal.shape(), "b");
251   XlaOp c = Parameter(&builder, 2, c_literal.shape(), "c");
252   Mul(Mul(a, b), c);
253 
254   ComputeAndCompareR0<float>(&builder, 5.775f,
255                              {a_data.get(), b_data.get(), c_data.get()},
256                              error_spec_);
257 }
258 
XLA_TEST_F(ScalarComputationsTest,DivideTwoScalarsF32)259 XLA_TEST_F(ScalarComputationsTest, DivideTwoScalarsF32) {
260   XlaBuilder builder(TestName());
261   Div(ConstantR0<float>(&builder, 5.0f), ConstantR0<float>(&builder, 2.5f));
262 
263   ComputeAndCompareR0<float>(&builder, 2.0f, {}, error_spec_);
264 }
265 
XLA_TEST_F(ScalarComputationsTest,RemTwoScalarsF32)266 XLA_TEST_F(ScalarComputationsTest, RemTwoScalarsF32) {
267   XlaBuilder builder(TestName());
268   Rem(ConstantR0<float>(&builder, 2.5f), ConstantR0<float>(&builder, 5.0f));
269 
270   ComputeAndCompareR0<float>(&builder, 2.5f, {}, error_spec_);
271 }
272 
273 struct DivS32Params {
274   int32 dividend;
275   int32 divisor;
276   int32 quotient;
277   int32 remainder;
278 };
279 
PrintTo(const DivS32Params & p,std::ostream * os)280 void PrintTo(const DivS32Params& p, std::ostream* os) {
281   *os << "{" << p.dividend << ", " << p.divisor << ", " << p.quotient << ", "
282       << p.remainder << "}";
283 }
284 
285 class DivS32Test : public ClientLibraryTestBase,
286                    public ::testing::WithParamInterface<DivS32Params> {};
287 
XLA_TEST_P(DivS32Test,DivideTwoScalarsS32)288 XLA_TEST_P(DivS32Test, DivideTwoScalarsS32) {
289   DivS32Params p = GetParam();
290   XlaBuilder builder(TestName());
291   Div(ConstantR0<int32>(&builder, p.dividend),
292       ConstantR0<int32>(&builder, p.divisor));
293 
294   ComputeAndCompareR0<int32>(&builder, p.quotient, {});
295 }
296 
XLA_TEST_P(DivS32Test,RemainderTwoScalarsS32)297 XLA_TEST_P(DivS32Test, RemainderTwoScalarsS32) {
298   DivS32Params p = GetParam();
299   XlaBuilder builder(TestName());
300   Rem(ConstantR0<int32>(&builder, p.dividend),
301       ConstantR0<int32>(&builder, p.divisor));
302 
303   ComputeAndCompareR0<int32>(&builder, p.remainder, {});
304 }
305 
XLA_TEST_P(DivS32Test,DivideTwoScalarsNonConstS32)306 XLA_TEST_P(DivS32Test, DivideTwoScalarsNonConstS32) {
307   DivS32Params p = GetParam();
308   XlaBuilder builder(TestName());
309   XlaOp dividend;
310   XlaOp divisor;
311   auto dividendd =
312       CreateR0Parameter<int32>(p.dividend, 0, "dividend", &builder, &dividend);
313   auto divisord =
314       CreateR0Parameter<int32>(p.divisor, 1, "divisor", &builder, &divisor);
315   Div(dividend, divisor);
316 
317   ComputeAndCompareR0<int32>(&builder, p.quotient,
318                              {dividendd.get(), divisord.get()});
319 }
320 
XLA_TEST_P(DivS32Test,RemainderTwoScalarsNonConstDivisorS32)321 XLA_TEST_P(DivS32Test, RemainderTwoScalarsNonConstDivisorS32) {
322   DivS32Params p = GetParam();
323   XlaBuilder builder(TestName());
324   XlaOp dividend;
325   XlaOp divisor;
326   auto dividendd =
327       CreateR0Parameter<int32>(p.dividend, 0, "dividend", &builder, &dividend);
328   auto divisord =
329       CreateR0Parameter<int32>(p.divisor, 1, "divisor", &builder, &divisor);
330   Rem(dividend, divisor);
331 
332   ComputeAndCompareR0<int32>(&builder, p.remainder,
333                              {dividendd.get(), divisord.get()});
334 }
335 
336 INSTANTIATE_TEST_CASE_P(
337     DivS32Test_Instantiation, DivS32Test,
338     ::testing::Values(
339         // Positive divisors.
340         DivS32Params{5, 2, 2, 1},      //
341         DivS32Params{-5, 2, -2, -1},   //
342         DivS32Params{17, 3, 5, 2},     //
343         DivS32Params{-17, 3, -5, -2},  //
344         // Negative divisors.
345         DivS32Params{5, -2, -2, 1},    //
346         DivS32Params{-5, -2, 2, -1},   //
347         DivS32Params{17, -3, -5, 2},   //
348         DivS32Params{-17, -3, 5, -2},  //
349         // Large positive divisors.
350         DivS32Params{INT32_MIN, 7919, -271181, -1309},             //
351         DivS32Params{INT32_MIN, INT32_MAX, -1, -1},                //
352         DivS32Params{INT32_MIN + 1, INT32_MAX, -1, 0},             //
353         DivS32Params{INT32_MIN + 2, INT32_MAX, 0, INT32_MIN + 2},  //
354         DivS32Params{INT32_MIN, 0x40000000, -2, 0},                //
355         DivS32Params{INT32_MIN + 1, 0x40000000, -1, -0x3fffffff},  //
356         // Large negative divisors.
357         DivS32Params{INT32_MIN, INT32_MIN, 1, 0},                  //
358         DivS32Params{INT32_MIN, INT32_MIN + 1, 1, -1},             //
359         DivS32Params{INT32_MIN + 1, INT32_MIN, 0, INT32_MIN + 1},  //
360         DivS32Params{INT32_MAX, INT32_MIN, 0, INT32_MAX},          //
361         DivS32Params{INT32_MAX, INT32_MIN + 1, -1, 0},             //
362         DivS32Params{INT32_MIN, -0x40000000, 2, 0},                //
363         DivS32Params{INT32_MIN + 1, -0x40000000, 1, -0x3fffffff}));
364 
XLA_TEST_F(ScalarComputationsTest,DivU32s)365 XLA_TEST_F(ScalarComputationsTest, DivU32s) {
366   // clang-format off
367   // Some interesting values to test.
368   std::vector<uint32> vals = {
369     0, 1, 2, 17, 101, 3333, 0x7FFFFFFF, 0x80000000, UINT32_MAX - 1, UINT32_MAX};
370   // clang-format on
371 
372   XlaComputation div_computation;
373   {
374     XlaBuilder builder(TestName());
375 
376     XlaOp dividend =
377         Parameter(&builder, 0, ShapeUtil::MakeShape(U32, {}), "dividend");
378     XlaOp divisor =
379         Parameter(&builder, 1, ShapeUtil::MakeShape(U32, {}), "divisor");
380     Div(dividend, divisor);
381     TF_ASSERT_OK_AND_ASSIGN(div_computation, builder.Build());
382   }
383 
384   for (uint32 divisor : vals) {
385     if (divisor != 0) {
386       for (uint32 dividend : vals) {
387         auto dividend_literal = LiteralUtil::CreateR0<uint32>(dividend);
388         auto divisor_literal = LiteralUtil::CreateR0<uint32>(divisor);
389         TF_ASSERT_OK_AND_ASSIGN(auto dividend_data,
390                                 client_->TransferToServer(dividend_literal));
391         TF_ASSERT_OK_AND_ASSIGN(auto divisor_data,
392                                 client_->TransferToServer(divisor_literal));
393         auto actual_literal =
394             client_
395                 ->ExecuteAndTransfer(div_computation,
396                                      {dividend_data.get(), divisor_data.get()},
397                                      &execution_options_)
398                 .ConsumeValueOrDie();
399         auto expected_literal =
400             LiteralUtil::CreateR0<uint32>(dividend / divisor);
401         EXPECT_TRUE(LiteralTestUtil::Equal(expected_literal, actual_literal));
402       }
403     }
404   }
405 }
406 
XLA_TEST_F(ScalarComputationsTest,RemU32s)407 XLA_TEST_F(ScalarComputationsTest, RemU32s) {
408   // clang-format off
409   // Some interesting values to test.
410   std::vector<uint32> vals = {
411     0, 1, 2, 17, 101, 3333, 0x7FFFFFFF, 0x80000000, UINT32_MAX - 1, UINT32_MAX};
412   // clang-format on
413 
414   XlaComputation rem_computation;
415   {
416     XlaBuilder builder(TestName());
417 
418     XlaOp dividend =
419         Parameter(&builder, 0, ShapeUtil::MakeShape(U32, {}), "dividend");
420     XlaOp divisor =
421         Parameter(&builder, 1, ShapeUtil::MakeShape(U32, {}), "divisor");
422     Rem(dividend, divisor);
423     TF_ASSERT_OK_AND_ASSIGN(rem_computation, builder.Build());
424   }
425 
426   for (uint32 divisor : vals) {
427     if (divisor != 0) {
428       for (uint32 dividend : vals) {
429         auto dividend_literal = LiteralUtil::CreateR0<uint32>(dividend);
430         auto divisor_literal = LiteralUtil::CreateR0<uint32>(divisor);
431         TF_ASSERT_OK_AND_ASSIGN(auto dividend_data,
432                                 client_->TransferToServer(dividend_literal));
433         TF_ASSERT_OK_AND_ASSIGN(auto divisor_data,
434                                 client_->TransferToServer(divisor_literal));
435         auto actual_literal =
436             client_
437                 ->ExecuteAndTransfer(rem_computation,
438                                      {dividend_data.get(), divisor_data.get()},
439                                      &execution_options_)
440                 .ConsumeValueOrDie();
441         auto expected_literal =
442             LiteralUtil::CreateR0<uint32>(dividend % divisor);
443         EXPECT_TRUE(LiteralTestUtil::Equal(expected_literal, actual_literal));
444       }
445     }
446   }
447 }
448 
XLA_TEST_F(ScalarComputationsTest,RemainderTwoScalarsNonConstDividendS32)449 XLA_TEST_F(ScalarComputationsTest, RemainderTwoScalarsNonConstDividendS32) {
450   XlaBuilder builder(TestName());
451   auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(S32, {}), "x");
452   Rem(x, ConstantR0<int32>(&builder, 80000));
453 
454   Literal literal = LiteralUtil::CreateR0<int32>(87919);
455   TF_ASSERT_OK_AND_ASSIGN(auto input_data, client_->TransferToServer(literal));
456   ComputeAndCompareR0<int32>(&builder, 7919, {input_data.get()});
457 }
458 
XLA_TEST_F(ScalarComputationsTest,DivideTwoScalarsU32)459 XLA_TEST_F(ScalarComputationsTest, DivideTwoScalarsU32) {
460   XlaBuilder builder(TestName());
461   // This verifies 0xFFFFFFFE / 2 = 0x7FFFFFFF. If XLA incorrectly treated U32
462   // as S32, it would output -2 / 2 = -1 (0xFFFFFFFF).
463   Div(ConstantR0<uint32>(&builder, 0xFFFFFFFE),
464       ConstantR0<uint32>(&builder, 2));
465 
466   ComputeAndCompareR0<uint32>(&builder, 0x7FFFFFFF, {});
467 }
468 
XLA_TEST_F(ScalarComputationsTest,RemTwoScalarsU32)469 XLA_TEST_F(ScalarComputationsTest, RemTwoScalarsU32) {
470   XlaBuilder builder(TestName());
471   Rem(ConstantR0<uint32>(&builder, 11), ConstantR0<uint32>(&builder, 3));
472 
473   ComputeAndCompareR0<uint32>(&builder, 2, {});
474 }
475 
XLA_TEST_F(ScalarComputationsTest,AndBool)476 XLA_TEST_F(ScalarComputationsTest, AndBool) {
477   for (bool x : {false, true}) {
478     for (bool y : {false, true}) {
479       XlaBuilder builder(TestName());
480       And(ConstantR0<bool>(&builder, x), ConstantR0<bool>(&builder, y));
481 
482       ComputeAndCompareR0<bool>(&builder, x && y, {});
483     }
484   }
485 }
486 
XLA_TEST_F(ScalarComputationsTest,AndS32)487 XLA_TEST_F(ScalarComputationsTest, AndS32) {
488   for (int32_t x : {0, 8}) {
489     for (int32_t y : {1, -16}) {
490       XlaBuilder builder(TestName());
491       And(ConstantR0<int32>(&builder, x), ConstantR0<int32>(&builder, y));
492 
493       ComputeAndCompareR0<int32>(&builder, x & y, {});
494     }
495   }
496 }
497 
XLA_TEST_F(ScalarComputationsTest,AndU32)498 XLA_TEST_F(ScalarComputationsTest, AndU32) {
499   for (uint32 x : {0, 8}) {
500     for (uint32 y : {1, 16}) {
501       XlaBuilder builder(TestName());
502       And(ConstantR0<uint32>(&builder, x), ConstantR0<uint32>(&builder, y));
503 
504       ComputeAndCompareR0<uint32>(&builder, x & y, {});
505     }
506   }
507 }
508 
XLA_TEST_F(ScalarComputationsTest,OrBool)509 XLA_TEST_F(ScalarComputationsTest, OrBool) {
510   for (bool x : {false, true}) {
511     for (bool y : {false, true}) {
512       XlaBuilder builder(TestName());
513       Or(ConstantR0<bool>(&builder, x), ConstantR0<bool>(&builder, y));
514 
515       ComputeAndCompareR0<bool>(&builder, x || y, {});
516     }
517   }
518 }
519 
XLA_TEST_F(ScalarComputationsTest,OrS32)520 XLA_TEST_F(ScalarComputationsTest, OrS32) {
521   for (int32_t x : {0, 8}) {
522     for (int32_t y : {1, -16}) {
523       XlaBuilder builder(TestName());
524       Or(ConstantR0<int32>(&builder, x), ConstantR0<int32>(&builder, y));
525 
526       ComputeAndCompareR0<int32>(&builder, x | y, {});
527     }
528   }
529 }
530 
XLA_TEST_F(ScalarComputationsTest,OrU32)531 XLA_TEST_F(ScalarComputationsTest, OrU32) {
532   for (uint32 x : {0, 8}) {
533     for (uint32 y : {1, 16}) {
534       XlaBuilder builder(TestName());
535       Or(ConstantR0<uint32>(&builder, x), ConstantR0<uint32>(&builder, y));
536 
537       ComputeAndCompareR0<uint32>(&builder, x | y, {});
538     }
539   }
540 }
541 
XLA_TEST_F(ScalarComputationsTest,NotBool)542 XLA_TEST_F(ScalarComputationsTest, NotBool) {
543   for (bool x : {false, true}) {
544     XlaBuilder builder(TestName());
545     Not(ConstantR0<bool>(&builder, x));
546 
547     ComputeAndCompareR0<bool>(&builder, !x, {});
548   }
549 }
550 
XLA_TEST_F(ScalarComputationsTest,NotS32)551 XLA_TEST_F(ScalarComputationsTest, NotS32) {
552   for (int32_t x : {-1, 0, 1}) {
553     XlaBuilder builder(TestName());
554     Not(ConstantR0<int32>(&builder, x));
555 
556     ComputeAndCompareR0<int32>(&builder, ~x, {});
557   }
558 }
559 
XLA_TEST_F(ScalarComputationsTest,NotU32)560 XLA_TEST_F(ScalarComputationsTest, NotU32) {
561   for (uint32 x : {0, 1, 2}) {
562     XlaBuilder builder(TestName());
563     Not(ConstantR0<uint32>(&builder, x));
564 
565     ComputeAndCompareR0<uint32>(&builder, ~x, {});
566   }
567 }
568 
XLA_TEST_F(ScalarComputationsTest,SelectScalarTrue)569 XLA_TEST_F(ScalarComputationsTest, SelectScalarTrue) {
570   XlaBuilder builder(TestName());
571   Select(ConstantR0<bool>(&builder, true),     // The predicate.
572          ConstantR0<float>(&builder, 123.0f),  // The value on true.
573          ConstantR0<float>(&builder, 42.0f));  // The value on false.
574 
575   ComputeAndCompareR0<float>(&builder, 123.0f, {}, error_spec_);
576 }
577 
XLA_TEST_F(ScalarComputationsTest,SelectScalarFalse)578 XLA_TEST_F(ScalarComputationsTest, SelectScalarFalse) {
579   XlaBuilder builder(TestName());
580   Select(ConstantR0<bool>(&builder, false),    // The predicate.
581          ConstantR0<float>(&builder, 123.0f),  // The value on true.
582          ConstantR0<float>(&builder, 42.0f));  // The value on false.
583 
584   ComputeAndCompareR0<float>(&builder, 42.0f, {}, error_spec_);
585 }
586 
587 // This test is an explicit version of what is happening in the following
588 // templatized comparison tests.
XLA_TEST_F(ScalarComputationsTest,CompareGtScalar)589 XLA_TEST_F(ScalarComputationsTest, CompareGtScalar) {
590   XlaBuilder builder(TestName());
591   Gt(ConstantR0<float>(&builder, 2.0f), ConstantR0<float>(&builder, 1.0f));
592 
593   ComputeAndCompareR0<bool>(&builder, true, {});
594 }
595 
596 // S32 comparisons.
XLA_TEST_F(ScalarComputationsTest,CompareEqS32Greater)597 XLA_TEST_F(ScalarComputationsTest, CompareEqS32Greater) {
598   TestCompare<int32>(2, 1, false, &Eq);
599 }
XLA_TEST_F(ScalarComputationsTest,CompareEqS32Equal)600 XLA_TEST_F(ScalarComputationsTest, CompareEqS32Equal) {
601   TestCompare<int32>(3, 3, true, &Eq);
602 }
603 
XLA_TEST_F(ScalarComputationsTest,CompareNeS32)604 XLA_TEST_F(ScalarComputationsTest, CompareNeS32) {
605   TestCompare<int32>(2, 1, true, &Ne);
606 }
607 
XLA_TEST_F(ScalarComputationsTest,CompareGeS32)608 XLA_TEST_F(ScalarComputationsTest, CompareGeS32) {
609   TestCompare<int32>(2, 1, true, &Ge);
610 }
611 
XLA_TEST_F(ScalarComputationsTest,CompareGtS32)612 XLA_TEST_F(ScalarComputationsTest, CompareGtS32) {
613   TestCompare<int32>(1, 5, false, &Gt);
614 }
615 
XLA_TEST_F(ScalarComputationsTest,CompareLeS32)616 XLA_TEST_F(ScalarComputationsTest, CompareLeS32) {
617   TestCompare<int32>(2, 1, false, &Le);
618 }
619 
XLA_TEST_F(ScalarComputationsTest,CompareLtS32)620 XLA_TEST_F(ScalarComputationsTest, CompareLtS32) {
621   TestCompare<int32>(9, 7, false, &Lt);
622   TestCompare<int32>(std::numeric_limits<int32>::min(),
623                      std::numeric_limits<int32>::max(), true, &Lt);
624 }
625 
626 // U32 comparisons.
XLA_TEST_F(ScalarComputationsTest,CompareEqU32False)627 XLA_TEST_F(ScalarComputationsTest, CompareEqU32False) {
628   TestCompare<uint32>(2, 1, false, &Eq);
629 }
630 
XLA_TEST_F(ScalarComputationsTest,CompareNeU32)631 XLA_TEST_F(ScalarComputationsTest, CompareNeU32) {
632   TestCompare<uint32>(2, 1, true, &Ne);
633 }
634 
XLA_TEST_F(ScalarComputationsTest,CompareGeU32Greater)635 XLA_TEST_F(ScalarComputationsTest, CompareGeU32Greater) {
636   TestCompare<uint32>(2, 1, true, &Ge);
637 }
638 
XLA_TEST_F(ScalarComputationsTest,CompareGeU32Equal)639 XLA_TEST_F(ScalarComputationsTest, CompareGeU32Equal) {
640   TestCompare<uint32>(3, 3, true, &Ge);
641 }
642 
XLA_TEST_F(ScalarComputationsTest,CompareGtU32)643 XLA_TEST_F(ScalarComputationsTest, CompareGtU32) {
644   TestCompare<uint32>(1, 5, false, &Gt);
645   TestCompare<uint32>(5, 5, false, &Gt);
646   TestCompare<uint32>(5, 1, true, &Gt);
647 }
648 
XLA_TEST_F(ScalarComputationsTest,CompareLeU32)649 XLA_TEST_F(ScalarComputationsTest, CompareLeU32) {
650   TestCompare<uint32>(2, 1, false, &Le);
651 }
652 
XLA_TEST_F(ScalarComputationsTest,CompareLtU32)653 XLA_TEST_F(ScalarComputationsTest, CompareLtU32) {
654   TestCompare<uint32>(9, 7, false, &Lt);
655   TestCompare<uint32>(0, std::numeric_limits<uint32>::max(), true, &Lt);
656 }
657 
658 // F32 comparisons.
XLA_TEST_F(ScalarComputationsTest,CompareEqF32False)659 XLA_TEST_F(ScalarComputationsTest, CompareEqF32False) {
660   TestCompare<float>(2.0, 1.3, false, &Eq);
661 }
662 
XLA_TEST_F(ScalarComputationsTest,CompareNeF32)663 XLA_TEST_F(ScalarComputationsTest, CompareNeF32) {
664   TestCompare<float>(2.0, 1.3, true, &Ne);
665 }
666 
XLA_TEST_F(ScalarComputationsTest,CompareGeF32Greater)667 XLA_TEST_F(ScalarComputationsTest, CompareGeF32Greater) {
668   TestCompare<float>(2.0, 1.9, true, &Ge);
669 }
XLA_TEST_F(ScalarComputationsTest,CompareGeF32Equal)670 XLA_TEST_F(ScalarComputationsTest, CompareGeF32Equal) {
671   TestCompare<float>(3.5, 3.5, true, &Ge);
672 }
673 
XLA_TEST_F(ScalarComputationsTest,CompareGtF32)674 XLA_TEST_F(ScalarComputationsTest, CompareGtF32) {
675   TestCompare<float>(1.0, 5.2, false, &Gt);
676 }
677 
XLA_TEST_F(ScalarComputationsTest,CompareLeF32)678 XLA_TEST_F(ScalarComputationsTest, CompareLeF32) {
679   TestCompare<float>(2.0, 1.2, false, &Le);
680 }
681 
XLA_TEST_F(ScalarComputationsTest,CompareLtF32)682 XLA_TEST_F(ScalarComputationsTest, CompareLtF32) {
683   TestCompare<float>(9.0, 7.2, false, &Lt);
684 }
685 
686 // F32 comparisons with exceptional values.  The test names encode the
687 // left/right operands at the end, and use Minf and Mzero for -inf and -0.0.
XLA_TEST_F(ScalarComputationsTest,CompareLtF32MinfMzero)688 XLA_TEST_F(ScalarComputationsTest, CompareLtF32MinfMzero) {
689   TestCompare<float>(-INFINITY, -0.0, true, &Lt);
690 }
XLA_TEST_F(ScalarComputationsTest,CompareLtF32MzeroZero)691 XLA_TEST_F(ScalarComputationsTest, CompareLtF32MzeroZero) {
692   // Comparisons of 0.0 to -0.0 consider them equal in IEEE 754.
693   TestCompare<float>(-0.0, 0.0, false, &Lt);
694 }
XLA_TEST_F(ScalarComputationsTest,CompareLtF32ZeroInf)695 XLA_TEST_F(ScalarComputationsTest, CompareLtF32ZeroInf) {
696   TestCompare<float>(0.0, INFINITY, true, &Lt);
697 }
698 
XLA_TEST_F(ScalarComputationsTest,CompareGeF32MinfMzero)699 XLA_TEST_F(ScalarComputationsTest, CompareGeF32MinfMzero) {
700   TestCompare<float>(-INFINITY, -0.0, false, &Ge);
701 }
XLA_TEST_F(ScalarComputationsTest,CompareGeF32MzeroZero)702 XLA_TEST_F(ScalarComputationsTest, CompareGeF32MzeroZero) {
703   // Comparisons of 0.0 to -0.0 consider them equal in IEEE 754.
704   TestCompare<float>(-0.0, 0.0, true, &Ge);
705 }
XLA_TEST_F(ScalarComputationsTest,CompareGeF32ZeroInf)706 XLA_TEST_F(ScalarComputationsTest, CompareGeF32ZeroInf) {
707   TestCompare<float>(0.0, INFINITY, false, &Ge);
708 }
709 
XLA_TEST_F(ScalarComputationsTest,ExpScalar)710 XLA_TEST_F(ScalarComputationsTest, ExpScalar) {
711   XlaBuilder builder(TestName());
712   Exp(ConstantR0<float>(&builder, 2.0f));
713 
714   ComputeAndCompareR0<float>(&builder, 7.3890562, {}, error_spec_);
715 }
716 
XLA_TEST_F(ScalarComputationsTest,LogScalar)717 XLA_TEST_F(ScalarComputationsTest, LogScalar) {
718   XlaBuilder builder("log");
719   Log(ConstantR0<float>(&builder, 2.0f));
720 
721   ComputeAndCompareR0<float>(&builder, 0.6931471, {}, error_spec_);
722 }
723 
XLA_TEST_F(ScalarComputationsTest,TanhScalar)724 XLA_TEST_F(ScalarComputationsTest, TanhScalar) {
725   XlaBuilder builder(TestName());
726   Tanh(ConstantR0<float>(&builder, 2.0f));
727 
728   ComputeAndCompareR0<float>(&builder, 0.96402758, {}, error_spec_);
729 }
730 
XLA_TEST_F(ScalarComputationsTest,TanhDoubleScalar)731 XLA_TEST_F(ScalarComputationsTest, TanhDoubleScalar) {
732   XlaBuilder builder(TestName());
733   Tanh(ConstantR0<double>(&builder, 2.0));
734 
735   ComputeAndCompareR0<double>(&builder, 0.96402758, {}, error_spec_);
736 }
737 
XLA_TEST_F(ScalarComputationsTest,PowScalar)738 XLA_TEST_F(ScalarComputationsTest, PowScalar) {
739   XlaBuilder builder(TestName());
740   Pow(ConstantR0<float>(&builder, 2.0f), ConstantR0<float>(&builder, 3.0f));
741 
742   ComputeAndCompareR0<float>(&builder, 8.0, {}, error_spec_);
743 }
744 
XLA_TEST_F(ScalarComputationsTest,CbrtScalar)745 XLA_TEST_F(ScalarComputationsTest, CbrtScalar) {
746   XlaBuilder builder(TestName());
747   Cbrt(ConstantR0<float>(&builder, 2.0f));
748 
749   ComputeAndCompare(&builder, {}, error_spec_);
750 }
751 
XLA_TEST_F(ScalarComputationsTest,ClampScalarHighS32)752 XLA_TEST_F(ScalarComputationsTest, ClampScalarHighS32) {
753   XlaBuilder builder(TestName());
754   Clamp(ConstantR0<int32>(&builder, -1),  // The lower bound.
755         ConstantR0<int32>(&builder, 5),   // The operand to be clamped.
756         ConstantR0<int32>(&builder, 3));  // The upper bound.
757 
758   ComputeAndCompareR0<int32>(&builder, 3, {});
759 }
760 
XLA_TEST_F(ScalarComputationsTest,ClampScalarMiddleS32)761 XLA_TEST_F(ScalarComputationsTest, ClampScalarMiddleS32) {
762   XlaBuilder builder(TestName());
763   Clamp(ConstantR0<int32>(&builder, -1),  // The lower bound.
764         ConstantR0<int32>(&builder, 2),   // The operand to be clamped.
765         ConstantR0<int32>(&builder, 3));  // The upper bound.
766 
767   ComputeAndCompareR0<int32>(&builder, 2, {});
768 }
769 
XLA_TEST_F(ScalarComputationsTest,ClampScalarLowS32)770 XLA_TEST_F(ScalarComputationsTest, ClampScalarLowS32) {
771   XlaBuilder builder(TestName());
772   Clamp(ConstantR0<int32>(&builder, -1),  // The lower bound.
773         ConstantR0<int32>(&builder, -5),  // The operand to be clamped.
774         ConstantR0<int32>(&builder, 3));  // The upper bound.
775 
776   ComputeAndCompareR0<int32>(&builder, -1, {});
777 }
778 
XLA_TEST_F(ScalarComputationsTest,ClampScalarHighU32)779 XLA_TEST_F(ScalarComputationsTest, ClampScalarHighU32) {
780   XlaBuilder builder(TestName());
781   Clamp(ConstantR0<uint32>(&builder, 1),   // The lower bound.
782         ConstantR0<uint32>(&builder, 5),   // The operand to be clamped.
783         ConstantR0<uint32>(&builder, 3));  // The upper bound.
784 
785   ComputeAndCompareR0<uint32>(&builder, 3, {});
786 }
787 
XLA_TEST_F(ScalarComputationsTest,ClampScalarMiddleU32)788 XLA_TEST_F(ScalarComputationsTest, ClampScalarMiddleU32) {
789   XlaBuilder builder(TestName());
790   Clamp(ConstantR0<uint32>(&builder, 1),   // The lower bound.
791         ConstantR0<uint32>(&builder, 2),   // The operand to be clamped.
792         ConstantR0<uint32>(&builder, 3));  // The upper bound.
793 
794   ComputeAndCompareR0<uint32>(&builder, 2, {});
795 }
796 
XLA_TEST_F(ScalarComputationsTest,ClampScalarLowU32)797 XLA_TEST_F(ScalarComputationsTest, ClampScalarLowU32) {
798   XlaBuilder builder(TestName());
799   Clamp(ConstantR0<uint32>(&builder, 1),   // The lower bound.
800         ConstantR0<uint32>(&builder, 0),   // The operand to be clamped.
801         ConstantR0<uint32>(&builder, 3));  // The upper bound.
802 
803   ComputeAndCompareR0<uint32>(&builder, 1, {});
804 }
805 
XLA_TEST_F(ScalarComputationsTest,ClampScalarHighF32)806 XLA_TEST_F(ScalarComputationsTest, ClampScalarHighF32) {
807   XlaBuilder builder(TestName());
808   Clamp(ConstantR0<float>(&builder, 2.0f),   // The lower bound.
809         ConstantR0<float>(&builder, 5.0f),   // The operand to be clamped.
810         ConstantR0<float>(&builder, 3.0f));  // The upper bound.
811 
812   ComputeAndCompareR0<float>(&builder, 3.0, {}, error_spec_);
813 }
814 
XLA_TEST_F(ScalarComputationsTest,ClampScalarMiddleF32)815 XLA_TEST_F(ScalarComputationsTest, ClampScalarMiddleF32) {
816   XlaBuilder builder(TestName());
817   Clamp(ConstantR0<float>(&builder, 2.0f),   // The lower bound.
818         ConstantR0<float>(&builder, 2.5f),   // The operand to be clamped.
819         ConstantR0<float>(&builder, 3.0f));  // The upper bound.
820 
821   ComputeAndCompareR0<float>(&builder, 2.5, {}, error_spec_);
822 }
823 
XLA_TEST_F(ScalarComputationsTest,ClampScalarLowF32)824 XLA_TEST_F(ScalarComputationsTest, ClampScalarLowF32) {
825   XlaBuilder builder(TestName());
826   Clamp(ConstantR0<float>(&builder, 2.0f),   // The lower bound.
827         ConstantR0<float>(&builder, -5.0f),  // The operand to be clamped.
828         ConstantR0<float>(&builder, 3.0f));  // The upper bound.
829 
830   ComputeAndCompareR0<float>(&builder, 2.0, {}, error_spec_);
831 }
832 
XLA_TEST_F(ScalarComputationsTest,MinS32Above)833 XLA_TEST_F(ScalarComputationsTest, MinS32Above) {
834   TestMinMax<int32>(10, 3, 3, &Min);
835 }
836 
XLA_TEST_F(ScalarComputationsTest,MinS32Below)837 XLA_TEST_F(ScalarComputationsTest, MinS32Below) {
838   TestMinMax<int32>(-100, 3, -100, &Min);
839 }
840 
XLA_TEST_F(ScalarComputationsTest,MaxS32Above)841 XLA_TEST_F(ScalarComputationsTest, MaxS32Above) {
842   TestMinMax<int32>(10, 3, 10, &Max);
843 }
844 
XLA_TEST_F(ScalarComputationsTest,MaxS32Below)845 XLA_TEST_F(ScalarComputationsTest, MaxS32Below) {
846   TestMinMax<int32>(-100, 3, 3, &Max);
847 }
848 
XLA_TEST_F(ScalarComputationsTest,MinU32Above)849 XLA_TEST_F(ScalarComputationsTest, MinU32Above) {
850   const uint32 large = std::numeric_limits<int32>::max();
851   TestMinMax<uint32>(large, 3, 3, &Min);
852 }
853 
XLA_TEST_F(ScalarComputationsTest,MinU32Below)854 XLA_TEST_F(ScalarComputationsTest, MinU32Below) {
855   TestMinMax<uint32>(0, 5, 0, &Min);
856 }
857 
XLA_TEST_F(ScalarComputationsTest,MaxU32Above)858 XLA_TEST_F(ScalarComputationsTest, MaxU32Above) {
859   const uint32 large = std::numeric_limits<int32>::max();
860   TestMinMax<uint32>(large, 3, large, &Max);
861 }
862 
XLA_TEST_F(ScalarComputationsTest,MaxU32Below)863 XLA_TEST_F(ScalarComputationsTest, MaxU32Below) {
864   TestMinMax<uint32>(0, 5, 5, &Max);
865 }
866 
XLA_TEST_F(ScalarComputationsTest,MinF32Above)867 XLA_TEST_F(ScalarComputationsTest, MinF32Above) {
868   TestMinMax<float>(10.1f, 3.1f, 3.1f, &Min);
869 }
870 
XLA_TEST_F(ScalarComputationsTest,MinF32Below)871 XLA_TEST_F(ScalarComputationsTest, MinF32Below) {
872   TestMinMax<float>(-100.1f, 3.1f, -100.1f, &Min);
873 }
874 
XLA_TEST_F(ScalarComputationsTest,MinPropagatesNan)875 XLA_TEST_F(ScalarComputationsTest, MinPropagatesNan) {
876   SetFastMathDisabled(true);
877   TestMinMax<float>(NAN, 3.1f, NAN, &Min);
878   TestMinMax<float>(-3.1f, NAN, NAN, &Min);
879 }
880 
XLA_TEST_F(ScalarComputationsTest,MaxF32Above)881 XLA_TEST_F(ScalarComputationsTest, MaxF32Above) {
882   TestMinMax<float>(10.1f, 3.1f, 10.1f, &Max);
883 }
884 
XLA_TEST_F(ScalarComputationsTest,MaxF32Below)885 XLA_TEST_F(ScalarComputationsTest, MaxF32Below) {
886   TestMinMax<float>(-100.1f, 3.1f, 3.1f, &Max);
887 }
888 
XLA_TEST_F(ScalarComputationsTest,MaxPropagatesNan)889 XLA_TEST_F(ScalarComputationsTest, MaxPropagatesNan) {
890   SetFastMathDisabled(true);
891   TestMinMax<float>(NAN, 3.1f, NAN, &Max);
892   TestMinMax<float>(-3.1f, NAN, NAN, &Max);
893 }
894 
XLA_TEST_F(ScalarComputationsTest,ComplicatedArithmeticExpressionF32)895 XLA_TEST_F(ScalarComputationsTest, ComplicatedArithmeticExpressionF32) {
896   // Compute the expression (1 * (3 - 1) * (7 + 0) - 4) / 20.
897   XlaBuilder b(TestName());
898   Div(Sub(Mul(ConstantR0<float>(&b, 1),
899               Mul(Sub(ConstantR0<float>(&b, 3), ConstantR0<float>(&b, 1)),
900                   Add(ConstantR0<float>(&b, 7), ConstantR0<float>(&b, 0)))),
901           ConstantR0<float>(&b, 4)),
902       ConstantR0<float>(&b, 20));
903 
904   ComputeAndCompareR0<float>(&b, 0.5, {}, error_spec_);
905 }
906 
XLA_TEST_F(ScalarComputationsTest,ComplicatedArithmeticExpressionS32)907 XLA_TEST_F(ScalarComputationsTest, ComplicatedArithmeticExpressionS32) {
908   // Compute the expression 1 * (3 - 1) * (7 + 0) - 4.
909   XlaBuilder b(TestName());
910   Sub(Mul(ConstantR0<int32>(&b, 1),
911           Mul(Sub(ConstantR0<int32>(&b, 3), ConstantR0<int32>(&b, 1)),
912               Add(ConstantR0<int32>(&b, 7), ConstantR0<int32>(&b, 0)))),
913       ConstantR0<int32>(&b, 4));
914 
915   ComputeAndCompareR0<int32>(&b, 10, {});
916 }
917 
XLA_TEST_F(ScalarComputationsTest,RoundScalar)918 XLA_TEST_F(ScalarComputationsTest, RoundScalar) {
919   XlaBuilder builder(TestName());
920   Round(ConstantR0<float>(&builder, 1.4f));
921 
922   ComputeAndCompareR0<float>(&builder, 1.0f, {}, error_spec_);
923 }
924 
925 }  // namespace
926 }  // namespace xla
927