• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <array>
17 #include <cstdint>
18 #include <limits>
19 #include <memory>
20 #include <vector>
21 
22 #include "absl/algorithm/container.h"
23 #include "absl/base/casts.h"
24 #include "tensorflow/compiler/xla/client/local_client.h"
25 #include "tensorflow/compiler/xla/client/xla_builder.h"
26 #include "tensorflow/compiler/xla/shape_util.h"
27 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
28 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
29 #include "tensorflow/compiler/xla/tests/test_macros.h"
30 #include "tensorflow/compiler/xla/xla_data.pb.h"
31 #include "tensorflow/core/lib/math/math_util.h"
32 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
33 #include "tensorflow/core/platform/test.h"
34 #include "tensorflow/core/platform/types.h"
35 
36 namespace xla {
37 namespace {
38 
39 class ConvertTest : public ClientLibraryTestBase {
40  public:
ConvertTest(se::Platform * platform=nullptr)41   explicit ConvertTest(se::Platform* platform = nullptr)
42       : ClientLibraryTestBase(platform) {
43     mutable_debug_options()->add_xla_disable_hlo_passes("algsimp");
44     mutable_debug_options()->add_xla_disable_hlo_passes("inline");
45   }
46 };
47 
TEST_F(ConvertTest,ConvertR1S32ToR1S32)48 TEST_F(ConvertTest, ConvertR1S32ToR1S32) {
49   XlaBuilder builder(TestName());
50   auto a = ConstantR1<int32>(&builder, {42, 64});
51   ConvertElementType(a, S32);
52 
53   std::vector<int32> expected = {42, 64};
54   ComputeAndCompareR1<int32>(&builder, expected, {});
55 }
56 
TEST_F(ConvertTest,ConvertR1S32ToR1U32)57 TEST_F(ConvertTest, ConvertR1S32ToR1U32) {
58   XlaBuilder builder(TestName());
59   auto a = ConstantR1<int32>(&builder, {42, 64});
60   ConvertElementType(a, U32);
61 
62   std::vector<uint32> expected = {42, 64};
63   ComputeAndCompareR1<uint32>(&builder, expected, {});
64 }
65 
TEST_F(ConvertTest,ConvertR1S32ToR1PRED)66 TEST_F(ConvertTest, ConvertR1S32ToR1PRED) {
67   XlaBuilder builder(TestName());
68   auto a = ConstantR1<int32>(&builder, {42, 0, -64});
69   ConvertElementType(a, PRED);
70 
71   std::array<bool, 3> expected = {true, false, true};
72   ComputeAndCompareR1<bool>(&builder, expected, {});
73 }
74 
TEST_F(ConvertTest,ConvertR1U32ToR1U32)75 TEST_F(ConvertTest, ConvertR1U32ToR1U32) {
76   XlaBuilder builder(TestName());
77   auto a = ConstantR1<uint32>(&builder, {42, 64});
78   ConvertElementType(a, U32);
79 
80   std::vector<uint32> expected = {42, 64};
81   ComputeAndCompareR1<uint32>(&builder, expected, {});
82 }
83 
TEST_F(ConvertTest,ConvertR1U32ToR1S32)84 TEST_F(ConvertTest, ConvertR1U32ToR1S32) {
85   XlaBuilder builder(TestName());
86   auto a = ConstantR1<uint32>(&builder, {42, 64});
87   ConvertElementType(a, S32);
88 
89   std::vector<int32> expected = {42, 64};
90   ComputeAndCompareR1<int32>(&builder, expected, {});
91 }
92 
TEST_F(ConvertTest,ConvertR1U32ToR1PRED)93 TEST_F(ConvertTest, ConvertR1U32ToR1PRED) {
94   XlaBuilder builder(TestName());
95   auto a = ConstantR1<uint32>(&builder, {42, 0, 64});
96   ConvertElementType(a, PRED);
97 
98   std::array<bool, 3> expected = {true, false, true};
99   ComputeAndCompareR1<bool>(&builder, expected, {});
100 }
101 
TEST_F(ConvertTest,ConvertR1F32ToR1F32)102 TEST_F(ConvertTest, ConvertR1F32ToR1F32) {
103   XlaBuilder builder(TestName());
104   auto a = ConstantR1<float>(&builder, {42.0f, 64.0f});
105   ConvertElementType(a, F32);
106 
107   std::vector<float> expected = {42.0f, 64.0f};
108   ComputeAndCompareR1<float>(&builder, expected, {});
109 }
110 
TEST_F(ConvertTest,ConvertR1F32ToR1PRED)111 TEST_F(ConvertTest, ConvertR1F32ToR1PRED) {
112   XlaBuilder builder(TestName());
113   auto a = ConstantR1<float>(&builder, {42.0f, 0.0f, 64.0f});
114   ConvertElementType(a, PRED);
115 
116   std::array<bool, 3> expected = {true, false, true};
117   ComputeAndCompareR1<bool>(&builder, expected, {});
118 }
119 
TEST_F(ConvertTest,ConvertR1S32ToR1F32)120 TEST_F(ConvertTest, ConvertR1S32ToR1F32) {
121   XlaBuilder builder(TestName());
122   auto a = ConstantR1<int32>(&builder, {42, 64});
123   ConvertElementType(a, F32);
124 
125   std::vector<float> expected = {42.0f, 64.0f};
126   ComputeAndCompareR1<float>(&builder, expected, {});
127 }
128 
TEST_F(ConvertTest,ConvertR1PREDToR1S32)129 TEST_F(ConvertTest, ConvertR1PREDToR1S32) {
130   XlaBuilder builder(TestName());
131   auto a = ConstantR1<bool>(&builder, {true, false, true});
132   ConvertElementType(a, S32);
133 
134   std::vector<int32> expected = {1, 0, 1};
135   ComputeAndCompareR1<int32>(&builder, expected, {});
136 }
137 
TEST_F(ConvertTest,ConvertR1PREDToR1U32)138 TEST_F(ConvertTest, ConvertR1PREDToR1U32) {
139   XlaBuilder builder(TestName());
140   auto a = ConstantR1<bool>(&builder, {true, false, true});
141   ConvertElementType(a, U32);
142 
143   std::vector<uint32> expected = {1, 0, 1};
144   ComputeAndCompareR1<uint32>(&builder, expected, {});
145 }
146 
TEST_F(ConvertTest,ConvertR1PREDToR1F32)147 TEST_F(ConvertTest, ConvertR1PREDToR1F32) {
148   XlaBuilder builder(TestName());
149   auto a = ConstantR1<bool>(&builder, {true, false, true});
150   ConvertElementType(a, F32);
151 
152   std::vector<float> expected = {1., 0., 1.};
153   ComputeAndCompareR1<float>(&builder, expected, {});
154 }
155 
XLA_TEST_F(ConvertTest,ConvertR1S0S32ToR1S0F32)156 XLA_TEST_F(ConvertTest, ConvertR1S0S32ToR1S0F32) {
157   XlaBuilder builder(TestName());
158   auto a = ConstantR1<int32>(&builder, {});
159   ConvertElementType(a, F32);
160 
161   std::vector<float> expected = {};
162   ComputeAndCompareR1<float>(&builder, expected, {});
163 }
164 
TEST_F(ConvertTest,ConvertR1F32ToR1S32)165 TEST_F(ConvertTest, ConvertR1F32ToR1S32) {
166   XlaBuilder builder(TestName());
167   auto a = ConstantR1<float>(&builder, {42.6, 64.4});
168   ConvertElementType(a, S32);
169 
170   std::vector<int32> expected = {42, 64};
171   ComputeAndCompareR1<int32>(&builder, expected, {});
172 }
173 
XLA_TEST_F(ConvertTest,ConvertR1S64ToR1F32)174 XLA_TEST_F(ConvertTest, ConvertR1S64ToR1F32) {
175   XlaBuilder builder(TestName());
176   std::vector<int64> arg{
177       -9223371216516022272,
178       -2,
179       -1,
180       -0x7FFFFFFF,
181       -0x80000000,
182       0,
183       1,
184       2,
185       1073742145,
186       1073742656,
187       0x7FFFFFFF,
188       0x80000000,
189       826720496944058148,
190       4296062029846194332,
191       0x0007FB72E4000000LL,
192       0x0007FB72E4000001LL,
193       0x0007FB72E6000000LL,
194       0x0007FB72E7000000LL,
195       0x0007FB72E7FFFFFFLL,
196       0x0007FB72E8000000LL,
197       0x0007FB72E8000001LL,
198       0x0007FB72EA000000LL,
199       0x0007FB72EB000000LL,
200       0x0007FB72EBFFFFFFLL,
201       0x0007FB72EC000000LL,
202       0x7FFFFF0000000000LL,
203       0x7FFFFF8000000000LL,
204       0x7FFFFFFFFFFFFF00,
205       static_cast<int64>(0xFFFFFFFFFFFFFFFF),
206       static_cast<int64>(0x0000f234e67e0001LL),
207       static_cast<int64>(0x8000000000000000),
208       static_cast<int64>(0x8000000000000000LL),
209       static_cast<int64>(0x8000000000000001LL),
210       static_cast<int64>(0x8000008000000000LL),
211       static_cast<int64>(0x8000010000000000LL),
212   };
213   Literal arg_literal = LiteralUtil::CreateR1<int64>({arg});
214   auto arg_param = Parameter(&builder, 0, arg_literal.shape(), "arg_param");
215   std::unique_ptr<GlobalData> arg_data =
216       client_->TransferToServer(arg_literal).ConsumeValueOrDie();
217 
218   ConvertElementType(arg_param, F32);
219 
220   std::vector<float> expected(arg.size());
221   for (int64 i = 0; i < arg.size(); ++i) {
222     expected[i] = static_cast<float>(arg[i]);
223   }
224   ComputeAndCompareR1<float>(&builder, expected, {arg_data.get()});
225 }
226 
XLA_TEST_F(ConvertTest,ConvertR1U32ToR1F32)227 XLA_TEST_F(ConvertTest, ConvertR1U32ToR1F32) {
228   XlaBuilder builder(TestName());
229   std::vector<uint32> arg{0,          1,          0x1000,     0x7fffffff,
230                           0x80000000, 0x80000001, 0x80000002, 0x80000003,
231                           0x80000080, 0x80000081, 0x80000082, 0xFFFFFFFF};
232   Literal arg_literal = LiteralUtil::CreateR1<uint32>({arg});
233   auto arg_param = Parameter(&builder, 0, arg_literal.shape(), "arg_param");
234   std::unique_ptr<GlobalData> arg_data =
235       client_->TransferToServer(arg_literal).ConsumeValueOrDie();
236 
237   ConvertElementType(arg_param, F32);
238 
239   std::vector<float> expected(arg.size());
240   for (int64 i = 0; i < arg.size(); ++i) {
241     expected[i] = static_cast<float>(arg[i]);
242   }
243   ComputeAndCompareR1<float>(&builder, expected, {arg_data.get()});
244 }
245 
XLA_TEST_F(ConvertTest,ConvertR1F32ToR1U32)246 XLA_TEST_F(ConvertTest, ConvertR1F32ToR1U32) {
247   XlaBuilder builder(TestName());
248   std::vector<float> arg{0.0f,        1.0f,          16777216.0f,
249                          16777218.0f, 2147483647.0f, 4294967040.0f};
250   Literal arg_literal = LiteralUtil::CreateR1<float>({arg});
251   auto arg_param = Parameter(&builder, 0, arg_literal.shape(), "arg_param");
252   std::unique_ptr<GlobalData> arg_data =
253       client_->TransferToServer(arg_literal).ConsumeValueOrDie();
254 
255   ConvertElementType(arg_param, U32);
256 
257   std::vector<uint32> expected(arg.size());
258   for (int64 i = 0; i < arg.size(); ++i) {
259     expected[i] = static_cast<uint32>(arg[i]);
260   }
261   ComputeAndCompareR1<uint32>(&builder, expected, {arg_data.get()});
262 }
263 
XLA_TEST_F(ConvertTest,ConvertR1U32ToR1S64)264 XLA_TEST_F(ConvertTest, ConvertR1U32ToR1S64) {
265   XlaBuilder builder(TestName());
266   std::vector<uint32> arg{0, 1, 0x1000, 0x7fffffff, 0x80000082, 0xFFFFFFFF};
267   Literal arg_literal = LiteralUtil::CreateR1<uint32>({arg});
268   auto arg_param = Parameter(&builder, 0, arg_literal.shape(), "arg_param");
269   std::unique_ptr<GlobalData> arg_data =
270       client_->TransferToServer(arg_literal).ConsumeValueOrDie();
271 
272   ConvertElementType(arg_param, S64);
273 
274   std::vector<int64> expected(arg.size());
275   for (int64 i = 0; i < arg.size(); ++i) {
276     expected[i] = static_cast<int64>(arg[i]);
277   }
278   ComputeAndCompareR1<int64>(&builder, expected, {arg_data.get()});
279 }
280 
XLA_TEST_F(ConvertTest,ConvertR1S32ToR1S64)281 XLA_TEST_F(ConvertTest, ConvertR1S32ToR1S64) {
282   XlaBuilder builder(TestName());
283   std::vector<int32> arg{0, 1, 0x1000, -1, -0x1000};
284   Literal arg_literal = LiteralUtil::CreateR1<int32>({arg});
285   auto arg_param = Parameter(&builder, 0, arg_literal.shape(), "arg_param");
286   std::unique_ptr<GlobalData> arg_data =
287       client_->TransferToServer(arg_literal).ConsumeValueOrDie();
288 
289   ConvertElementType(arg_param, S64);
290 
291   std::vector<int64> expected(arg.size());
292   for (int64 i = 0; i < arg.size(); ++i) {
293     expected[i] = static_cast<int64>(arg[i]);
294   }
295   ComputeAndCompareR1<int64>(&builder, expected, {arg_data.get()});
296 }
297 
XLA_TEST_F(ConvertTest,ConvertR1F32ToR1S64)298 XLA_TEST_F(ConvertTest, ConvertR1F32ToR1S64) {
299   XlaBuilder builder(TestName());
300   // Test cases from compiler_rt library.
301   std::vector<float> arg{0.0f,
302                          0.5f,
303                          0.99f,
304                          1.0f,
305                          1.5f,
306                          1.99f,
307                          2.0f,
308                          2.01f,
309                          2147483648.f,
310                          -0.5f,
311                          -0.99f,
312                          -1.0f,
313                          -1.5f,
314                          -1.99f,
315                          -2.0f,
316                          -2.01f,
317                          9223371487098961920.f,
318                          9223370937343148032.f,
319                          -9223371487098961920.f,
320                          -9223370937343148032.f};
321   Literal arg_literal = LiteralUtil::CreateR1<float>({arg});
322   auto arg_param = Parameter(&builder, 0, arg_literal.shape(), "arg_param");
323   std::unique_ptr<GlobalData> arg_data =
324       client_->TransferToServer(arg_literal).ConsumeValueOrDie();
325 
326   ConvertElementType(arg_param, S64);
327 
328   std::vector<int64> expected(arg.size());
329   for (int64 i = 0; i < arg.size(); ++i) {
330     expected[i] = static_cast<int64>(arg[i]);
331   }
332   ComputeAndCompareR1<int64>(&builder, expected, {arg_data.get()});
333 }
334 
XLA_TEST_F(ConvertTest,ConvertR1U8ToR1F32)335 XLA_TEST_F(ConvertTest, ConvertR1U8ToR1F32) {
336   XlaBuilder builder(TestName());
337   auto a = ConstantR1<uint8_t>(&builder, {32, 64});
338   ConvertElementType(a, F32);
339 
340   std::vector<float> expected = {32.0, 64.0};
341   ComputeAndCompareR1<float>(&builder, expected, {});
342 }
343 
XLA_TEST_F(ConvertTest,ConvertR1U8ToR1S32)344 XLA_TEST_F(ConvertTest, ConvertR1U8ToR1S32) {
345   XlaBuilder builder(TestName());
346   auto a = ConstantR1<uint8_t>(&builder, {32, 64});
347   ConvertElementType(a, S32);
348 
349   std::vector<int32_t> expected = {32, 64};
350   ComputeAndCompareR1<int32_t>(&builder, expected, {});
351 }
352 
XLA_TEST_F(ConvertTest,ConvertR1U8ToR1U32)353 XLA_TEST_F(ConvertTest, ConvertR1U8ToR1U32) {
354   XlaBuilder builder(TestName());
355   auto a = ConstantR1<uint8_t>(&builder, {32, 64});
356   ConvertElementType(a, U32);
357 
358   std::vector<uint32_t> expected = {32, 64};
359   ComputeAndCompareR1<uint32_t>(&builder, expected, {});
360 }
361 
XLA_TEST_F(ConvertTest,ConvertR1F32ToR1F64)362 XLA_TEST_F(ConvertTest, ConvertR1F32ToR1F64) {
363   XlaBuilder builder(TestName());
364   auto a = ConstantR1<float>(&builder, {32.0f, 64.0f});
365   ConvertElementType(a, F64);
366 
367   std::vector<double> expected = {32.0, 64.0};
368   ComputeAndCompareR1<double>(&builder, expected, {});
369 }
370 
XLA_TEST_F(ConvertTest,ConvertR1F64ToR1F32)371 XLA_TEST_F(ConvertTest, ConvertR1F64ToR1F32) {
372   XlaBuilder builder(TestName());
373   auto a = ConstantR1<double>(&builder, {32.0, 64.0});
374   ConvertElementType(a, F32);
375 
376   std::vector<float> expected = {32.0f, 64.0f};
377   ComputeAndCompareR1<float>(&builder, expected, {});
378 }
379 
TEST_F(ConvertTest,ConvertS32Extremes)380 TEST_F(ConvertTest, ConvertS32Extremes) {
381   XlaBuilder builder(TestName());
382   auto a = ConstantR1<int32>(&builder, {std::numeric_limits<int32>::min(),
383                                         std::numeric_limits<int32>::max()});
384   ConvertElementType(a, F32);
385 
386   std::vector<float> expected = {
387       static_cast<float>(std::numeric_limits<int32>::min()),
388       static_cast<float>(std::numeric_limits<int32>::max())};
389   ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
390 }
391 
TEST_F(ConvertTest,ConvertMapToS32)392 TEST_F(ConvertTest, ConvertMapToS32) {
393   XlaBuilder builder(TestName());
394   auto b = builder.CreateSubBuilder("convert");
395   auto param = Parameter(b.get(), 0, ShapeUtil::MakeShape(F32, {}), "in");
396   ConvertElementType(param, S32);
397   auto a = ConstantR1<float>(&builder, {42.0f, 64.0f});
398   Map(&builder, {a}, b->BuildAndNoteError(), {0});
399 
400   std::vector<int32> expected = {42, 64};
401   ComputeAndCompareR1<int32>(&builder, expected, {});
402 }
403 
TEST_F(ConvertTest,ConvertMapToF32)404 TEST_F(ConvertTest, ConvertMapToF32) {
405   XlaBuilder builder(TestName());
406   auto b = builder.CreateSubBuilder("convert");
407   auto param = Parameter(b.get(), 0, ShapeUtil::MakeShape(S32, {}), "in");
408   ConvertElementType(param, F32);
409   auto a = ConstantR1<int32>(&builder, {42, 64});
410   Map(&builder, {a}, b->BuildAndNoteError(), {0});
411 
412   std::vector<float> expected = {42.0f, 64.0f};
413   ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
414 }
415 
416 // Regression test for b/31758660. When ReshapeMover transforms
417 //   input -> reshape -> convert
418 // to
419 //   input -> convert -> reshape
420 // the new convert should have the same element type as the old convert.
TEST_F(ConvertTest,ConvertReshape)421 TEST_F(ConvertTest, ConvertReshape) {
422   XlaBuilder builder(TestName());
423   auto input = ConstantR1<int32>(&builder, {42});
424   auto reshape = Reshape(input, /*dimensions=*/{0}, /*new_sizes=*/{});
425   ConvertElementType(reshape, F32);
426 
427   ComputeAndCompareR0<float>(&builder, 42.0f, {}, ErrorSpec(0.0001));
428 }
429 
GetInterestingF16ConversionTestCases()430 std::vector<float> GetInterestingF16ConversionTestCases() {
431   float infinity = std::numeric_limits<float>::infinity();
432   float half_min_positive_normal = absl::bit_cast<float, uint32>(0x38800000);
433   float half_max_subnormal = absl::bit_cast<float, uint32>(0x387fc000);
434   float half_min_positive_subnormal = absl::bit_cast<float, uint32>(0x33800000);
435   float half_max = 65504.0f;
436 
437   std::vector<float> test_cases(
438       {-infinity, -(half_max * 2 + 1), -half_max, -42.0f, -1.0f,
439        -half_min_positive_subnormal, -half_max_subnormal,
440        -half_min_positive_normal, -0.0f, 0.0f, half_min_positive_subnormal,
441        half_max_subnormal, half_min_positive_normal, 1.0f, 42.0f, half_max,
442        (half_max * 2 + 1), infinity});
443   return test_cases;
444 }
445 
XLA_TEST_F(ConvertTest,ConvertR1F16ToR1F32)446 XLA_TEST_F(ConvertTest, ConvertR1F16ToR1F32) {
447   std::vector<float> test_cases = GetInterestingF16ConversionTestCases();
448   std::vector<half> input;
449   absl::c_transform(test_cases, std::back_inserter(input),
450                     [](float f) { return Eigen::half(f); });
451   std::vector<float> expected_output;
452   absl::c_transform(input, std::back_inserter(expected_output),
453                     [](Eigen::half h) { return static_cast<float>(h); });
454 
455   TF_ASSERT_OK_AND_ASSIGN(
456       std::unique_ptr<GlobalData> dot_lhs_handle,
457       client_->TransferToServer(LiteralUtil::CreateR1<half>(input)));
458 
459   XlaBuilder builder(TestName());
460   ConvertElementType(
461       Parameter(&builder, 0,
462                 ShapeUtil::MakeShape(F16, {static_cast<int64>(input.size())}),
463                 "param"),
464       F32);
465 
466   ComputeAndCompareR1<float>(&builder, expected_output, {dot_lhs_handle.get()});
467 }
468 
XLA_TEST_F(ConvertTest,ConvertR1F32ToR1F16)469 XLA_TEST_F(ConvertTest, ConvertR1F32ToR1F16) {
470   std::vector<float> input = GetInterestingF16ConversionTestCases();
471   std::vector<half> expected_output;
472   absl::c_transform(input, std::back_inserter(expected_output),
473                     [](float f) { return Eigen::half(f); });
474 
475   TF_ASSERT_OK_AND_ASSIGN(
476       std::unique_ptr<GlobalData> dot_lhs_handle,
477       client_->TransferToServer(LiteralUtil::CreateR1<float>(input)));
478 
479   XlaBuilder builder(TestName());
480   ConvertElementType(
481       Parameter(&builder, 0,
482                 ShapeUtil::MakeShape(F32, {static_cast<int64>(input.size())}),
483                 "param"),
484       F16);
485 
486   ComputeAndCompareR1<half>(&builder, expected_output, {dot_lhs_handle.get()});
487 }
488 
XLA_TEST_F(ConvertTest,ConvertC64ToC64)489 XLA_TEST_F(ConvertTest, ConvertC64ToC64) {
490   XlaBuilder builder(TestName());
491   std::vector<complex64> x = {{42.0f, 64.0f}};
492   ConvertElementType(ConstantR1<complex64>(&builder, x), C64);
493   ComputeAndCompareR1<complex64>(&builder, x, {}, ErrorSpec(0.0001));
494 }
495 
XLA_TEST_F(ConvertTest,ConvertS64S64)496 XLA_TEST_F(ConvertTest, ConvertS64S64) {
497   XlaBuilder builder(TestName());
498   std::vector<int64> x = {{-42, 64}};
499   ConvertElementType(ConstantR1<int64>(&builder, x), S64);
500   ComputeAndCompareR1<int64>(&builder, x, {});
501 }
502 
XLA_TEST_F(ConvertTest,ConvertU64U64)503 XLA_TEST_F(ConvertTest, ConvertU64U64) {
504   XlaBuilder builder(TestName());
505   std::vector<uint64> x = {{42, 64}};
506   ConvertElementType(ConstantR1<uint64>(&builder, x), U64);
507   ComputeAndCompareR1<uint64>(&builder, x, {});
508 }
509 
XLA_TEST_F(ConvertTest,ConvertU64S64)510 XLA_TEST_F(ConvertTest, ConvertU64S64) {
511   XlaBuilder builder(TestName());
512   std::vector<uint64> unsigned_x = {{42, UINT64_MAX}};
513   ConvertElementType(ConstantR1<uint64>(&builder, unsigned_x), S64);
514   std::vector<int64> signed_x = {{42, -1}};
515   ComputeAndCompareR1<int64>(&builder, signed_x, {});
516 }
517 
XLA_TEST_F(ConvertTest,ConvertS64U64)518 XLA_TEST_F(ConvertTest, ConvertS64U64) {
519   XlaBuilder builder(TestName());
520   std::vector<int64> signed_x = {{42, -1, INT64_MIN}};
521   ConvertElementType(ConstantR1<int64>(&builder, signed_x), U64);
522   std::vector<uint64> unsigned_x = {
523       {42, UINT64_MAX, tensorflow::MathUtil::IPow<uint64>(2, 63)}};
524   ComputeAndCompareR1<uint64>(&builder, unsigned_x, {});
525 }
526 
XLA_TEST_F(ConvertTest,ConvertBF16F32)527 XLA_TEST_F(ConvertTest, ConvertBF16F32) {
528   XlaBuilder builder(TestName());
529 
530   std::vector<bfloat16> all_bfloats(1 << 16);
531   for (int i = 0; i < all_bfloats.size(); ++i) {
532     all_bfloats[i].value = i;
533   }
534 
535   std::vector<uint32> expected(all_bfloats.size());
536   for (int i = 0; i < expected.size(); ++i) {
537     expected[i] = (1U << 16) * i;
538   }
539 
540   // Exhaustively test all bf16 to f32 conversions.
541   xla::XlaOp all_bfloats_bf16 = ConstantR1<bfloat16>(&builder, all_bfloats);
542   xla::XlaOp all_bfloats_f32 = ConvertElementType(all_bfloats_bf16, F32);
543   BitcastConvertType(all_bfloats_f32, U32);
544   ComputeAndCompareR1<uint32>(&builder, expected, {});
545 }
546 
547 }  // namespace
548 }  // namespace xla
549