1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <cmath>
17 #include <limits>
18 #include <memory>
19
20 #include "absl/strings/str_cat.h"
21 #include "absl/types/span.h"
22 #include "tensorflow/compiler/xla/client/global_data.h"
23 #include "tensorflow/compiler/xla/client/local_client.h"
24 #include "tensorflow/compiler/xla/client/xla_builder.h"
25 #include "tensorflow/compiler/xla/client/xla_computation.h"
26 #include "tensorflow/compiler/xla/literal.h"
27 #include "tensorflow/compiler/xla/literal_util.h"
28 #include "tensorflow/compiler/xla/status_macros.h"
29 #include "tensorflow/compiler/xla/statusor.h"
30 #include "tensorflow/compiler/xla/test_helpers.h"
31 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
32 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
33 #include "tensorflow/compiler/xla/tests/test_macros.h"
34 #include "tensorflow/compiler/xla/xla_data.pb.h"
35 #include "tensorflow/core/platform/test.h"
36 #include "tensorflow/core/platform/types.h"
37
38 namespace xla {
39 namespace {
40
41 class ScalarComputationsTest : public ClientLibraryTestBase {
42 public:
43 ErrorSpec error_spec_{0.0001};
44
45 protected:
46 // A template for building and running a binary comparison test.
47 template <typename NativeT>
TestCompare(NativeT lhs,NativeT rhs,bool expected,const std::function<XlaOp (const XlaOp &,const XlaOp &,absl::Span<const int64>)> & op)48 void TestCompare(NativeT lhs, NativeT rhs, bool expected,
49 const std::function<XlaOp(const XlaOp&, const XlaOp&,
50 absl::Span<const int64>)>& op) {
51 XlaBuilder builder(TestName());
52 XlaOp lhs_op = ConstantR0<NativeT>(&builder, lhs);
53 XlaOp rhs_op = ConstantR0<NativeT>(&builder, rhs);
54 op(lhs_op, rhs_op, {});
55 ComputeAndCompareR0<bool>(&builder, expected, {});
56 }
57
58 template <typename NativeT>
TestMinMax(NativeT lhs,NativeT rhs,NativeT expected,const std::function<XlaOp (const XlaOp &,const XlaOp &,absl::Span<const int64>)> & op)59 void TestMinMax(NativeT lhs, NativeT rhs, NativeT expected,
60 const std::function<XlaOp(const XlaOp&, const XlaOp&,
61 absl::Span<const int64>)>& op) {
62 XlaBuilder builder(TestName());
63 XlaOp lhs_op = ConstantR0<NativeT>(&builder, lhs);
64 XlaOp rhs_op = ConstantR0<NativeT>(&builder, rhs);
65 op(lhs_op, rhs_op, {});
66 ComputeAndCompareR0<NativeT>(&builder, expected, {});
67 }
68 };
69
XLA_TEST_F(ScalarComputationsTest,ReturnScalarF32)70 XLA_TEST_F(ScalarComputationsTest, ReturnScalarF32) {
71 XlaBuilder builder(TestName());
72 ConstantR0<float>(&builder, 2.1f);
73
74 ComputeAndCompareR0<float>(&builder, 2.1f, {}, error_spec_);
75 }
76
XLA_TEST_F(ScalarComputationsTest,NegateScalarF32)77 XLA_TEST_F(ScalarComputationsTest, NegateScalarF32) {
78 XlaBuilder builder(TestName());
79 Neg(ConstantR0<float>(&builder, 2.1f));
80
81 ComputeAndCompareR0<float>(&builder, -2.1f, {}, error_spec_);
82 }
83
XLA_TEST_F(ScalarComputationsTest,NegateScalarS32)84 XLA_TEST_F(ScalarComputationsTest, NegateScalarS32) {
85 XlaBuilder builder(TestName());
86 Neg(ConstantR0<int32>(&builder, 2));
87
88 ComputeAndCompareR0<int32>(&builder, -2, {});
89 }
90
XLA_TEST_F(ScalarComputationsTest,AddTwoScalarsF32)91 XLA_TEST_F(ScalarComputationsTest, AddTwoScalarsF32) {
92 XlaBuilder builder(TestName());
93 Add(ConstantR0<float>(&builder, 2.1f), ConstantR0<float>(&builder, 5.5f));
94
95 ComputeAndCompareR0<float>(&builder, 7.6f, {}, error_spec_);
96 }
97
XLA_TEST_F(ScalarComputationsTest,AddTwoScalarsS32)98 XLA_TEST_F(ScalarComputationsTest, AddTwoScalarsS32) {
99 XlaBuilder builder(TestName());
100 Add(ConstantR0<int32>(&builder, 2), ConstantR0<int32>(&builder, 5));
101
102 ComputeAndCompareR0<int32>(&builder, 7, {});
103 }
104
XLA_TEST_F(ScalarComputationsTest,AddTwoScalarsU32)105 XLA_TEST_F(ScalarComputationsTest, AddTwoScalarsU32) {
106 XlaBuilder builder(TestName());
107 Add(ConstantR0<uint32>(&builder, 35), ConstantR0<uint32>(&builder, 57));
108
109 ComputeAndCompareR0<uint32>(&builder, 92, {});
110 }
111
XLA_TEST_F(ScalarComputationsTest,AddTwoScalarsU8)112 XLA_TEST_F(ScalarComputationsTest, AddTwoScalarsU8) {
113 XlaBuilder builder(TestName());
114 Add(ConstantR0<uint8>(&builder, 35), ConstantR0<uint8>(&builder, 57));
115
116 ComputeAndCompareR0<uint8>(&builder, 92, {});
117 }
118
XLA_TEST_F(ScalarComputationsTest,AddTwoScalarsU64)119 XLA_TEST_F(ScalarComputationsTest, AddTwoScalarsU64) {
120 XlaBuilder builder(TestName());
121 const uint64 a = static_cast<uint64>(1) << 63;
122 const uint64 b = a + 1;
123 Add(ConstantR0<uint64>(&builder, a), ConstantR0<uint64>(&builder, b));
124
125 ComputeAndCompareR0<uint64>(&builder, a + b, {});
126 }
127
XLA_TEST_F(ScalarComputationsTest,AddTwoScalarsS64)128 XLA_TEST_F(ScalarComputationsTest, AddTwoScalarsS64) {
129 XlaBuilder builder(TestName());
130 const int64_t a = static_cast<int64>(1) << 62;
131 const int64_t b = a - 1;
132 Add(ConstantR0<int64>(&builder, a), ConstantR0<int64>(&builder, b));
133
134 ComputeAndCompareR0<int64>(&builder, a + b, {});
135 }
136
XLA_TEST_F(ScalarComputationsTest,AddTwoScalarsF64)137 XLA_TEST_F(ScalarComputationsTest, AddTwoScalarsF64) {
138 XlaBuilder builder(TestName());
139 Add(ConstantR0<double>(&builder, 0.25), ConstantR0<double>(&builder, 3.5));
140
141 ComputeAndCompareR0<double>(&builder, 3.75, {});
142 }
143
XLA_TEST_F(ScalarComputationsTest,SubtractTwoScalarsF32)144 XLA_TEST_F(ScalarComputationsTest, SubtractTwoScalarsF32) {
145 XlaBuilder builder(TestName());
146 Sub(ConstantR0<float>(&builder, 2.1f), ConstantR0<float>(&builder, 5.5f));
147
148 ComputeAndCompareR0<float>(&builder, -3.4f, {}, error_spec_);
149 }
150
XLA_TEST_F(ScalarComputationsTest,SubtractTwoScalarsS32)151 XLA_TEST_F(ScalarComputationsTest, SubtractTwoScalarsS32) {
152 XlaBuilder builder(TestName());
153 Sub(ConstantR0<int32>(&builder, 2), ConstantR0<int32>(&builder, 5));
154
155 ComputeAndCompareR0<int32>(&builder, -3, {});
156 }
157
XLA_TEST_F(ScalarComputationsTest,CastS64ToF32)158 XLA_TEST_F(ScalarComputationsTest, CastS64ToF32) {
159 XlaBuilder builder(TestName());
160 auto a = Parameter(&builder, 0, ShapeUtil::MakeShape(S64, {}), "a");
161 ConvertElementType(a, F32);
162
163 int64_t value = 3LL << 35;
164 Literal a_literal = LiteralUtil::CreateR0<int64>(value);
165 std::unique_ptr<GlobalData> a_data =
166 client_->TransferToServer(a_literal).ConsumeValueOrDie();
167 ComputeAndCompareR0<float>(&builder, static_cast<float>(value),
168 {a_data.get()});
169 }
170
XLA_TEST_F(ScalarComputationsTest,MulThreeScalarsF32)171 XLA_TEST_F(ScalarComputationsTest, MulThreeScalarsF32) {
172 XlaBuilder builder(TestName());
173 Mul(Mul(ConstantR0<float>(&builder, 2.1f), ConstantR0<float>(&builder, 5.5f)),
174 ConstantR0<float>(&builder, 0.5f));
175
176 ComputeAndCompareR0<float>(&builder, 5.775f, {}, error_spec_);
177 }
178
XLA_TEST_F(ScalarComputationsTest,MulThreeScalarsF64)179 XLA_TEST_F(ScalarComputationsTest, MulThreeScalarsF64) {
180 XlaBuilder builder(TestName());
181 Mul(Mul(ConstantR0<double>(&builder, 3.1415926535897932),
182 ConstantR0<double>(&builder, 2.7182818284590452)),
183 ConstantR0<double>(&builder, 0.5772156649015328));
184
185 ComputeAndCompareR0<double>(&builder, 4.929268367422896, {},
186 ErrorSpec{3.6e-15});
187 }
188
XLA_TEST_F(ScalarComputationsTest,MulTwoScalarsS32)189 XLA_TEST_F(ScalarComputationsTest, MulTwoScalarsS32) {
190 std::vector<int32> data = {0,
191 1,
192 -1,
193 1234,
194 0x1a243514,
195 std::numeric_limits<int32>::max(),
196 std::numeric_limits<int32>::min()};
197
198 for (int32_t x : data) {
199 for (int32_t y : data) {
200 XlaBuilder builder(TestName());
201 Mul(ConstantR0<int32>(&builder, x), ConstantR0<int32>(&builder, y));
202
203 // Signed integer overflow is undefined behavior in C++. Convert the input
204 // integers to unsigned, perform the multiplication unsigned, and convert
205 // back.
206 int32_t expected = static_cast<uint32>(x) * static_cast<uint32>(y);
207
208 ComputeAndCompareR0<int32>(&builder, expected, {});
209 }
210 }
211 }
212
XLA_TEST_F(ScalarComputationsTest,MulTwoScalarsU32)213 XLA_TEST_F(ScalarComputationsTest, MulTwoScalarsU32) {
214 std::vector<uint32> data = {0, 1, 0xDEADBEEF, 1234,
215 0x1a243514, 0xFFFFFFFF, 0x80808080};
216
217 for (uint32 x : data) {
218 for (uint32 y : data) {
219 XlaBuilder builder(TestName());
220 Mul(ConstantR0<uint32>(&builder, x), ConstantR0<uint32>(&builder, y));
221
222 uint32 expected = x * y;
223 ComputeAndCompareR0<uint32>(&builder, expected, {});
224 }
225 }
226 }
227
XLA_TEST_F(ScalarComputationsTest,MulThreeScalarsS32)228 XLA_TEST_F(ScalarComputationsTest, MulThreeScalarsS32) {
229 XlaBuilder builder(TestName());
230 Mul(Mul(ConstantR0<int32>(&builder, 2), ConstantR0<int32>(&builder, 5)),
231 ConstantR0<int32>(&builder, 1));
232
233 ComputeAndCompareR0<int32>(&builder, 10, {});
234 }
235
XLA_TEST_F(ScalarComputationsTest,MulThreeScalarsF32Params)236 XLA_TEST_F(ScalarComputationsTest, MulThreeScalarsF32Params) {
237 XlaBuilder builder(TestName());
238 Literal a_literal = LiteralUtil::CreateR0<float>(2.1f);
239 Literal b_literal = LiteralUtil::CreateR0<float>(5.5f);
240 Literal c_literal = LiteralUtil::CreateR0<float>(0.5f);
241
242 std::unique_ptr<GlobalData> a_data =
243 client_->TransferToServer(a_literal).ConsumeValueOrDie();
244 std::unique_ptr<GlobalData> b_data =
245 client_->TransferToServer(b_literal).ConsumeValueOrDie();
246 std::unique_ptr<GlobalData> c_data =
247 client_->TransferToServer(c_literal).ConsumeValueOrDie();
248
249 XlaOp a = Parameter(&builder, 0, a_literal.shape(), "a");
250 XlaOp b = Parameter(&builder, 1, b_literal.shape(), "b");
251 XlaOp c = Parameter(&builder, 2, c_literal.shape(), "c");
252 Mul(Mul(a, b), c);
253
254 ComputeAndCompareR0<float>(&builder, 5.775f,
255 {a_data.get(), b_data.get(), c_data.get()},
256 error_spec_);
257 }
258
XLA_TEST_F(ScalarComputationsTest,DivideTwoScalarsF32)259 XLA_TEST_F(ScalarComputationsTest, DivideTwoScalarsF32) {
260 XlaBuilder builder(TestName());
261 Div(ConstantR0<float>(&builder, 5.0f), ConstantR0<float>(&builder, 2.5f));
262
263 ComputeAndCompareR0<float>(&builder, 2.0f, {}, error_spec_);
264 }
265
XLA_TEST_F(ScalarComputationsTest,RemTwoScalarsF32)266 XLA_TEST_F(ScalarComputationsTest, RemTwoScalarsF32) {
267 XlaBuilder builder(TestName());
268 Rem(ConstantR0<float>(&builder, 2.5f), ConstantR0<float>(&builder, 5.0f));
269
270 ComputeAndCompareR0<float>(&builder, 2.5f, {}, error_spec_);
271 }
272
273 struct DivS32Params {
274 int32 dividend;
275 int32 divisor;
276 int32 quotient;
277 int32 remainder;
278 };
279
PrintTo(const DivS32Params & p,std::ostream * os)280 void PrintTo(const DivS32Params& p, std::ostream* os) {
281 *os << "{" << p.dividend << ", " << p.divisor << ", " << p.quotient << ", "
282 << p.remainder << "}";
283 }
284
285 class DivS32Test : public ClientLibraryTestBase,
286 public ::testing::WithParamInterface<DivS32Params> {};
287
XLA_TEST_P(DivS32Test,DivideTwoScalarsS32)288 XLA_TEST_P(DivS32Test, DivideTwoScalarsS32) {
289 DivS32Params p = GetParam();
290 XlaBuilder builder(TestName());
291 Div(ConstantR0<int32>(&builder, p.dividend),
292 ConstantR0<int32>(&builder, p.divisor));
293
294 ComputeAndCompareR0<int32>(&builder, p.quotient, {});
295 }
296
XLA_TEST_P(DivS32Test,RemainderTwoScalarsS32)297 XLA_TEST_P(DivS32Test, RemainderTwoScalarsS32) {
298 DivS32Params p = GetParam();
299 XlaBuilder builder(TestName());
300 Rem(ConstantR0<int32>(&builder, p.dividend),
301 ConstantR0<int32>(&builder, p.divisor));
302
303 ComputeAndCompareR0<int32>(&builder, p.remainder, {});
304 }
305
XLA_TEST_P(DivS32Test,DivideTwoScalarsNonConstS32)306 XLA_TEST_P(DivS32Test, DivideTwoScalarsNonConstS32) {
307 DivS32Params p = GetParam();
308 XlaBuilder builder(TestName());
309 XlaOp dividend;
310 XlaOp divisor;
311 auto dividendd =
312 CreateR0Parameter<int32>(p.dividend, 0, "dividend", &builder, ÷nd);
313 auto divisord =
314 CreateR0Parameter<int32>(p.divisor, 1, "divisor", &builder, &divisor);
315 Div(dividend, divisor);
316
317 ComputeAndCompareR0<int32>(&builder, p.quotient,
318 {dividendd.get(), divisord.get()});
319 }
320
XLA_TEST_P(DivS32Test,RemainderTwoScalarsNonConstDivisorS32)321 XLA_TEST_P(DivS32Test, RemainderTwoScalarsNonConstDivisorS32) {
322 DivS32Params p = GetParam();
323 XlaBuilder builder(TestName());
324 XlaOp dividend;
325 XlaOp divisor;
326 auto dividendd =
327 CreateR0Parameter<int32>(p.dividend, 0, "dividend", &builder, ÷nd);
328 auto divisord =
329 CreateR0Parameter<int32>(p.divisor, 1, "divisor", &builder, &divisor);
330 Rem(dividend, divisor);
331
332 ComputeAndCompareR0<int32>(&builder, p.remainder,
333 {dividendd.get(), divisord.get()});
334 }
335
336 INSTANTIATE_TEST_CASE_P(
337 DivS32Test_Instantiation, DivS32Test,
338 ::testing::Values(
339 // Positive divisors.
340 DivS32Params{5, 2, 2, 1}, //
341 DivS32Params{-5, 2, -2, -1}, //
342 DivS32Params{17, 3, 5, 2}, //
343 DivS32Params{-17, 3, -5, -2}, //
344 // Negative divisors.
345 DivS32Params{5, -2, -2, 1}, //
346 DivS32Params{-5, -2, 2, -1}, //
347 DivS32Params{17, -3, -5, 2}, //
348 DivS32Params{-17, -3, 5, -2}, //
349 // Large positive divisors.
350 DivS32Params{INT32_MIN, 7919, -271181, -1309}, //
351 DivS32Params{INT32_MIN, INT32_MAX, -1, -1}, //
352 DivS32Params{INT32_MIN + 1, INT32_MAX, -1, 0}, //
353 DivS32Params{INT32_MIN + 2, INT32_MAX, 0, INT32_MIN + 2}, //
354 DivS32Params{INT32_MIN, 0x40000000, -2, 0}, //
355 DivS32Params{INT32_MIN + 1, 0x40000000, -1, -0x3fffffff}, //
356 // Large negative divisors.
357 DivS32Params{INT32_MIN, INT32_MIN, 1, 0}, //
358 DivS32Params{INT32_MIN, INT32_MIN + 1, 1, -1}, //
359 DivS32Params{INT32_MIN + 1, INT32_MIN, 0, INT32_MIN + 1}, //
360 DivS32Params{INT32_MAX, INT32_MIN, 0, INT32_MAX}, //
361 DivS32Params{INT32_MAX, INT32_MIN + 1, -1, 0}, //
362 DivS32Params{INT32_MIN, -0x40000000, 2, 0}, //
363 DivS32Params{INT32_MIN + 1, -0x40000000, 1, -0x3fffffff}));
364
XLA_TEST_F(ScalarComputationsTest,DivU32s)365 XLA_TEST_F(ScalarComputationsTest, DivU32s) {
366 // clang-format off
367 // Some interesting values to test.
368 std::vector<uint32> vals = {
369 0, 1, 2, 17, 101, 3333, 0x7FFFFFFF, 0x80000000, UINT32_MAX - 1, UINT32_MAX};
370 // clang-format on
371
372 XlaComputation div_computation;
373 {
374 XlaBuilder builder(TestName());
375
376 XlaOp dividend =
377 Parameter(&builder, 0, ShapeUtil::MakeShape(U32, {}), "dividend");
378 XlaOp divisor =
379 Parameter(&builder, 1, ShapeUtil::MakeShape(U32, {}), "divisor");
380 Div(dividend, divisor);
381 TF_ASSERT_OK_AND_ASSIGN(div_computation, builder.Build());
382 }
383
384 for (uint32 divisor : vals) {
385 if (divisor != 0) {
386 for (uint32 dividend : vals) {
387 auto dividend_literal = LiteralUtil::CreateR0<uint32>(dividend);
388 auto divisor_literal = LiteralUtil::CreateR0<uint32>(divisor);
389 TF_ASSERT_OK_AND_ASSIGN(auto dividend_data,
390 client_->TransferToServer(dividend_literal));
391 TF_ASSERT_OK_AND_ASSIGN(auto divisor_data,
392 client_->TransferToServer(divisor_literal));
393 auto actual_literal =
394 client_
395 ->ExecuteAndTransfer(div_computation,
396 {dividend_data.get(), divisor_data.get()},
397 &execution_options_)
398 .ConsumeValueOrDie();
399 auto expected_literal =
400 LiteralUtil::CreateR0<uint32>(dividend / divisor);
401 EXPECT_TRUE(LiteralTestUtil::Equal(expected_literal, actual_literal));
402 }
403 }
404 }
405 }
406
XLA_TEST_F(ScalarComputationsTest,RemU32s)407 XLA_TEST_F(ScalarComputationsTest, RemU32s) {
408 // clang-format off
409 // Some interesting values to test.
410 std::vector<uint32> vals = {
411 0, 1, 2, 17, 101, 3333, 0x7FFFFFFF, 0x80000000, UINT32_MAX - 1, UINT32_MAX};
412 // clang-format on
413
414 XlaComputation rem_computation;
415 {
416 XlaBuilder builder(TestName());
417
418 XlaOp dividend =
419 Parameter(&builder, 0, ShapeUtil::MakeShape(U32, {}), "dividend");
420 XlaOp divisor =
421 Parameter(&builder, 1, ShapeUtil::MakeShape(U32, {}), "divisor");
422 Rem(dividend, divisor);
423 TF_ASSERT_OK_AND_ASSIGN(rem_computation, builder.Build());
424 }
425
426 for (uint32 divisor : vals) {
427 if (divisor != 0) {
428 for (uint32 dividend : vals) {
429 auto dividend_literal = LiteralUtil::CreateR0<uint32>(dividend);
430 auto divisor_literal = LiteralUtil::CreateR0<uint32>(divisor);
431 TF_ASSERT_OK_AND_ASSIGN(auto dividend_data,
432 client_->TransferToServer(dividend_literal));
433 TF_ASSERT_OK_AND_ASSIGN(auto divisor_data,
434 client_->TransferToServer(divisor_literal));
435 auto actual_literal =
436 client_
437 ->ExecuteAndTransfer(rem_computation,
438 {dividend_data.get(), divisor_data.get()},
439 &execution_options_)
440 .ConsumeValueOrDie();
441 auto expected_literal =
442 LiteralUtil::CreateR0<uint32>(dividend % divisor);
443 EXPECT_TRUE(LiteralTestUtil::Equal(expected_literal, actual_literal));
444 }
445 }
446 }
447 }
448
XLA_TEST_F(ScalarComputationsTest,RemainderTwoScalarsNonConstDividendS32)449 XLA_TEST_F(ScalarComputationsTest, RemainderTwoScalarsNonConstDividendS32) {
450 XlaBuilder builder(TestName());
451 auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(S32, {}), "x");
452 Rem(x, ConstantR0<int32>(&builder, 80000));
453
454 Literal literal = LiteralUtil::CreateR0<int32>(87919);
455 TF_ASSERT_OK_AND_ASSIGN(auto input_data, client_->TransferToServer(literal));
456 ComputeAndCompareR0<int32>(&builder, 7919, {input_data.get()});
457 }
458
XLA_TEST_F(ScalarComputationsTest,DivideTwoScalarsU32)459 XLA_TEST_F(ScalarComputationsTest, DivideTwoScalarsU32) {
460 XlaBuilder builder(TestName());
461 // This verifies 0xFFFFFFFE / 2 = 0x7FFFFFFF. If XLA incorrectly treated U32
462 // as S32, it would output -2 / 2 = -1 (0xFFFFFFFF).
463 Div(ConstantR0<uint32>(&builder, 0xFFFFFFFE),
464 ConstantR0<uint32>(&builder, 2));
465
466 ComputeAndCompareR0<uint32>(&builder, 0x7FFFFFFF, {});
467 }
468
XLA_TEST_F(ScalarComputationsTest,RemTwoScalarsU32)469 XLA_TEST_F(ScalarComputationsTest, RemTwoScalarsU32) {
470 XlaBuilder builder(TestName());
471 Rem(ConstantR0<uint32>(&builder, 11), ConstantR0<uint32>(&builder, 3));
472
473 ComputeAndCompareR0<uint32>(&builder, 2, {});
474 }
475
XLA_TEST_F(ScalarComputationsTest,AndBool)476 XLA_TEST_F(ScalarComputationsTest, AndBool) {
477 for (bool x : {false, true}) {
478 for (bool y : {false, true}) {
479 XlaBuilder builder(TestName());
480 And(ConstantR0<bool>(&builder, x), ConstantR0<bool>(&builder, y));
481
482 ComputeAndCompareR0<bool>(&builder, x && y, {});
483 }
484 }
485 }
486
XLA_TEST_F(ScalarComputationsTest,AndS32)487 XLA_TEST_F(ScalarComputationsTest, AndS32) {
488 for (int32_t x : {0, 8}) {
489 for (int32_t y : {1, -16}) {
490 XlaBuilder builder(TestName());
491 And(ConstantR0<int32>(&builder, x), ConstantR0<int32>(&builder, y));
492
493 ComputeAndCompareR0<int32>(&builder, x & y, {});
494 }
495 }
496 }
497
XLA_TEST_F(ScalarComputationsTest,AndU32)498 XLA_TEST_F(ScalarComputationsTest, AndU32) {
499 for (uint32 x : {0, 8}) {
500 for (uint32 y : {1, 16}) {
501 XlaBuilder builder(TestName());
502 And(ConstantR0<uint32>(&builder, x), ConstantR0<uint32>(&builder, y));
503
504 ComputeAndCompareR0<uint32>(&builder, x & y, {});
505 }
506 }
507 }
508
XLA_TEST_F(ScalarComputationsTest,OrBool)509 XLA_TEST_F(ScalarComputationsTest, OrBool) {
510 for (bool x : {false, true}) {
511 for (bool y : {false, true}) {
512 XlaBuilder builder(TestName());
513 Or(ConstantR0<bool>(&builder, x), ConstantR0<bool>(&builder, y));
514
515 ComputeAndCompareR0<bool>(&builder, x || y, {});
516 }
517 }
518 }
519
XLA_TEST_F(ScalarComputationsTest,OrS32)520 XLA_TEST_F(ScalarComputationsTest, OrS32) {
521 for (int32_t x : {0, 8}) {
522 for (int32_t y : {1, -16}) {
523 XlaBuilder builder(TestName());
524 Or(ConstantR0<int32>(&builder, x), ConstantR0<int32>(&builder, y));
525
526 ComputeAndCompareR0<int32>(&builder, x | y, {});
527 }
528 }
529 }
530
XLA_TEST_F(ScalarComputationsTest,OrU32)531 XLA_TEST_F(ScalarComputationsTest, OrU32) {
532 for (uint32 x : {0, 8}) {
533 for (uint32 y : {1, 16}) {
534 XlaBuilder builder(TestName());
535 Or(ConstantR0<uint32>(&builder, x), ConstantR0<uint32>(&builder, y));
536
537 ComputeAndCompareR0<uint32>(&builder, x | y, {});
538 }
539 }
540 }
541
XLA_TEST_F(ScalarComputationsTest,NotBool)542 XLA_TEST_F(ScalarComputationsTest, NotBool) {
543 for (bool x : {false, true}) {
544 XlaBuilder builder(TestName());
545 Not(ConstantR0<bool>(&builder, x));
546
547 ComputeAndCompareR0<bool>(&builder, !x, {});
548 }
549 }
550
XLA_TEST_F(ScalarComputationsTest,NotS32)551 XLA_TEST_F(ScalarComputationsTest, NotS32) {
552 for (int32_t x : {-1, 0, 1}) {
553 XlaBuilder builder(TestName());
554 Not(ConstantR0<int32>(&builder, x));
555
556 ComputeAndCompareR0<int32>(&builder, ~x, {});
557 }
558 }
559
XLA_TEST_F(ScalarComputationsTest,NotU32)560 XLA_TEST_F(ScalarComputationsTest, NotU32) {
561 for (uint32 x : {0, 1, 2}) {
562 XlaBuilder builder(TestName());
563 Not(ConstantR0<uint32>(&builder, x));
564
565 ComputeAndCompareR0<uint32>(&builder, ~x, {});
566 }
567 }
568
XLA_TEST_F(ScalarComputationsTest,SelectScalarTrue)569 XLA_TEST_F(ScalarComputationsTest, SelectScalarTrue) {
570 XlaBuilder builder(TestName());
571 Select(ConstantR0<bool>(&builder, true), // The predicate.
572 ConstantR0<float>(&builder, 123.0f), // The value on true.
573 ConstantR0<float>(&builder, 42.0f)); // The value on false.
574
575 ComputeAndCompareR0<float>(&builder, 123.0f, {}, error_spec_);
576 }
577
XLA_TEST_F(ScalarComputationsTest,SelectScalarFalse)578 XLA_TEST_F(ScalarComputationsTest, SelectScalarFalse) {
579 XlaBuilder builder(TestName());
580 Select(ConstantR0<bool>(&builder, false), // The predicate.
581 ConstantR0<float>(&builder, 123.0f), // The value on true.
582 ConstantR0<float>(&builder, 42.0f)); // The value on false.
583
584 ComputeAndCompareR0<float>(&builder, 42.0f, {}, error_spec_);
585 }
586
587 // This test is an explicit version of what is happening in the following
588 // templatized comparison tests.
XLA_TEST_F(ScalarComputationsTest,CompareGtScalar)589 XLA_TEST_F(ScalarComputationsTest, CompareGtScalar) {
590 XlaBuilder builder(TestName());
591 Gt(ConstantR0<float>(&builder, 2.0f), ConstantR0<float>(&builder, 1.0f));
592
593 ComputeAndCompareR0<bool>(&builder, true, {});
594 }
595
596 // S32 comparisons.
XLA_TEST_F(ScalarComputationsTest,CompareEqS32Greater)597 XLA_TEST_F(ScalarComputationsTest, CompareEqS32Greater) {
598 TestCompare<int32>(2, 1, false, &Eq);
599 }
XLA_TEST_F(ScalarComputationsTest,CompareEqS32Equal)600 XLA_TEST_F(ScalarComputationsTest, CompareEqS32Equal) {
601 TestCompare<int32>(3, 3, true, &Eq);
602 }
603
XLA_TEST_F(ScalarComputationsTest,CompareNeS32)604 XLA_TEST_F(ScalarComputationsTest, CompareNeS32) {
605 TestCompare<int32>(2, 1, true, &Ne);
606 }
607
XLA_TEST_F(ScalarComputationsTest,CompareGeS32)608 XLA_TEST_F(ScalarComputationsTest, CompareGeS32) {
609 TestCompare<int32>(2, 1, true, &Ge);
610 }
611
XLA_TEST_F(ScalarComputationsTest,CompareGtS32)612 XLA_TEST_F(ScalarComputationsTest, CompareGtS32) {
613 TestCompare<int32>(1, 5, false, &Gt);
614 }
615
XLA_TEST_F(ScalarComputationsTest,CompareLeS32)616 XLA_TEST_F(ScalarComputationsTest, CompareLeS32) {
617 TestCompare<int32>(2, 1, false, &Le);
618 }
619
XLA_TEST_F(ScalarComputationsTest,CompareLtS32)620 XLA_TEST_F(ScalarComputationsTest, CompareLtS32) {
621 TestCompare<int32>(9, 7, false, &Lt);
622 TestCompare<int32>(std::numeric_limits<int32>::min(),
623 std::numeric_limits<int32>::max(), true, &Lt);
624 }
625
626 // U32 comparisons.
XLA_TEST_F(ScalarComputationsTest,CompareEqU32False)627 XLA_TEST_F(ScalarComputationsTest, CompareEqU32False) {
628 TestCompare<uint32>(2, 1, false, &Eq);
629 }
630
XLA_TEST_F(ScalarComputationsTest,CompareNeU32)631 XLA_TEST_F(ScalarComputationsTest, CompareNeU32) {
632 TestCompare<uint32>(2, 1, true, &Ne);
633 }
634
XLA_TEST_F(ScalarComputationsTest,CompareGeU32Greater)635 XLA_TEST_F(ScalarComputationsTest, CompareGeU32Greater) {
636 TestCompare<uint32>(2, 1, true, &Ge);
637 }
638
XLA_TEST_F(ScalarComputationsTest,CompareGeU32Equal)639 XLA_TEST_F(ScalarComputationsTest, CompareGeU32Equal) {
640 TestCompare<uint32>(3, 3, true, &Ge);
641 }
642
XLA_TEST_F(ScalarComputationsTest,CompareGtU32)643 XLA_TEST_F(ScalarComputationsTest, CompareGtU32) {
644 TestCompare<uint32>(1, 5, false, &Gt);
645 TestCompare<uint32>(5, 5, false, &Gt);
646 TestCompare<uint32>(5, 1, true, &Gt);
647 }
648
XLA_TEST_F(ScalarComputationsTest,CompareLeU32)649 XLA_TEST_F(ScalarComputationsTest, CompareLeU32) {
650 TestCompare<uint32>(2, 1, false, &Le);
651 }
652
XLA_TEST_F(ScalarComputationsTest,CompareLtU32)653 XLA_TEST_F(ScalarComputationsTest, CompareLtU32) {
654 TestCompare<uint32>(9, 7, false, &Lt);
655 TestCompare<uint32>(0, std::numeric_limits<uint32>::max(), true, &Lt);
656 }
657
658 // F32 comparisons.
XLA_TEST_F(ScalarComputationsTest,CompareEqF32False)659 XLA_TEST_F(ScalarComputationsTest, CompareEqF32False) {
660 TestCompare<float>(2.0, 1.3, false, &Eq);
661 }
662
XLA_TEST_F(ScalarComputationsTest,CompareNeF32)663 XLA_TEST_F(ScalarComputationsTest, CompareNeF32) {
664 TestCompare<float>(2.0, 1.3, true, &Ne);
665 }
666
XLA_TEST_F(ScalarComputationsTest,CompareGeF32Greater)667 XLA_TEST_F(ScalarComputationsTest, CompareGeF32Greater) {
668 TestCompare<float>(2.0, 1.9, true, &Ge);
669 }
XLA_TEST_F(ScalarComputationsTest,CompareGeF32Equal)670 XLA_TEST_F(ScalarComputationsTest, CompareGeF32Equal) {
671 TestCompare<float>(3.5, 3.5, true, &Ge);
672 }
673
XLA_TEST_F(ScalarComputationsTest,CompareGtF32)674 XLA_TEST_F(ScalarComputationsTest, CompareGtF32) {
675 TestCompare<float>(1.0, 5.2, false, &Gt);
676 }
677
XLA_TEST_F(ScalarComputationsTest,CompareLeF32)678 XLA_TEST_F(ScalarComputationsTest, CompareLeF32) {
679 TestCompare<float>(2.0, 1.2, false, &Le);
680 }
681
XLA_TEST_F(ScalarComputationsTest,CompareLtF32)682 XLA_TEST_F(ScalarComputationsTest, CompareLtF32) {
683 TestCompare<float>(9.0, 7.2, false, &Lt);
684 }
685
686 // F32 comparisons with exceptional values. The test names encode the
687 // left/right operands at the end, and use Minf and Mzero for -inf and -0.0.
XLA_TEST_F(ScalarComputationsTest,CompareLtF32MinfMzero)688 XLA_TEST_F(ScalarComputationsTest, CompareLtF32MinfMzero) {
689 TestCompare<float>(-INFINITY, -0.0, true, &Lt);
690 }
XLA_TEST_F(ScalarComputationsTest,CompareLtF32MzeroZero)691 XLA_TEST_F(ScalarComputationsTest, CompareLtF32MzeroZero) {
692 // Comparisons of 0.0 to -0.0 consider them equal in IEEE 754.
693 TestCompare<float>(-0.0, 0.0, false, &Lt);
694 }
XLA_TEST_F(ScalarComputationsTest,CompareLtF32ZeroInf)695 XLA_TEST_F(ScalarComputationsTest, CompareLtF32ZeroInf) {
696 TestCompare<float>(0.0, INFINITY, true, &Lt);
697 }
698
XLA_TEST_F(ScalarComputationsTest,CompareGeF32MinfMzero)699 XLA_TEST_F(ScalarComputationsTest, CompareGeF32MinfMzero) {
700 TestCompare<float>(-INFINITY, -0.0, false, &Ge);
701 }
XLA_TEST_F(ScalarComputationsTest,CompareGeF32MzeroZero)702 XLA_TEST_F(ScalarComputationsTest, CompareGeF32MzeroZero) {
703 // Comparisons of 0.0 to -0.0 consider them equal in IEEE 754.
704 TestCompare<float>(-0.0, 0.0, true, &Ge);
705 }
XLA_TEST_F(ScalarComputationsTest,CompareGeF32ZeroInf)706 XLA_TEST_F(ScalarComputationsTest, CompareGeF32ZeroInf) {
707 TestCompare<float>(0.0, INFINITY, false, &Ge);
708 }
709
XLA_TEST_F(ScalarComputationsTest,ExpScalar)710 XLA_TEST_F(ScalarComputationsTest, ExpScalar) {
711 XlaBuilder builder(TestName());
712 Exp(ConstantR0<float>(&builder, 2.0f));
713
714 ComputeAndCompareR0<float>(&builder, 7.3890562, {}, error_spec_);
715 }
716
XLA_TEST_F(ScalarComputationsTest,LogScalar)717 XLA_TEST_F(ScalarComputationsTest, LogScalar) {
718 XlaBuilder builder("log");
719 Log(ConstantR0<float>(&builder, 2.0f));
720
721 ComputeAndCompareR0<float>(&builder, 0.6931471, {}, error_spec_);
722 }
723
XLA_TEST_F(ScalarComputationsTest,TanhScalar)724 XLA_TEST_F(ScalarComputationsTest, TanhScalar) {
725 XlaBuilder builder(TestName());
726 Tanh(ConstantR0<float>(&builder, 2.0f));
727
728 ComputeAndCompareR0<float>(&builder, 0.96402758, {}, error_spec_);
729 }
730
XLA_TEST_F(ScalarComputationsTest,TanhDoubleScalar)731 XLA_TEST_F(ScalarComputationsTest, TanhDoubleScalar) {
732 XlaBuilder builder(TestName());
733 Tanh(ConstantR0<double>(&builder, 2.0));
734
735 ComputeAndCompareR0<double>(&builder, 0.96402758, {}, error_spec_);
736 }
737
XLA_TEST_F(ScalarComputationsTest,PowScalar)738 XLA_TEST_F(ScalarComputationsTest, PowScalar) {
739 XlaBuilder builder(TestName());
740 Pow(ConstantR0<float>(&builder, 2.0f), ConstantR0<float>(&builder, 3.0f));
741
742 ComputeAndCompareR0<float>(&builder, 8.0, {}, error_spec_);
743 }
744
XLA_TEST_F(ScalarComputationsTest,CbrtScalar)745 XLA_TEST_F(ScalarComputationsTest, CbrtScalar) {
746 XlaBuilder builder(TestName());
747 Cbrt(ConstantR0<float>(&builder, 2.0f));
748
749 ComputeAndCompare(&builder, {}, error_spec_);
750 }
751
XLA_TEST_F(ScalarComputationsTest,ClampScalarHighS32)752 XLA_TEST_F(ScalarComputationsTest, ClampScalarHighS32) {
753 XlaBuilder builder(TestName());
754 Clamp(ConstantR0<int32>(&builder, -1), // The lower bound.
755 ConstantR0<int32>(&builder, 5), // The operand to be clamped.
756 ConstantR0<int32>(&builder, 3)); // The upper bound.
757
758 ComputeAndCompareR0<int32>(&builder, 3, {});
759 }
760
XLA_TEST_F(ScalarComputationsTest,ClampScalarMiddleS32)761 XLA_TEST_F(ScalarComputationsTest, ClampScalarMiddleS32) {
762 XlaBuilder builder(TestName());
763 Clamp(ConstantR0<int32>(&builder, -1), // The lower bound.
764 ConstantR0<int32>(&builder, 2), // The operand to be clamped.
765 ConstantR0<int32>(&builder, 3)); // The upper bound.
766
767 ComputeAndCompareR0<int32>(&builder, 2, {});
768 }
769
XLA_TEST_F(ScalarComputationsTest,ClampScalarLowS32)770 XLA_TEST_F(ScalarComputationsTest, ClampScalarLowS32) {
771 XlaBuilder builder(TestName());
772 Clamp(ConstantR0<int32>(&builder, -1), // The lower bound.
773 ConstantR0<int32>(&builder, -5), // The operand to be clamped.
774 ConstantR0<int32>(&builder, 3)); // The upper bound.
775
776 ComputeAndCompareR0<int32>(&builder, -1, {});
777 }
778
XLA_TEST_F(ScalarComputationsTest,ClampScalarHighU32)779 XLA_TEST_F(ScalarComputationsTest, ClampScalarHighU32) {
780 XlaBuilder builder(TestName());
781 Clamp(ConstantR0<uint32>(&builder, 1), // The lower bound.
782 ConstantR0<uint32>(&builder, 5), // The operand to be clamped.
783 ConstantR0<uint32>(&builder, 3)); // The upper bound.
784
785 ComputeAndCompareR0<uint32>(&builder, 3, {});
786 }
787
XLA_TEST_F(ScalarComputationsTest,ClampScalarMiddleU32)788 XLA_TEST_F(ScalarComputationsTest, ClampScalarMiddleU32) {
789 XlaBuilder builder(TestName());
790 Clamp(ConstantR0<uint32>(&builder, 1), // The lower bound.
791 ConstantR0<uint32>(&builder, 2), // The operand to be clamped.
792 ConstantR0<uint32>(&builder, 3)); // The upper bound.
793
794 ComputeAndCompareR0<uint32>(&builder, 2, {});
795 }
796
XLA_TEST_F(ScalarComputationsTest,ClampScalarLowU32)797 XLA_TEST_F(ScalarComputationsTest, ClampScalarLowU32) {
798 XlaBuilder builder(TestName());
799 Clamp(ConstantR0<uint32>(&builder, 1), // The lower bound.
800 ConstantR0<uint32>(&builder, 0), // The operand to be clamped.
801 ConstantR0<uint32>(&builder, 3)); // The upper bound.
802
803 ComputeAndCompareR0<uint32>(&builder, 1, {});
804 }
805
XLA_TEST_F(ScalarComputationsTest,ClampScalarHighF32)806 XLA_TEST_F(ScalarComputationsTest, ClampScalarHighF32) {
807 XlaBuilder builder(TestName());
808 Clamp(ConstantR0<float>(&builder, 2.0f), // The lower bound.
809 ConstantR0<float>(&builder, 5.0f), // The operand to be clamped.
810 ConstantR0<float>(&builder, 3.0f)); // The upper bound.
811
812 ComputeAndCompareR0<float>(&builder, 3.0, {}, error_spec_);
813 }
814
XLA_TEST_F(ScalarComputationsTest,ClampScalarMiddleF32)815 XLA_TEST_F(ScalarComputationsTest, ClampScalarMiddleF32) {
816 XlaBuilder builder(TestName());
817 Clamp(ConstantR0<float>(&builder, 2.0f), // The lower bound.
818 ConstantR0<float>(&builder, 2.5f), // The operand to be clamped.
819 ConstantR0<float>(&builder, 3.0f)); // The upper bound.
820
821 ComputeAndCompareR0<float>(&builder, 2.5, {}, error_spec_);
822 }
823
XLA_TEST_F(ScalarComputationsTest,ClampScalarLowF32)824 XLA_TEST_F(ScalarComputationsTest, ClampScalarLowF32) {
825 XlaBuilder builder(TestName());
826 Clamp(ConstantR0<float>(&builder, 2.0f), // The lower bound.
827 ConstantR0<float>(&builder, -5.0f), // The operand to be clamped.
828 ConstantR0<float>(&builder, 3.0f)); // The upper bound.
829
830 ComputeAndCompareR0<float>(&builder, 2.0, {}, error_spec_);
831 }
832
XLA_TEST_F(ScalarComputationsTest,MinS32Above)833 XLA_TEST_F(ScalarComputationsTest, MinS32Above) {
834 TestMinMax<int32>(10, 3, 3, &Min);
835 }
836
XLA_TEST_F(ScalarComputationsTest,MinS32Below)837 XLA_TEST_F(ScalarComputationsTest, MinS32Below) {
838 TestMinMax<int32>(-100, 3, -100, &Min);
839 }
840
XLA_TEST_F(ScalarComputationsTest,MaxS32Above)841 XLA_TEST_F(ScalarComputationsTest, MaxS32Above) {
842 TestMinMax<int32>(10, 3, 10, &Max);
843 }
844
XLA_TEST_F(ScalarComputationsTest,MaxS32Below)845 XLA_TEST_F(ScalarComputationsTest, MaxS32Below) {
846 TestMinMax<int32>(-100, 3, 3, &Max);
847 }
848
XLA_TEST_F(ScalarComputationsTest,MinU32Above)849 XLA_TEST_F(ScalarComputationsTest, MinU32Above) {
850 const uint32 large = std::numeric_limits<int32>::max();
851 TestMinMax<uint32>(large, 3, 3, &Min);
852 }
853
XLA_TEST_F(ScalarComputationsTest,MinU32Below)854 XLA_TEST_F(ScalarComputationsTest, MinU32Below) {
855 TestMinMax<uint32>(0, 5, 0, &Min);
856 }
857
XLA_TEST_F(ScalarComputationsTest,MaxU32Above)858 XLA_TEST_F(ScalarComputationsTest, MaxU32Above) {
859 const uint32 large = std::numeric_limits<int32>::max();
860 TestMinMax<uint32>(large, 3, large, &Max);
861 }
862
XLA_TEST_F(ScalarComputationsTest,MaxU32Below)863 XLA_TEST_F(ScalarComputationsTest, MaxU32Below) {
864 TestMinMax<uint32>(0, 5, 5, &Max);
865 }
866
XLA_TEST_F(ScalarComputationsTest,MinF32Above)867 XLA_TEST_F(ScalarComputationsTest, MinF32Above) {
868 TestMinMax<float>(10.1f, 3.1f, 3.1f, &Min);
869 }
870
XLA_TEST_F(ScalarComputationsTest,MinF32Below)871 XLA_TEST_F(ScalarComputationsTest, MinF32Below) {
872 TestMinMax<float>(-100.1f, 3.1f, -100.1f, &Min);
873 }
874
XLA_TEST_F(ScalarComputationsTest,MinPropagatesNan)875 XLA_TEST_F(ScalarComputationsTest, MinPropagatesNan) {
876 SetFastMathDisabled(true);
877 TestMinMax<float>(NAN, 3.1f, NAN, &Min);
878 TestMinMax<float>(-3.1f, NAN, NAN, &Min);
879 }
880
XLA_TEST_F(ScalarComputationsTest,MaxF32Above)881 XLA_TEST_F(ScalarComputationsTest, MaxF32Above) {
882 TestMinMax<float>(10.1f, 3.1f, 10.1f, &Max);
883 }
884
XLA_TEST_F(ScalarComputationsTest,MaxF32Below)885 XLA_TEST_F(ScalarComputationsTest, MaxF32Below) {
886 TestMinMax<float>(-100.1f, 3.1f, 3.1f, &Max);
887 }
888
XLA_TEST_F(ScalarComputationsTest,MaxPropagatesNan)889 XLA_TEST_F(ScalarComputationsTest, MaxPropagatesNan) {
890 SetFastMathDisabled(true);
891 TestMinMax<float>(NAN, 3.1f, NAN, &Max);
892 TestMinMax<float>(-3.1f, NAN, NAN, &Max);
893 }
894
XLA_TEST_F(ScalarComputationsTest,ComplicatedArithmeticExpressionF32)895 XLA_TEST_F(ScalarComputationsTest, ComplicatedArithmeticExpressionF32) {
896 // Compute the expression (1 * (3 - 1) * (7 + 0) - 4) / 20.
897 XlaBuilder b(TestName());
898 Div(Sub(Mul(ConstantR0<float>(&b, 1),
899 Mul(Sub(ConstantR0<float>(&b, 3), ConstantR0<float>(&b, 1)),
900 Add(ConstantR0<float>(&b, 7), ConstantR0<float>(&b, 0)))),
901 ConstantR0<float>(&b, 4)),
902 ConstantR0<float>(&b, 20));
903
904 ComputeAndCompareR0<float>(&b, 0.5, {}, error_spec_);
905 }
906
XLA_TEST_F(ScalarComputationsTest,ComplicatedArithmeticExpressionS32)907 XLA_TEST_F(ScalarComputationsTest, ComplicatedArithmeticExpressionS32) {
908 // Compute the expression 1 * (3 - 1) * (7 + 0) - 4.
909 XlaBuilder b(TestName());
910 Sub(Mul(ConstantR0<int32>(&b, 1),
911 Mul(Sub(ConstantR0<int32>(&b, 3), ConstantR0<int32>(&b, 1)),
912 Add(ConstantR0<int32>(&b, 7), ConstantR0<int32>(&b, 0)))),
913 ConstantR0<int32>(&b, 4));
914
915 ComputeAndCompareR0<int32>(&b, 10, {});
916 }
917
XLA_TEST_F(ScalarComputationsTest,RoundScalar)918 XLA_TEST_F(ScalarComputationsTest, RoundScalar) {
919 XlaBuilder builder(TestName());
920 Round(ConstantR0<float>(&builder, 1.4f));
921
922 ComputeAndCompareR0<float>(&builder, 1.0f, {}, error_spec_);
923 }
924
925 } // namespace
926 } // namespace xla
927