1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <cmath>
17 #include <limits>
18 #include <memory>
19
20 #include "absl/strings/str_cat.h"
21 #include "absl/types/span.h"
22 #include "tensorflow/compiler/xla/client/global_data.h"
23 #include "tensorflow/compiler/xla/client/local_client.h"
24 #include "tensorflow/compiler/xla/client/xla_builder.h"
25 #include "tensorflow/compiler/xla/client/xla_computation.h"
26 #include "tensorflow/compiler/xla/literal.h"
27 #include "tensorflow/compiler/xla/literal_util.h"
28 #include "tensorflow/compiler/xla/status_macros.h"
29 #include "tensorflow/compiler/xla/statusor.h"
30 #include "tensorflow/compiler/xla/test_helpers.h"
31 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
32 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
33 #include "tensorflow/compiler/xla/tests/test_macros.h"
34 #include "tensorflow/compiler/xla/xla_data.pb.h"
35 #include "tensorflow/core/platform/test.h"
36
37 namespace xla {
38 namespace {
39
40 class ScalarComputationsTest : public ClientLibraryTestBase {
41 public:
42 ErrorSpec error_spec_{0.0001};
43
44 protected:
45 // A template for building and running a binary comparison test.
46 template <typename NativeT>
TestCompare(NativeT lhs,NativeT rhs,bool expected,const std::function<XlaOp (const XlaOp &,const XlaOp &,absl::Span<const int64_t>)> & op)47 void TestCompare(NativeT lhs, NativeT rhs, bool expected,
48 const std::function<XlaOp(const XlaOp&, const XlaOp&,
49 absl::Span<const int64_t>)>& op) {
50 XlaBuilder builder(TestName());
51 XlaOp lhs_op = ConstantR0<NativeT>(&builder, lhs);
52 XlaOp rhs_op = ConstantR0<NativeT>(&builder, rhs);
53 op(lhs_op, rhs_op, {});
54 ComputeAndCompareR0<bool>(&builder, expected, {});
55 }
56
57 template <typename NativeT>
TestMinMax(NativeT lhs,NativeT rhs,NativeT expected,const std::function<XlaOp (const XlaOp &,const XlaOp &,absl::Span<const int64_t>)> & op)58 void TestMinMax(NativeT lhs, NativeT rhs, NativeT expected,
59 const std::function<XlaOp(const XlaOp&, const XlaOp&,
60 absl::Span<const int64_t>)>& op) {
61 XlaBuilder builder(TestName());
62 XlaOp lhs_op = ConstantR0<NativeT>(&builder, lhs);
63 XlaOp rhs_op = ConstantR0<NativeT>(&builder, rhs);
64 op(lhs_op, rhs_op, {});
65 ComputeAndCompareR0<NativeT>(&builder, expected, {});
66 }
67 };
68
XLA_TEST_F(ScalarComputationsTest,ReturnScalarF32)69 XLA_TEST_F(ScalarComputationsTest, ReturnScalarF32) {
70 XlaBuilder builder(TestName());
71 ConstantR0<float>(&builder, 2.1f);
72
73 ComputeAndCompareR0<float>(&builder, 2.1f, {}, error_spec_);
74 }
75
XLA_TEST_F(ScalarComputationsTest,NegateScalarF32)76 XLA_TEST_F(ScalarComputationsTest, NegateScalarF32) {
77 XlaBuilder builder(TestName());
78 Neg(ConstantR0<float>(&builder, 2.1f));
79
80 ComputeAndCompareR0<float>(&builder, -2.1f, {}, error_spec_);
81 }
82
XLA_TEST_F(ScalarComputationsTest,NegateScalarS32)83 XLA_TEST_F(ScalarComputationsTest, NegateScalarS32) {
84 XlaBuilder builder(TestName());
85 Neg(ConstantR0<int32_t>(&builder, 2));
86
87 ComputeAndCompareR0<int32_t>(&builder, -2, {});
88 }
89
XLA_TEST_F(ScalarComputationsTest,AddTwoScalarsF32)90 XLA_TEST_F(ScalarComputationsTest, AddTwoScalarsF32) {
91 XlaBuilder builder(TestName());
92 Add(ConstantR0<float>(&builder, 2.1f), ConstantR0<float>(&builder, 5.5f));
93
94 ComputeAndCompareR0<float>(&builder, 7.6f, {}, error_spec_);
95 }
96
XLA_TEST_F(ScalarComputationsTest,AddTwoScalarsS32)97 XLA_TEST_F(ScalarComputationsTest, AddTwoScalarsS32) {
98 XlaBuilder builder(TestName());
99 Add(ConstantR0<int32_t>(&builder, 2), ConstantR0<int32_t>(&builder, 5));
100
101 ComputeAndCompareR0<int32_t>(&builder, 7, {});
102 }
103
XLA_TEST_F(ScalarComputationsTest,AddTwoScalarsU32)104 XLA_TEST_F(ScalarComputationsTest, AddTwoScalarsU32) {
105 XlaBuilder builder(TestName());
106 Add(ConstantR0<uint32_t>(&builder, 35), ConstantR0<uint32_t>(&builder, 57));
107
108 ComputeAndCompareR0<uint32_t>(&builder, 92, {});
109 }
110
XLA_TEST_F(ScalarComputationsTest,AddTwoScalarsU8)111 XLA_TEST_F(ScalarComputationsTest, AddTwoScalarsU8) {
112 XlaBuilder builder(TestName());
113 Add(ConstantR0<uint8_t>(&builder, 35), ConstantR0<uint8_t>(&builder, 57));
114
115 ComputeAndCompareR0<uint8_t>(&builder, 92, {});
116 }
117
XLA_TEST_F(ScalarComputationsTest,AddTwoScalarsU64)118 XLA_TEST_F(ScalarComputationsTest, AddTwoScalarsU64) {
119 XlaBuilder builder(TestName());
120 const uint64_t a = static_cast<uint64_t>(1) << 63;
121 const uint64_t b = a + 1;
122 Add(ConstantR0<uint64_t>(&builder, a), ConstantR0<uint64_t>(&builder, b));
123
124 ComputeAndCompareR0<uint64_t>(&builder, a + b, {});
125 }
126
XLA_TEST_F(ScalarComputationsTest,AddTwoScalarsS64)127 XLA_TEST_F(ScalarComputationsTest, AddTwoScalarsS64) {
128 XlaBuilder builder(TestName());
129 const int64_t a = static_cast<int64_t>(1) << 62;
130 const int64_t b = a - 1;
131 Add(ConstantR0<int64_t>(&builder, a), ConstantR0<int64_t>(&builder, b));
132
133 ComputeAndCompareR0<int64_t>(&builder, a + b, {});
134 }
135
XLA_TEST_F(ScalarComputationsTest,AddTwoScalarsF64)136 XLA_TEST_F(ScalarComputationsTest, AddTwoScalarsF64) {
137 XlaBuilder builder(TestName());
138 Add(ConstantR0<double>(&builder, 0.25), ConstantR0<double>(&builder, 3.5));
139
140 ComputeAndCompareR0<double>(&builder, 3.75, {});
141 }
142
XLA_TEST_F(ScalarComputationsTest,SubtractTwoScalarsF32)143 XLA_TEST_F(ScalarComputationsTest, SubtractTwoScalarsF32) {
144 XlaBuilder builder(TestName());
145 Sub(ConstantR0<float>(&builder, 2.1f), ConstantR0<float>(&builder, 5.5f));
146
147 ComputeAndCompareR0<float>(&builder, -3.4f, {}, error_spec_);
148 }
149
XLA_TEST_F(ScalarComputationsTest,SubtractTwoScalarsS32)150 XLA_TEST_F(ScalarComputationsTest, SubtractTwoScalarsS32) {
151 XlaBuilder builder(TestName());
152 Sub(ConstantR0<int32_t>(&builder, 2), ConstantR0<int32_t>(&builder, 5));
153
154 ComputeAndCompareR0<int32_t>(&builder, -3, {});
155 }
156
XLA_TEST_F(ScalarComputationsTest,CastS64ToF32)157 XLA_TEST_F(ScalarComputationsTest, CastS64ToF32) {
158 XlaBuilder builder(TestName());
159 auto a = Parameter(&builder, 0, ShapeUtil::MakeShape(S64, {}), "a");
160 ConvertElementType(a, F32);
161
162 int64_t value = 3LL << 35;
163 Literal a_literal = LiteralUtil::CreateR0<int64_t>(value);
164 std::unique_ptr<GlobalData> a_data =
165 client_->TransferToServer(a_literal).value();
166 ComputeAndCompareR0<float>(&builder, static_cast<float>(value),
167 {a_data.get()});
168 }
169
XLA_TEST_F(ScalarComputationsTest,MulThreeScalarsF32)170 XLA_TEST_F(ScalarComputationsTest, MulThreeScalarsF32) {
171 XlaBuilder builder(TestName());
172 Mul(Mul(ConstantR0<float>(&builder, 2.1f), ConstantR0<float>(&builder, 5.5f)),
173 ConstantR0<float>(&builder, 0.5f));
174
175 ComputeAndCompareR0<float>(&builder, 5.775f, {}, error_spec_);
176 }
177
XLA_TEST_F(ScalarComputationsTest,MulThreeScalarsF64)178 XLA_TEST_F(ScalarComputationsTest, MulThreeScalarsF64) {
179 XlaBuilder builder(TestName());
180 Mul(Mul(ConstantR0<double>(&builder, 3.1415926535897932),
181 ConstantR0<double>(&builder, 2.7182818284590452)),
182 ConstantR0<double>(&builder, 0.5772156649015328));
183
184 ComputeAndCompareR0<double>(&builder, 4.929268367422896, {},
185 ErrorSpec{3.6e-15});
186 }
187
XLA_TEST_F(ScalarComputationsTest,MulTwoScalarsS32)188 XLA_TEST_F(ScalarComputationsTest, MulTwoScalarsS32) {
189 std::vector<int32_t> data = {0,
190 1,
191 -1,
192 1234,
193 0x1a243514,
194 std::numeric_limits<int32_t>::max(),
195 std::numeric_limits<int32_t>::min()};
196
197 for (int32_t x : data) {
198 for (int32_t y : data) {
199 XlaBuilder builder(TestName());
200 Mul(ConstantR0<int32_t>(&builder, x), ConstantR0<int32_t>(&builder, y));
201
202 // Signed integer overflow is undefined behavior in C++. Convert the input
203 // integers to unsigned, perform the multiplication unsigned, and convert
204 // back.
205 int32_t expected = static_cast<uint32_t>(x) * static_cast<uint32_t>(y);
206
207 ComputeAndCompareR0<int32_t>(&builder, expected, {});
208 }
209 }
210 }
211
XLA_TEST_F(ScalarComputationsTest,MulTwoScalarsU32)212 XLA_TEST_F(ScalarComputationsTest, MulTwoScalarsU32) {
213 std::vector<uint32_t> data = {0, 1, 0xDEADBEEF, 1234,
214 0x1a243514, 0xFFFFFFFF, 0x80808080};
215
216 for (uint32_t x : data) {
217 for (uint32_t y : data) {
218 XlaBuilder builder(TestName());
219 Mul(ConstantR0<uint32_t>(&builder, x), ConstantR0<uint32_t>(&builder, y));
220
221 uint32_t expected = x * y;
222 ComputeAndCompareR0<uint32_t>(&builder, expected, {});
223 }
224 }
225 }
226
XLA_TEST_F(ScalarComputationsTest,MulThreeScalarsS32)227 XLA_TEST_F(ScalarComputationsTest, MulThreeScalarsS32) {
228 XlaBuilder builder(TestName());
229 Mul(Mul(ConstantR0<int32_t>(&builder, 2), ConstantR0<int32_t>(&builder, 5)),
230 ConstantR0<int32_t>(&builder, 1));
231
232 ComputeAndCompareR0<int32_t>(&builder, 10, {});
233 }
234
XLA_TEST_F(ScalarComputationsTest,MulThreeScalarsF32Params)235 XLA_TEST_F(ScalarComputationsTest, MulThreeScalarsF32Params) {
236 XlaBuilder builder(TestName());
237 Literal a_literal = LiteralUtil::CreateR0<float>(2.1f);
238 Literal b_literal = LiteralUtil::CreateR0<float>(5.5f);
239 Literal c_literal = LiteralUtil::CreateR0<float>(0.5f);
240
241 std::unique_ptr<GlobalData> a_data =
242 client_->TransferToServer(a_literal).value();
243 std::unique_ptr<GlobalData> b_data =
244 client_->TransferToServer(b_literal).value();
245 std::unique_ptr<GlobalData> c_data =
246 client_->TransferToServer(c_literal).value();
247
248 XlaOp a = Parameter(&builder, 0, a_literal.shape(), "a");
249 XlaOp b = Parameter(&builder, 1, b_literal.shape(), "b");
250 XlaOp c = Parameter(&builder, 2, c_literal.shape(), "c");
251 Mul(Mul(a, b), c);
252
253 ComputeAndCompareR0<float>(&builder, 5.775f,
254 {a_data.get(), b_data.get(), c_data.get()},
255 error_spec_);
256 }
257
XLA_TEST_F(ScalarComputationsTest,DivideTwoScalarsF32)258 XLA_TEST_F(ScalarComputationsTest, DivideTwoScalarsF32) {
259 XlaBuilder builder(TestName());
260 Div(ConstantR0<float>(&builder, 5.0f), ConstantR0<float>(&builder, 2.5f));
261
262 ComputeAndCompareR0<float>(&builder, 2.0f, {}, error_spec_);
263 }
264
XLA_TEST_F(ScalarComputationsTest,RemTwoScalarsF32)265 XLA_TEST_F(ScalarComputationsTest, RemTwoScalarsF32) {
266 XlaBuilder builder(TestName());
267 Rem(ConstantR0<float>(&builder, 2.5f), ConstantR0<float>(&builder, 5.0f));
268
269 ComputeAndCompareR0<float>(&builder, 2.5f, {}, error_spec_);
270 }
271
272 struct DivS32Params {
273 int32_t dividend;
274 int32_t divisor;
275 int32_t quotient;
276 int32_t remainder;
277 };
278
PrintTo(const DivS32Params & p,std::ostream * os)279 void PrintTo(const DivS32Params& p, std::ostream* os) {
280 *os << "{" << p.dividend << ", " << p.divisor << ", " << p.quotient << ", "
281 << p.remainder << "}";
282 }
283
284 class DivS32Test : public ClientLibraryTestBase,
285 public ::testing::WithParamInterface<DivS32Params> {};
286
XLA_TEST_P(DivS32Test,DivideTwoScalarsS32)287 XLA_TEST_P(DivS32Test, DivideTwoScalarsS32) {
288 DivS32Params p = GetParam();
289 XlaBuilder builder(TestName());
290 Div(ConstantR0<int32_t>(&builder, p.dividend),
291 ConstantR0<int32_t>(&builder, p.divisor));
292
293 ComputeAndCompareR0<int32_t>(&builder, p.quotient, {});
294 }
295
XLA_TEST_P(DivS32Test,RemainderTwoScalarsS32)296 XLA_TEST_P(DivS32Test, RemainderTwoScalarsS32) {
297 DivS32Params p = GetParam();
298 XlaBuilder builder(TestName());
299 Rem(ConstantR0<int32_t>(&builder, p.dividend),
300 ConstantR0<int32_t>(&builder, p.divisor));
301
302 ComputeAndCompareR0<int32_t>(&builder, p.remainder, {});
303 }
304
XLA_TEST_P(DivS32Test,DivideTwoScalarsNonConstS32)305 XLA_TEST_P(DivS32Test, DivideTwoScalarsNonConstS32) {
306 DivS32Params p = GetParam();
307 XlaBuilder builder(TestName());
308 XlaOp dividend;
309 XlaOp divisor;
310 auto dividendd = CreateR0Parameter<int32_t>(p.dividend, 0, "dividend",
311 &builder, ÷nd);
312 auto divisord =
313 CreateR0Parameter<int32_t>(p.divisor, 1, "divisor", &builder, &divisor);
314 Div(dividend, divisor);
315
316 ComputeAndCompareR0<int32_t>(&builder, p.quotient,
317 {dividendd.get(), divisord.get()});
318 }
319
XLA_TEST_P(DivS32Test,RemainderTwoScalarsNonConstDivisorS32)320 XLA_TEST_P(DivS32Test, RemainderTwoScalarsNonConstDivisorS32) {
321 DivS32Params p = GetParam();
322 XlaBuilder builder(TestName());
323 XlaOp dividend;
324 XlaOp divisor;
325 auto dividendd = CreateR0Parameter<int32_t>(p.dividend, 0, "dividend",
326 &builder, ÷nd);
327 auto divisord =
328 CreateR0Parameter<int32_t>(p.divisor, 1, "divisor", &builder, &divisor);
329 Rem(dividend, divisor);
330
331 ComputeAndCompareR0<int32_t>(&builder, p.remainder,
332 {dividendd.get(), divisord.get()});
333 }
334
335 INSTANTIATE_TEST_CASE_P(
336 DivS32Test_Instantiation, DivS32Test,
337 ::testing::Values(
338 // Positive divisors.
339 DivS32Params{5, 2, 2, 1}, //
340 DivS32Params{-5, 2, -2, -1}, //
341 DivS32Params{17, 3, 5, 2}, //
342 DivS32Params{-17, 3, -5, -2}, //
343 // Negative divisors.
344 DivS32Params{5, -2, -2, 1}, //
345 DivS32Params{-5, -2, 2, -1}, //
346 DivS32Params{17, -3, -5, 2}, //
347 DivS32Params{-17, -3, 5, -2}, //
348 // Large positive divisors.
349 DivS32Params{INT32_MIN, 7919, -271181, -1309}, //
350 DivS32Params{INT32_MIN, INT32_MAX, -1, -1}, //
351 DivS32Params{INT32_MIN + 1, INT32_MAX, -1, 0}, //
352 DivS32Params{INT32_MIN + 2, INT32_MAX, 0, INT32_MIN + 2}, //
353 DivS32Params{INT32_MIN, 0x40000000, -2, 0}, //
354 DivS32Params{INT32_MIN + 1, 0x40000000, -1, -0x3fffffff}, //
355 // Large negative divisors.
356 DivS32Params{INT32_MIN, INT32_MIN, 1, 0}, //
357 DivS32Params{INT32_MIN, INT32_MIN + 1, 1, -1}, //
358 DivS32Params{INT32_MIN + 1, INT32_MIN, 0, INT32_MIN + 1}, //
359 DivS32Params{INT32_MAX, INT32_MIN, 0, INT32_MAX}, //
360 DivS32Params{INT32_MAX, INT32_MIN + 1, -1, 0}, //
361 DivS32Params{INT32_MIN, -0x40000000, 2, 0}, //
362 DivS32Params{INT32_MIN + 1, -0x40000000, 1, -0x3fffffff}));
363
XLA_TEST_F(ScalarComputationsTest,DivU32s)364 XLA_TEST_F(ScalarComputationsTest, DivU32s) {
365 // clang-format off
366 // Some interesting values to test.
367 std::vector<uint32_t> vals = {
368 0, 1, 2, 17, 101, 3333, 0x7FFFFFFF, 0x80000000, UINT32_MAX - 1, UINT32_MAX};
369 // clang-format on
370
371 XlaComputation div_computation;
372 {
373 XlaBuilder builder(TestName());
374
375 XlaOp dividend =
376 Parameter(&builder, 0, ShapeUtil::MakeShape(U32, {}), "dividend");
377 XlaOp divisor =
378 Parameter(&builder, 1, ShapeUtil::MakeShape(U32, {}), "divisor");
379 Div(dividend, divisor);
380 TF_ASSERT_OK_AND_ASSIGN(div_computation, builder.Build());
381 }
382
383 for (uint32_t divisor : vals) {
384 if (divisor != 0) {
385 for (uint32_t dividend : vals) {
386 auto dividend_literal = LiteralUtil::CreateR0<uint32_t>(dividend);
387 auto divisor_literal = LiteralUtil::CreateR0<uint32_t>(divisor);
388 TF_ASSERT_OK_AND_ASSIGN(auto dividend_data,
389 client_->TransferToServer(dividend_literal));
390 TF_ASSERT_OK_AND_ASSIGN(auto divisor_data,
391 client_->TransferToServer(divisor_literal));
392 auto actual_literal =
393 client_
394 ->ExecuteAndTransfer(div_computation,
395 {dividend_data.get(), divisor_data.get()},
396 &execution_options_)
397 .value();
398 auto expected_literal =
399 LiteralUtil::CreateR0<uint32_t>(dividend / divisor);
400 EXPECT_TRUE(LiteralTestUtil::Equal(expected_literal, actual_literal));
401 }
402 }
403 }
404 }
405
XLA_TEST_F(ScalarComputationsTest,RemU32s)406 XLA_TEST_F(ScalarComputationsTest, RemU32s) {
407 // clang-format off
408 // Some interesting values to test.
409 std::vector<uint32_t> vals = {
410 0, 1, 2, 17, 101, 3333, 0x7FFFFFFF, 0x80000000, UINT32_MAX - 1, UINT32_MAX};
411 // clang-format on
412
413 XlaComputation rem_computation;
414 {
415 XlaBuilder builder(TestName());
416
417 XlaOp dividend =
418 Parameter(&builder, 0, ShapeUtil::MakeShape(U32, {}), "dividend");
419 XlaOp divisor =
420 Parameter(&builder, 1, ShapeUtil::MakeShape(U32, {}), "divisor");
421 Rem(dividend, divisor);
422 TF_ASSERT_OK_AND_ASSIGN(rem_computation, builder.Build());
423 }
424
425 for (uint32_t divisor : vals) {
426 if (divisor != 0) {
427 for (uint32_t dividend : vals) {
428 auto dividend_literal = LiteralUtil::CreateR0<uint32_t>(dividend);
429 auto divisor_literal = LiteralUtil::CreateR0<uint32_t>(divisor);
430 TF_ASSERT_OK_AND_ASSIGN(auto dividend_data,
431 client_->TransferToServer(dividend_literal));
432 TF_ASSERT_OK_AND_ASSIGN(auto divisor_data,
433 client_->TransferToServer(divisor_literal));
434 auto actual_literal =
435 client_
436 ->ExecuteAndTransfer(rem_computation,
437 {dividend_data.get(), divisor_data.get()},
438 &execution_options_)
439 .value();
440 auto expected_literal =
441 LiteralUtil::CreateR0<uint32_t>(dividend % divisor);
442 EXPECT_TRUE(LiteralTestUtil::Equal(expected_literal, actual_literal));
443 }
444 }
445 }
446 }
447
XLA_TEST_F(ScalarComputationsTest,RemainderTwoScalarsNonConstDividendS32)448 XLA_TEST_F(ScalarComputationsTest, RemainderTwoScalarsNonConstDividendS32) {
449 XlaBuilder builder(TestName());
450 auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(S32, {}), "x");
451 Rem(x, ConstantR0<int32_t>(&builder, 80000));
452
453 Literal literal = LiteralUtil::CreateR0<int32_t>(87919);
454 TF_ASSERT_OK_AND_ASSIGN(auto input_data, client_->TransferToServer(literal));
455 ComputeAndCompareR0<int32_t>(&builder, 7919, {input_data.get()});
456 }
457
XLA_TEST_F(ScalarComputationsTest,DivideTwoScalarsU32)458 XLA_TEST_F(ScalarComputationsTest, DivideTwoScalarsU32) {
459 XlaBuilder builder(TestName());
460 // This verifies 0xFFFFFFFE / 2 = 0x7FFFFFFF. If XLA incorrectly treated U32
461 // as S32, it would output -2 / 2 = -1 (0xFFFFFFFF).
462 Div(ConstantR0<uint32_t>(&builder, 0xFFFFFFFE),
463 ConstantR0<uint32_t>(&builder, 2));
464
465 ComputeAndCompareR0<uint32_t>(&builder, 0x7FFFFFFF, {});
466 }
467
XLA_TEST_F(ScalarComputationsTest,RemTwoScalarsU32)468 XLA_TEST_F(ScalarComputationsTest, RemTwoScalarsU32) {
469 XlaBuilder builder(TestName());
470 Rem(ConstantR0<uint32_t>(&builder, 11), ConstantR0<uint32_t>(&builder, 3));
471
472 ComputeAndCompareR0<uint32_t>(&builder, 2, {});
473 }
474
XLA_TEST_F(ScalarComputationsTest,AndBool)475 XLA_TEST_F(ScalarComputationsTest, AndBool) {
476 for (bool x : {false, true}) {
477 for (bool y : {false, true}) {
478 XlaBuilder builder(TestName());
479 And(ConstantR0<bool>(&builder, x), ConstantR0<bool>(&builder, y));
480
481 ComputeAndCompareR0<bool>(&builder, x && y, {});
482 }
483 }
484 }
485
XLA_TEST_F(ScalarComputationsTest,AndS32)486 XLA_TEST_F(ScalarComputationsTest, AndS32) {
487 for (int32_t x : {0, 8}) {
488 for (int32_t y : {1, -16}) {
489 XlaBuilder builder(TestName());
490 And(ConstantR0<int32_t>(&builder, x), ConstantR0<int32_t>(&builder, y));
491
492 ComputeAndCompareR0<int32_t>(&builder, x & y, {});
493 }
494 }
495 }
496
XLA_TEST_F(ScalarComputationsTest,AndU32)497 XLA_TEST_F(ScalarComputationsTest, AndU32) {
498 for (uint32_t x : {0, 8}) {
499 for (uint32_t y : {1, 16}) {
500 XlaBuilder builder(TestName());
501 And(ConstantR0<uint32_t>(&builder, x), ConstantR0<uint32_t>(&builder, y));
502
503 ComputeAndCompareR0<uint32_t>(&builder, x & y, {});
504 }
505 }
506 }
507
XLA_TEST_F(ScalarComputationsTest,OrBool)508 XLA_TEST_F(ScalarComputationsTest, OrBool) {
509 for (bool x : {false, true}) {
510 for (bool y : {false, true}) {
511 XlaBuilder builder(TestName());
512 Or(ConstantR0<bool>(&builder, x), ConstantR0<bool>(&builder, y));
513
514 ComputeAndCompareR0<bool>(&builder, x || y, {});
515 }
516 }
517 }
518
XLA_TEST_F(ScalarComputationsTest,OrS32)519 XLA_TEST_F(ScalarComputationsTest, OrS32) {
520 for (int32_t x : {0, 8}) {
521 for (int32_t y : {1, -16}) {
522 XlaBuilder builder(TestName());
523 Or(ConstantR0<int32_t>(&builder, x), ConstantR0<int32_t>(&builder, y));
524
525 ComputeAndCompareR0<int32_t>(&builder, x | y, {});
526 }
527 }
528 }
529
XLA_TEST_F(ScalarComputationsTest,OrU32)530 XLA_TEST_F(ScalarComputationsTest, OrU32) {
531 for (uint32_t x : {0, 8}) {
532 for (uint32_t y : {1, 16}) {
533 XlaBuilder builder(TestName());
534 Or(ConstantR0<uint32_t>(&builder, x), ConstantR0<uint32_t>(&builder, y));
535
536 ComputeAndCompareR0<uint32_t>(&builder, x | y, {});
537 }
538 }
539 }
540
XLA_TEST_F(ScalarComputationsTest,NotBool)541 XLA_TEST_F(ScalarComputationsTest, NotBool) {
542 for (bool x : {false, true}) {
543 XlaBuilder builder(TestName());
544 Not(ConstantR0<bool>(&builder, x));
545
546 ComputeAndCompareR0<bool>(&builder, !x, {});
547 }
548 }
549
XLA_TEST_F(ScalarComputationsTest,NotS32)550 XLA_TEST_F(ScalarComputationsTest, NotS32) {
551 for (int32_t x : {-1, 0, 1}) {
552 XlaBuilder builder(TestName());
553 Not(ConstantR0<int32_t>(&builder, x));
554
555 ComputeAndCompareR0<int32_t>(&builder, ~x, {});
556 }
557 }
558
XLA_TEST_F(ScalarComputationsTest,NotU32)559 XLA_TEST_F(ScalarComputationsTest, NotU32) {
560 for (uint32_t x : {0, 1, 2}) {
561 XlaBuilder builder(TestName());
562 Not(ConstantR0<uint32_t>(&builder, x));
563
564 ComputeAndCompareR0<uint32_t>(&builder, ~x, {});
565 }
566 }
567
XLA_TEST_F(ScalarComputationsTest,SelectScalarTrue)568 XLA_TEST_F(ScalarComputationsTest, SelectScalarTrue) {
569 XlaBuilder builder(TestName());
570 Select(ConstantR0<bool>(&builder, true), // The predicate.
571 ConstantR0<float>(&builder, 123.0f), // The value on true.
572 ConstantR0<float>(&builder, 42.0f)); // The value on false.
573
574 ComputeAndCompareR0<float>(&builder, 123.0f, {}, error_spec_);
575 }
576
XLA_TEST_F(ScalarComputationsTest,SelectScalarFalse)577 XLA_TEST_F(ScalarComputationsTest, SelectScalarFalse) {
578 XlaBuilder builder(TestName());
579 Select(ConstantR0<bool>(&builder, false), // The predicate.
580 ConstantR0<float>(&builder, 123.0f), // The value on true.
581 ConstantR0<float>(&builder, 42.0f)); // The value on false.
582
583 ComputeAndCompareR0<float>(&builder, 42.0f, {}, error_spec_);
584 }
585
586 // This test is an explicit version of what is happening in the following
587 // templatized comparison tests.
XLA_TEST_F(ScalarComputationsTest,CompareGtScalar)588 XLA_TEST_F(ScalarComputationsTest, CompareGtScalar) {
589 XlaBuilder builder(TestName());
590 Gt(ConstantR0<float>(&builder, 2.0f), ConstantR0<float>(&builder, 1.0f));
591
592 ComputeAndCompareR0<bool>(&builder, true, {});
593 }
594
595 // S32 comparisons.
XLA_TEST_F(ScalarComputationsTest,CompareEqS32Greater)596 XLA_TEST_F(ScalarComputationsTest, CompareEqS32Greater) {
597 TestCompare<int32_t>(2, 1, false, &Eq);
598 }
XLA_TEST_F(ScalarComputationsTest,CompareEqS32Equal)599 XLA_TEST_F(ScalarComputationsTest, CompareEqS32Equal) {
600 TestCompare<int32_t>(3, 3, true, &Eq);
601 }
602
XLA_TEST_F(ScalarComputationsTest,CompareNeS32)603 XLA_TEST_F(ScalarComputationsTest, CompareNeS32) {
604 TestCompare<int32_t>(2, 1, true, &Ne);
605 }
606
XLA_TEST_F(ScalarComputationsTest,CompareGeS32)607 XLA_TEST_F(ScalarComputationsTest, CompareGeS32) {
608 TestCompare<int32_t>(2, 1, true, &Ge);
609 }
610
XLA_TEST_F(ScalarComputationsTest,CompareGtS32)611 XLA_TEST_F(ScalarComputationsTest, CompareGtS32) {
612 TestCompare<int32_t>(1, 5, false, &Gt);
613 }
614
XLA_TEST_F(ScalarComputationsTest,CompareLeS32)615 XLA_TEST_F(ScalarComputationsTest, CompareLeS32) {
616 TestCompare<int32_t>(2, 1, false, &Le);
617 }
618
XLA_TEST_F(ScalarComputationsTest,CompareLtS32)619 XLA_TEST_F(ScalarComputationsTest, CompareLtS32) {
620 TestCompare<int32_t>(9, 7, false, &Lt);
621 TestCompare<int32_t>(std::numeric_limits<int32_t>::min(),
622 std::numeric_limits<int32_t>::max(), true, &Lt);
623 }
624
625 // U32 comparisons.
XLA_TEST_F(ScalarComputationsTest,CompareEqU32False)626 XLA_TEST_F(ScalarComputationsTest, CompareEqU32False) {
627 TestCompare<uint32_t>(2, 1, false, &Eq);
628 }
629
XLA_TEST_F(ScalarComputationsTest,CompareNeU32)630 XLA_TEST_F(ScalarComputationsTest, CompareNeU32) {
631 TestCompare<uint32_t>(2, 1, true, &Ne);
632 }
633
XLA_TEST_F(ScalarComputationsTest,CompareGeU32Greater)634 XLA_TEST_F(ScalarComputationsTest, CompareGeU32Greater) {
635 TestCompare<uint32_t>(2, 1, true, &Ge);
636 }
637
XLA_TEST_F(ScalarComputationsTest,CompareGeU32Equal)638 XLA_TEST_F(ScalarComputationsTest, CompareGeU32Equal) {
639 TestCompare<uint32_t>(3, 3, true, &Ge);
640 }
641
XLA_TEST_F(ScalarComputationsTest,CompareGtU32)642 XLA_TEST_F(ScalarComputationsTest, CompareGtU32) {
643 TestCompare<uint32_t>(1, 5, false, &Gt);
644 TestCompare<uint32_t>(5, 5, false, &Gt);
645 TestCompare<uint32_t>(5, 1, true, &Gt);
646 }
647
XLA_TEST_F(ScalarComputationsTest,CompareLeU32)648 XLA_TEST_F(ScalarComputationsTest, CompareLeU32) {
649 TestCompare<uint32_t>(2, 1, false, &Le);
650 }
651
XLA_TEST_F(ScalarComputationsTest,CompareLtU32)652 XLA_TEST_F(ScalarComputationsTest, CompareLtU32) {
653 TestCompare<uint32_t>(9, 7, false, &Lt);
654 TestCompare<uint32_t>(0, std::numeric_limits<uint32_t>::max(), true, &Lt);
655 }
656
657 // F32 comparisons.
XLA_TEST_F(ScalarComputationsTest,CompareEqF32False)658 XLA_TEST_F(ScalarComputationsTest, CompareEqF32False) {
659 TestCompare<float>(2.0, 1.3, false, &Eq);
660 }
661
XLA_TEST_F(ScalarComputationsTest,CompareNeF32)662 XLA_TEST_F(ScalarComputationsTest, CompareNeF32) {
663 TestCompare<float>(2.0, 1.3, true, &Ne);
664 }
665
XLA_TEST_F(ScalarComputationsTest,CompareGeF32Greater)666 XLA_TEST_F(ScalarComputationsTest, CompareGeF32Greater) {
667 TestCompare<float>(2.0, 1.9, true, &Ge);
668 }
XLA_TEST_F(ScalarComputationsTest,CompareGeF32Equal)669 XLA_TEST_F(ScalarComputationsTest, CompareGeF32Equal) {
670 TestCompare<float>(3.5, 3.5, true, &Ge);
671 }
672
XLA_TEST_F(ScalarComputationsTest,CompareGtF32)673 XLA_TEST_F(ScalarComputationsTest, CompareGtF32) {
674 TestCompare<float>(1.0, 5.2, false, &Gt);
675 }
676
XLA_TEST_F(ScalarComputationsTest,CompareLeF32)677 XLA_TEST_F(ScalarComputationsTest, CompareLeF32) {
678 TestCompare<float>(2.0, 1.2, false, &Le);
679 }
680
XLA_TEST_F(ScalarComputationsTest,CompareLtF32)681 XLA_TEST_F(ScalarComputationsTest, CompareLtF32) {
682 TestCompare<float>(9.0, 7.2, false, &Lt);
683 }
684
685 // F32 comparisons with exceptional values. The test names encode the
686 // left/right operands at the end, and use Minf and Mzero for -inf and -0.0.
XLA_TEST_F(ScalarComputationsTest,CompareLtF32MinfMzero)687 XLA_TEST_F(ScalarComputationsTest, CompareLtF32MinfMzero) {
688 TestCompare<float>(-INFINITY, -0.0, true, &Lt);
689 }
XLA_TEST_F(ScalarComputationsTest,CompareLtF32MzeroZero)690 XLA_TEST_F(ScalarComputationsTest, CompareLtF32MzeroZero) {
691 // Comparisons of 0.0 to -0.0 consider them equal in IEEE 754.
692 TestCompare<float>(-0.0, 0.0, false, &Lt);
693 }
XLA_TEST_F(ScalarComputationsTest,CompareLtF32ZeroInf)694 XLA_TEST_F(ScalarComputationsTest, CompareLtF32ZeroInf) {
695 TestCompare<float>(0.0, INFINITY, true, &Lt);
696 }
697
XLA_TEST_F(ScalarComputationsTest,CompareGeF32MinfMzero)698 XLA_TEST_F(ScalarComputationsTest, CompareGeF32MinfMzero) {
699 TestCompare<float>(-INFINITY, -0.0, false, &Ge);
700 }
XLA_TEST_F(ScalarComputationsTest,CompareGeF32MzeroZero)701 XLA_TEST_F(ScalarComputationsTest, CompareGeF32MzeroZero) {
702 // Comparisons of 0.0 to -0.0 consider them equal in IEEE 754.
703 TestCompare<float>(-0.0, 0.0, true, &Ge);
704 }
XLA_TEST_F(ScalarComputationsTest,CompareGeF32ZeroInf)705 XLA_TEST_F(ScalarComputationsTest, CompareGeF32ZeroInf) {
706 TestCompare<float>(0.0, INFINITY, false, &Ge);
707 }
708
XLA_TEST_F(ScalarComputationsTest,ExpScalar)709 XLA_TEST_F(ScalarComputationsTest, ExpScalar) {
710 XlaBuilder builder(TestName());
711 Exp(ConstantR0<float>(&builder, 2.0f));
712
713 ComputeAndCompareR0<float>(&builder, 7.3890562, {}, error_spec_);
714 }
715
XLA_TEST_F(ScalarComputationsTest,LogScalar)716 XLA_TEST_F(ScalarComputationsTest, LogScalar) {
717 XlaBuilder builder("log");
718 Log(ConstantR0<float>(&builder, 2.0f));
719
720 ComputeAndCompareR0<float>(&builder, 0.6931471, {}, error_spec_);
721 }
722
XLA_TEST_F(ScalarComputationsTest,TanhScalar)723 XLA_TEST_F(ScalarComputationsTest, TanhScalar) {
724 XlaBuilder builder(TestName());
725 Tanh(ConstantR0<float>(&builder, 2.0f));
726
727 ComputeAndCompareR0<float>(&builder, 0.96402758, {}, error_spec_);
728 }
729
XLA_TEST_F(ScalarComputationsTest,TanhDoubleScalar)730 XLA_TEST_F(ScalarComputationsTest, TanhDoubleScalar) {
731 XlaBuilder builder(TestName());
732 Tanh(ConstantR0<double>(&builder, 2.0));
733
734 ComputeAndCompareR0<double>(&builder, 0.96402758, {}, error_spec_);
735 }
736
XLA_TEST_F(ScalarComputationsTest,PowScalar)737 XLA_TEST_F(ScalarComputationsTest, PowScalar) {
738 XlaBuilder builder(TestName());
739 Pow(ConstantR0<float>(&builder, 2.0f), ConstantR0<float>(&builder, 3.0f));
740
741 ComputeAndCompareR0<float>(&builder, 8.0, {}, error_spec_);
742 }
743
XLA_TEST_F(ScalarComputationsTest,CbrtScalar)744 XLA_TEST_F(ScalarComputationsTest, CbrtScalar) {
745 XlaBuilder builder(TestName());
746 Cbrt(ConstantR0<float>(&builder, 2.0f));
747
748 ComputeAndCompare(&builder, {}, error_spec_);
749 }
750
XLA_TEST_F(ScalarComputationsTest,ClampScalarHighS32)751 XLA_TEST_F(ScalarComputationsTest, ClampScalarHighS32) {
752 XlaBuilder builder(TestName());
753 Clamp(ConstantR0<int32_t>(&builder, -1), // The lower bound.
754 ConstantR0<int32_t>(&builder, 5), // The operand to be clamped.
755 ConstantR0<int32_t>(&builder, 3)); // The upper bound.
756
757 ComputeAndCompareR0<int32_t>(&builder, 3, {});
758 }
759
XLA_TEST_F(ScalarComputationsTest,ClampScalarMiddleS32)760 XLA_TEST_F(ScalarComputationsTest, ClampScalarMiddleS32) {
761 XlaBuilder builder(TestName());
762 Clamp(ConstantR0<int32_t>(&builder, -1), // The lower bound.
763 ConstantR0<int32_t>(&builder, 2), // The operand to be clamped.
764 ConstantR0<int32_t>(&builder, 3)); // The upper bound.
765
766 ComputeAndCompareR0<int32_t>(&builder, 2, {});
767 }
768
XLA_TEST_F(ScalarComputationsTest,ClampScalarLowS32)769 XLA_TEST_F(ScalarComputationsTest, ClampScalarLowS32) {
770 XlaBuilder builder(TestName());
771 Clamp(ConstantR0<int32_t>(&builder, -1), // The lower bound.
772 ConstantR0<int32_t>(&builder, -5), // The operand to be clamped.
773 ConstantR0<int32_t>(&builder, 3)); // The upper bound.
774
775 ComputeAndCompareR0<int32_t>(&builder, -1, {});
776 }
777
XLA_TEST_F(ScalarComputationsTest,ClampScalarHighU32)778 XLA_TEST_F(ScalarComputationsTest, ClampScalarHighU32) {
779 XlaBuilder builder(TestName());
780 Clamp(ConstantR0<uint32_t>(&builder, 1), // The lower bound.
781 ConstantR0<uint32_t>(&builder, 5), // The operand to be clamped.
782 ConstantR0<uint32_t>(&builder, 3)); // The upper bound.
783
784 ComputeAndCompareR0<uint32_t>(&builder, 3, {});
785 }
786
XLA_TEST_F(ScalarComputationsTest,ClampScalarMiddleU32)787 XLA_TEST_F(ScalarComputationsTest, ClampScalarMiddleU32) {
788 XlaBuilder builder(TestName());
789 Clamp(ConstantR0<uint32_t>(&builder, 1), // The lower bound.
790 ConstantR0<uint32_t>(&builder, 2), // The operand to be clamped.
791 ConstantR0<uint32_t>(&builder, 3)); // The upper bound.
792
793 ComputeAndCompareR0<uint32_t>(&builder, 2, {});
794 }
795
XLA_TEST_F(ScalarComputationsTest,ClampScalarLowU32)796 XLA_TEST_F(ScalarComputationsTest, ClampScalarLowU32) {
797 XlaBuilder builder(TestName());
798 Clamp(ConstantR0<uint32_t>(&builder, 1), // The lower bound.
799 ConstantR0<uint32_t>(&builder, 0), // The operand to be clamped.
800 ConstantR0<uint32_t>(&builder, 3)); // The upper bound.
801
802 ComputeAndCompareR0<uint32_t>(&builder, 1, {});
803 }
804
XLA_TEST_F(ScalarComputationsTest,ClampScalarHighF32)805 XLA_TEST_F(ScalarComputationsTest, ClampScalarHighF32) {
806 XlaBuilder builder(TestName());
807 Clamp(ConstantR0<float>(&builder, 2.0f), // The lower bound.
808 ConstantR0<float>(&builder, 5.0f), // The operand to be clamped.
809 ConstantR0<float>(&builder, 3.0f)); // The upper bound.
810
811 ComputeAndCompareR0<float>(&builder, 3.0, {}, error_spec_);
812 }
813
XLA_TEST_F(ScalarComputationsTest,ClampScalarMiddleF32)814 XLA_TEST_F(ScalarComputationsTest, ClampScalarMiddleF32) {
815 XlaBuilder builder(TestName());
816 Clamp(ConstantR0<float>(&builder, 2.0f), // The lower bound.
817 ConstantR0<float>(&builder, 2.5f), // The operand to be clamped.
818 ConstantR0<float>(&builder, 3.0f)); // The upper bound.
819
820 ComputeAndCompareR0<float>(&builder, 2.5, {}, error_spec_);
821 }
822
XLA_TEST_F(ScalarComputationsTest,ClampScalarLowF32)823 XLA_TEST_F(ScalarComputationsTest, ClampScalarLowF32) {
824 XlaBuilder builder(TestName());
825 Clamp(ConstantR0<float>(&builder, 2.0f), // The lower bound.
826 ConstantR0<float>(&builder, -5.0f), // The operand to be clamped.
827 ConstantR0<float>(&builder, 3.0f)); // The upper bound.
828
829 ComputeAndCompareR0<float>(&builder, 2.0, {}, error_spec_);
830 }
831
XLA_TEST_F(ScalarComputationsTest,MinS32Above)832 XLA_TEST_F(ScalarComputationsTest, MinS32Above) {
833 TestMinMax<int32_t>(10, 3, 3, &Min);
834 }
835
XLA_TEST_F(ScalarComputationsTest,MinS32Below)836 XLA_TEST_F(ScalarComputationsTest, MinS32Below) {
837 TestMinMax<int32_t>(-100, 3, -100, &Min);
838 }
839
XLA_TEST_F(ScalarComputationsTest,MaxS32Above)840 XLA_TEST_F(ScalarComputationsTest, MaxS32Above) {
841 TestMinMax<int32_t>(10, 3, 10, &Max);
842 }
843
XLA_TEST_F(ScalarComputationsTest,MaxS32Below)844 XLA_TEST_F(ScalarComputationsTest, MaxS32Below) {
845 TestMinMax<int32_t>(-100, 3, 3, &Max);
846 }
847
XLA_TEST_F(ScalarComputationsTest,MinU32Above)848 XLA_TEST_F(ScalarComputationsTest, MinU32Above) {
849 const uint32_t large = std::numeric_limits<int32_t>::max();
850 TestMinMax<uint32_t>(large, 3, 3, &Min);
851 }
852
XLA_TEST_F(ScalarComputationsTest,MinU32Below)853 XLA_TEST_F(ScalarComputationsTest, MinU32Below) {
854 TestMinMax<uint32_t>(0, 5, 0, &Min);
855 }
856
XLA_TEST_F(ScalarComputationsTest,MaxU32Above)857 XLA_TEST_F(ScalarComputationsTest, MaxU32Above) {
858 const uint32_t large = std::numeric_limits<int32_t>::max();
859 TestMinMax<uint32_t>(large, 3, large, &Max);
860 }
861
XLA_TEST_F(ScalarComputationsTest,MaxU32Below)862 XLA_TEST_F(ScalarComputationsTest, MaxU32Below) {
863 TestMinMax<uint32_t>(0, 5, 5, &Max);
864 }
865
XLA_TEST_F(ScalarComputationsTest,MinF32Above)866 XLA_TEST_F(ScalarComputationsTest, MinF32Above) {
867 TestMinMax<float>(10.1f, 3.1f, 3.1f, &Min);
868 }
869
XLA_TEST_F(ScalarComputationsTest,MinF32Below)870 XLA_TEST_F(ScalarComputationsTest, MinF32Below) {
871 TestMinMax<float>(-100.1f, 3.1f, -100.1f, &Min);
872 }
873
XLA_TEST_F(ScalarComputationsTest,MinPropagatesNan)874 XLA_TEST_F(ScalarComputationsTest, MinPropagatesNan) {
875 SetFastMathDisabled(true);
876 TestMinMax<float>(NAN, 3.1f, NAN, &Min);
877 TestMinMax<float>(-3.1f, NAN, NAN, &Min);
878 }
879
XLA_TEST_F(ScalarComputationsTest,MaxF32Above)880 XLA_TEST_F(ScalarComputationsTest, MaxF32Above) {
881 TestMinMax<float>(10.1f, 3.1f, 10.1f, &Max);
882 }
883
XLA_TEST_F(ScalarComputationsTest,MaxF32Below)884 XLA_TEST_F(ScalarComputationsTest, MaxF32Below) {
885 TestMinMax<float>(-100.1f, 3.1f, 3.1f, &Max);
886 }
887
XLA_TEST_F(ScalarComputationsTest,MaxPropagatesNan)888 XLA_TEST_F(ScalarComputationsTest, MaxPropagatesNan) {
889 SetFastMathDisabled(true);
890 TestMinMax<float>(NAN, 3.1f, NAN, &Max);
891 TestMinMax<float>(-3.1f, NAN, NAN, &Max);
892 }
893
XLA_TEST_F(ScalarComputationsTest,ComplicatedArithmeticExpressionF32)894 XLA_TEST_F(ScalarComputationsTest, ComplicatedArithmeticExpressionF32) {
895 // Compute the expression (1 * (3 - 1) * (7 + 0) - 4) / 20.
896 XlaBuilder b(TestName());
897 Div(Sub(Mul(ConstantR0<float>(&b, 1),
898 Mul(Sub(ConstantR0<float>(&b, 3), ConstantR0<float>(&b, 1)),
899 Add(ConstantR0<float>(&b, 7), ConstantR0<float>(&b, 0)))),
900 ConstantR0<float>(&b, 4)),
901 ConstantR0<float>(&b, 20));
902
903 ComputeAndCompareR0<float>(&b, 0.5, {}, error_spec_);
904 }
905
XLA_TEST_F(ScalarComputationsTest,ComplicatedArithmeticExpressionS32)906 XLA_TEST_F(ScalarComputationsTest, ComplicatedArithmeticExpressionS32) {
907 // Compute the expression 1 * (3 - 1) * (7 + 0) - 4.
908 XlaBuilder b(TestName());
909 Sub(Mul(ConstantR0<int32_t>(&b, 1),
910 Mul(Sub(ConstantR0<int32_t>(&b, 3), ConstantR0<int32_t>(&b, 1)),
911 Add(ConstantR0<int32_t>(&b, 7), ConstantR0<int32_t>(&b, 0)))),
912 ConstantR0<int32_t>(&b, 4));
913
914 ComputeAndCompareR0<int32_t>(&b, 10, {});
915 }
916
XLA_TEST_F(ScalarComputationsTest,RoundScalar)917 XLA_TEST_F(ScalarComputationsTest, RoundScalar) {
918 XlaBuilder builder(TestName());
919 Round(ConstantR0<float>(&builder, 1.4f));
920
921 ComputeAndCompareR0<float>(&builder, 1.0f, {}, error_spec_);
922 }
923
924 } // namespace
925 } // namespace xla
926