1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <array>
17 #include <cstdint>
18 #include <limits>
19 #include <memory>
20 #include <vector>
21
22 #include "absl/algorithm/container.h"
23 #include "absl/base/casts.h"
24 #include "tensorflow/compiler/xla/client/local_client.h"
25 #include "tensorflow/compiler/xla/client/xla_builder.h"
26 #include "tensorflow/compiler/xla/shape_util.h"
27 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
28 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
29 #include "tensorflow/compiler/xla/tests/test_macros.h"
30 #include "tensorflow/compiler/xla/xla_data.pb.h"
31 #include "tensorflow/core/lib/math/math_util.h"
32 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
33 #include "tensorflow/core/platform/test.h"
34 #include "tensorflow/core/platform/types.h"
35
36 namespace xla {
37 namespace {
38
39 class ConvertTest : public ClientLibraryTestBase {
40 public:
ConvertTest(se::Platform * platform=nullptr)41 explicit ConvertTest(se::Platform* platform = nullptr)
42 : ClientLibraryTestBase(platform) {
43 mutable_debug_options()->add_xla_disable_hlo_passes("algsimp");
44 mutable_debug_options()->add_xla_disable_hlo_passes("inline");
45 }
46 };
47
TEST_F(ConvertTest,ConvertR1S32ToR1S32)48 TEST_F(ConvertTest, ConvertR1S32ToR1S32) {
49 XlaBuilder builder(TestName());
50 auto a = ConstantR1<int32>(&builder, {42, 64});
51 ConvertElementType(a, S32);
52
53 std::vector<int32> expected = {42, 64};
54 ComputeAndCompareR1<int32>(&builder, expected, {});
55 }
56
TEST_F(ConvertTest,ConvertR1S32ToR1U32)57 TEST_F(ConvertTest, ConvertR1S32ToR1U32) {
58 XlaBuilder builder(TestName());
59 auto a = ConstantR1<int32>(&builder, {42, 64});
60 ConvertElementType(a, U32);
61
62 std::vector<uint32> expected = {42, 64};
63 ComputeAndCompareR1<uint32>(&builder, expected, {});
64 }
65
TEST_F(ConvertTest,ConvertR1S32ToR1PRED)66 TEST_F(ConvertTest, ConvertR1S32ToR1PRED) {
67 XlaBuilder builder(TestName());
68 auto a = ConstantR1<int32>(&builder, {42, 0, -64});
69 ConvertElementType(a, PRED);
70
71 std::array<bool, 3> expected = {true, false, true};
72 ComputeAndCompareR1<bool>(&builder, expected, {});
73 }
74
TEST_F(ConvertTest,ConvertR1U32ToR1U32)75 TEST_F(ConvertTest, ConvertR1U32ToR1U32) {
76 XlaBuilder builder(TestName());
77 auto a = ConstantR1<uint32>(&builder, {42, 64});
78 ConvertElementType(a, U32);
79
80 std::vector<uint32> expected = {42, 64};
81 ComputeAndCompareR1<uint32>(&builder, expected, {});
82 }
83
TEST_F(ConvertTest,ConvertR1U32ToR1S32)84 TEST_F(ConvertTest, ConvertR1U32ToR1S32) {
85 XlaBuilder builder(TestName());
86 auto a = ConstantR1<uint32>(&builder, {42, 64});
87 ConvertElementType(a, S32);
88
89 std::vector<int32> expected = {42, 64};
90 ComputeAndCompareR1<int32>(&builder, expected, {});
91 }
92
TEST_F(ConvertTest,ConvertR1U32ToR1PRED)93 TEST_F(ConvertTest, ConvertR1U32ToR1PRED) {
94 XlaBuilder builder(TestName());
95 auto a = ConstantR1<uint32>(&builder, {42, 0, 64});
96 ConvertElementType(a, PRED);
97
98 std::array<bool, 3> expected = {true, false, true};
99 ComputeAndCompareR1<bool>(&builder, expected, {});
100 }
101
TEST_F(ConvertTest,ConvertR1F32ToR1F32)102 TEST_F(ConvertTest, ConvertR1F32ToR1F32) {
103 XlaBuilder builder(TestName());
104 auto a = ConstantR1<float>(&builder, {42.0f, 64.0f});
105 ConvertElementType(a, F32);
106
107 std::vector<float> expected = {42.0f, 64.0f};
108 ComputeAndCompareR1<float>(&builder, expected, {});
109 }
110
TEST_F(ConvertTest,ConvertR1F32ToR1PRED)111 TEST_F(ConvertTest, ConvertR1F32ToR1PRED) {
112 XlaBuilder builder(TestName());
113 auto a = ConstantR1<float>(&builder, {42.0f, 0.0f, 64.0f});
114 ConvertElementType(a, PRED);
115
116 std::array<bool, 3> expected = {true, false, true};
117 ComputeAndCompareR1<bool>(&builder, expected, {});
118 }
119
TEST_F(ConvertTest,ConvertR1S32ToR1F32)120 TEST_F(ConvertTest, ConvertR1S32ToR1F32) {
121 XlaBuilder builder(TestName());
122 auto a = ConstantR1<int32>(&builder, {42, 64});
123 ConvertElementType(a, F32);
124
125 std::vector<float> expected = {42.0f, 64.0f};
126 ComputeAndCompareR1<float>(&builder, expected, {});
127 }
128
TEST_F(ConvertTest,ConvertR1PREDToR1S32)129 TEST_F(ConvertTest, ConvertR1PREDToR1S32) {
130 XlaBuilder builder(TestName());
131 auto a = ConstantR1<bool>(&builder, {true, false, true});
132 ConvertElementType(a, S32);
133
134 std::vector<int32> expected = {1, 0, 1};
135 ComputeAndCompareR1<int32>(&builder, expected, {});
136 }
137
TEST_F(ConvertTest,ConvertR1PREDToR1U32)138 TEST_F(ConvertTest, ConvertR1PREDToR1U32) {
139 XlaBuilder builder(TestName());
140 auto a = ConstantR1<bool>(&builder, {true, false, true});
141 ConvertElementType(a, U32);
142
143 std::vector<uint32> expected = {1, 0, 1};
144 ComputeAndCompareR1<uint32>(&builder, expected, {});
145 }
146
TEST_F(ConvertTest,ConvertR1PREDToR1F32)147 TEST_F(ConvertTest, ConvertR1PREDToR1F32) {
148 XlaBuilder builder(TestName());
149 auto a = ConstantR1<bool>(&builder, {true, false, true});
150 ConvertElementType(a, F32);
151
152 std::vector<float> expected = {1., 0., 1.};
153 ComputeAndCompareR1<float>(&builder, expected, {});
154 }
155
XLA_TEST_F(ConvertTest,ConvertR1S0S32ToR1S0F32)156 XLA_TEST_F(ConvertTest, ConvertR1S0S32ToR1S0F32) {
157 XlaBuilder builder(TestName());
158 auto a = ConstantR1<int32>(&builder, {});
159 ConvertElementType(a, F32);
160
161 std::vector<float> expected = {};
162 ComputeAndCompareR1<float>(&builder, expected, {});
163 }
164
TEST_F(ConvertTest,ConvertR1F32ToR1S32)165 TEST_F(ConvertTest, ConvertR1F32ToR1S32) {
166 XlaBuilder builder(TestName());
167 auto a = ConstantR1<float>(&builder, {42.6, 64.4});
168 ConvertElementType(a, S32);
169
170 std::vector<int32> expected = {42, 64};
171 ComputeAndCompareR1<int32>(&builder, expected, {});
172 }
173
XLA_TEST_F(ConvertTest,ConvertR1S64ToR1F32)174 XLA_TEST_F(ConvertTest, ConvertR1S64ToR1F32) {
175 XlaBuilder builder(TestName());
176 std::vector<int64> arg{
177 -9223371216516022272,
178 -2,
179 -1,
180 -0x7FFFFFFF,
181 -0x80000000,
182 0,
183 1,
184 2,
185 1073742145,
186 1073742656,
187 0x7FFFFFFF,
188 0x80000000,
189 826720496944058148,
190 4296062029846194332,
191 0x0007FB72E4000000LL,
192 0x0007FB72E4000001LL,
193 0x0007FB72E6000000LL,
194 0x0007FB72E7000000LL,
195 0x0007FB72E7FFFFFFLL,
196 0x0007FB72E8000000LL,
197 0x0007FB72E8000001LL,
198 0x0007FB72EA000000LL,
199 0x0007FB72EB000000LL,
200 0x0007FB72EBFFFFFFLL,
201 0x0007FB72EC000000LL,
202 0x7FFFFF0000000000LL,
203 0x7FFFFF8000000000LL,
204 0x7FFFFFFFFFFFFF00,
205 static_cast<int64>(0xFFFFFFFFFFFFFFFF),
206 static_cast<int64>(0x0000f234e67e0001LL),
207 static_cast<int64>(0x8000000000000000),
208 static_cast<int64>(0x8000000000000000LL),
209 static_cast<int64>(0x8000000000000001LL),
210 static_cast<int64>(0x8000008000000000LL),
211 static_cast<int64>(0x8000010000000000LL),
212 };
213 Literal arg_literal = LiteralUtil::CreateR1<int64>({arg});
214 auto arg_param = Parameter(&builder, 0, arg_literal.shape(), "arg_param");
215 std::unique_ptr<GlobalData> arg_data =
216 client_->TransferToServer(arg_literal).ConsumeValueOrDie();
217
218 ConvertElementType(arg_param, F32);
219
220 std::vector<float> expected(arg.size());
221 for (int64 i = 0; i < arg.size(); ++i) {
222 expected[i] = static_cast<float>(arg[i]);
223 }
224 ComputeAndCompareR1<float>(&builder, expected, {arg_data.get()});
225 }
226
XLA_TEST_F(ConvertTest,ConvertR1U32ToR1F32)227 XLA_TEST_F(ConvertTest, ConvertR1U32ToR1F32) {
228 XlaBuilder builder(TestName());
229 std::vector<uint32> arg{0, 1, 0x1000, 0x7fffffff,
230 0x80000000, 0x80000001, 0x80000002, 0x80000003,
231 0x80000080, 0x80000081, 0x80000082, 0xFFFFFFFF};
232 Literal arg_literal = LiteralUtil::CreateR1<uint32>({arg});
233 auto arg_param = Parameter(&builder, 0, arg_literal.shape(), "arg_param");
234 std::unique_ptr<GlobalData> arg_data =
235 client_->TransferToServer(arg_literal).ConsumeValueOrDie();
236
237 ConvertElementType(arg_param, F32);
238
239 std::vector<float> expected(arg.size());
240 for (int64 i = 0; i < arg.size(); ++i) {
241 expected[i] = static_cast<float>(arg[i]);
242 }
243 ComputeAndCompareR1<float>(&builder, expected, {arg_data.get()});
244 }
245
XLA_TEST_F(ConvertTest,ConvertR1F32ToR1U32)246 XLA_TEST_F(ConvertTest, ConvertR1F32ToR1U32) {
247 XlaBuilder builder(TestName());
248 std::vector<float> arg{0.0f, 1.0f, 16777216.0f,
249 16777218.0f, 2147483647.0f, 4294967040.0f};
250 Literal arg_literal = LiteralUtil::CreateR1<float>({arg});
251 auto arg_param = Parameter(&builder, 0, arg_literal.shape(), "arg_param");
252 std::unique_ptr<GlobalData> arg_data =
253 client_->TransferToServer(arg_literal).ConsumeValueOrDie();
254
255 ConvertElementType(arg_param, U32);
256
257 std::vector<uint32> expected(arg.size());
258 for (int64 i = 0; i < arg.size(); ++i) {
259 expected[i] = static_cast<uint32>(arg[i]);
260 }
261 ComputeAndCompareR1<uint32>(&builder, expected, {arg_data.get()});
262 }
263
XLA_TEST_F(ConvertTest,ConvertR1U32ToR1S64)264 XLA_TEST_F(ConvertTest, ConvertR1U32ToR1S64) {
265 XlaBuilder builder(TestName());
266 std::vector<uint32> arg{0, 1, 0x1000, 0x7fffffff, 0x80000082, 0xFFFFFFFF};
267 Literal arg_literal = LiteralUtil::CreateR1<uint32>({arg});
268 auto arg_param = Parameter(&builder, 0, arg_literal.shape(), "arg_param");
269 std::unique_ptr<GlobalData> arg_data =
270 client_->TransferToServer(arg_literal).ConsumeValueOrDie();
271
272 ConvertElementType(arg_param, S64);
273
274 std::vector<int64> expected(arg.size());
275 for (int64 i = 0; i < arg.size(); ++i) {
276 expected[i] = static_cast<int64>(arg[i]);
277 }
278 ComputeAndCompareR1<int64>(&builder, expected, {arg_data.get()});
279 }
280
XLA_TEST_F(ConvertTest,ConvertR1S32ToR1S64)281 XLA_TEST_F(ConvertTest, ConvertR1S32ToR1S64) {
282 XlaBuilder builder(TestName());
283 std::vector<int32> arg{0, 1, 0x1000, -1, -0x1000};
284 Literal arg_literal = LiteralUtil::CreateR1<int32>({arg});
285 auto arg_param = Parameter(&builder, 0, arg_literal.shape(), "arg_param");
286 std::unique_ptr<GlobalData> arg_data =
287 client_->TransferToServer(arg_literal).ConsumeValueOrDie();
288
289 ConvertElementType(arg_param, S64);
290
291 std::vector<int64> expected(arg.size());
292 for (int64 i = 0; i < arg.size(); ++i) {
293 expected[i] = static_cast<int64>(arg[i]);
294 }
295 ComputeAndCompareR1<int64>(&builder, expected, {arg_data.get()});
296 }
297
XLA_TEST_F(ConvertTest,ConvertR1F32ToR1S64)298 XLA_TEST_F(ConvertTest, ConvertR1F32ToR1S64) {
299 XlaBuilder builder(TestName());
300 // Test cases from compiler_rt library.
301 std::vector<float> arg{0.0f,
302 0.5f,
303 0.99f,
304 1.0f,
305 1.5f,
306 1.99f,
307 2.0f,
308 2.01f,
309 2147483648.f,
310 -0.5f,
311 -0.99f,
312 -1.0f,
313 -1.5f,
314 -1.99f,
315 -2.0f,
316 -2.01f,
317 9223371487098961920.f,
318 9223370937343148032.f,
319 -9223371487098961920.f,
320 -9223370937343148032.f};
321 Literal arg_literal = LiteralUtil::CreateR1<float>({arg});
322 auto arg_param = Parameter(&builder, 0, arg_literal.shape(), "arg_param");
323 std::unique_ptr<GlobalData> arg_data =
324 client_->TransferToServer(arg_literal).ConsumeValueOrDie();
325
326 ConvertElementType(arg_param, S64);
327
328 std::vector<int64> expected(arg.size());
329 for (int64 i = 0; i < arg.size(); ++i) {
330 expected[i] = static_cast<int64>(arg[i]);
331 }
332 ComputeAndCompareR1<int64>(&builder, expected, {arg_data.get()});
333 }
334
XLA_TEST_F(ConvertTest,ConvertR1U8ToR1F32)335 XLA_TEST_F(ConvertTest, ConvertR1U8ToR1F32) {
336 XlaBuilder builder(TestName());
337 auto a = ConstantR1<uint8_t>(&builder, {32, 64});
338 ConvertElementType(a, F32);
339
340 std::vector<float> expected = {32.0, 64.0};
341 ComputeAndCompareR1<float>(&builder, expected, {});
342 }
343
XLA_TEST_F(ConvertTest,ConvertR1U8ToR1S32)344 XLA_TEST_F(ConvertTest, ConvertR1U8ToR1S32) {
345 XlaBuilder builder(TestName());
346 auto a = ConstantR1<uint8_t>(&builder, {32, 64});
347 ConvertElementType(a, S32);
348
349 std::vector<int32_t> expected = {32, 64};
350 ComputeAndCompareR1<int32_t>(&builder, expected, {});
351 }
352
XLA_TEST_F(ConvertTest,ConvertR1U8ToR1U32)353 XLA_TEST_F(ConvertTest, ConvertR1U8ToR1U32) {
354 XlaBuilder builder(TestName());
355 auto a = ConstantR1<uint8_t>(&builder, {32, 64});
356 ConvertElementType(a, U32);
357
358 std::vector<uint32_t> expected = {32, 64};
359 ComputeAndCompareR1<uint32_t>(&builder, expected, {});
360 }
361
XLA_TEST_F(ConvertTest,ConvertR1F32ToR1F64)362 XLA_TEST_F(ConvertTest, ConvertR1F32ToR1F64) {
363 XlaBuilder builder(TestName());
364 auto a = ConstantR1<float>(&builder, {32.0f, 64.0f});
365 ConvertElementType(a, F64);
366
367 std::vector<double> expected = {32.0, 64.0};
368 ComputeAndCompareR1<double>(&builder, expected, {});
369 }
370
XLA_TEST_F(ConvertTest,ConvertR1F64ToR1F32)371 XLA_TEST_F(ConvertTest, ConvertR1F64ToR1F32) {
372 XlaBuilder builder(TestName());
373 auto a = ConstantR1<double>(&builder, {32.0, 64.0});
374 ConvertElementType(a, F32);
375
376 std::vector<float> expected = {32.0f, 64.0f};
377 ComputeAndCompareR1<float>(&builder, expected, {});
378 }
379
TEST_F(ConvertTest,ConvertS32Extremes)380 TEST_F(ConvertTest, ConvertS32Extremes) {
381 XlaBuilder builder(TestName());
382 auto a = ConstantR1<int32>(&builder, {std::numeric_limits<int32>::min(),
383 std::numeric_limits<int32>::max()});
384 ConvertElementType(a, F32);
385
386 std::vector<float> expected = {
387 static_cast<float>(std::numeric_limits<int32>::min()),
388 static_cast<float>(std::numeric_limits<int32>::max())};
389 ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
390 }
391
TEST_F(ConvertTest,ConvertMapToS32)392 TEST_F(ConvertTest, ConvertMapToS32) {
393 XlaBuilder builder(TestName());
394 auto b = builder.CreateSubBuilder("convert");
395 auto param = Parameter(b.get(), 0, ShapeUtil::MakeShape(F32, {}), "in");
396 ConvertElementType(param, S32);
397 auto a = ConstantR1<float>(&builder, {42.0f, 64.0f});
398 Map(&builder, {a}, b->BuildAndNoteError(), {0});
399
400 std::vector<int32> expected = {42, 64};
401 ComputeAndCompareR1<int32>(&builder, expected, {});
402 }
403
TEST_F(ConvertTest,ConvertMapToF32)404 TEST_F(ConvertTest, ConvertMapToF32) {
405 XlaBuilder builder(TestName());
406 auto b = builder.CreateSubBuilder("convert");
407 auto param = Parameter(b.get(), 0, ShapeUtil::MakeShape(S32, {}), "in");
408 ConvertElementType(param, F32);
409 auto a = ConstantR1<int32>(&builder, {42, 64});
410 Map(&builder, {a}, b->BuildAndNoteError(), {0});
411
412 std::vector<float> expected = {42.0f, 64.0f};
413 ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
414 }
415
416 // Regression test for b/31758660. When ReshapeMover transforms
417 // input -> reshape -> convert
418 // to
419 // input -> convert -> reshape
420 // the new convert should have the same element type as the old convert.
TEST_F(ConvertTest,ConvertReshape)421 TEST_F(ConvertTest, ConvertReshape) {
422 XlaBuilder builder(TestName());
423 auto input = ConstantR1<int32>(&builder, {42});
424 auto reshape = Reshape(input, /*dimensions=*/{0}, /*new_sizes=*/{});
425 ConvertElementType(reshape, F32);
426
427 ComputeAndCompareR0<float>(&builder, 42.0f, {}, ErrorSpec(0.0001));
428 }
429
GetInterestingF16ConversionTestCases()430 std::vector<float> GetInterestingF16ConversionTestCases() {
431 float infinity = std::numeric_limits<float>::infinity();
432 float half_min_positive_normal = absl::bit_cast<float, uint32>(0x38800000);
433 float half_max_subnormal = absl::bit_cast<float, uint32>(0x387fc000);
434 float half_min_positive_subnormal = absl::bit_cast<float, uint32>(0x33800000);
435 float half_max = 65504.0f;
436
437 std::vector<float> test_cases(
438 {-infinity, -(half_max * 2 + 1), -half_max, -42.0f, -1.0f,
439 -half_min_positive_subnormal, -half_max_subnormal,
440 -half_min_positive_normal, -0.0f, 0.0f, half_min_positive_subnormal,
441 half_max_subnormal, half_min_positive_normal, 1.0f, 42.0f, half_max,
442 (half_max * 2 + 1), infinity});
443 return test_cases;
444 }
445
XLA_TEST_F(ConvertTest,ConvertR1F16ToR1F32)446 XLA_TEST_F(ConvertTest, ConvertR1F16ToR1F32) {
447 std::vector<float> test_cases = GetInterestingF16ConversionTestCases();
448 std::vector<half> input;
449 absl::c_transform(test_cases, std::back_inserter(input),
450 [](float f) { return Eigen::half(f); });
451 std::vector<float> expected_output;
452 absl::c_transform(input, std::back_inserter(expected_output),
453 [](Eigen::half h) { return static_cast<float>(h); });
454
455 TF_ASSERT_OK_AND_ASSIGN(
456 std::unique_ptr<GlobalData> dot_lhs_handle,
457 client_->TransferToServer(LiteralUtil::CreateR1<half>(input)));
458
459 XlaBuilder builder(TestName());
460 ConvertElementType(
461 Parameter(&builder, 0,
462 ShapeUtil::MakeShape(F16, {static_cast<int64>(input.size())}),
463 "param"),
464 F32);
465
466 ComputeAndCompareR1<float>(&builder, expected_output, {dot_lhs_handle.get()});
467 }
468
XLA_TEST_F(ConvertTest,ConvertR1F32ToR1F16)469 XLA_TEST_F(ConvertTest, ConvertR1F32ToR1F16) {
470 std::vector<float> input = GetInterestingF16ConversionTestCases();
471 std::vector<half> expected_output;
472 absl::c_transform(input, std::back_inserter(expected_output),
473 [](float f) { return Eigen::half(f); });
474
475 TF_ASSERT_OK_AND_ASSIGN(
476 std::unique_ptr<GlobalData> dot_lhs_handle,
477 client_->TransferToServer(LiteralUtil::CreateR1<float>(input)));
478
479 XlaBuilder builder(TestName());
480 ConvertElementType(
481 Parameter(&builder, 0,
482 ShapeUtil::MakeShape(F32, {static_cast<int64>(input.size())}),
483 "param"),
484 F16);
485
486 ComputeAndCompareR1<half>(&builder, expected_output, {dot_lhs_handle.get()});
487 }
488
XLA_TEST_F(ConvertTest,ConvertC64ToC64)489 XLA_TEST_F(ConvertTest, ConvertC64ToC64) {
490 XlaBuilder builder(TestName());
491 std::vector<complex64> x = {{42.0f, 64.0f}};
492 ConvertElementType(ConstantR1<complex64>(&builder, x), C64);
493 ComputeAndCompareR1<complex64>(&builder, x, {}, ErrorSpec(0.0001));
494 }
495
XLA_TEST_F(ConvertTest,ConvertS64S64)496 XLA_TEST_F(ConvertTest, ConvertS64S64) {
497 XlaBuilder builder(TestName());
498 std::vector<int64> x = {{-42, 64}};
499 ConvertElementType(ConstantR1<int64>(&builder, x), S64);
500 ComputeAndCompareR1<int64>(&builder, x, {});
501 }
502
XLA_TEST_F(ConvertTest,ConvertU64U64)503 XLA_TEST_F(ConvertTest, ConvertU64U64) {
504 XlaBuilder builder(TestName());
505 std::vector<uint64> x = {{42, 64}};
506 ConvertElementType(ConstantR1<uint64>(&builder, x), U64);
507 ComputeAndCompareR1<uint64>(&builder, x, {});
508 }
509
XLA_TEST_F(ConvertTest,ConvertU64S64)510 XLA_TEST_F(ConvertTest, ConvertU64S64) {
511 XlaBuilder builder(TestName());
512 std::vector<uint64> unsigned_x = {{42, UINT64_MAX}};
513 ConvertElementType(ConstantR1<uint64>(&builder, unsigned_x), S64);
514 std::vector<int64> signed_x = {{42, -1}};
515 ComputeAndCompareR1<int64>(&builder, signed_x, {});
516 }
517
XLA_TEST_F(ConvertTest,ConvertS64U64)518 XLA_TEST_F(ConvertTest, ConvertS64U64) {
519 XlaBuilder builder(TestName());
520 std::vector<int64> signed_x = {{42, -1, INT64_MIN}};
521 ConvertElementType(ConstantR1<int64>(&builder, signed_x), U64);
522 std::vector<uint64> unsigned_x = {
523 {42, UINT64_MAX, tensorflow::MathUtil::IPow<uint64>(2, 63)}};
524 ComputeAndCompareR1<uint64>(&builder, unsigned_x, {});
525 }
526
XLA_TEST_F(ConvertTest,ConvertBF16F32)527 XLA_TEST_F(ConvertTest, ConvertBF16F32) {
528 XlaBuilder builder(TestName());
529
530 std::vector<bfloat16> all_bfloats(1 << 16);
531 for (int i = 0; i < all_bfloats.size(); ++i) {
532 all_bfloats[i].value = i;
533 }
534
535 std::vector<uint32> expected(all_bfloats.size());
536 for (int i = 0; i < expected.size(); ++i) {
537 expected[i] = (1U << 16) * i;
538 }
539
540 // Exhaustively test all bf16 to f32 conversions.
541 xla::XlaOp all_bfloats_bf16 = ConstantR1<bfloat16>(&builder, all_bfloats);
542 xla::XlaOp all_bfloats_f32 = ConvertElementType(all_bfloats_bf16, F32);
543 BitcastConvertType(all_bfloats_f32, U32);
544 ComputeAndCompareR1<uint32>(&builder, expected, {});
545 }
546
547 } // namespace
548 } // namespace xla
549