• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/literal.h"
17 
18 #include <vector>
19 
20 #include "absl/base/casts.h"
21 #include "absl/memory/memory.h"
22 #include "absl/strings/match.h"
23 #include "absl/strings/str_cat.h"
24 #include "tensorflow/compiler/tf2xla/shape_util.h"
25 #include "tensorflow/compiler/xla/array3d.h"
26 #include "tensorflow/compiler/xla/array4d.h"
27 #include "tensorflow/compiler/xla/layout_util.h"
28 #include "tensorflow/compiler/xla/literal_util.h"
29 #include "tensorflow/compiler/xla/shape_util.h"
30 #include "tensorflow/compiler/xla/test.h"
31 #include "tensorflow/compiler/xla/types.h"
32 #include "tensorflow/core/lib/core/status_test_util.h"
33 #include "tensorflow/core/platform/macros.h"
34 #include "tensorflow/core/platform/types.h"
35 
36 namespace xla {
37 namespace {
38 
39 using ::testing::ElementsAre;
40 using ::testing::HasSubstr;
41 
42 class LiteralUtilTest : public ::testing::Test {
43  protected:
LiteralUtilTest()44   LiteralUtilTest() {
45     Array4D<float> arr4d({
46         // clang-format off
47       {  // i0=0
48           {  // i1=0
49               {1, 2, 3},  // i2=0
50               {4, 5, 6},  // i2=1
51               {7, 8, 9},  // i2=2
52           },
53           {  // i1=1
54               {11, 12, 13},
55               {14, 15, 16},
56               {17, 18, 19},
57           },
58       },
59       {  // i0=1
60           {  // i1=0
61               {101, 102, 103},
62               {104, 105, 106},
63               {107, 108, 109},
64           },
65           {  // i1=1
66               {201, 202, 203},  // i2=0
67               {204, 205, 206},  // i2=1
68               {207, 208, 209},  // i2=2
69           },
70       },
71         // clang-format on
72     });
73 
74     layout_r2_dim0major_ = LayoutUtil::MakeLayout({1, 0});
75     layout_r2_dim0minor_ = LayoutUtil::MakeLayout({0, 1});
76     layout_r3_dim0major_ = LayoutUtil::MakeLayout({2, 1, 0});
77     layout_r3_dim0minor_ = LayoutUtil::MakeLayout({0, 1, 2});
78     layout_r4_dim0major_ = LayoutUtil::MakeLayout({3, 2, 1, 0});
79     layout_r4_dim0minor_ = LayoutUtil::MakeLayout({0, 1, 2, 3});
80 
81     literal_r4_2x2x3x3_dim0major_ =
82         LiteralUtil::CreateR4FromArray4DWithLayout<float>(arr4d,
83                                                           layout_r4_dim0major_);
84     literal_r4_2x2x3x3_dim0minor_ =
85         LiteralUtil::CreateR4FromArray4DWithLayout<float>(arr4d,
86                                                           layout_r4_dim0minor_);
87   }
88 
89   Layout layout_r2_dim0major_;
90   Layout layout_r2_dim0minor_;
91   Layout layout_r3_dim0major_;
92   Layout layout_r3_dim0minor_;
93   Layout layout_r4_dim0major_;
94   Layout layout_r4_dim0minor_;
95   Literal literal_r4_2x2x3x3_dim0major_;
96   Literal literal_r4_2x2x3x3_dim0minor_;
97 };
98 
TEST_F(LiteralUtilTest,LiteralScalarToString)99 TEST_F(LiteralUtilTest, LiteralScalarToString) {
100   auto true_lit = LiteralUtil::CreateR0<bool>(true);
101   EXPECT_EQ("pred[] true", true_lit.ToString());
102 
103   auto false_lit = LiteralUtil::CreateR0<bool>(false);
104   EXPECT_EQ("pred[] false", false_lit.ToString());
105 
106   auto u32_lit = LiteralUtil::CreateR0<uint32>(42);
107   EXPECT_EQ("u32[] 42", u32_lit.ToString());
108 
109   auto s32_lit = LiteralUtil::CreateR0<int32>(-999);
110   EXPECT_EQ("s32[] -999", s32_lit.ToString());
111 
112   auto f32_lit = LiteralUtil::CreateR0<float>(3.14f);
113   EXPECT_EQ("f32[] 3.14", f32_lit.ToString());
114 
115   auto f16_lit = LiteralUtil::CreateR0<half>(static_cast<half>(0.5f));
116   EXPECT_EQ("f16[] 0.5", f16_lit.ToString());
117 
118   auto c64_lit = LiteralUtil::CreateR0<complex64>({3.14f, 2.78f});
119   EXPECT_EQ("c64[] (3.14, 2.78)", c64_lit.ToString());
120 
121   auto c128_lit = LiteralUtil::CreateR0<complex128>({3.14f, 2.78f});
122   EXPECT_EQ("c128[] (3.14, 2.78)", c128_lit.ToString());
123 
124   auto bf16_lit = LiteralUtil::CreateR0<bfloat16>(static_cast<bfloat16>(0.5f));
125   EXPECT_EQ("bf16[] 0.5", bf16_lit.ToString());
126 
127   // 3.14 will be rounded to 3.14062 in bfloat16 format.
128   auto bf16_lit_truncated =
129       LiteralUtil::CreateR0<bfloat16>(static_cast<bfloat16>(3.14f));
130   ASSERT_EQ("bf16[] 3.14062", bf16_lit_truncated.ToString());
131 
132   auto bf16_lit_truncated2 =
133       LiteralUtil::CreateR0<bfloat16>(static_cast<bfloat16>(9.001f));
134   EXPECT_EQ("bf16[] 9", bf16_lit_truncated2.ToString());
135 }
136 
TEST_F(LiteralUtilTest,LiteralVectorToString)137 TEST_F(LiteralUtilTest, LiteralVectorToString) {
138   auto pred_vec = LiteralUtil::CreateR1<bool>({true, false, true});
139   EXPECT_EQ("pred[3] {1, 0, 1}", pred_vec.ToString());
140 }
141 
TEST_F(LiteralUtilTest,R2ToString)142 TEST_F(LiteralUtilTest, R2ToString) {
143   const auto literal = LiteralUtil::CreateR2({{1, 2}, {3, 4}, {5, 6}});
144   const string expected = R"(s32[3,2] {
145   { 1, 2 },
146   { 3, 4 },
147   { 5, 6 }
148 })";
149   EXPECT_EQ(expected, literal.ToString());
150 }
151 
TEST_F(LiteralUtilTest,R3ToString)152 TEST_F(LiteralUtilTest, R3ToString) {
153   const auto literal =
154       LiteralUtil::CreateR3({{{1}, {2}}, {{3}, {4}}, {{5}, {6}}});
155   const string expected = R"(s32[3,2,1] {
156 {
157   {1},
158   {2}
159 },
160 {
161   {3},
162   {4}
163 },
164 {
165   {5},
166   {6}
167 }
168 })";
169   EXPECT_EQ(expected, literal.ToString());
170 }
171 
TEST_F(LiteralUtilTest,R6ToString)172 TEST_F(LiteralUtilTest, R6ToString) {
173   const auto literal =
174       LiteralUtil::CreateFromDimensions(S32, {2, 2, 1, 1, 1, 2});
175   const string expected = R"(s32[2,2,1,1,1,2] {
176 { /*i0=0*/
177 { /*i1=0*/
178 { /*i2=0*/
179 { /*i3=0*/
180   { 0, 0 }
181 }
182 }
183 },
184 { /*i1=1*/
185 { /*i2=0*/
186 { /*i3=0*/
187   { 0, 0 }
188 }
189 }
190 }
191 },
192 { /*i0=1*/
193 { /*i1=0*/
194 { /*i2=0*/
195 { /*i3=0*/
196   { 0, 0 }
197 }
198 }
199 },
200 { /*i1=1*/
201 { /*i2=0*/
202 { /*i3=0*/
203   { 0, 0 }
204 }
205 }
206 }
207 }
208 })";
209   EXPECT_EQ(expected, literal.ToString());
210 }
211 
TEST_F(LiteralUtilTest,TupleToString)212 TEST_F(LiteralUtilTest, TupleToString) {
213   auto scalar = LiteralUtil::CreateR0<float>(1.0);
214   auto matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
215   auto tuple = LiteralUtil::MakeTuple({&scalar, &matrix});
216   const string expected = R"((
217 f32[] 1,
218 f32[2,2] {
219   { 1, 2 },
220   { 3, 4 }
221 }
222 ))";
223   EXPECT_EQ(expected, tuple.ToString());
224 }
225 
TEST_F(LiteralUtilTest,CreateR3FromArray3d)226 TEST_F(LiteralUtilTest, CreateR3FromArray3d) {
227   // clang-format off
228   Array3D<float> array_3d({
229     {{1.0f, 2.0f},
230      {3.0f, 4.0f},
231      {5.0f, 6.0f}},
232     {{7.0f, 8.0f},
233      {9.0f, 10.0f},
234      {11.0f, 12.0f}},
235   });
236   // clang-format on
237 
238   auto literal = LiteralUtil::CreateR3FromArray3D(array_3d);
239   EXPECT_THAT(literal.shape().dimensions(), ElementsAre(2, 3, 2));
240   string result = literal.ToString();
241   const string expected = R"(f32[2,3,2] {
242 {
243   { 1, 2 },
244   { 3, 4 },
245   { 5, 6 }
246 },
247 {
248   { 7, 8 },
249   { 9, 10 },
250   { 11, 12 }
251 }
252 })";
253   EXPECT_EQ(expected, result);
254 }
255 
TEST_F(LiteralUtilTest,CreateSparse)256 TEST_F(LiteralUtilTest, CreateSparse) {
257   std::vector<int64> dimensions = {8, 8, 8};
258   Array2D<int64> indices = {
259       {3, 4, 5},
260       {1, 2, 3},
261       {2, 3, 4},
262       {3, 5, 6},
263   };
264   std::vector<int64> values = {7, 8, 9, 10};
265   auto literal = LiteralUtil::CreateSparse<int64>(
266       dimensions, SparseIndexArray(indices.n1() + 3, indices), values);
267 
268   Array2D<int64> expected_indices = {
269       {1, 2, 3},
270       {2, 3, 4},
271       {3, 4, 5},
272       {3, 5, 6},
273   };
274   std::vector<int64> expected_values = {8, 9, 7, 10};
275 
276   EXPECT_EQ(literal.sparse_indices()->data(),
277             absl::Span<const int64>(expected_indices.data(),
278                                     expected_indices.num_elements()));
279   EXPECT_EQ(literal.data<int64>(), absl::Span<const int64>(expected_values));
280 
281   // Serialize then deserialize and verify the resulting literal.
282   TF_ASSERT_OK_AND_ASSIGN(Literal literal_from_proto,
283                           Literal::CreateFromProto(literal.ToProto()));
284 
285   EXPECT_EQ(literal_from_proto.sparse_indices()->data(),
286             absl::Span<const int64>(expected_indices.data(),
287                                     expected_indices.num_elements()));
288   EXPECT_EQ(literal_from_proto.data<int64>(),
289             absl::Span<const int64>(expected_values));
290 }
291 
TEST_F(LiteralUtilTest,LiteralR4F32ProjectedStringifies)292 TEST_F(LiteralUtilTest, LiteralR4F32ProjectedStringifies) {
293   // clang-format off
294   auto literal = LiteralUtil::CreateR4Projected<float>({
295     {1, 2},
296     {1001, 1002},
297     {2001, 2002},
298   }, /*projection_p=*/1, /*projection_z=*/2);
299   // clang-format on
300   EXPECT_THAT(literal.shape().dimensions(), ElementsAre(1, 2, 3, 2));
301   string result = literal.ToString();
302   const string expected = R"(f32[1,2,3,2] {
303 { /*i0=0*/
304 { /*i1=0*/
305   { 1, 2 },
306   { 1001, 1002 },
307   { 2001, 2002 }
308 },
309 { /*i1=1*/
310   { 1, 2 },
311   { 1001, 1002 },
312   { 2001, 2002 }
313 }
314 }
315 })";
316   EXPECT_EQ(expected, result);
317 }
318 
TEST_F(LiteralUtilTest,LiteralR4F32Stringifies)319 TEST_F(LiteralUtilTest, LiteralR4F32Stringifies) {
320   EXPECT_THAT(literal_r4_2x2x3x3_dim0major_.shape().dimensions(),
321               ElementsAre(2, 2, 3, 3));
322   string result = literal_r4_2x2x3x3_dim0major_.ToString();
323   const string expected = R"(f32[2,2,3,3] {
324 { /*i0=0*/
325 { /*i1=0*/
326   { 1, 2, 3 },
327   { 4, 5, 6 },
328   { 7, 8, 9 }
329 },
330 { /*i1=1*/
331   { 11, 12, 13 },
332   { 14, 15, 16 },
333   { 17, 18, 19 }
334 }
335 },
336 { /*i0=1*/
337 { /*i1=0*/
338   { 101, 102, 103 },
339   { 104, 105, 106 },
340   { 107, 108, 109 }
341 },
342 { /*i1=1*/
343   { 201, 202, 203 },
344   { 204, 205, 206 },
345   { 207, 208, 209 }
346 }
347 }
348 })";
349   EXPECT_EQ(expected, result);
350 }
351 
TEST_F(LiteralUtilTest,EachCellR2F32)352 TEST_F(LiteralUtilTest, EachCellR2F32) {
353   // clang-format off
354   auto literal = LiteralUtil::CreateR2<float>({
355     {3.1f, 4.2f},
356     {9.3f, 12.4f},
357   });
358   // clang-format on
359   std::vector<std::tuple<int64, int64, string>> seen;
360   literal.EachCellAsString(
361       [&seen](absl::Span<const int64> indices, const string& value) {
362         seen.emplace_back(indices[0], indices[1], value);
363       });
364 
365   using Elem = std::tuple<int64, int64, string>;
366   std::vector<Elem> expected = {Elem(0, 0, "3.1"), Elem(0, 1, "4.2"),
367                                 Elem(1, 0, "9.3"), Elem(1, 1, "12.4")};
368   EXPECT_EQ(expected, seen);
369 }
370 
TEST_F(LiteralUtilTest,ScalarEquality)371 TEST_F(LiteralUtilTest, ScalarEquality) {
372   // Test equality with scalars.
373   auto f32_42 = LiteralUtil::CreateR0<float>(42.0);
374   auto f32_42_clone = LiteralUtil::CreateR0<float>(42.0);
375 
376   EXPECT_EQ(f32_42, f32_42);
377   EXPECT_EQ(f32_42, f32_42_clone);
378 
379   auto f32_123 = LiteralUtil::CreateR0<float>(123.0);
380   EXPECT_NE(f32_42, f32_123);
381 
382   auto f64_42 = LiteralUtil::CreateR0<double>(42.0);
383   EXPECT_NE(f32_42, f64_42);
384 }
385 
TEST_F(LiteralUtilTest,NonScalarEquality)386 TEST_F(LiteralUtilTest, NonScalarEquality) {
387   // Test equality with nonscalars.
388   auto matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
389   auto matrix_clone = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
390   auto matrix_different =
391       LiteralUtil::CreateR2<float>({{4.0, 3.0}, {1.0, 2.0}});
392   auto vector_literal = LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0, 4.0});
393   auto scalar = LiteralUtil::CreateR0<float>(1.0);
394   Literal nil(ShapeUtil::MakeNil());
395 
396   EXPECT_EQ(matrix, matrix);
397   EXPECT_EQ(matrix, matrix_clone);
398   EXPECT_NE(matrix, matrix_different);
399   EXPECT_NE(matrix, vector_literal);
400   EXPECT_NE(matrix, scalar);
401   EXPECT_NE(matrix, nil);
402   EXPECT_EQ(nil, nil);
403 }
404 
TEST_F(LiteralUtilTest,TokenEquality)405 TEST_F(LiteralUtilTest, TokenEquality) {
406   auto token0 = LiteralUtil::CreateToken();
407   auto token1 = LiteralUtil::CreateToken();
408   auto scalar = LiteralUtil::CreateR0<float>(1.0);
409 
410   EXPECT_EQ(token0, token1);
411   EXPECT_NE(token0, scalar);
412 
413   EXPECT_EQ(LiteralUtil::MakeTuple({&token0}),
414             LiteralUtil::MakeTuple({&token0}));
415   EXPECT_EQ(LiteralUtil::MakeTuple({&token0, &scalar}),
416             LiteralUtil::MakeTuple({&token1, &scalar}));
417   EXPECT_NE(LiteralUtil::MakeTuple({&token0, &scalar}),
418             LiteralUtil::MakeTuple({&scalar, &token1}));
419 }
420 
TEST_F(LiteralUtilTest,DifferentLayoutEquality)421 TEST_F(LiteralUtilTest, DifferentLayoutEquality) {
422   // Test equality with literals which have different layouts.
423   Literal colmajor(ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {0, 1}));
424   colmajor.Set<float>({0, 0}, 1.0);
425   colmajor.Set<float>({0, 1}, 2.0);
426   colmajor.Set<float>({1, 0}, 3.0);
427   colmajor.Set<float>({1, 1}, 4.0);
428 
429   Literal rowmajor(ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {1, 0}));
430   rowmajor.Set<float>({0, 0}, 1.0);
431   rowmajor.Set<float>({0, 1}, 2.0);
432   rowmajor.Set<float>({1, 0}, 3.0);
433   rowmajor.Set<float>({1, 1}, 4.0);
434 
435   EXPECT_EQ(rowmajor, colmajor);
436 }
437 
TEST_F(LiteralUtilTest,TupleEquality)438 TEST_F(LiteralUtilTest, TupleEquality) {
439   // Test equality with tuples.
440   auto scalar = LiteralUtil::CreateR0<float>(1.0);
441   auto matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
442   auto tuple1 = LiteralUtil::MakeTuple({&scalar, &matrix});
443 
444   // Tuple with the same elements. One element is shared with the original
445   // tuple, the other is a clone of the element in the original tuple.
446   auto scalar_clone = LiteralUtil::CreateR0<float>(1.0);
447   auto tuple2 = LiteralUtil::MakeTuple({&scalar_clone, &matrix});
448   EXPECT_EQ(tuple1, tuple2);
449 
450   // Tuple with elements reversed.
451   auto reversed_tuple = LiteralUtil::MakeTuple({&matrix, &scalar});
452   EXPECT_NE(tuple1, reversed_tuple);
453 
454   // Tuple with different value.
455   auto scalar_42 = LiteralUtil::CreateR0<float>(42.0);
456   auto different_tuple = LiteralUtil::MakeTuple({&scalar_42, &matrix});
457   EXPECT_NE(tuple1, different_tuple);
458 }
459 
TEST_F(LiteralUtilTest,C64Equality)460 TEST_F(LiteralUtilTest, C64Equality) {
461   // Test equality with tuples.
462   auto vector = LiteralUtil::CreateR1<complex64>({{1.0, 2.0}, {3.0, 4.0}});
463 
464   // Tuple with the same elements. One element is shared with the original
465   // tuple, the other is a clone of the element in the original tuple.
466   auto vector_clone =
467       LiteralUtil::CreateR1<complex64>({{1.0, 2.0}, {3.0, 4.0}});
468   EXPECT_EQ(vector, vector_clone);
469 
470   auto vector_reversed =
471       LiteralUtil::CreateR1<complex64>({{3.0, 4.0}, {1.0, 2.0}});
472   EXPECT_NE(vector, vector_reversed);
473 }
474 
TEST_F(LiteralUtilTest,C128Equality)475 TEST_F(LiteralUtilTest, C128Equality) {
476   // Test equality with tuples.
477   auto vector = LiteralUtil::CreateR1<complex128>({{1.0, 2.0}, {3.0, 4.0}});
478 
479   // Tuple with the same elements. One element is shared with the original
480   // tuple, the other is a clone of the element in the original tuple.
481   auto vector_clone =
482       LiteralUtil::CreateR1<complex128>({{1.0, 2.0}, {3.0, 4.0}});
483   EXPECT_EQ(vector, vector_clone);
484 
485   auto vector_reversed =
486       LiteralUtil::CreateR1<complex128>({{3.0, 4.0}, {1.0, 2.0}});
487   EXPECT_NE(vector, vector_reversed);
488 }
489 
TEST_F(LiteralUtilTest,IsAllTuple)490 TEST_F(LiteralUtilTest, IsAllTuple) {
491   auto element1 = LiteralUtil::CreateR0<float>(0.0);
492   auto element2 = LiteralUtil::CreateR2<float>({{0.0, 0.0}, {0.0, 0.0}});
493   auto tuple = LiteralUtil::MakeTuple({&element1, &element1});
494 
495   // Tuples should always return false for IsAll.
496   EXPECT_FALSE(tuple.IsAll(0));
497   EXPECT_FALSE(tuple.IsAll(1));
498 }
499 
500 // Verifies that CreateFromShape works for tuples.
TEST_F(LiteralUtilTest,CreateFromShapeTuple)501 TEST_F(LiteralUtilTest, CreateFromShapeTuple) {
502   auto scalar = LiteralUtil::CreateR0<float>(0.0);
503   auto matrix = LiteralUtil::CreateR2<int32>({{0, 0}, {0, 0}});
504   auto tuple = LiteralUtil::MakeTuple({&scalar, &matrix});
505 
506   auto x = Literal::CreateFromShape(tuple.shape());
507   EXPECT_EQ(tuple, x);
508 }
509 
TEST_F(LiteralUtilTest,IsAll)510 TEST_F(LiteralUtilTest, IsAll) {
511   EXPECT_TRUE(LiteralUtil::CreateR0<bool>(false).IsAll(0));
512   EXPECT_TRUE(LiteralUtil::CreateR0<bool>(true).IsAll(1));
513   EXPECT_FALSE(LiteralUtil::CreateR0<bool>(false).IsAll(1));
514   EXPECT_FALSE(LiteralUtil::CreateR0<bool>(false).IsAll(2));
515   EXPECT_FALSE(LiteralUtil::CreateR0<bool>(true).IsAll(0));
516   EXPECT_FALSE(LiteralUtil::CreateR0<bool>(true).IsAll(2));
517   EXPECT_FALSE(LiteralUtil::CreateR0<bool>(true).IsAll(-1));
518 
519   // We shouldn't reinterpret int8_min as an unsigned type and then decide that
520   // it is equal to 255.
521   auto int8_min = std::numeric_limits<int8>::min();
522   EXPECT_FALSE(LiteralUtil::CreateR0<uint8>(255).IsAll(int8_min));
523 
524   EXPECT_TRUE(LiteralUtil::CreateR0<float>(42.0).IsAll(42));
525   EXPECT_FALSE(LiteralUtil::CreateR0<float>(42.0001).IsAll(42));
526 
527   EXPECT_TRUE(LiteralUtil::CreateR1<int>({100, 100, 100}).IsAll(100));
528   EXPECT_FALSE(LiteralUtil::CreateR1<double>({100, 100, 100.001}).IsAll(100));
529 
530   EXPECT_TRUE(LiteralUtil::CreateR2<uint64>({{8, 8}, {8, 8}}).IsAll(8));
531   EXPECT_FALSE(LiteralUtil::CreateR2<uint64>({{8, 8}, {8, 9}}).IsAll(8));
532   EXPECT_FALSE(LiteralUtil::CreateR2<uint64>({{9, 8}, {8, 8}}).IsAll(8));
533 
534   half h8(8.0f);
535   half h9(9.0f);
536   EXPECT_TRUE(LiteralUtil::CreateR2<half>({{h8}, {h8}}).IsAll(8));
537   EXPECT_FALSE(LiteralUtil::CreateR2<half>({{h8}, {h9}}).IsAll(8));
538   EXPECT_FALSE(LiteralUtil::CreateR2<half>({{h9}, {h8}}).IsAll(8));
539 
540   bfloat16 b8(8.0f);
541   bfloat16 b9(9.0f);
542 
543   EXPECT_TRUE(LiteralUtil::CreateR2<bfloat16>({{b8}, {b8}}).IsAll(8));
544   EXPECT_FALSE(LiteralUtil::CreateR2<bfloat16>({{b8}, {b9}}).IsAll(8));
545   EXPECT_FALSE(LiteralUtil::CreateR2<bfloat16>({{b9}, {b8}}).IsAll(8));
546 
547   // 9.001 will be truncated to 9.0
548   bfloat16 b91(9.001f);
549   bfloat16 b90(9.00f);
550   EXPECT_TRUE(LiteralUtil::CreateR2<bfloat16>({{b91}, {b90}}).IsAll(9.0));
551 
552   complex64 c8_9 = {8, 9};
553   EXPECT_FALSE(LiteralUtil::CreateR2<complex64>({{c8_9}, {c8_9}}).IsAll(8));
554 
555   auto uint64_max = std::numeric_limits<uint64>::max();
556   EXPECT_FALSE(LiteralUtil::CreateR2<uint64>(
557                    {{uint64_max, uint64_max}, {uint64_max, uint64_max}})
558                    .IsAll(-1));
559 }
560 
TEST_F(LiteralUtilTest,IsAllFloat)561 TEST_F(LiteralUtilTest, IsAllFloat) {
562   // IsAllFloat always returns false when the literal is not floating-point.
563   EXPECT_FALSE(LiteralUtil::CreateR0<bool>(false).IsAllFloat(0));
564   EXPECT_FALSE(LiteralUtil::CreateR0<int8>(0).IsAllFloat(0));
565   EXPECT_FALSE(LiteralUtil::CreateR0<uint8>(0).IsAllFloat(0));
566   EXPECT_FALSE(LiteralUtil::CreateR0<int>(0).IsAllFloat(0));
567 
568   EXPECT_TRUE(LiteralUtil::CreateR0<float>(0).IsAllFloat(0));
569   EXPECT_TRUE(LiteralUtil::CreateR0<float>(.5).IsAllFloat(.5));
570   EXPECT_TRUE(LiteralUtil::CreateR0<float>(-.5).IsAllFloat(-.5));
571   EXPECT_FALSE(LiteralUtil::CreateR0<float>(-.5).IsAllFloat(-.49));
572   EXPECT_FALSE(
573       LiteralUtil::CreateR2<float>({{0, 0, 0}, {0, .1, 0}}).IsAllFloat(0));
574   EXPECT_TRUE(LiteralUtil::CreateR2<float>({{.5, .5, .5}, {.5, .5, .5}})
575                   .IsAllFloat(.5));
576 
577   EXPECT_TRUE(LiteralUtil::CreateR0<double>(0).IsAllFloat(0));
578   EXPECT_TRUE(LiteralUtil::CreateR0<double>(.5).IsAllFloat(.5));
579   EXPECT_TRUE(LiteralUtil::CreateR0<double>(-.5).IsAllFloat(-.5));
580   EXPECT_FALSE(LiteralUtil::CreateR0<double>(-.5).IsAllFloat(-.49));
581   EXPECT_FALSE(
582       LiteralUtil::CreateR2<double>({{0, 0, 0}, {0, .1, 0}}).IsAllFloat(0));
583 }
584 
TEST_F(LiteralUtilTest,IsAllComplex)585 TEST_F(LiteralUtilTest, IsAllComplex) {
586   // IsAllComplex always returns false when the literal is not complex.
587   EXPECT_FALSE(LiteralUtil::CreateR0<bool>(false).IsAllComplex(0));
588   EXPECT_FALSE(LiteralUtil::CreateR0<int8>(0).IsAllComplex(0));
589   EXPECT_FALSE(LiteralUtil::CreateR0<uint8>(0).IsAllComplex(0));
590   EXPECT_FALSE(LiteralUtil::CreateR0<int>(0).IsAllComplex(0));
591   EXPECT_FALSE(LiteralUtil::CreateR0<float>(0).IsAllComplex(0));
592   EXPECT_FALSE(LiteralUtil::CreateR0<double>(0).IsAllComplex(0));
593 
594   complex64 c8_9 = {8, 9};
595   complex64 c7_9 = {7, 9};
596   EXPECT_TRUE(LiteralUtil::CreateR2<complex64>({{c8_9}, {c8_9}})
597                   .IsAllComplex({8.0f, 9.0f}));
598   EXPECT_FALSE(LiteralUtil::CreateR2<complex64>({{c7_9}, {c8_9}})
599                    .IsAllComplex({8.0f, 9.0f}));
600   EXPECT_FALSE(LiteralUtil::CreateR2<complex64>({{c8_9}, {c7_9}})
601                    .IsAllComplex({8.0f, 9.0f}));
602 }
603 
TEST_F(LiteralUtilTest,IsAllFirst)604 TEST_F(LiteralUtilTest, IsAllFirst) {
605   // IsAllComplex always returns false when the literal is not complex.
606   EXPECT_FALSE(LiteralUtil::CreateR1<bool>({false, true}).IsAllFirst());
607   EXPECT_TRUE(LiteralUtil::CreateR1<bool>({false, false}).IsAllFirst());
608   EXPECT_FALSE(LiteralUtil::CreateR1<int8>({1, 1, 2}).IsAllFirst());
609   EXPECT_TRUE(LiteralUtil::CreateR1<int8>({5, 5, 5, 5}).IsAllFirst());
610   EXPECT_FALSE(LiteralUtil::CreateR1<uint8>({1, 1, 2}).IsAllFirst());
611   EXPECT_TRUE(LiteralUtil::CreateR1<int32>({5, 5, 5, 5}).IsAllFirst());
612   EXPECT_FALSE(LiteralUtil::CreateR1<int32>({1, 1, 2}).IsAllFirst());
613   EXPECT_TRUE(LiteralUtil::CreateR1<uint32>({5, 5, 5, 5}).IsAllFirst());
614   EXPECT_FALSE(LiteralUtil::CreateR1<uint32>({1, 1, 2}).IsAllFirst());
615 
616   complex64 c8_9 = {8, 9};
617   complex64 c7_9 = {7, 9};
618   EXPECT_TRUE(LiteralUtil::CreateR2<complex64>({{c8_9}, {c8_9}}).IsAllFirst());
619   EXPECT_FALSE(LiteralUtil::CreateR2<complex64>({{c7_9}, {c8_9}}).IsAllFirst());
620 }
621 
TEST_F(LiteralUtilTest,IsZero)622 TEST_F(LiteralUtilTest, IsZero) {
623   auto scalar_zero = LiteralUtil::CreateR0<float>(0.0f);
624   auto scalar_one = LiteralUtil::CreateR0<float>(1.0f);
625   EXPECT_TRUE(scalar_zero.IsZero({}));
626   EXPECT_FALSE(scalar_one.IsZero({}));
627 
628   auto array = LiteralUtil::CreateR2<uint32>({{1, 2, 0, 3}, {1, 0, 1, 2}});
629   EXPECT_FALSE(array.IsZero({0, 1}));
630   EXPECT_TRUE(array.IsZero({0, 2}));
631   EXPECT_TRUE(array.IsZero({1, 1}));
632   EXPECT_FALSE(array.IsZero({1, 2}));
633 
634   auto complex_zero = LiteralUtil::CreateR0<complex64>(0.0f);
635   auto complex_nonzero = LiteralUtil::CreateR0<complex64>(0.5f);
636   EXPECT_TRUE(complex_zero.IsZero({}));
637   EXPECT_FALSE(complex_nonzero.IsZero({}));
638 }
639 
640 template <typename T>
641 class LiteralUtilTestTemplated : public ::testing::Test {};
642 
643 using TestedTypes = ::testing::Types<float, int32, uint32, complex64>;
644 TYPED_TEST_SUITE(LiteralUtilTestTemplated, TestedTypes);
645 
TYPED_TEST(LiteralUtilTestTemplated,Relayout2x2)646 TYPED_TEST(LiteralUtilTestTemplated, Relayout2x2) {
647   // Make a non-integer for floating point types.
648   TypeParam half = TypeParam(1) / TypeParam(2);
649   auto data = LiteralUtil::CreateR2<TypeParam>({{half, 2}, {3, 4}});
650   const Layout layout01 = LayoutUtil::MakeLayout({0, 1});
651   const Layout layout10 = LayoutUtil::MakeLayout({1, 0});
652 
653   auto data01 = data.Relayout(layout01);
654   EXPECT_TRUE(LayoutUtil::Equal(data01.shape().layout(), layout01));
655   EXPECT_EQ(data, data01);
656 
657   auto data10 = data.Relayout(layout10);
658   EXPECT_TRUE(LayoutUtil::Equal(data10.shape().layout(), layout10));
659   EXPECT_EQ(data, data10);
660 }
661 
TEST_F(LiteralUtilTest,ReshapeR0)662 TEST_F(LiteralUtilTest, ReshapeR0) {
663   auto original = LiteralUtil::CreateR0<float>(1.7f);
664   auto reshape = original.Reshape(/*dimensions=*/{}).ConsumeValueOrDie();
665   EXPECT_EQ(original, reshape);
666 }
667 
TEST_F(LiteralUtilTest,ReshapeR4)668 TEST_F(LiteralUtilTest, ReshapeR4) {
669   // clang-format off
670   // F32[1x3x2x4]
671   auto original = LiteralUtil::CreateR4WithLayout<float>({{
672      {{10, 11, 12, 13}, {14, 15, 16, 17}},
673      {{18, 19, 20, 21}, {22, 23, 24, 25}},
674      {{26, 27, 28, 29}, {30, 31, 32, 33}},
675   }}, layout_r4_dim0major_);
676   // F32[1x3x4x2]
677   auto expected = LiteralUtil::CreateR3WithLayout<float>({
678     {{10, 11}, {12, 13}, {14, 15}, {16, 17}},
679     {{18, 19}, {20, 21}, {22, 23}, {24, 25}},
680     {{26, 27}, {28, 29}, {30, 31}, {32, 33}},
681   }, layout_r3_dim0major_);
682   // clang-format on
683   auto reshape = original.Reshape({3, 4, 2}).ConsumeValueOrDie();
684 
685   EXPECT_EQ(expected, reshape);
686 }
687 
TEST_F(LiteralUtilTest,ReshapeR4Dim0Minor)688 TEST_F(LiteralUtilTest, ReshapeR4Dim0Minor) {
689   // clang-format off
690   // F32[1x3x2x4]
691   auto original = LiteralUtil::CreateR4WithLayout<float>({{
692      {{10, 11, 12, 13}, {14, 15, 16, 17}},
693      {{18, 19, 20, 21}, {22, 23, 24, 25}},
694      {{26, 27, 28, 29}, {30, 31, 32, 33}},
695   }}, layout_r4_dim0minor_);
696   // F32[1x3x4x2]
697   auto expected = LiteralUtil::CreateR3WithLayout<float>({
698     {{10, 11}, {12, 13}, {14, 15}, {16, 17}},
699     {{18, 19}, {20, 21}, {22, 23}, {24, 25}},
700     {{26, 27}, {28, 29}, {30, 31}, {32, 33}},
701   }, layout_r3_dim0major_);
702   // clang-format on
703   auto reshape = original.Reshape({3, 4, 2}).ConsumeValueOrDie();
704 
705   EXPECT_EQ(expected, reshape);
706 }
707 
TEST_F(LiteralUtilTest,TransposeR0)708 TEST_F(LiteralUtilTest, TransposeR0) {
709   auto original = LiteralUtil::CreateR0<float>(1.7f);
710   auto reshape = original.Transpose(/*permutation=*/{});
711   EXPECT_EQ(original, reshape);
712 }
713 
TEST_F(LiteralUtilTest,TransposeR4)714 TEST_F(LiteralUtilTest, TransposeR4) {
715   // clang-format off
716   // F32[1x3x2x4]
717   auto original = LiteralUtil::CreateR4<float>({{
718      {{10, 11, 12, 13}, {14, 15, 16, 17}},
719      {{18, 19, 20, 21}, {22, 23, 24, 25}},
720      {{26, 27, 28, 29}, {30, 31, 32, 33}},
721   }});
722   // clang-format on
723   auto reshape = original.Transpose(/*permutation=*/{2, 3, 0, 1});
724 
725   reshape.EachCell<float>([&](absl::Span<const int64> indices, float value) {
726     EXPECT_EQ(value, original.Get<float>(
727                          {indices[2], indices[3], indices[0], indices[1]}));
728   });
729 }
730 
TEST_F(LiteralUtilTest,TestR4RelayoutEquivalence)731 TEST_F(LiteralUtilTest, TestR4RelayoutEquivalence) {
732   // Tests that using Relayout on an array is equivalent to creating it in the
733   // target layout in the first place.
734   auto dim0minor_relaid_to_dim0major =
735       literal_r4_2x2x3x3_dim0minor_.Relayout(layout_r4_dim0major_);
736   EXPECT_EQ(literal_r4_2x2x3x3_dim0major_, dim0minor_relaid_to_dim0major);
737 
738   auto dim0major_relaid_to_dim0minor =
739       literal_r4_2x2x3x3_dim0major_.Relayout(layout_r4_dim0minor_);
740   EXPECT_EQ(literal_r4_2x2x3x3_dim0minor_, dim0major_relaid_to_dim0minor);
741 }
742 
TEST_F(LiteralUtilTest,TestR2LinearLayout)743 TEST_F(LiteralUtilTest, TestR2LinearLayout) {
744   // Test expected memory layout of R2 dim0-minor (column-major) literal.
745   auto mat_dim0minor = LiteralUtil::CreateR2WithLayout<int32>(
746       {{1, 2, 3}, {4, 5, 6}}, layout_r2_dim0minor_);
747   EXPECT_EQ(mat_dim0minor.element_count(), 6);
748   EXPECT_THAT(mat_dim0minor.data<int32>(), ElementsAre(1, 4, 2, 5, 3, 6));
749 
750   // Test expected memory layout when using Relayout to row major.
751   auto relaid_mat_to_dim0major = mat_dim0minor.Relayout(layout_r2_dim0major_);
752   EXPECT_THAT(relaid_mat_to_dim0major.data<int32>(),
753               ElementsAre(1, 2, 3, 4, 5, 6));
754 
755   // Test expected memory layout of R2 created with dim0-major (row-major).
756   auto mat_dim0major = LiteralUtil::CreateR2WithLayout<int32>(
757       {{1, 2, 3}, {4, 5, 6}}, layout_r2_dim0major_);
758   EXPECT_EQ(mat_dim0major.element_count(), 6);
759   EXPECT_THAT(mat_dim0major.data<int32>(), ElementsAre(1, 2, 3, 4, 5, 6));
760 
761   // Test expected memory layout when using Relayout to column major.
762   auto relaid_mat_to_dim0minor = mat_dim0major.Relayout(layout_r2_dim0minor_);
763   EXPECT_THAT(relaid_mat_to_dim0minor.data<int32>(),
764               ElementsAre(1, 4, 2, 5, 3, 6));
765 }
766 
TEST_F(LiteralUtilTest,TestR3LinearLayout)767 TEST_F(LiteralUtilTest, TestR3LinearLayout) {
768   // Test expected memory layout of R3 dim0-minor (column-major) literal.
769   Array3D<int> arr3d(
770       // clang-format off
771         {
772           {
773             {1, 2, 3},
774             {4, 5, 6},
775           },
776           {
777             {7, 8, 9},
778             {10, 11, 12},
779           },
780       });  // clang-format on
781   auto lit_dim0minor = LiteralUtil::CreateR3FromArray3DWithLayout<int>(
782       arr3d, layout_r3_dim0minor_);
783 
784   EXPECT_EQ(lit_dim0minor.element_count(), 12);
785   std::vector<int> expected_dim0minor{1, 7, 4, 10, 2, 8, 5, 11, 3, 9, 6, 12};
786   EXPECT_THAT(lit_dim0minor.data<int32>(),
787               testing::ElementsAreArray(expected_dim0minor));
788 
789   // Test expected memory layout when using Relayout to row major.
790   auto relaid_lit_to_dim0major = lit_dim0minor.Relayout(layout_r3_dim0major_);
791   std::vector<int> expected_dim0major{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12};
792   EXPECT_THAT(relaid_lit_to_dim0major.data<int32>(),
793               testing::ElementsAreArray(expected_dim0major));
794 
795   // Test expected memory layout of R3 created with dim0-major (row-major).
796   auto lit_dim0major = LiteralUtil::CreateR3FromArray3DWithLayout<int>(
797       arr3d, layout_r3_dim0major_);
798   EXPECT_EQ(lit_dim0major.element_count(), 12);
799   EXPECT_THAT(lit_dim0major.data<int32>(),
800               testing::ElementsAreArray(expected_dim0major));
801 
802   // Test expected memory layout when using Relayout to column major.
803   auto relaid_lit_to_dim0minor = lit_dim0major.Relayout(layout_r3_dim0minor_);
804   EXPECT_THAT(relaid_lit_to_dim0minor.data<int32>(),
805               testing::ElementsAreArray(expected_dim0minor));
806 }
807 
TEST_F(LiteralUtilTest,SliceR0S32)808 TEST_F(LiteralUtilTest, SliceR0S32) {
809   auto input = LiteralUtil::CreateR0<int32>(1);
810   auto result = input.Slice({}, {});
811   EXPECT_EQ(input, result);
812 }
813 
TEST_F(LiteralUtilTest,SliceR1F32)814 TEST_F(LiteralUtilTest, SliceR1F32) {
815   auto input = LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0, 4.0, 5.0});
816   auto result = input.Slice({3}, {4});
817   auto expected = LiteralUtil::CreateR1<float>({4.0});
818   EXPECT_EQ(expected, result);
819 }
820 
TEST_F(LiteralUtilTest,SliceR2U32)821 TEST_F(LiteralUtilTest, SliceR2U32) {
822   auto input_3x4 = LiteralUtil::CreateR2<uint32>(
823       {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}});
824   auto result = input_3x4.Slice({0, 2}, {2, 4});
825   auto expected = LiteralUtil::CreateR2<uint32>({{3, 4}, {7, 8}});
826   EXPECT_EQ(expected, result);
827 }
828 
TEST_F(LiteralUtilTest,SliceR3U32Full)829 TEST_F(LiteralUtilTest, SliceR3U32Full) {
830   auto input_2x3x2 = LiteralUtil::CreateR3<uint32>(
831       {{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 12}}});
832   auto result = input_2x3x2.Slice({0, 0, 0}, {2, 3, 2});
833   EXPECT_EQ(input_2x3x2, result);
834 }
835 
TEST_F(LiteralUtilTest,PopulateR1S64)836 TEST_F(LiteralUtilTest, PopulateR1S64) {
837   Literal output(ShapeUtil::MakeShape(S64, {1}));
838   output.PopulateR1<int64>({77});
839   auto expected = LiteralUtil::CreateR1<int64>({77});
840   EXPECT_EQ(output, expected);
841 }
842 
TEST_F(LiteralUtilTest,PopulateR1U64)843 TEST_F(LiteralUtilTest, PopulateR1U64) {
844   Literal output(ShapeUtil::MakeShape(U64, {2}));
845   output.PopulateR1<uint64>({{77, 88}});
846   auto expected = LiteralUtil::CreateR1<uint64>({{77, 88}});
847   EXPECT_EQ(output, expected);
848 }
849 
TEST_F(LiteralUtilTest,PopulateR1C64)850 TEST_F(LiteralUtilTest, PopulateR1C64) {
851   Literal output(ShapeUtil::MakeShape(C64, {1}));
852   output.PopulateR1<complex64>({{77, 88}});
853   auto expected = LiteralUtil::CreateR1<complex64>({{77, 88}});
854   EXPECT_EQ(output, expected);
855 }
856 
TEST_F(LiteralUtilTest,PopulateR1C128)857 TEST_F(LiteralUtilTest, PopulateR1C128) {
858   Literal output(ShapeUtil::MakeShape(C128, {1}));
859   output.PopulateR1<complex128>({{77, 88}});
860   auto expected = LiteralUtil::CreateR1<complex128>({{77, 88}});
861   EXPECT_EQ(output, expected);
862 }
863 
TEST_F(LiteralUtilTest,PopulateR2C64)864 TEST_F(LiteralUtilTest, PopulateR2C64) {
865   Literal output(ShapeUtil::MakeShape(C64, {2, 2}));
866   output.PopulateR2<complex64>({{{7, 8}, {9, 10}}, {{1, 2}, {3, 4}}});
867   auto expected =
868       LiteralUtil::CreateR2<complex64>({{{7, 8}, {9, 10}}, {{1, 2}, {3, 4}}});
869   EXPECT_EQ(output, expected);
870 }
871 
TEST_F(LiteralUtilTest,PopulateWithValueR0BF16)872 TEST_F(LiteralUtilTest, PopulateWithValueR0BF16) {
873   Literal output(ShapeUtil::MakeShape(BF16, {}));
874   bfloat16 h(0.25f);
875   output.PopulateWithValue<bfloat16>(h);
876   auto expected = LiteralUtil::CreateR0<bfloat16>(h);
877   EXPECT_EQ(output, expected);
878 }
879 
TEST_F(LiteralUtilTest,PopulateWithValueR1BF16)880 TEST_F(LiteralUtilTest, PopulateWithValueR1BF16) {
881   Literal output(ShapeUtil::MakeShape(BF16, {3}));
882   bfloat16 h(0.5f);
883   output.PopulateWithValue<bfloat16>(h);
884   auto expected = LiteralUtil::CreateR1<bfloat16>({h, h, h});
885   EXPECT_EQ(output, expected);
886 }
887 
TEST_F(LiteralUtilTest,PopulateWithValueR2BF16)888 TEST_F(LiteralUtilTest, PopulateWithValueR2BF16) {
889   Literal output(ShapeUtil::MakeShape(BF16, {2, 2}));
890   bfloat16 h(2.0f);
891   output.PopulateWithValue<bfloat16>(h);
892   auto expected = LiteralUtil::CreateR2<bfloat16>({{h, h}, {h, h}});
893   EXPECT_EQ(output, expected);
894 }
895 
TEST_F(LiteralUtilTest,PopulateWithValueR0F32)896 TEST_F(LiteralUtilTest, PopulateWithValueR0F32) {
897   Literal output(ShapeUtil::MakeShape(F32, {}));
898   output.PopulateWithValue<float>(2.5f);
899   auto expected = LiteralUtil::CreateR0<float>(2.5f);
900   EXPECT_EQ(output, expected);
901 }
902 
TEST_F(LiteralUtilTest,PopulateWithValueR1S64)903 TEST_F(LiteralUtilTest, PopulateWithValueR1S64) {
904   Literal output(ShapeUtil::MakeShape(S64, {3}));
905   output.PopulateWithValue<int64>(-7);
906   auto expected = LiteralUtil::CreateR1<int64>({-7, -7, -7});
907   EXPECT_EQ(output, expected);
908 }
909 
TEST_F(LiteralUtilTest,PopulateWithValueR2U64)910 TEST_F(LiteralUtilTest, PopulateWithValueR2U64) {
911   Literal output(ShapeUtil::MakeShape(U64, {2, 2}));
912   output.PopulateWithValue<uint64>(42);
913   auto expected = LiteralUtil::CreateR2<uint64>({{42, 42}, {42, 42}});
914   EXPECT_EQ(output, expected);
915 }
916 
TEST_F(LiteralUtilTest,PopulateWithValueR2C64)917 TEST_F(LiteralUtilTest, PopulateWithValueR2C64) {
918   Literal output(ShapeUtil::MakeShape(C64, {2, 2}));
919   output.PopulateWithValue<complex64>({4, 2});
920   auto expected =
921       LiteralUtil::CreateR2<complex64>({{{4, 2}, {4, 2}}, {{4, 2}, {4, 2}}});
922   EXPECT_EQ(output, expected);
923 }
924 
TEST_F(LiteralUtilTest,PopulateWithValueR2C128)925 TEST_F(LiteralUtilTest, PopulateWithValueR2C128) {
926   Literal output(ShapeUtil::MakeShape(C128, {2, 2}));
927   output.PopulateWithValue<complex128>({4, 2});
928   auto expected =
929       LiteralUtil::CreateR2<complex128>({{{4, 2}, {4, 2}}, {{4, 2}, {4, 2}}});
930   EXPECT_EQ(output, expected);
931 }
932 
TEST_F(LiteralUtilTest,PopulateWithValueR0F16)933 TEST_F(LiteralUtilTest, PopulateWithValueR0F16) {
934   Literal output(ShapeUtil::MakeShape(F16, {}));
935   half h(0.25f);
936   output.PopulateWithValue<half>(h);
937   auto expected = LiteralUtil::CreateR0<half>(h);
938   EXPECT_EQ(output, expected);
939 }
940 
TEST_F(LiteralUtilTest,PopulateWithValueR1F16)941 TEST_F(LiteralUtilTest, PopulateWithValueR1F16) {
942   Literal output(ShapeUtil::MakeShape(F16, {3}));
943   half h(0.5f);
944   output.PopulateWithValue<half>(h);
945   auto expected = LiteralUtil::CreateR1<half>({h, h, h});
946   EXPECT_EQ(output, expected);
947 }
948 
TEST_F(LiteralUtilTest,PopulateWithValueR2F16)949 TEST_F(LiteralUtilTest, PopulateWithValueR2F16) {
950   Literal output(ShapeUtil::MakeShape(F16, {2, 2}));
951   half h(2.0f);
952   output.PopulateWithValue<half>(h);
953   auto expected = LiteralUtil::CreateR2<half>({{h, h}, {h, h}});
954   EXPECT_EQ(output, expected);
955 }
956 
TEST_F(LiteralUtilTest,ReplicateR2U32)957 TEST_F(LiteralUtilTest, ReplicateR2U32) {
958   auto input = LiteralUtil::CreateR2<uint32>(
959       {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}});
960   auto output = input.Replicate<uint32>(3);
961   auto expected = LiteralUtil::CreateR3<uint32>(
962       {{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}},
963        {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}},
964        {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}});
965   EXPECT_EQ(output, expected);
966 }
967 
TEST_F(LiteralUtilTest,CopySliceFrom)968 TEST_F(LiteralUtilTest, CopySliceFrom) {
969   const int64 dimensions[] = {17, 15, 34, 21};
970   const int64 layouts[][4] = {
971       {3, 2, 1, 0}, {0, 2, 1, 3}, {0, 1, 2, 3}, {2, 0, 3, 1}, {1, 3, 0, 2}};
972   for (const auto& layout : layouts) {
973     Shape shape = ShapeUtil::MakeShapeWithLayout(
974         primitive_util::NativeToPrimitiveType<uint32>(), dimensions, layout);
975 
976     auto source = Literal::CreateFromShape(shape);
977     const int64 zero_base[] = {0, 0, 0, 0};
978     const int64 step[] = {1, 1, 1, 1};
979     uint32 seqnr = 0;
980     auto init_proc = [&](absl::Span<const int64> indexes) {
981       source.Set(indexes, ++seqnr);
982       return true;
983     };
984     ShapeUtil::ForEachIndex(source.shape(), zero_base, dimensions, step,
985                             init_proc);
986 
987     auto blank = Literal::CreateFromShape(shape);
988     const int64 src_base[] = {3, 1, 5, 7};
989     const int64 dest_base[] = {6, 4, 12, 2};
990     const int64 copy_size[] = {7, 8, 11, 9};
991     TF_EXPECT_OK(blank.CopySliceFrom(source, src_base, dest_base, copy_size));
992 
993     std::vector<int64> source_indexes(TF_ARRAYSIZE(dimensions), 0);
994     std::vector<int64> blank_indexes(TF_ARRAYSIZE(dimensions), 0);
995     bool matched = true;
996     auto check_proc = [&](absl::Span<const int64> indexes) {
997       std::copy(indexes.begin(), indexes.end(), source_indexes.begin());
998       std::transform(source_indexes.begin(), source_indexes.end(), src_base,
999                      source_indexes.begin(), std::plus<int64>());
1000       std::copy(indexes.begin(), indexes.end(), blank_indexes.begin());
1001       std::transform(blank_indexes.begin(), blank_indexes.end(), dest_base,
1002                      blank_indexes.begin(), std::plus<int64>());
1003       auto bval = blank.Get<uint32>(blank_indexes);
1004       matched = (bval != 0 && bval == source.Get<uint32>(source_indexes));
1005       return matched;
1006     };
1007 
1008     ShapeUtil::ForEachIndex(source.shape(), zero_base, copy_size, step,
1009                             check_proc);
1010     EXPECT_TRUE(matched);
1011   }
1012 }
1013 
TEST_F(LiteralUtilTest,CopyFromScalars)1014 TEST_F(LiteralUtilTest, CopyFromScalars) {
1015   auto zero = LiteralUtil::CreateR0<uint32>(0);
1016   auto nine = LiteralUtil::CreateR0<uint32>(9);
1017   TF_EXPECT_OK(zero.CopyFrom(nine));
1018   EXPECT_EQ(zero, nine);
1019 
1020   auto vect = LiteralUtil::CreateR1<uint32>({3, 4, 9, 12, 5, 17, 21});
1021   TF_EXPECT_OK(zero.CopySliceFrom(vect, {5}, {}, {}));
1022   EXPECT_EQ(zero.Get<uint32>({}), 17);
1023   TF_EXPECT_OK(vect.CopySliceFrom(zero, {}, {4}, {}));
1024   EXPECT_EQ(vect.Get<uint32>({4}), 17);
1025 }
1026 
TEST_F(LiteralUtilTest,CopyFromAndToZeroElement)1027 TEST_F(LiteralUtilTest, CopyFromAndToZeroElement) {
1028   const Shape empty_r1_shape = ShapeUtil::MakeShape(F32, {0});
1029   const auto const_nine = LiteralUtil::CreateR1<float>({9});
1030   const auto const_empty = Literal::CreateFromShape(empty_r1_shape);
1031 
1032   {
1033     // Source contains dimension with zero elements.
1034     const auto empty = Literal::CreateFromShape(empty_r1_shape);
1035     auto nine = LiteralUtil::CreateR1<float>({9});
1036 
1037     TF_EXPECT_OK(nine.CopySliceFrom(empty, {0}, {0}, {0}));
1038     EXPECT_EQ(nine, const_nine);
1039   }
1040 
1041   {
1042     // Copy 0 element to destination with zero elements.
1043     auto empty = Literal::CreateFromShape(empty_r1_shape);
1044     auto nine = LiteralUtil::CreateR1<float>({9});
1045 
1046     TF_EXPECT_OK(empty.CopySliceFrom(nine, {0}, {0}, {0}));
1047     EXPECT_EQ(empty, const_empty);
1048   }
1049 }
1050 
TEST_F(LiteralUtilTest,CopyFromNilShape)1051 TEST_F(LiteralUtilTest, CopyFromNilShape) {
1052   Literal nil_literal0(ShapeUtil::MakeNil());
1053   Literal nil_literal1(ShapeUtil::MakeNil());
1054   // This doesn't actually do any copying, but it should succeed.
1055   TF_ASSERT_OK(nil_literal0.CopyFrom(nil_literal1));
1056 }
1057 
TEST_F(LiteralUtilTest,CopyFromArrays)1058 TEST_F(LiteralUtilTest, CopyFromArrays) {
1059   auto scalar_42 = LiteralUtil::CreateR0<float>(42.0);
1060   auto scalar_123 = LiteralUtil::CreateR0<float>(123.0);
1061   EXPECT_NE(scalar_42, scalar_123);
1062   TF_ASSERT_OK(scalar_42.CopyFrom(scalar_123, /*dest_shape_index=*/{},
1063                                   /*src_shape_index=*/{}));
1064   EXPECT_EQ(scalar_42, scalar_123);
1065   EXPECT_EQ(scalar_42.Get<float>({}), 123.0f);
1066 
1067   auto matrix_1234 = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
1068   auto matrix_5678 = LiteralUtil::CreateR2<float>({{5.0, 6.0}, {7.0, 8.0}});
1069   EXPECT_NE(matrix_1234, matrix_5678);
1070   EXPECT_EQ(matrix_1234.Get<float>({0, 0}), 1.0f);
1071   TF_ASSERT_OK(matrix_1234.CopyFrom(matrix_5678, /*dest_shape_index=*/{},
1072                                     /*src_shape_index=*/{}));
1073   EXPECT_EQ(matrix_1234, matrix_5678);
1074   EXPECT_EQ(matrix_1234.Get<float>({0, 0}), 5.0f);
1075 }
1076 
TEST_F(LiteralUtilTest,CopyFromTuples)1077 TEST_F(LiteralUtilTest, CopyFromTuples) {
1078   auto matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
1079   Literal nil_literal(ShapeUtil::MakeNil());
1080   Literal inner_elements[] = {LiteralUtil::CreateR0<int32>(42),
1081                               LiteralUtil::CreateR1<double>({23.0, 44.0})};
1082   Literal inner_tuple = LiteralUtil::MakeTuple(
1083       {&inner_elements[0], &inner_elements[1], &nil_literal});
1084   Literal nested_tuple = LiteralUtil::MakeTuple({&matrix, &inner_tuple});
1085   // Create a tuple the same shape as the inner tuple of nested_tuple but with
1086   // different values..
1087   Literal int32_minus5 = LiteralUtil::CreateR0<int32>(-5);
1088   Literal double_2_4 = LiteralUtil::CreateR1<double>({2.0, 4.0});
1089   Literal tuple =
1090       LiteralUtil::MakeTuple({&int32_minus5, &double_2_4, &nil_literal});
1091 
1092   EXPECT_EQ(matrix, LiteralSlice(nested_tuple, {0}));
1093   EXPECT_EQ(nested_tuple.Get<int32>({}, {1, 0}), 42);
1094   EXPECT_EQ(nested_tuple.Get<double>({0}, {1, 1}), 23.0);
1095   EXPECT_EQ(nested_tuple.Get<double>({1}, {1, 1}), 44.0);
1096 
1097   // Overwrite the inner tuple element of nested_tuple with the contents of
1098   // 'tuple'.
1099   TF_ASSERT_OK(nested_tuple.CopyFrom(tuple, /*dest_shape_index=*/{1},
1100                                      /*src_shape_index=*/{}));
1101 
1102   // The matrix element should be unchanged.
1103   EXPECT_EQ(matrix, LiteralSlice(nested_tuple, {0}));
1104 
1105   // The tuple element should have been copied from 'tuple'.
1106   EXPECT_EQ(nested_tuple.Get<int32>({}, {1, 0}), -5);
1107   EXPECT_EQ(nested_tuple.Get<double>({0}, {1, 1}), 2.0);
1108   EXPECT_EQ(nested_tuple.Get<double>({1}, {1, 1}), 4.0);
1109 }
TEST_F(LiteralUtilTest,CopyBetweenSameTuple)1110 TEST_F(LiteralUtilTest, CopyBetweenSameTuple) {
1111   Literal elements[] = {LiteralUtil::CreateR0<int32>(-2),
1112                         LiteralUtil::CreateR0<int32>(4)};
1113   Literal tuple = LiteralUtil::MakeTuple({&elements[0], &elements[1]});
1114 
1115   EXPECT_EQ(tuple.Get<int32>({}, {0}), -2);
1116   EXPECT_EQ(tuple.Get<int32>({}, {1}), 4);
1117 
1118   // Copy from one element to the other.
1119   TF_ASSERT_OK(tuple.CopyFrom(tuple, /*dest_shape_index=*/{1},
1120                               /*src_shape_index=*/{0}));
1121 
1122   EXPECT_EQ(tuple.Get<int32>({}, {0}), -2);
1123   EXPECT_EQ(tuple.Get<int32>({}, {1}), -2);
1124 }
1125 
TEST_F(LiteralUtilTest,CopyFromDifferentShapes)1126 TEST_F(LiteralUtilTest, CopyFromDifferentShapes) {
1127   auto matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
1128   auto vector = LiteralUtil::CreateR1<float>({5.0, 7.0});
1129   Status status = matrix.CopyFrom(vector);
1130   ASSERT_FALSE(status.ok());
1131   EXPECT_THAT(status.error_message(),
1132               HasSubstr("Destination subshape incompatible"));
1133 }
1134 
TEST_F(LiteralUtilTest,F16)1135 TEST_F(LiteralUtilTest, F16) {
1136   // Verify that the internal data views are consistent and that they
1137   // are in little endian format
1138   // TODO - modify if we make the data format machine endianess dependent
1139   Literal m1 = Literal::CreateFromShape(ShapeUtil::MakeShape(F16, {2, 2}));
1140   const char* d1 = reinterpret_cast<const char*>(m1.data<half>().data());
1141   EXPECT_EQ(d1[0], 0);
1142   EXPECT_EQ(d1[1], 0);
1143   EXPECT_EQ(d1[2], 0);
1144   EXPECT_EQ(d1[3], 0);
1145   EXPECT_EQ(d1[4], 0);
1146   EXPECT_EQ(d1[5], 0);
1147   EXPECT_EQ(d1[6], 0);
1148   EXPECT_EQ(d1[7], 0);
1149 
1150   half h1(1.0f);
1151   half h2(2.0f);
1152   auto m2 = LiteralUtil::CreateR2<half>({{h1, h2}, {h2, h1}});
1153   const char* d2 = reinterpret_cast<const char*>(m2.data<half>().data());
1154   EXPECT_EQ(d2[0], 0);
1155   EXPECT_EQ(d2[1], 0x3C);
1156   EXPECT_EQ(d2[2], 0);
1157   EXPECT_EQ(d2[3], 0x40);
1158   EXPECT_EQ(d2[4], 0);
1159   EXPECT_EQ(d2[5], 0x40);
1160   EXPECT_EQ(d2[6], 0);
1161   EXPECT_EQ(d2[7], 0x3C);
1162 }
1163 
TEST_F(LiteralUtilTest,Populate)1164 TEST_F(LiteralUtilTest, Populate) {
1165   struct PopulateData {
1166     std::vector<int64> dimensions;
1167     std::vector<int64> layout;
1168   } populate_data[] = {
1169       {{}, {}},
1170       {{0}, {0}},
1171       {{16}, {0}},
1172       {{2, 0}, {1, 0}},
1173       {{4, 16}, {1, 0}},
1174       {{21, 12}, {0, 1}},
1175       {{6, 11, 17}, {2, 0, 1}},
1176       {{6, 11, 5, 17}, {3, 2, 0, 1}},
1177   };
1178   for (const auto& data : populate_data) {
1179     Shape shape = ShapeUtil::MakeShapeWithLayout(
1180         primitive_util::NativeToPrimitiveType<uint32>(), data.dimensions,
1181         data.layout);
1182     Literal literal(shape);
1183     auto generator = [&](absl::Span<const int64> indexes) -> uint32 {
1184       // Offsets from linear index just to avoid R0 literals to be initialized
1185       // with zero.
1186       return IndexUtil::MultidimensionalIndexToLinearIndex(literal.shape(),
1187                                                            indexes) +
1188              17;
1189     };
1190     TF_EXPECT_OK(literal.Populate<uint32>(generator));
1191 
1192     std::vector<int64> zero_base(data.dimensions.size(), 0);
1193     std::vector<int64> step(data.dimensions.size(), 1);
1194     bool matched = true;
1195     auto check_function = [&](absl::Span<const int64> indexes) {
1196       auto value = literal.Get<uint32>(indexes);
1197       matched = matched && (value == generator(indexes));
1198       return matched;
1199     };
1200     ShapeUtil::ForEachIndex(literal.shape(), zero_base, data.dimensions, step,
1201                             check_function);
1202     EXPECT_TRUE(matched);
1203   }
1204 }
1205 
TEST_F(LiteralUtilTest,PopulateParallel)1206 TEST_F(LiteralUtilTest, PopulateParallel) {
1207   struct PopulateData {
1208     std::vector<int64> dimensions;
1209     std::vector<int64> layout;
1210   } populate_data[] = {
1211       {{}, {}},
1212       {{0}, {0}},
1213       {{16}, {0}},
1214       {{2, 0}, {1, 0}},
1215       {{4, 16}, {1, 0}},
1216       {{21, 12}, {0, 1}},
1217       {{6, 11, 17}, {2, 0, 1}},
1218       {{6, 11, 5, 17}, {3, 2, 0, 1}},
1219   };
1220   for (const auto& data : populate_data) {
1221     Shape shape = ShapeUtil::MakeShapeWithLayout(
1222         primitive_util::NativeToPrimitiveType<uint32>(), data.dimensions,
1223         data.layout);
1224     Literal literal(shape);
1225     auto generator = [&](absl::Span<const int64> indexes) -> uint32 {
1226       // Offsets from linear index just to avoid R0 literals to be initialized
1227       // with zero.
1228       return IndexUtil::MultidimensionalIndexToLinearIndex(literal.shape(),
1229                                                            indexes) +
1230              17;
1231     };
1232     TF_EXPECT_OK(literal.PopulateParallel<uint32>(generator));
1233 
1234     std::vector<int64> zero_base(data.dimensions.size(), 0);
1235     std::vector<int64> step(data.dimensions.size(), 1);
1236     bool matched = true;
1237     auto check_function = [&](absl::Span<const int64> indexes) {
1238       auto value = literal.Get<uint32>(indexes);
1239       matched = matched && (value == generator(indexes));
1240       return matched;
1241     };
1242     ShapeUtil::ForEachIndex(literal.shape(), zero_base, data.dimensions, step,
1243                             check_function);
1244     EXPECT_TRUE(matched);
1245   }
1246 }
1247 
TEST_F(LiteralUtilTest,ConvertR4)1248 TEST_F(LiteralUtilTest, ConvertR4) {
1249   // clang-format off
1250   auto original = LiteralUtil::CreateR4WithLayout<int8>({{
1251      {{10, 11, 12, 13}, {14, 15, 16, 17}},
1252      {{18, 19, 20, 21}, {22, 23, 24, 25}},
1253      {{26, 27, 28, 29}, {30, 31, 32, 33}},
1254   }}, layout_r4_dim0major_);
1255   auto expected = LiteralUtil::CreateR4WithLayout<uint32>({{
1256      {{10, 11, 12, 13}, {14, 15, 16, 17}},
1257      {{18, 19, 20, 21}, {22, 23, 24, 25}},
1258      {{26, 27, 28, 29}, {30, 31, 32, 33}},
1259   }}, layout_r4_dim0major_);
1260   // clang-format on
1261   TF_ASSERT_OK_AND_ASSIGN(Literal converted, original.Convert(U32));
1262 
1263   EXPECT_EQ(expected, converted);
1264 }
1265 
TEST_F(LiteralUtilTest,ConvertIfTypesMatch)1266 TEST_F(LiteralUtilTest, ConvertIfTypesMatch) {
1267   // clang-format off
1268   auto s8 = LiteralUtil::CreateR4WithLayout<int8>({{
1269     {{10, 0, 12, 0}, {0, 15, 0, 17}},
1270     {{0, 19, 0, 21}, {22, 0, 24, 0}},
1271     {{26, 0, 28, 0}, {0, 31, 0, 33}},
1272   }}, layout_r4_dim0major_);
1273   auto s16 = LiteralUtil::CreateR4WithLayout<int16>({{
1274     {{10, 0, 12, 0}, {0, 15, 0, 17}},
1275     {{0, 19, 0, 21}, {22, 0, 24, 0}},
1276     {{26, 0, 28, 0}, {0, 31, 0, 33}},
1277   }}, layout_r4_dim0major_);
1278   auto s32 = LiteralUtil::CreateR4WithLayout<int32>({{
1279     {{10, 0, 12, 0}, {0, 15, 0, 17}},
1280     {{0, 19, 0, 21}, {22, 0, 24, 0}},
1281     {{26, 0, 28, 0}, {0, 31, 0, 33}},
1282   }}, layout_r4_dim0major_);
1283   auto u16 = LiteralUtil::CreateR4WithLayout<uint16>({{
1284     {{10, 0, 12, 0}, {0, 15, 0, 17}},
1285     {{0, 19, 0, 21}, {22, 0, 24, 0}},
1286     {{26, 0, 28, 0}, {0, 31, 0, 33}},
1287   }}, layout_r4_dim0major_);
1288   auto u32 = LiteralUtil::CreateR4WithLayout<uint32>({{
1289     {{10, 0, 12, 0}, {0, 15, 0, 17}},
1290     {{0, 19, 0, 21}, {22, 0, 24, 0}},
1291     {{26, 0, 28, 0}, {0, 31, 0, 33}},
1292   }}, layout_r4_dim0major_);
1293   auto s64 = LiteralUtil::CreateR4WithLayout<int64>({{
1294     {{10, 0, 12, 0}, {0, 15, 0, 17}},
1295     {{0, 19, 0, 21}, {22, 0, 24, 0}},
1296     {{26, 0, 28, 0}, {0, 31, 0, 33}},
1297   }}, layout_r4_dim0major_);
1298   auto u64 = LiteralUtil::CreateR4WithLayout<uint64>({{
1299     {{10, 0, 12, 0}, {0, 15, 0, 17}},
1300     {{0, 19, 0, 21}, {22, 0, 24, 0}},
1301     {{26, 0, 28, 0}, {0, 31, 0, 33}},
1302   }}, layout_r4_dim0major_);
1303   auto pred = LiteralUtil::CreateR4WithLayout<bool>({{
1304     {{true, false, true, false}, {false, true, false, true}},
1305     {{false, true, false, true}, {true, false, true, false}},
1306     {{true, false, true, false}, {false, true, false, true}},
1307   }}, layout_r4_dim0major_);
1308   auto int32_pred = LiteralUtil::CreateR4WithLayout<int32>({{
1309     {{1, 0, 1, 0}, {0, 1, 0, 1}},
1310     {{0, 1, 0, 1}, {1, 0, 1, 0}},
1311     {{1, 0, 1, 0}, {0, 1, 0, 1}},
1312   }}, layout_r4_dim0major_);
1313   auto f16 = LiteralUtil::CreateR4WithLayout<half>({{
1314     {{half(10.0), half(0.0), half(12.0), half(0.0)},
1315      {half(0.0), half(15.0), half(0.0), half(17.0)}},
1316     {{half(0.0), half(19.0), half(0.0), half(21.0)},
1317      {half(22.0), half(0.0), half(24.0), half(0.0)}},
1318     {{half(26.0), half(0.0), half(28.0), half(0.0)},
1319      {half(0.0), half(31.0), half(0.0), half(33.0)}},
1320   }}, layout_r4_dim0major_);
1321   auto bf16 = LiteralUtil::CreateR4WithLayout<bfloat16>({{
1322     {{bfloat16(10.0), bfloat16(0.0), bfloat16(12.0), bfloat16(0.0)},
1323      {bfloat16(0.0), bfloat16(15.0), bfloat16(0.0), bfloat16(17.0)}},
1324     {{bfloat16(0.0), bfloat16(19.0), bfloat16(0.0), bfloat16(21.0)},
1325      {bfloat16(22.0), bfloat16(0.0), bfloat16(24.0), bfloat16(0.0)}},
1326     {{bfloat16(26.0), bfloat16(0.0), bfloat16(28.0), bfloat16(0.0)},
1327      {bfloat16(0.0), bfloat16(31.0), bfloat16(0.0), bfloat16(33.0)}},
1328   }}, layout_r4_dim0major_);
1329   auto f32 = LiteralUtil::CreateR4WithLayout<float>({{
1330     {{10.0f, 0.0f, 12.0f, 0.0f}, {0.0f, 15.0f, 0.0f, 17.0f}},
1331     {{0.0f, 19.0f, 0.0f, 21.0f}, {22.0f, 0.0f, 24.0f, 0.0f}},
1332     {{26.0f, 0.0f, 28.0f, 0.0f}, {0.0f, 31.0f, 0.0f, 33.0f}},
1333   }}, layout_r4_dim0major_);
1334   auto f64 = LiteralUtil::CreateR4WithLayout<double>({{
1335     {{10.0, 0.0, 12.0, 0.0}, {0.0, 15.0, 0.0, 17.0}},
1336     {{0.0, 19.0, 0.0, 21.0}, {22.0, 0.0, 24.0, 0.0}},
1337     {{26.0, 0.0, 28.0, 0.0}, {0.0, 31.0, 0.0, 33.0}},
1338   }}, layout_r4_dim0major_);
1339   auto c64 = LiteralUtil::CreateR4WithLayout<complex64>({{
1340     {{10.0f, 0.0f, 12.0f, 0.0f}, {0.0f, 15.0f, 0.0f, 17.0f}},
1341     {{0.0f, 19.0f, 0.0f, 21.0f}, {22.0f, 0.0f, 24.0f, 0.0f}},
1342     {{26.0f, 0.0f, 28.0f, 0.0f}, {0.0f, 31.0f, 0.0f, 33.0f}},
1343   }}, layout_r4_dim0major_);
1344   auto c128 = LiteralUtil::CreateR4WithLayout<complex128>({{
1345     {{10.0, 0.0, 12.0, 0.0}, {0.0, 15.0, 0.0, 17.0}},
1346     {{0.0, 19.0, 0.0, 21.0}, {22.0, 0.0, 24.0, 0.0}},
1347     {{26.0, 0.0, 28.0, 0.0}, {0.0, 31.0, 0.0, 33.0}},
1348   }}, layout_r4_dim0major_);  // clang-format on
1349   Literal conv;
1350 
1351   conv = s8.Convert(U16).ConsumeValueOrDie();
1352   EXPECT_EQ(conv, u16);
1353 
1354   conv = s8.Convert(S16).ConsumeValueOrDie();
1355   EXPECT_EQ(conv, s16);
1356 
1357   conv = s8.Convert(U32).ConsumeValueOrDie();
1358   EXPECT_EQ(conv, u32);
1359 
1360   conv = s8.Convert(S32).ConsumeValueOrDie();
1361   EXPECT_EQ(conv, s32);
1362 
1363   conv = s8.Convert(U64).ConsumeValueOrDie();
1364   EXPECT_EQ(conv, u64);
1365 
1366   conv = s8.Convert(S64).ConsumeValueOrDie();
1367   EXPECT_EQ(conv, s64);
1368 
1369   conv = s8.Convert(PRED).ConsumeValueOrDie();
1370   EXPECT_EQ(conv, pred);
1371 
1372   conv = bf16.Convert(S32).ConsumeValueOrDie();
1373   EXPECT_EQ(conv, s32);
1374 
1375   conv = bf16.Convert(F32).ConsumeValueOrDie();
1376   EXPECT_EQ(conv, f32);
1377 
1378   conv = pred.Convert(S32).ConsumeValueOrDie();
1379   EXPECT_EQ(conv, int32_pred);
1380 
1381   conv = f32.Convert(S32).ConsumeValueOrDie();
1382   EXPECT_EQ(conv, s32);
1383 
1384   conv = f64.Convert(S32).ConsumeValueOrDie();
1385   EXPECT_EQ(conv, s32);
1386 
1387   conv = s32.Convert(F32).ConsumeValueOrDie();
1388   EXPECT_EQ(conv, f32);
1389 
1390   conv = f32.Convert(F16).ConsumeValueOrDie();
1391   EXPECT_EQ(conv, f16);
1392 
1393   conv = f64.Convert(F16).ConsumeValueOrDie();
1394   EXPECT_EQ(conv, f16);
1395 
1396   conv = s32.Convert(F16).ConsumeValueOrDie();
1397   EXPECT_EQ(conv, f16);
1398 
1399   conv = u32.Convert(F16).ConsumeValueOrDie();
1400   EXPECT_EQ(conv, f16);
1401 
1402   conv = s32.Convert(C64).ConsumeValueOrDie();
1403   EXPECT_EQ(conv, c64);
1404 
1405   conv = f16.Convert(C64).ConsumeValueOrDie();
1406   EXPECT_EQ(conv, c64);
1407 
1408   conv = s32.Convert(S16).ConsumeValueOrDie();
1409   EXPECT_EQ(conv, s16);
1410 
1411   conv = s32.Convert(U16).ConsumeValueOrDie();
1412   EXPECT_EQ(conv, u16);
1413 
1414   conv = s32.Convert(C128).ConsumeValueOrDie();
1415   EXPECT_EQ(conv, c128);
1416 
1417   conv = f16.Convert(C128).ConsumeValueOrDie();
1418   EXPECT_EQ(conv, c128);
1419 
1420   EXPECT_EQ(s32.Convert(TUPLE).status().code(),
1421             tensorflow::error::UNIMPLEMENTED);
1422   EXPECT_EQ(c64.Convert(F32).status().code(), tensorflow::error::UNIMPLEMENTED);
1423   EXPECT_EQ(c64.Convert(S32).status().code(), tensorflow::error::UNIMPLEMENTED);
1424   EXPECT_EQ(c128.Convert(F32).status().code(),
1425             tensorflow::error::UNIMPLEMENTED);
1426   EXPECT_EQ(c128.Convert(S32).status().code(),
1427             tensorflow::error::UNIMPLEMENTED);
1428 }
1429 
TEST_F(LiteralUtilTest,BitcastConvert)1430 TEST_F(LiteralUtilTest, BitcastConvert) {
1431   auto original = LiteralUtil::CreateR1<uint32>(
1432       {absl::bit_cast<uint32>(2.5f), absl::bit_cast<uint32>(-42.25f),
1433        absl::bit_cast<uint32>(100.f), 0xbeef});
1434   auto expected = LiteralUtil::CreateR1<float>(
1435       {2.5f, -42.25f, 100.0f, absl::bit_cast<float>(0xbeef)});
1436   TF_ASSERT_OK_AND_ASSIGN(Literal converted, original.BitcastConvert(F32));
1437 }
1438 
TEST_F(LiteralUtilTest,BitcastConvertBetweenInvalidTypes)1439 TEST_F(LiteralUtilTest, BitcastConvertBetweenInvalidTypes) {
1440   auto literal = LiteralUtil::CreateR0<uint32>(1234);
1441   Status status = literal.BitcastConvert(F64).status();
1442   EXPECT_NE(Status::OK(), status);
1443   EXPECT_TRUE(
1444       absl::StrContains(status.error_message(), "bit widths are different"));
1445 }
1446 
1447 // Sets the layout of the given ShapeProto to the default.
SetDefaultLayoutOnProto(ShapeProto * shape_proto)1448 void SetDefaultLayoutOnProto(ShapeProto* shape_proto) {
1449   CHECK(ShapeUtil::IsArrayPrimitiveType(shape_proto->element_type()));
1450   shape_proto->mutable_layout()->set_format(DENSE);
1451   auto* minor_to_major =
1452       shape_proto->mutable_layout()->mutable_minor_to_major();
1453   minor_to_major->Resize(shape_proto->dimensions_size(), 0);
1454   const int64 size = minor_to_major->size();
1455   for (int64 i = 0; i < size; ++i) {
1456     minor_to_major->Set(i, size - 1 - i);
1457   }
1458 }
1459 
TEST_F(LiteralUtilTest,CopyFromProto_Bool)1460 TEST_F(LiteralUtilTest, CopyFromProto_Bool) {
1461   LiteralProto p;
1462   p.mutable_shape()->set_element_type(PRED);
1463   for (int len = 0; len < 25; ++len) {
1464     p.mutable_shape()->clear_dimensions();
1465     p.mutable_shape()->add_dimensions(len);
1466     SetDefaultLayoutOnProto(p.mutable_shape());
1467     p.clear_preds();
1468     for (int i = 0; i < len; ++i) {
1469       p.add_preds((i % 2) == (len % 2));
1470     }
1471 
1472     TF_ASSERT_OK_AND_ASSIGN(Literal literal, Literal::CreateFromProto(p));
1473     ASSERT_EQ(len, literal.data<bool>().size());
1474     int i = 0;
1475     for (bool value : literal.data<bool>()) {
1476       EXPECT_EQ((i % 2) == (len % 2), value);
1477       ++i;
1478     }
1479   }
1480 }
1481 
1482 // Note that f16 is currently stored in a byte array in little endian byte order
TEST_F(LiteralUtilTest,ToProto_f16)1483 TEST_F(LiteralUtilTest, ToProto_f16) {
1484   half h1(1.0f);
1485   half h2(2.0f);
1486 
1487   auto m = LiteralUtil::CreateR2<half>({{h1, h2}, {h2, h1}});
1488   EXPECT_EQ(4, ShapeUtil::ElementsIn(m.shape()));
1489   EXPECT_EQ(4, m.data<half>().size());
1490 
1491   LiteralProto p = m.ToProto();
1492   EXPECT_EQ(4, ShapeUtil::ElementsIn(Shape(p.shape())));
1493   EXPECT_EQ(8, p.f16s().size());
1494   const char* d = p.f16s().data();
1495   EXPECT_EQ(d[0], 0);
1496   EXPECT_EQ(d[1], 0x3C);
1497   EXPECT_EQ(d[2], 0);
1498   EXPECT_EQ(d[3], 0x40);
1499   EXPECT_EQ(d[4], 0);
1500   EXPECT_EQ(d[5], 0x40);
1501   EXPECT_EQ(d[6], 0);
1502   EXPECT_EQ(d[7], 0x3C);
1503 }
1504 
1505 // Note that f16 is currently stored in a byte array in little endian byte order
TEST_F(LiteralUtilTest,CopyFromProto_f16)1506 TEST_F(LiteralUtilTest, CopyFromProto_f16) {
1507   half h1(1.0f);
1508   half h2(2.0f);
1509 
1510   const char half_vals[8] = {0x00, 0x3C, 0x00, 0x40, 0x00, 0x40, 0x00, 0x3C};
1511   LiteralProto p;
1512   p.mutable_shape()->set_element_type(F16);
1513   p.mutable_shape()->clear_dimensions();
1514   p.mutable_shape()->add_dimensions(4);
1515   SetDefaultLayoutOnProto(p.mutable_shape());
1516   p.clear_f16s();
1517   p.set_f16s(half_vals, 8);
1518   TF_ASSERT_OK_AND_ASSIGN(Literal literal, Literal::CreateFromProto(p));
1519   auto r = literal.data<half>();
1520   ASSERT_EQ(4, r.size());
1521   EXPECT_EQ(h1, r[0]);
1522   EXPECT_EQ(h2, r[1]);
1523   EXPECT_EQ(h2, r[2]);
1524   EXPECT_EQ(h1, r[3]);
1525 }
1526 
TEST_F(LiteralUtilTest,CopyFromProto_u16)1527 TEST_F(LiteralUtilTest, CopyFromProto_u16) {
1528   uint16 u1(0xabcd);
1529   uint16 u2(0x1234);
1530 
1531   const unsigned char uint16_vals[8] = {0xcd, 0xab, 0x34, 0x12,
1532                                         0x34, 0x12, 0xcd, 0xab};
1533   LiteralProto p;
1534   p.mutable_shape()->set_element_type(U16);
1535   p.mutable_shape()->clear_dimensions();
1536   p.mutable_shape()->add_dimensions(4);
1537   SetDefaultLayoutOnProto(p.mutable_shape());
1538   p.clear_u16s();
1539   p.set_u16s(uint16_vals, 8);
1540   TF_ASSERT_OK_AND_ASSIGN(Literal literal, Literal::CreateFromProto(p));
1541   auto r = literal.data<uint16>();
1542   ASSERT_EQ(4, r.size());
1543   EXPECT_EQ(u1, r[0]);
1544   EXPECT_EQ(u2, r[1]);
1545   EXPECT_EQ(u2, r[2]);
1546   EXPECT_EQ(u1, r[3]);
1547 }
1548 
TEST_F(LiteralUtilTest,LiteralSliceTest)1549 TEST_F(LiteralUtilTest, LiteralSliceTest) {
1550   auto scalar = LiteralUtil::CreateR0<float>(1.0);
1551   auto matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
1552   auto tuple = LiteralUtil::MakeTuple({&scalar, &matrix});
1553   auto nested_tuple = LiteralUtil::MakeTuple({&tuple, &scalar});
1554   Literal nil(ShapeUtil::MakeNil());
1555 
1556   EXPECT_EQ(LiteralSlice(scalar, {}), scalar);
1557   EXPECT_EQ(LiteralSlice(matrix, {}), matrix);
1558   EXPECT_EQ(LiteralSlice(tuple, {}), tuple);
1559   EXPECT_EQ(LiteralSlice(nested_tuple, {}), nested_tuple);
1560   EXPECT_EQ(LiteralSlice(nil, {}), nil);
1561 
1562   EXPECT_EQ(LiteralSlice(tuple, {0}), scalar);
1563   EXPECT_EQ(LiteralSlice(tuple, {1}), matrix);
1564 
1565   EXPECT_EQ(LiteralSlice(nested_tuple, {0}), tuple);
1566   EXPECT_EQ(LiteralSlice(nested_tuple, {0, 0}), scalar);
1567   EXPECT_EQ(LiteralSlice(nested_tuple, {0, 1}), matrix);
1568   EXPECT_EQ(LiteralSlice(nested_tuple, {1}), scalar);
1569 }
1570 
TEST_F(LiteralUtilTest,MutatingLiteralSlice)1571 TEST_F(LiteralUtilTest, MutatingLiteralSlice) {
1572   auto scalar = LiteralUtil::CreateR0<float>(1.0);
1573   auto matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
1574   auto tuple = LiteralUtil::MakeTuple({&scalar, &matrix});
1575   auto nested_tuple = LiteralUtil::MakeTuple({&tuple, &scalar});
1576   // Verify that changing the underlying data beneath the view changes the
1577   // data of the view itself.
1578   const auto nested_tuple_view = LiteralSlice(nested_tuple);
1579   EXPECT_EQ(nested_tuple.Get<float>(/*multi_index=*/{}, /*shape_index=*/{0, 0}),
1580             1.0f);
1581   EXPECT_EQ(nested_tuple_view.Get<float>(/*multi_index=*/{},
1582                                          /*shape_index=*/{0, 0}),
1583             1.0f);
1584   nested_tuple.Set<float>(/*multi_index=*/{}, /*shape_index=*/{0, 0}, 555.0f);
1585   EXPECT_EQ(nested_tuple.Get<float>(/*multi_index=*/{}, /*shape_index=*/{0, 0}),
1586             555.0f);
1587   EXPECT_EQ(nested_tuple_view.Get<float>(/*multi_index=*/{},
1588                                          /*shape_index=*/{0, 0}),
1589             555.0f);
1590 }
1591 
TEST_F(LiteralUtilTest,LiteralSliceOfALiteralSlice)1592 TEST_F(LiteralUtilTest, LiteralSliceOfALiteralSlice) {
1593   auto scalar = LiteralUtil::CreateR0<float>(1.0);
1594   auto matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
1595   auto tuple = LiteralUtil::MakeTuple({&scalar, &matrix});
1596   auto nested_tuple = LiteralUtil::MakeTuple({&tuple, &scalar});
1597 
1598   const auto nested_tuple_view = LiteralSlice(nested_tuple);
1599   const auto tuple_view = LiteralSlice(nested_tuple_view, /*view_root=*/{0});
1600   const auto matrix_view = LiteralSlice(tuple_view, /*view_root=*/{1});
1601   EXPECT_EQ(matrix_view,
1602             LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}));
1603 }
1604 
TEST_F(LiteralUtilTest,BorrowingLiteralFromOneBufferPtr)1605 TEST_F(LiteralUtilTest, BorrowingLiteralFromOneBufferPtr) {
1606   std::vector<int64> int64_values = {1, 2, 3};
1607   const Shape literal_shape = ShapeUtil::MakeShape(S64, {3});
1608 
1609   BorrowingLiteral literal(reinterpret_cast<const char*>(int64_values.data()),
1610                            literal_shape);
1611 
1612   EXPECT_EQ(literal.Get<int64>({0}), 1);
1613   EXPECT_EQ(literal.Get<int64>({1}), 2);
1614   EXPECT_EQ(literal.Get<int64>({2}), 3);
1615 }
1616 
TEST_F(LiteralUtilTest,BorrowingLiteralFromMultipleBufferPtrs)1617 TEST_F(LiteralUtilTest, BorrowingLiteralFromMultipleBufferPtrs) {
1618   std::vector<int64> one_two_three = {1, 2, 3};
1619   const Shape one_two_three_shape = ShapeUtil::MakeShape(S64, {3});
1620 
1621   std::vector<int64> hundred = {100};
1622   const Shape hundred_shape = ShapeUtil::MakeShape(S64, {1});
1623 
1624   std::vector<const char*> src_buf_ptrs;
1625   src_buf_ptrs.emplace_back(
1626       reinterpret_cast<const char*>(one_two_three.data()));
1627   src_buf_ptrs.emplace_back(reinterpret_cast<const char*>(hundred.data()));
1628   auto literal_tuple = BorrowingLiteral(
1629       src_buf_ptrs,
1630       ShapeUtil::MakeTupleShape({one_two_three_shape, hundred_shape}));
1631 
1632   EXPECT_EQ(literal_tuple.Get<int64>(/*multi_index=*/{0}, /*shape_index=*/{0}),
1633             1);
1634   EXPECT_EQ(literal_tuple.Get<int64>(/*multi_index=*/{0}, /*shape_index=*/{1}),
1635             100);
1636 
1637   EXPECT_EQ(literal_tuple.Get<int64>(/*multi_index=*/{1}, /*shape_index=*/{0}),
1638             2);
1639 
1640   EXPECT_EQ(literal_tuple.Get<int64>(/*multi_index=*/{2}, /*shape_index=*/{0}),
1641             3);
1642 }
1643 
TEST_F(LiteralUtilTest,LiteralMove)1644 TEST_F(LiteralUtilTest, LiteralMove) {
1645   Literal matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
1646   Literal literal(std::move(matrix));
1647 
1648   EXPECT_TRUE(
1649       ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {2, 2}), literal.shape()));
1650   EXPECT_EQ(literal.Get<float>({0, 0}), 1.0);
1651   EXPECT_EQ(literal.Get<float>({0, 1}), 2.0);
1652   EXPECT_EQ(literal.Get<float>({1, 0}), 3.0);
1653   EXPECT_EQ(literal.Get<float>({1, 1}), 4.0);
1654 }
1655 
TEST_F(LiteralUtilTest,DecomposeTuple)1656 TEST_F(LiteralUtilTest, DecomposeTuple) {
1657   Literal nil_literal(ShapeUtil::MakeNil());
1658   Literal inner_elements[] = {
1659       LiteralUtil::CreateR0<int32>(42),
1660       LiteralUtil::CreateR1<double>({23.0, 44.0}),
1661   };
1662   Literal tuple_elements[] = {
1663       LiteralUtil::CreateR2<int32>({{1, 2}, {3, 4}}),
1664       LiteralUtil::MakeTuple(
1665           {&inner_elements[0], &inner_elements[1], &nil_literal}),
1666   };
1667   Literal nested_tuple = LiteralUtil::MakeTuple(
1668       {&tuple_elements[0], &tuple_elements[1], &nil_literal});
1669 
1670   EXPECT_FALSE(ShapeUtil::IsEmptyTuple(nested_tuple.shape()));
1671   std::vector<Literal> elements = nested_tuple.DecomposeTuple();
1672   EXPECT_TRUE(ShapeUtil::IsEmptyTuple(nested_tuple.shape()));
1673 
1674   ASSERT_EQ(elements.size(), 3);
1675 
1676   EXPECT_TRUE(ShapeUtil::Compatible(elements[0].shape(),
1677                                     ShapeUtil::MakeShape(S32, {2, 2})));
1678   EXPECT_EQ(elements[0].Get<int32>({0, 0}), 1);
1679   EXPECT_EQ(elements[0].Get<int32>({0, 1}), 2);
1680   EXPECT_EQ(elements[0].Get<int32>({1, 0}), 3);
1681   EXPECT_EQ(elements[0].Get<int32>({1, 1}), 4);
1682 
1683   EXPECT_TRUE(ShapeUtil::Compatible(
1684       elements[1].shape(),
1685       ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(S32, {}),
1686                                  ShapeUtil::MakeShape(F64, {2}),
1687                                  ShapeUtil::MakeNil()})));
1688   EXPECT_EQ(elements[1].Get<int32>({}, /*shape_index=*/{0}), 42);
1689   EXPECT_EQ(elements[1].Get<double>({0}, /*shape_index=*/{1}), 23.0);
1690   EXPECT_EQ(elements[1].Get<double>({1}, /*shape_index=*/{1}), 44.0);
1691 
1692   EXPECT_TRUE(ShapeUtil::Compatible(elements[2].shape(), ShapeUtil::MakeNil()));
1693 }
1694 
TEST_F(LiteralUtilTest,DecomposeEmptyTuple)1695 TEST_F(LiteralUtilTest, DecomposeEmptyTuple) {
1696   Literal nil_literal(ShapeUtil::MakeNil());
1697   std::vector<Literal> elements = nil_literal.DecomposeTuple();
1698   EXPECT_EQ(elements.size(), 0);
1699 }
1700 
TEST_F(LiteralUtilTest,MoveIntoTuple)1701 TEST_F(LiteralUtilTest, MoveIntoTuple) {
1702   std::vector<Literal> elements;
1703   elements.push_back(LiteralUtil::CreateR0<float>(1.0));
1704   elements.push_back(LiteralUtil::CreateR1<int32>({4, 8}));
1705   std::vector<Literal> inner_elements;
1706   inner_elements.push_back(LiteralUtil::CreateR0<int32>(42));
1707   inner_elements.push_back(LiteralUtil::CreateR1<double>({23.0, 44.0}));
1708   elements.push_back(
1709       LiteralUtil::MakeTuple({&inner_elements[0], &inner_elements[1]}));
1710 
1711   Literal literal = Literal::MoveIntoTuple(absl::MakeSpan(elements));
1712   ASSERT_TRUE(literal.shape().IsTuple());
1713   ASSERT_EQ(ShapeUtil::TupleElementCount(literal.shape()), 3);
1714 
1715   EXPECT_EQ(literal.Get<float>({}, /*shape_index=*/{0}), 1.0);
1716   EXPECT_EQ(literal.Get<int32>({0}, /*shape_index=*/{1}), 4);
1717   EXPECT_EQ(literal.Get<int32>({1}, /*shape_index=*/{1}), 8);
1718   EXPECT_EQ(literal.Get<int32>({}, /*shape_index=*/{2, 0}), 42);
1719   EXPECT_EQ(literal.Get<double>({0}, /*shape_index=*/{2, 1}), 23.0);
1720   EXPECT_EQ(literal.Get<double>({1}, /*shape_index=*/{2, 1}), 44.0);
1721 
1722   for (const Literal& element : elements) {
1723     EXPECT_TRUE(ShapeUtil::IsEmptyTuple(element.shape()));
1724   }
1725 }
1726 
TEST_F(LiteralUtilTest,MoveIntoEmptyTuple)1727 TEST_F(LiteralUtilTest, MoveIntoEmptyTuple) {
1728   Literal literal = Literal::MoveIntoTuple({});
1729   ASSERT_TRUE(literal.shape().IsTuple());
1730   EXPECT_EQ(ShapeUtil::TupleElementCount(literal.shape()), 0);
1731 }
1732 
TEST_F(LiteralUtilTest,LiteralMoveAssignment)1733 TEST_F(LiteralUtilTest, LiteralMoveAssignment) {
1734   Literal literal;
1735   EXPECT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeNil(), literal.shape()));
1736 
1737   Literal matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
1738   literal = std::move(matrix);
1739 
1740   EXPECT_TRUE(
1741       ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {2, 2}), literal.shape()));
1742   EXPECT_EQ(literal.Get<float>({0, 0}), 1.0);
1743   EXPECT_EQ(literal.Get<float>({0, 1}), 2.0);
1744   EXPECT_EQ(literal.Get<float>({1, 0}), 3.0);
1745   EXPECT_EQ(literal.Get<float>({1, 1}), 4.0);
1746 }
1747 
TEST_F(LiteralUtilTest,LiteralSliceCopy)1748 TEST_F(LiteralUtilTest, LiteralSliceCopy) {
1749   Literal matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
1750   const auto matrix_view = LiteralSlice(matrix);
1751   LiteralSlice matrix_view_copy(matrix_view);
1752 
1753   EXPECT_EQ(matrix_view_copy.Get<float>({0, 0}), 1.0);
1754   EXPECT_EQ(matrix_view_copy.Get<float>({0, 1}), 2.0);
1755   EXPECT_EQ(matrix_view_copy.Get<float>({1, 0}), 3.0);
1756   EXPECT_EQ(matrix_view_copy.Get<float>({1, 1}), 4.0);
1757 }
1758 
TEST_F(LiteralUtilTest,GetSetTuple)1759 TEST_F(LiteralUtilTest, GetSetTuple) {
1760   Literal elements[] = {
1761       LiteralUtil::CreateR0<float>(42.0),
1762       LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}),
1763   };
1764   auto tuple = LiteralUtil::MakeTuple({&elements[0], &elements[1]});
1765   EXPECT_EQ(tuple.Get<float>(/*multi_index=*/{}, /*shape_index=*/{0}), 42.0);
1766   tuple.Set<float>(/*multi_index=*/{}, /*shape_index=*/{0}, -5.0);
1767   EXPECT_EQ(tuple.Get<float>(/*multi_index=*/{}, /*shape_index=*/{0}), -5.0);
1768 
1769   EXPECT_EQ(tuple.Get<float>(/*multi_index=*/{1, 0}, /*shape_index=*/{1}), 3.0);
1770   tuple.Set<float>(/*multi_index=*/{1, 0}, /*shape_index=*/{1}, -4.0);
1771   EXPECT_EQ(tuple.Get<float>(/*multi_index=*/{1, 0}, /*shape_index=*/{1}),
1772             -4.0);
1773 }
1774 
TEST_F(LiteralUtilTest,CreateFromShapeZeroInitialized)1775 TEST_F(LiteralUtilTest, CreateFromShapeZeroInitialized) {
1776   // Literals constructed using CreateFromShape should be zero initialized.
1777   Literal scalar_f32 = Literal::CreateFromShape(ShapeUtil::MakeShape(F32, {}));
1778   EXPECT_EQ(scalar_f32.Get<float>({}), 0.0);
1779   EXPECT_TRUE(scalar_f32.IsAll(0));
1780 
1781   Literal vector_s32 = Literal::CreateFromShape(ShapeUtil::MakeShape(S32, {3}));
1782   EXPECT_EQ(vector_s32.Get<int32>({0}), 0);
1783   EXPECT_EQ(vector_s32.Get<int32>({1}), 0);
1784   EXPECT_EQ(vector_s32.Get<int32>({2}), 0);
1785   EXPECT_TRUE(vector_s32.IsAll(0));
1786 
1787   Literal tuple = Literal::CreateFromShape(ShapeUtil::MakeTupleShape(
1788       {ShapeUtil::MakeShape(F64, {}), ShapeUtil::MakeShape(PRED, {2}),
1789        ShapeUtil::MakeShape(U64, {2, 1}), ShapeUtil::MakeShape(C64, {}),
1790        ShapeUtil::MakeShape(C128, {})}));
1791 
1792   EXPECT_EQ(tuple.Get<double>({}, {0}), 0.0);
1793   EXPECT_EQ(tuple.Get<bool>({0}, {1}), false);
1794   EXPECT_EQ(tuple.Get<bool>({1}, {1}), false);
1795   EXPECT_EQ(tuple.Get<uint64>({0, 0}, {2}), 0);
1796   EXPECT_EQ(tuple.Get<uint64>({1, 0}, {2}), 0);
1797   EXPECT_EQ(tuple.Get<complex64>({}, {3}), complex64(0.0f, 0.0f));
1798   EXPECT_EQ(tuple.Get<complex128>({}, {4}), complex128(0.0, 0.0));
1799 }
1800 
TEST_F(LiteralUtilTest,ProtoRoundTrip)1801 TEST_F(LiteralUtilTest, ProtoRoundTrip) {
1802   // Test serializing then deserializing a Literal through a proto.
1803   auto one_f32 = LiteralUtil::CreateR0<float>(1.0);
1804   auto two_f32 = LiteralUtil::CreateR0<float>(2.0);
1805   auto vector_int8 = LiteralUtil::CreateR1<int8>({-128, 0, 2, 4, 7, 56, 127});
1806   auto vector_uint8 = LiteralUtil::CreateR1<uint8>({128, 0, 2, 56, 127, 255});
1807   auto vector_c64 = LiteralUtil::CreateR1<complex64>({{1.0, 2.0}, {3.0, 4.0}});
1808   auto vector_c128 =
1809       LiteralUtil::CreateR1<complex128>({{1.0, 2.0}, {3.0, 4.0}});
1810   auto vector_bfloat16 = LiteralUtil::CreateR1<bfloat16>(
1811       {bfloat16{-1.0}, bfloat16{2.0}, bfloat16{-3.0}});
1812   auto vector_half =
1813       LiteralUtil::CreateR1<half>({half{10.0}, half{20.0}, half{-30.0}});
1814   auto matrix_pred =
1815       LiteralUtil::CreateR2<bool>({{true, false, true}, {false, false, true}});
1816   auto tuple = LiteralUtil::MakeTuple(
1817       {&one_f32, &vector_half, &matrix_pred, &matrix_pred});
1818   Literal nil_literal(ShapeUtil::MakeNil());
1819   auto nested_tuple =
1820       LiteralUtil::MakeTuple({&tuple, &vector_bfloat16, &tuple, &nil_literal});
1821 
1822   auto to_from_proto = [](const Literal& literal) -> Literal {
1823     return Literal::CreateFromProto(literal.ToProto()).ValueOrDie();
1824   };
1825 
1826   EXPECT_EQ(one_f32, to_from_proto(one_f32));
1827   EXPECT_EQ(vector_int8, to_from_proto(vector_int8));
1828   EXPECT_EQ(vector_uint8, to_from_proto(vector_uint8));
1829   EXPECT_EQ(vector_c64, to_from_proto(vector_c64));
1830   EXPECT_EQ(vector_c128, to_from_proto(vector_c128));
1831   EXPECT_EQ(vector_bfloat16, to_from_proto(vector_bfloat16));
1832   EXPECT_EQ(matrix_pred, to_from_proto(matrix_pred));
1833   EXPECT_EQ(tuple, to_from_proto(tuple));
1834   EXPECT_EQ(nested_tuple, to_from_proto(nested_tuple));
1835   EXPECT_EQ(nil_literal, to_from_proto(nil_literal));
1836 
1837   EXPECT_NE(one_f32, two_f32);
1838   EXPECT_NE(one_f32, to_from_proto(two_f32));
1839 }
1840 
TEST_F(LiteralUtilTest,InvalidProtoNoValues)1841 TEST_F(LiteralUtilTest, InvalidProtoNoValues) {
1842   // Proto contains a shape, but no values.
1843   LiteralProto proto;
1844   *proto.mutable_shape() = ShapeUtil::MakeShape(F32, {3}).ToProto();
1845   Status status = Literal::CreateFromProto(proto).status();
1846   ASSERT_FALSE(status.ok());
1847   EXPECT_THAT(status.error_message(),
1848               HasSubstr("Expected 3 elements in LiteralProto"));
1849 }
1850 
TEST_F(LiteralUtilTest,InvalidProtoNoShape)1851 TEST_F(LiteralUtilTest, InvalidProtoNoShape) {
1852   // Proto contains values, but no shape.
1853   LiteralProto proto;
1854   proto.add_preds(false);
1855   proto.add_preds(true);
1856   proto.add_preds(false);
1857   Status status = Literal::CreateFromProto(proto).status();
1858   ASSERT_FALSE(status.ok());
1859   EXPECT_THAT(status.error_message(), HasSubstr("LiteralProto has no shape"));
1860 }
1861 
TEST_F(LiteralUtilTest,InvalidProtoWrongContainer)1862 TEST_F(LiteralUtilTest, InvalidProtoWrongContainer) {
1863   // Proto contains values in wrong container.
1864   LiteralProto proto;
1865   *proto.mutable_shape() = ShapeUtil::MakeShape(F32, {3}).ToProto();
1866   proto.add_preds(false);
1867   proto.add_preds(true);
1868   proto.add_preds(false);
1869   Status status = Literal::CreateFromProto(proto).status();
1870   ASSERT_FALSE(status.ok());
1871   EXPECT_THAT(status.error_message(),
1872               HasSubstr("Expected 3 elements in LiteralProto"));
1873 }
1874 
TEST_F(LiteralUtilTest,InvalidProtoTooFewValues)1875 TEST_F(LiteralUtilTest, InvalidProtoTooFewValues) {
1876   // Proto contains too few values.
1877   LiteralProto proto;
1878   *proto.mutable_shape() = ShapeUtil::MakeShape(F32, {42, 2}).ToProto();
1879   proto.add_f32s(1.0);
1880   proto.add_f32s(2.0);
1881   proto.add_f32s(3.0);
1882   Status status = Literal::CreateFromProto(proto).status();
1883   ASSERT_FALSE(status.ok());
1884   EXPECT_THAT(status.error_message(),
1885               HasSubstr("Expected 84 elements in LiteralProto"));
1886 }
1887 
TEST_F(LiteralUtilTest,InvalidProtoTooManyValues)1888 TEST_F(LiteralUtilTest, InvalidProtoTooManyValues) {
1889   // Proto contains too many values.
1890   LiteralProto proto;
1891   *proto.mutable_shape() = ShapeUtil::MakeShape(S32, {2}).ToProto();
1892   proto.add_s32s(42);
1893   proto.add_s32s(-10);
1894   proto.add_s32s(100);
1895   Status status = Literal::CreateFromProto(proto).status();
1896   ASSERT_FALSE(status.ok());
1897   EXPECT_THAT(status.error_message(),
1898               HasSubstr("Expected 2 elements in LiteralProto"));
1899 }
1900 
TEST_F(LiteralUtilTest,InvalidProtoMissingLayout)1901 TEST_F(LiteralUtilTest, InvalidProtoMissingLayout) {
1902   // Proto shape missing layout.
1903   LiteralProto proto;
1904   *proto.mutable_shape() = ShapeUtil::MakeShape(PRED, {2, 2}).ToProto();
1905   proto.mutable_shape()->clear_layout();
1906   proto.add_preds(true);
1907   proto.add_preds(false);
1908   proto.add_preds(true);
1909   proto.add_preds(false);
1910   Status status = Literal::CreateFromProto(proto).status();
1911   ASSERT_FALSE(status.ok());
1912   EXPECT_THAT(status.error_message(), HasSubstr("LiteralProto has no layout"));
1913 }
1914 
TEST_F(LiteralUtilTest,InvalidProtoTooFewTupleElements)1915 TEST_F(LiteralUtilTest, InvalidProtoTooFewTupleElements) {
1916   // Proto has the too few tuple elements.
1917   LiteralProto proto;
1918   *proto.mutable_shape() =
1919       ShapeUtil::MakeTupleShape(
1920           {ShapeUtil::MakeShape(PRED, {2}), ShapeUtil::MakeShape(F32, {})})
1921           .ToProto();
1922   LiteralProto* element0 = proto.add_tuple_literals();
1923   *element0->mutable_shape() =
1924       ShapeUtil::GetTupleElementShape(Shape(proto.shape()), 0).ToProto();
1925   element0->add_preds(false);
1926   element0->add_preds(true);
1927 
1928   Status status = Literal::CreateFromProto(proto).status();
1929   ASSERT_FALSE(status.ok());
1930   EXPECT_THAT(status.error_message(), HasSubstr("Expected 2 tuple elements"));
1931 }
1932 
TEST_F(LiteralUtilTest,InvalidProtoTooManyTupleElements)1933 TEST_F(LiteralUtilTest, InvalidProtoTooManyTupleElements) {
1934   // Proto has the too many tuple elements.
1935   LiteralProto proto;
1936   *proto.mutable_shape() =
1937       ShapeUtil::MakeTupleShape(
1938           {ShapeUtil::MakeShape(PRED, {2}), ShapeUtil::MakeShape(F32, {})})
1939           .ToProto();
1940   LiteralProto* element0 = proto.add_tuple_literals();
1941   *element0->mutable_shape() =
1942       ShapeUtil::GetTupleElementShape(Shape(proto.shape()), 0).ToProto();
1943   element0->add_preds(false);
1944   element0->add_preds(true);
1945   LiteralProto* element1 = proto.add_tuple_literals();
1946   *element1->mutable_shape() =
1947       ShapeUtil::GetTupleElementShape(Shape(proto.shape()), 1).ToProto();
1948   element1->add_f32s(42.0);
1949   LiteralProto* element2 = proto.add_tuple_literals();
1950   *element2->mutable_shape() = ShapeUtil::MakeShape(F32, {}).ToProto();
1951   element2->add_f32s(123.0);
1952 
1953   Status status = Literal::CreateFromProto(proto).status();
1954   ASSERT_FALSE(status.ok());
1955   EXPECT_THAT(status.error_message(), HasSubstr("Expected 2 tuple elements"));
1956 }
1957 
TEST_F(LiteralUtilTest,SortSparseElements)1958 TEST_F(LiteralUtilTest, SortSparseElements) {
1959   auto literal = LiteralUtil::CreateSparse<float>({10, 10, 10},
1960                                                   SparseIndexArray(10, 3), {});
1961   literal.AppendSparseElement<float>({2, 3, 4}, 2.0);
1962   literal.AppendSparseElement<float>({3, 4, 5}, 3.0);
1963   literal.AppendSparseElement<float>({1, 2, 3}, 1.0);
1964   literal.SortSparseElements();
1965   EXPECT_EQ(literal.ToString(),
1966             "f32[10,10,10]{[1, 2, 3]: 1, [2, 3, 4]: 2, [3, 4, 5]: 3}");
1967 }
1968 
TEST_F(LiteralUtilTest,GetSparseElementAsString)1969 TEST_F(LiteralUtilTest, GetSparseElementAsString) {
1970   std::vector<int64> dimensions = {10, 10, 10};
1971   SparseIndexArray indices(10, {{1, 2, 3}, {2, 3, 4}, {3, 4, 5}});
1972 
1973   EXPECT_EQ(
1974       LiteralUtil::CreateSparse<bool>(dimensions, indices, {true, false, true})
1975           .GetSparseElementAsString(1),
1976       "false");
1977   EXPECT_EQ(LiteralUtil::CreateSparse<int64>(dimensions, indices, {1, 2, 3})
1978                 .GetSparseElementAsString(1),
1979             absl::StrCat(int64{2}));
1980   EXPECT_EQ(
1981       LiteralUtil::CreateSparse<double>(dimensions, indices, {1.0, 2.0, 3.0})
1982           .GetSparseElementAsString(1),
1983       absl::StrCat(double{2.0}));
1984   EXPECT_EQ(LiteralUtil::CreateSparse<half>(dimensions, indices,
1985                                             {half{1.0}, half{2.0}, half{3.0}})
1986                 .GetSparseElementAsString(1),
1987             absl::StrCat(static_cast<float>(half{2.0})));
1988   EXPECT_EQ(LiteralUtil::CreateSparse<complex64>(
1989                 dimensions, indices,
1990                 std::vector<complex64>{{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}})
1991                 .GetSparseElementAsString(1),
1992             absl::StrCat("(", float{3.0}, ", ", float{4.0}, ")"));
1993 }
1994 
TEST_F(LiteralUtilTest,BroadcastVectorToMatrix0)1995 TEST_F(LiteralUtilTest, BroadcastVectorToMatrix0) {
1996   Literal literal = LiteralUtil::CreateR1<int64>({1, 2});
1997   TF_ASSERT_OK_AND_ASSIGN(
1998       Literal broadcasted_literal,
1999       literal.Broadcast(/*result_shape=*/ShapeUtil::MakeShape(S64, {2, 2}),
2000                         /*dimensions=*/{0}));
2001   EXPECT_EQ(broadcasted_literal,
2002             LiteralUtil::CreateR2<int64>({{1, 1}, {2, 2}}));
2003 }
2004 
TEST_F(LiteralUtilTest,BroadcastVectorToMatrix1)2005 TEST_F(LiteralUtilTest, BroadcastVectorToMatrix1) {
2006   Literal literal = LiteralUtil::CreateR1<int64>({1, 2});
2007   TF_ASSERT_OK_AND_ASSIGN(
2008       Literal broadcasted_literal,
2009       literal.Broadcast(/*result_shape=*/ShapeUtil::MakeShape(S64, {2, 2}),
2010                         /*dimensions=*/{1}));
2011   EXPECT_EQ(broadcasted_literal,
2012             LiteralUtil::CreateR2<int64>({{1, 2}, {1, 2}}));
2013 }
2014 
TEST_F(LiteralUtilTest,BroadcastScalarToMatrix)2015 TEST_F(LiteralUtilTest, BroadcastScalarToMatrix) {
2016   Literal literal = LiteralUtil::CreateR0<int32>(9);
2017   TF_ASSERT_OK_AND_ASSIGN(
2018       Literal broadcasted_literal,
2019       literal.Broadcast(/*result_shape=*/ShapeUtil::MakeShape(S32, {2, 2}),
2020                         /*dimensions=*/{}));
2021   EXPECT_EQ(broadcasted_literal,
2022             LiteralUtil::CreateR2<int32>({{9, 9}, {9, 9}}));
2023 }
2024 
2025 }  // namespace
2026 }  // namespace xla
2027